Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/connection.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

834 statements  

1import asyncio 

2import copy 

3import inspect 

4import socket 

5import sys 

6import time 

7import warnings 

8import weakref 

9from abc import ABC, abstractmethod 

10from itertools import chain 

11from types import MappingProxyType 

12from typing import ( 

13 Any, 

14 Callable, 

15 Iterable, 

16 List, 

17 Mapping, 

18 Optional, 

19 Protocol, 

20 Set, 

21 Tuple, 

22 Type, 

23 TypedDict, 

24 TypeVar, 

25 Union, 

26) 

27from urllib.parse import ParseResult, parse_qs, unquote, urlparse 

28 

29from ..observability.attributes import ( 

30 DB_CLIENT_CONNECTION_POOL_NAME, 

31 DB_CLIENT_CONNECTION_STATE, 

32 AttributeBuilder, 

33 ConnectionState, 

34 get_pool_name, 

35) 

36from ..utils import SSL_AVAILABLE, deprecated_function 

37 

38if SSL_AVAILABLE: 

39 import ssl 

40 from ssl import SSLContext, TLSVersion, VerifyFlags 

41else: 

42 ssl = None 

43 TLSVersion = None 

44 SSLContext = None 

45 VerifyFlags = None 

46 

47from ..auth.token import TokenInterface 

48from ..driver_info import DriverInfo, resolve_driver_info 

49from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher 

50from ..utils import deprecated_args, format_error_message 

51 

52# the functionality is available in 3.11.x but has a major issue before 

53# 3.11.3. See https://github.com/redis/redis-py/issues/2633 

54if sys.version_info >= (3, 11, 3): 

55 from asyncio import timeout as async_timeout 

56else: 

57 from async_timeout import timeout as async_timeout 

58 

59from redis.asyncio.observability.recorder import ( 

60 record_connection_closed, 

61 record_connection_count, 

62 record_connection_create_time, 

63 record_connection_wait_time, 

64 record_error_count, 

65) 

66from redis.asyncio.retry import Retry 

67from redis.backoff import NoBackoff 

68from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider 

69from redis.exceptions import ( 

70 AuthenticationError, 

71 AuthenticationWrongNumberOfArgsError, 

72 ConnectionError, 

73 DataError, 

74 MaxConnectionsError, 

75 RedisError, 

76 ResponseError, 

77 TimeoutError, 

78) 

79from redis.observability.metrics import CloseReason 

80from redis.typing import EncodableT 

81from redis.utils import ( 

82 DEFAULT_RESP_VERSION, 

83 HIREDIS_AVAILABLE, 

84 SENTINEL, 

85 str_if_bytes, 

86) 

87 

88from .._defaults import ( 

89 DEFAULT_SOCKET_CONNECT_TIMEOUT, 

90 DEFAULT_SOCKET_READ_SIZE, 

91 DEFAULT_SOCKET_TIMEOUT, 

92 get_default_socket_keepalive_options, 

93) 

94from .._parsers import ( 

95 BaseParser, 

96 Encoder, 

97 _AsyncHiredisParser, 

98 _AsyncRESP2Parser, 

99 _AsyncRESP3Parser, 

100) 

101 

102SYM_STAR = b"*" 

103SYM_DOLLAR = b"$" 

104SYM_CRLF = b"\r\n" 

105SYM_LF = b"\n" 

106SYM_EMPTY = b"" 

107 

108 

109DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] 

110if HIREDIS_AVAILABLE: 

111 DefaultParser = _AsyncHiredisParser 

112else: 

113 DefaultParser = _AsyncRESP3Parser 

114 

115 

116class ConnectCallbackProtocol(Protocol): 

117 def __call__(self, connection: "AbstractConnection"): ... 

118 

119 

120class AsyncConnectCallbackProtocol(Protocol): 

121 async def __call__(self, connection: "AbstractConnection"): ... 

122 

123 

124ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] 

125 

126 

127class AbstractConnection: 

128 """Manages communication to and from a Redis server""" 

129 

130 __slots__ = ( 

131 "db", 

132 "username", 

133 "client_name", 

134 "lib_name", 

135 "lib_version", 

136 "credential_provider", 

137 "password", 

138 "socket_timeout", 

139 "socket_connect_timeout", 

140 "redis_connect_func", 

141 "retry_on_timeout", 

142 "retry_on_error", 

143 "health_check_interval", 

144 "next_health_check", 

145 "last_active_at", 

146 "encoder", 

147 "ssl_context", 

148 "protocol", 

149 "_reader", 

150 "_writer", 

151 "_parser", 

152 "_connect_callbacks", 

153 "_buffer_cutoff", 

154 "_lock", 

155 "_socket_read_size", 

156 "__dict__", 

157 ) 

158 

159 @deprecated_args( 

160 args_to_warn=["lib_name", "lib_version"], 

161 reason="Use 'driver_info' parameter instead. " 

162 "lib_name and lib_version will be removed in a future version.", 

163 ) 

164 def __init__( 

165 self, 

166 *, 

167 db: str | int = 0, 

168 password: str | None = None, 

169 socket_timeout: float | None = DEFAULT_SOCKET_TIMEOUT, 

170 socket_connect_timeout: float | None = DEFAULT_SOCKET_CONNECT_TIMEOUT, 

171 retry_on_timeout: bool = False, 

172 retry_on_error: list | object = SENTINEL, 

173 encoding: str = "utf-8", 

174 encoding_errors: str = "strict", 

175 decode_responses: bool = False, 

176 parser_class: Type[BaseParser] = DefaultParser, 

177 socket_read_size: int = DEFAULT_SOCKET_READ_SIZE, 

178 health_check_interval: float = 0, 

179 client_name: str | None = None, 

180 lib_name: str | object | None = SENTINEL, 

181 lib_version: str | object | None = SENTINEL, 

182 driver_info: DriverInfo | object | None = SENTINEL, 

183 username: str | None = None, 

184 retry: Retry | None = None, 

185 redis_connect_func: ConnectCallbackT | None = None, 

186 encoder_class: Type[Encoder] = Encoder, 

187 credential_provider: CredentialProvider | None = None, 

188 protocol: int | None = None, 

189 legacy_responses: bool = True, 

190 event_dispatcher: EventDispatcher | None = None, 

191 ): 

192 """ 

193 Initialize a new async Connection. 

194 

195 Parameters 

196 ---------- 

197 driver_info : DriverInfo, optional 

198 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version 

199 are ignored. If not provided, a DriverInfo will be created from lib_name 

200 and lib_version. Explicit None disables CLIENT SETINFO. 

201 lib_name : str, optional 

202 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO. 

203 lib_version : str, optional 

204 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO. 

205 """ 

206 if (username or password) and credential_provider is not None: 

207 raise DataError( 

208 "'username' and 'password' cannot be passed along with 'credential_" 

209 "provider'. Please provide only one of the following arguments: \n" 

210 "1. 'password' and (optional) 'username'\n" 

211 "2. 'credential_provider'" 

212 ) 

213 if event_dispatcher is None: 

214 self._event_dispatcher = EventDispatcher() 

215 else: 

216 self._event_dispatcher = event_dispatcher 

217 self.db = db 

218 self.client_name = client_name 

219 

220 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version. 

221 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version) 

222 

223 self.credential_provider = credential_provider 

224 self.password = password 

225 self.username = username 

226 self.socket_timeout = socket_timeout 

227 if socket_connect_timeout is None: 

228 socket_connect_timeout = socket_timeout 

229 self.socket_connect_timeout = socket_connect_timeout 

230 self.retry_on_timeout = retry_on_timeout 

231 if retry_on_error is SENTINEL: 

232 retry_on_error = [] 

233 if retry_on_timeout: 

234 retry_on_error.append(TimeoutError) 

235 retry_on_error.append(socket.timeout) 

236 retry_on_error.append(asyncio.TimeoutError) 

237 self.retry_on_error = retry_on_error 

238 if retry or retry_on_error: 

239 if not retry: 

240 self.retry = Retry(NoBackoff(), 1) 

241 else: 

242 # deep-copy the Retry object as it is mutable 

243 self.retry = copy.deepcopy(retry) 

244 # Update the retry's supported errors with the specified errors 

245 self.retry.update_supported_errors(retry_on_error) 

246 else: 

247 self.retry = Retry(NoBackoff(), 0) 

248 self.health_check_interval = health_check_interval 

249 self.next_health_check: float = -1 

250 self.encoder = encoder_class(encoding, encoding_errors, decode_responses) 

251 self.redis_connect_func = redis_connect_func 

252 self._reader: Optional[asyncio.StreamReader] = None 

253 self._writer: Optional[asyncio.StreamWriter] = None 

254 self._socket_read_size = socket_read_size 

255 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] 

256 self._buffer_cutoff = 6000 

257 self._re_auth_token: Optional[TokenInterface] = None 

258 self._should_reconnect = False 

259 

260 try: 

261 p = int(protocol) 

262 except TypeError: 

263 p = DEFAULT_RESP_VERSION 

264 except ValueError: 

265 raise ConnectionError("protocol must be an integer") 

266 else: 

267 if p < 2 or p > 3: 

268 raise ConnectionError("protocol must be either 2 or 3") 

269 self.protocol = p 

270 self.legacy_responses = legacy_responses 

271 if parser_class != _AsyncHiredisParser: 

272 # The Python parsers are protocol-specific; hiredis supports both. 

273 if self.protocol == 3 and parser_class == _AsyncRESP2Parser: 

274 parser_class = _AsyncRESP3Parser 

275 elif self.protocol == 2 and parser_class == _AsyncRESP3Parser: 

276 parser_class = _AsyncRESP2Parser 

277 self.set_parser(parser_class) 

278 

279 def __del__(self, _warnings: Any = warnings): 

280 # For some reason, the individual streams don't get properly garbage 

281 # collected and therefore produce no resource warnings. We add one 

282 # here, in the same style as those from the stdlib. 

283 if getattr(self, "_writer", None): 

284 _warnings.warn( 

285 f"unclosed Connection {self!r}", ResourceWarning, source=self 

286 ) 

287 

288 try: 

289 asyncio.get_running_loop() 

290 self._close() 

291 except RuntimeError: 

292 # No actions been taken if pool already closed. 

293 pass 

294 

295 def _close(self): 

296 """ 

297 Internal method to silently close the connection without waiting 

298 """ 

299 if self._writer: 

300 self._writer.close() 

301 self._writer = self._reader = None 

302 

303 def __repr__(self): 

304 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) 

305 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

306 

307 @abstractmethod 

308 def repr_pieces(self): 

309 pass 

310 

311 @property 

312 def is_connected(self): 

313 return self._reader is not None and self._writer is not None 

314 

315 def register_connect_callback(self, callback): 

316 """ 

317 Register a callback to be called when the connection is established either 

318 initially or reconnected. This allows listeners to issue commands that 

319 are ephemeral to the connection, for example pub/sub subscription or 

320 key tracking. The callback must be a _method_ and will be kept as 

321 a weak reference. 

322 """ 

323 wm = weakref.WeakMethod(callback) 

324 if wm not in self._connect_callbacks: 

325 self._connect_callbacks.append(wm) 

326 

327 def deregister_connect_callback(self, callback): 

328 """ 

329 De-register a previously registered callback. It will no-longer receive 

330 notifications on connection events. Calling this is not required when the 

331 listener goes away, since the callbacks are kept as weak methods. 

332 """ 

333 try: 

334 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

335 except ValueError: 

336 pass 

337 

338 def set_parser(self, parser_class: Type[BaseParser]) -> None: 

339 """ 

340 Creates a new instance of parser_class with socket size: 

341 _socket_read_size and assigns it to the parser for the connection 

342 :param parser_class: The required parser class 

343 """ 

344 self._parser = parser_class(socket_read_size=self._socket_read_size) 

345 

346 async def connect(self): 

347 """Connects to the Redis server if not already connected""" 

348 # try once the socket connect with the handshake, retry the whole 

349 # connect/handshake flow based on retry policy 

350 await self.retry.call_with_retry( 

351 lambda: self.connect_check_health( 

352 check_health=True, retry_socket_connect=False 

353 ), 

354 lambda error, failure_count: self.disconnect( 

355 error=error, failure_count=failure_count 

356 ), 

357 with_failure_count=True, 

358 ) 

359 

360 async def connect_check_health( 

361 self, check_health: bool = True, retry_socket_connect: bool = True 

362 ): 

363 if self.is_connected: 

364 return 

365 # Track actual retry attempts for error reporting 

366 actual_retry_attempts = 0 

367 

368 def failure_callback(error, failure_count): 

369 nonlocal actual_retry_attempts 

370 actual_retry_attempts = failure_count 

371 return self.disconnect(error=error, failure_count=failure_count) 

372 

373 try: 

374 if retry_socket_connect: 

375 await self.retry.call_with_retry( 

376 lambda: self._connect(), 

377 failure_callback, 

378 with_failure_count=True, 

379 ) 

380 else: 

381 await self._connect() 

382 except asyncio.CancelledError: 

383 raise # in 3.7 and earlier, this is an Exception, not BaseException 

384 except (socket.timeout, asyncio.TimeoutError): 

385 e = TimeoutError("Timeout connecting to server") 

386 await record_error_count( 

387 server_address=getattr(self, "host", None), 

388 server_port=getattr(self, "port", None), 

389 network_peer_address=getattr(self, "host", None), 

390 network_peer_port=getattr(self, "port", None), 

391 error_type=e, 

392 retry_attempts=actual_retry_attempts, 

393 is_internal=False, 

394 ) 

395 raise e 

396 except OSError as e: 

397 e = ConnectionError(self._error_message(e)) 

398 await record_error_count( 

399 server_address=getattr(self, "host", None), 

400 server_port=getattr(self, "port", None), 

401 network_peer_address=getattr(self, "host", None), 

402 network_peer_port=getattr(self, "port", None), 

403 error_type=e, 

404 retry_attempts=actual_retry_attempts, 

405 is_internal=False, 

406 ) 

407 raise e 

408 except Exception as exc: 

409 raise ConnectionError(exc) from exc 

410 

411 try: 

412 if not self.redis_connect_func: 

413 # Use the default on_connect function 

414 await self.on_connect_check_health(check_health=check_health) 

415 else: 

416 # Use the passed function redis_connect_func 

417 ( 

418 await self.redis_connect_func(self) 

419 if asyncio.iscoroutinefunction(self.redis_connect_func) 

420 else self.redis_connect_func(self) 

421 ) 

422 except RedisError: 

423 # clean up after any error in on_connect 

424 await self.disconnect() 

425 raise 

426 

427 # run any user callbacks. right now the only internal callback 

428 # is for pubsub channel/pattern resubscription 

429 # first, remove any dead weakrefs 

430 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

431 for ref in self._connect_callbacks: 

432 callback = ref() 

433 task = callback(self) 

434 if task and inspect.isawaitable(task): 

435 await task 

436 

437 def mark_for_reconnect(self): 

438 self._should_reconnect = True 

439 

440 def should_reconnect(self): 

441 return self._should_reconnect 

442 

443 def reset_should_reconnect(self): 

444 self._should_reconnect = False 

445 

446 @abstractmethod 

447 async def _connect(self): 

448 pass 

449 

450 @abstractmethod 

451 def _host_error(self) -> str: 

452 pass 

453 

454 def _error_message(self, exception: BaseException) -> str: 

455 return format_error_message(self._host_error(), exception) 

456 

457 def get_protocol(self): 

458 return self.protocol 

459 

460 async def on_connect(self) -> None: 

461 """Initialize the connection, authenticate and select a database""" 

462 await self.on_connect_check_health(check_health=True) 

463 

464 async def on_connect_check_health(self, check_health: bool = True) -> None: 

465 self._parser.on_connect(self) 

466 parser = self._parser 

467 

468 auth_args = None 

469 # if credential provider or username and/or password are set, authenticate 

470 if self.credential_provider or (self.username or self.password): 

471 cred_provider = ( 

472 self.credential_provider 

473 or UsernamePasswordCredentialProvider(self.username, self.password) 

474 ) 

475 auth_args = await cred_provider.get_credentials_async() 

476 

477 # if resp version is specified and we have auth args, 

478 # we need to send them via HELLO 

479 if auth_args and self.protocol not in [2, "2"]: 

480 if isinstance(self._parser, _AsyncRESP2Parser): 

481 self.set_parser(_AsyncRESP3Parser) 

482 # update cluster exception classes 

483 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

484 self._parser.on_connect(self) 

485 if len(auth_args) == 1: 

486 auth_args = ["default", auth_args[0]] 

487 # avoid checking health here -- PING will fail if we try 

488 # to check the health prior to the AUTH 

489 await self.send_command( 

490 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

491 ) 

492 response = await self.read_response() 

493 if response.get(b"proto") != int(self.protocol) and response.get( 

494 "proto" 

495 ) != int(self.protocol): 

496 raise ConnectionError("Invalid RESP version") 

497 # avoid checking health here -- PING will fail if we try 

498 # to check the health prior to the AUTH 

499 elif auth_args: 

500 await self.send_command("AUTH", *auth_args, check_health=False) 

501 

502 try: 

503 auth_response = await self.read_response() 

504 except AuthenticationWrongNumberOfArgsError: 

505 # a username and password were specified but the Redis 

506 # server seems to be < 6.0.0 which expects a single password 

507 # arg. retry auth with just the password. 

508 # https://github.com/andymccurdy/redis-py/issues/1274 

509 await self.send_command("AUTH", auth_args[-1], check_health=False) 

510 auth_response = await self.read_response() 

511 

512 if str_if_bytes(auth_response) != "OK": 

513 raise AuthenticationError("Invalid Username or Password") 

514 

515 # if resp version is specified, switch to it 

516 elif self.protocol not in [2, "2"]: 

517 if isinstance(self._parser, _AsyncRESP2Parser): 

518 self.set_parser(_AsyncRESP3Parser) 

519 # update cluster exception classes 

520 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

521 self._parser.on_connect(self) 

522 await self.send_command("HELLO", self.protocol, check_health=check_health) 

523 response = await self.read_response() 

524 # if response.get(b"proto") != self.protocol and response.get( 

525 # "proto" 

526 # ) != self.protocol: 

527 # raise ConnectionError("Invalid RESP version") 

528 

529 # if a client_name is given, set it 

530 if self.client_name: 

531 await self.send_command( 

532 "CLIENT", 

533 "SETNAME", 

534 self.client_name, 

535 check_health=check_health, 

536 ) 

537 if str_if_bytes(await self.read_response()) != "OK": 

538 raise ConnectionError("Error setting client name") 

539 

540 # Set the library name and version from driver_info, pipeline for lower startup latency 

541 lib_name_sent = False 

542 lib_version_sent = False 

543 

544 if self.driver_info and self.driver_info.formatted_name: 

545 await self.send_command( 

546 "CLIENT", 

547 "SETINFO", 

548 "LIB-NAME", 

549 self.driver_info.formatted_name, 

550 check_health=check_health, 

551 ) 

552 lib_name_sent = True 

553 

554 if self.driver_info and self.driver_info.lib_version: 

555 await self.send_command( 

556 "CLIENT", 

557 "SETINFO", 

558 "LIB-VER", 

559 self.driver_info.lib_version, 

560 check_health=check_health, 

561 ) 

562 lib_version_sent = True 

563 

564 # if a database is specified, switch to it. Also pipeline this 

565 if self.db: 

566 await self.send_command("SELECT", self.db, check_health=check_health) 

567 

568 # read responses from pipeline 

569 for _ in range(sum([lib_name_sent, lib_version_sent])): 

570 try: 

571 await self.read_response() 

572 except ResponseError: 

573 pass 

574 

575 if self.db: 

576 if str_if_bytes(await self.read_response()) != "OK": 

577 raise ConnectionError("Invalid Database") 

578 

579 async def disconnect( 

580 self, 

581 nowait: bool = False, 

582 error: Optional[Exception] = None, 

583 failure_count: Optional[int] = None, 

584 health_check_failed: bool = False, 

585 ) -> None: 

586 """Disconnects from the Redis server""" 

587 # On Python 3.13+, asyncio.timeout() raises RuntimeError when called 

588 # outside a running Task (e.g. during GC finalization or event-loop 

589 # callbacks). In that context we fall back to a synchronous close. 

590 # See https://github.com/redis/redis-py/issues/3856 

591 if asyncio.current_task() is None: 

592 self._parser.on_disconnect() 

593 self.reset_should_reconnect() 

594 self._close() 

595 return 

596 

597 try: 

598 async with async_timeout(self.socket_connect_timeout): 

599 self._parser.on_disconnect() 

600 # Reset the reconnect flag 

601 self.reset_should_reconnect() 

602 if not self.is_connected: 

603 return 

604 try: 

605 self._writer.close() # type: ignore[union-attr] 

606 # wait for close to finish, except when handling errors and 

607 # forcefully disconnecting. 

608 if not nowait: 

609 await self._writer.wait_closed() # type: ignore[union-attr] 

610 except OSError: 

611 pass 

612 finally: 

613 self._reader = None 

614 self._writer = None 

615 except asyncio.TimeoutError: 

616 raise TimeoutError( 

617 f"Timed out closing connection after {self.socket_connect_timeout}" 

618 ) from None 

619 

620 if error: 

621 if health_check_failed: 

622 close_reason = CloseReason.HEALTHCHECK_FAILED 

623 else: 

624 close_reason = CloseReason.ERROR 

625 

626 if failure_count is not None and failure_count > self.retry.get_retries(): 

627 await record_error_count( 

628 server_address=getattr(self, "host", None), 

629 server_port=getattr(self, "port", None), 

630 network_peer_address=getattr(self, "host", None), 

631 network_peer_port=getattr(self, "port", None), 

632 error_type=error, 

633 retry_attempts=failure_count, 

634 ) 

635 

636 await record_connection_closed( 

637 close_reason=close_reason, 

638 error_type=error, 

639 ) 

640 else: 

641 await record_connection_closed( 

642 close_reason=CloseReason.APPLICATION_CLOSE, 

643 ) 

644 

645 async def _send_ping(self): 

646 """Send PING, expect PONG in return""" 

647 await self.send_command("PING", check_health=False) 

648 if str_if_bytes(await self.read_response()) != "PONG": 

649 raise ConnectionError("Bad response from PING health check") 

650 

651 async def _ping_failed(self, error, failure_count): 

652 """Function to call when PING fails""" 

653 await self.disconnect( 

654 error=error, failure_count=failure_count, health_check_failed=True 

655 ) 

656 

657 async def check_health(self): 

658 """Check the health of the connection with a PING/PONG""" 

659 if ( 

660 self.health_check_interval 

661 and asyncio.get_running_loop().time() > self.next_health_check 

662 ): 

663 await self.retry.call_with_retry( 

664 self._send_ping, self._ping_failed, with_failure_count=True 

665 ) 

666 

667 async def _send_packed_command(self, command: Iterable[bytes]) -> None: 

668 self._writer.writelines(command) 

669 await self._writer.drain() 

670 

671 async def send_packed_command( 

672 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True 

673 ) -> None: 

674 if not self.is_connected: 

675 await self.connect_check_health(check_health=False) 

676 if check_health: 

677 await self.check_health() 

678 

679 try: 

680 if isinstance(command, str): 

681 command = command.encode() 

682 if isinstance(command, bytes): 

683 command = [command] 

684 if self.socket_timeout: 

685 await asyncio.wait_for( 

686 self._send_packed_command(command), self.socket_timeout 

687 ) 

688 else: 

689 self._writer.writelines(command) 

690 await self._writer.drain() 

691 except asyncio.TimeoutError: 

692 await self.disconnect(nowait=True) 

693 raise TimeoutError("Timeout writing to socket") from None 

694 except OSError as e: 

695 await self.disconnect(nowait=True) 

696 if len(e.args) == 1: 

697 err_no, errmsg = "UNKNOWN", e.args[0] 

698 else: 

699 err_no = e.args[0] 

700 errmsg = e.args[1] 

701 raise ConnectionError( 

702 f"Error {err_no} while writing to socket. {errmsg}." 

703 ) from e 

704 except BaseException: 

705 # BaseExceptions can be raised when a socket send operation is not 

706 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

707 # to send un-sent data. However, the send_packed_command() API 

708 # does not support it so there is no point in keeping the connection open. 

709 await self.disconnect(nowait=True) 

710 raise 

711 

712 async def send_command(self, *args: Any, **kwargs: Any) -> None: 

713 """Pack and send a command to the Redis server""" 

714 await self.send_packed_command( 

715 self.pack_command(*args), check_health=kwargs.get("check_health", True) 

716 ) 

717 

718 @deprecated_function( 

719 version="8.0.0", reason="Use can_read() instead", name="can_read_destructive" 

720 ) 

721 async def can_read_destructive(self) -> bool: 

722 """Check the socket to see if there's data loaded in the buffer.""" 

723 try: 

724 return await self._parser.can_read() 

725 except OSError as e: 

726 await self.disconnect(nowait=True) 

727 host_error = self._host_error() 

728 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

729 

730 async def can_read(self) -> bool: 

731 """Check the socket to see if there's data loaded in the buffer.""" 

732 # TODO: Rename this API; it detects pending data or dirty/closed 

733 # connection state, not only whether application data can be read. 

734 try: 

735 return await self._parser.can_read() 

736 except OSError as e: 

737 await self.disconnect(nowait=True) 

738 host_error = self._host_error() 

739 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

740 

741 async def read_response( 

742 self, 

743 disable_decoding: bool = False, 

744 timeout: Optional[float] = None, 

745 *, 

746 disconnect_on_error: bool = True, 

747 push_request: Optional[bool] = False, 

748 ): 

749 """Read the response from a previously sent command""" 

750 read_timeout = timeout if timeout is not None else self.socket_timeout 

751 host_error = self._host_error() 

752 try: 

753 if read_timeout is not None and self.protocol in ["3", 3]: 

754 async with async_timeout(read_timeout): 

755 response = await self._parser.read_response( 

756 disable_decoding=disable_decoding, push_request=push_request 

757 ) 

758 elif read_timeout is not None: 

759 async with async_timeout(read_timeout): 

760 response = await self._parser.read_response( 

761 disable_decoding=disable_decoding 

762 ) 

763 elif self.protocol in ["3", 3]: 

764 response = await self._parser.read_response( 

765 disable_decoding=disable_decoding, push_request=push_request 

766 ) 

767 else: 

768 response = await self._parser.read_response( 

769 disable_decoding=disable_decoding 

770 ) 

771 except asyncio.TimeoutError: 

772 if timeout is not None: 

773 # user requested timeout, return None. Operation can be retried 

774 return None 

775 # it was a self.socket_timeout error. 

776 if disconnect_on_error: 

777 await self.disconnect(nowait=True) 

778 raise TimeoutError(f"Timeout reading from {host_error}") 

779 except OSError as e: 

780 if disconnect_on_error: 

781 await self.disconnect(nowait=True) 

782 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

783 except BaseException: 

784 # Also by default close in case of BaseException. A lot of code 

785 # relies on this behaviour when doing Command/Response pairs. 

786 # See #1128. 

787 if disconnect_on_error: 

788 await self.disconnect(nowait=True) 

789 raise 

790 

791 if self.health_check_interval: 

792 next_time = asyncio.get_running_loop().time() + self.health_check_interval 

793 self.next_health_check = next_time 

794 

795 if isinstance(response, ResponseError): 

796 raise response from None 

797 return response 

798 

799 def pack_command(self, *args: EncodableT) -> List[bytes]: 

800 """Pack a series of arguments into the Redis protocol""" 

801 output = [] 

802 # the client might have included 1 or more literal arguments in 

803 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

804 # arguments to be sent separately, so split the first argument 

805 # manually. These arguments should be bytestrings so that they are 

806 # not encoded. 

807 assert not isinstance(args[0], float) 

808 if isinstance(args[0], str): 

809 args = tuple(args[0].encode().split()) + args[1:] 

810 elif b" " in args[0]: 

811 args = tuple(args[0].split()) + args[1:] 

812 

813 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

814 

815 buffer_cutoff = self._buffer_cutoff 

816 for arg in map(self.encoder.encode, args): 

817 # to avoid large string mallocs, chunk the command into the 

818 # output list if we're sending large values or memoryviews 

819 arg_length = len(arg) 

820 if ( 

821 len(buff) > buffer_cutoff 

822 or arg_length > buffer_cutoff 

823 or isinstance(arg, memoryview) 

824 ): 

825 buff = SYM_EMPTY.join( 

826 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

827 ) 

828 output.append(buff) 

829 output.append(arg) 

830 buff = SYM_CRLF 

831 else: 

832 buff = SYM_EMPTY.join( 

833 ( 

834 buff, 

835 SYM_DOLLAR, 

836 str(arg_length).encode(), 

837 SYM_CRLF, 

838 arg, 

839 SYM_CRLF, 

840 ) 

841 ) 

842 output.append(buff) 

843 return output 

844 

845 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: 

846 """Pack multiple commands into the Redis protocol""" 

847 output: List[bytes] = [] 

848 pieces: List[bytes] = [] 

849 buffer_length = 0 

850 buffer_cutoff = self._buffer_cutoff 

851 

852 for cmd in commands: 

853 for chunk in self.pack_command(*cmd): 

854 chunklen = len(chunk) 

855 if ( 

856 buffer_length > buffer_cutoff 

857 or chunklen > buffer_cutoff 

858 or isinstance(chunk, memoryview) 

859 ): 

860 if pieces: 

861 output.append(SYM_EMPTY.join(pieces)) 

862 buffer_length = 0 

863 pieces = [] 

864 

865 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

866 output.append(chunk) 

867 else: 

868 pieces.append(chunk) 

869 buffer_length += chunklen 

870 

871 if pieces: 

872 output.append(SYM_EMPTY.join(pieces)) 

873 return output 

874 

875 def _socket_is_empty(self): 

876 """Check if the socket is empty""" 

877 return len(self._reader._buffer) == 0 

878 

879 async def process_invalidation_messages(self): 

880 while not self._socket_is_empty(): 

881 await self.read_response(push_request=True) 

882 

883 def set_re_auth_token(self, token: TokenInterface): 

884 self._re_auth_token = token 

885 

886 async def re_auth(self): 

887 if self._re_auth_token is not None: 

888 await self.send_command( 

889 "AUTH", 

890 self._re_auth_token.try_get("oid"), 

891 self._re_auth_token.get_value(), 

892 ) 

893 await self.read_response() 

894 self._re_auth_token = None 

895 

896 

897class Connection(AbstractConnection): 

898 "Manages TCP communication to and from a Redis server" 

899 

900 def __init__( 

901 self, 

902 *, 

903 host: str = "localhost", 

904 port: str | int = 6379, 

905 socket_keepalive: bool = True, 

906 socket_keepalive_options: Mapping[int, int | bytes] | object | None = SENTINEL, 

907 socket_type: int = 0, 

908 **kwargs, 

909 ): 

910 """ 

911 Initialize a TCP connection. 

912 

913 Parameters 

914 ---------- 

915 socket_keepalive : bool 

916 If `True`, TCP keepalive is enabled for TCP socket connections. 

917 socket_keepalive_options : Mapping[int, int | bytes] | object | None 

918 Mapping of TCP keepalive socket option constants to values, for 

919 example `{socket.TCP_KEEPIDLE: 30}`. If left unspecified, redis-py 

920 uses TCP keepalive defaults when `socket_keepalive` is enabled: 

921 idle 30 seconds, interval 5 seconds, and 3 probes. Platform-specific 

922 options that are not available are skipped. Pass `None` or `{}` to 

923 avoid setting additional TCP keepalive options. 

924 """ 

925 self.host = host 

926 self.port = int(port) 

927 self.socket_keepalive = socket_keepalive 

928 if socket_keepalive_options is SENTINEL: 

929 socket_keepalive_options = get_default_socket_keepalive_options() 

930 self.socket_keepalive_options = socket_keepalive_options or {} 

931 self.socket_type = socket_type 

932 super().__init__(**kwargs) 

933 

934 def repr_pieces(self): 

935 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

936 if self.client_name: 

937 pieces.append(("client_name", self.client_name)) 

938 return pieces 

939 

940 def _connection_arguments(self) -> Mapping: 

941 return {"host": self.host, "port": self.port} 

942 

943 async def _connect(self): 

944 """Create a TCP socket connection""" 

945 async with async_timeout(self.socket_connect_timeout): 

946 reader, writer = await asyncio.open_connection( 

947 **self._connection_arguments() 

948 ) 

949 self._reader = reader 

950 self._writer = writer 

951 sock = writer.transport.get_extra_info("socket") 

952 if sock: 

953 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

954 try: 

955 # TCP_KEEPALIVE 

956 if self.socket_keepalive: 

957 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

958 for k, v in self.socket_keepalive_options.items(): 

959 sock.setsockopt(socket.SOL_TCP, k, v) 

960 

961 except (OSError, TypeError): 

962 # `socket_keepalive_options` might contain invalid options 

963 # causing an error. Do not leave the connection open. 

964 writer.close() 

965 raise 

966 

967 def _host_error(self) -> str: 

968 return f"{self.host}:{self.port}" 

969 

970 

971class SSLConnection(Connection): 

972 """Manages SSL connections to and from the Redis server(s). 

973 This class extends the Connection class, adding SSL functionality, and making 

974 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

975 """ 

976 

977 def __init__( 

978 self, 

979 ssl_keyfile: Optional[str] = None, 

980 ssl_certfile: Optional[str] = None, 

981 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", 

982 ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

983 ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

984 ssl_ca_certs: Optional[str] = None, 

985 ssl_ca_data: Optional[str] = None, 

986 ssl_ca_path: Optional[str] = None, 

987 ssl_check_hostname: bool = True, 

988 ssl_min_version: Optional[TLSVersion] = None, 

989 ssl_ciphers: Optional[str] = None, 

990 ssl_password: Optional[str] = None, 

991 **kwargs, 

992 ): 

993 if not SSL_AVAILABLE: 

994 raise RedisError("Python wasn't built with SSL support") 

995 

996 self.ssl_context: RedisSSLContext = RedisSSLContext( 

997 keyfile=ssl_keyfile, 

998 certfile=ssl_certfile, 

999 cert_reqs=ssl_cert_reqs, 

1000 include_verify_flags=ssl_include_verify_flags, 

1001 exclude_verify_flags=ssl_exclude_verify_flags, 

1002 ca_certs=ssl_ca_certs, 

1003 ca_data=ssl_ca_data, 

1004 ca_path=ssl_ca_path, 

1005 check_hostname=ssl_check_hostname, 

1006 min_version=ssl_min_version, 

1007 ciphers=ssl_ciphers, 

1008 password=ssl_password, 

1009 ) 

1010 super().__init__(**kwargs) 

1011 

1012 def _connection_arguments(self) -> Mapping: 

1013 kwargs = super()._connection_arguments() 

1014 kwargs["ssl"] = self.ssl_context.get() 

1015 return kwargs 

1016 

1017 @property 

1018 def keyfile(self): 

1019 return self.ssl_context.keyfile 

1020 

1021 @property 

1022 def certfile(self): 

1023 return self.ssl_context.certfile 

1024 

1025 @property 

1026 def cert_reqs(self): 

1027 return self.ssl_context.cert_reqs 

1028 

1029 @property 

1030 def include_verify_flags(self): 

1031 return self.ssl_context.include_verify_flags 

1032 

1033 @property 

1034 def exclude_verify_flags(self): 

1035 return self.ssl_context.exclude_verify_flags 

1036 

1037 @property 

1038 def ca_certs(self): 

1039 return self.ssl_context.ca_certs 

1040 

1041 @property 

1042 def ca_data(self): 

1043 return self.ssl_context.ca_data 

1044 

1045 @property 

1046 def check_hostname(self): 

1047 return self.ssl_context.check_hostname 

1048 

1049 @property 

1050 def min_version(self): 

1051 return self.ssl_context.min_version 

1052 

1053 

1054class RedisSSLContext: 

1055 __slots__ = ( 

1056 "keyfile", 

1057 "certfile", 

1058 "cert_reqs", 

1059 "include_verify_flags", 

1060 "exclude_verify_flags", 

1061 "ca_certs", 

1062 "ca_data", 

1063 "ca_path", 

1064 "context", 

1065 "check_hostname", 

1066 "min_version", 

1067 "ciphers", 

1068 "password", 

1069 ) 

1070 

1071 def __init__( 

1072 self, 

1073 keyfile: Optional[str] = None, 

1074 certfile: Optional[str] = None, 

1075 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, 

1076 include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

1077 exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None, 

1078 ca_certs: Optional[str] = None, 

1079 ca_data: Optional[str] = None, 

1080 ca_path: Optional[str] = None, 

1081 check_hostname: bool = False, 

1082 min_version: Optional[TLSVersion] = None, 

1083 ciphers: Optional[str] = None, 

1084 password: Optional[str] = None, 

1085 ): 

1086 if not SSL_AVAILABLE: 

1087 raise RedisError("Python wasn't built with SSL support") 

1088 

1089 self.keyfile = keyfile 

1090 self.certfile = certfile 

1091 if cert_reqs is None: 

1092 cert_reqs = ssl.CERT_NONE 

1093 elif isinstance(cert_reqs, str): 

1094 CERT_REQS = { # noqa: N806 

1095 "none": ssl.CERT_NONE, 

1096 "optional": ssl.CERT_OPTIONAL, 

1097 "required": ssl.CERT_REQUIRED, 

1098 } 

1099 if cert_reqs not in CERT_REQS: 

1100 raise RedisError( 

1101 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" 

1102 ) 

1103 cert_reqs = CERT_REQS[cert_reqs] 

1104 self.cert_reqs = cert_reqs 

1105 self.include_verify_flags = include_verify_flags 

1106 self.exclude_verify_flags = exclude_verify_flags 

1107 self.ca_certs = ca_certs 

1108 self.ca_data = ca_data 

1109 self.ca_path = ca_path 

1110 self.check_hostname = ( 

1111 check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1112 ) 

1113 self.min_version = min_version 

1114 self.ciphers = ciphers 

1115 self.password = password 

1116 self.context: Optional[SSLContext] = None 

1117 

1118 def get(self) -> SSLContext: 

1119 if not self.context: 

1120 context = ssl.create_default_context() 

1121 context.check_hostname = self.check_hostname 

1122 context.verify_mode = self.cert_reqs 

1123 if self.include_verify_flags: 

1124 for flag in self.include_verify_flags: 

1125 context.verify_flags |= flag 

1126 if self.exclude_verify_flags: 

1127 for flag in self.exclude_verify_flags: 

1128 context.verify_flags &= ~flag 

1129 if self.certfile or self.keyfile: 

1130 context.load_cert_chain( 

1131 certfile=self.certfile, 

1132 keyfile=self.keyfile, 

1133 password=self.password, 

1134 ) 

1135 if self.ca_certs or self.ca_data or self.ca_path: 

1136 context.load_verify_locations( 

1137 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1138 ) 

1139 if self.min_version is not None: 

1140 context.minimum_version = self.min_version 

1141 if self.ciphers is not None: 

1142 context.set_ciphers(self.ciphers) 

1143 self.context = context 

1144 return self.context 

1145 

1146 

1147class UnixDomainSocketConnection(AbstractConnection): 

1148 "Manages UDS communication to and from a Redis server" 

1149 

1150 def __init__(self, *, path: str = "", **kwargs): 

1151 self.path = path 

1152 super().__init__(**kwargs) 

1153 

1154 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: 

1155 pieces = [("path", self.path), ("db", self.db)] 

1156 if self.client_name: 

1157 pieces.append(("client_name", self.client_name)) 

1158 return pieces 

1159 

1160 async def _connect(self): 

1161 async with async_timeout(self.socket_connect_timeout): 

1162 reader, writer = await asyncio.open_unix_connection(path=self.path) 

1163 self._reader = reader 

1164 self._writer = writer 

1165 await self.on_connect() 

1166 

1167 def _host_error(self) -> str: 

1168 return self.path 

1169 

1170 

1171FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1172 

1173 

1174def to_bool(value) -> Optional[bool]: 

1175 if value is None or value == "": 

1176 return None 

1177 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1178 return False 

1179 return bool(value) 

1180 

1181 

1182def parse_ssl_verify_flags(value): 

1183 # flags are passed in as a string representation of a list, 

1184 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

1185 verify_flags_str = value.replace("[", "").replace("]", "") 

1186 

1187 verify_flags = [] 

1188 for flag in verify_flags_str.split(","): 

1189 flag = flag.strip() 

1190 if not hasattr(VerifyFlags, flag): 

1191 raise ValueError(f"Invalid ssl verify flag: {flag}") 

1192 verify_flags.append(getattr(VerifyFlags, flag)) 

1193 return verify_flags 

1194 

1195 

1196URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( 

1197 { 

1198 "db": int, 

1199 "socket_timeout": float, 

1200 "socket_connect_timeout": float, 

1201 "socket_read_size": int, 

1202 "socket_keepalive": to_bool, 

1203 "retry_on_timeout": to_bool, 

1204 "max_connections": int, 

1205 "health_check_interval": int, 

1206 "ssl_check_hostname": to_bool, 

1207 "ssl_include_verify_flags": parse_ssl_verify_flags, 

1208 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

1209 "ssl_min_version": int, 

1210 "timeout": float, 

1211 "protocol": int, 

1212 "legacy_responses": to_bool, 

1213 } 

1214) 

1215 

1216 

1217class ConnectKwargs(TypedDict, total=False): 

1218 username: str 

1219 password: str 

1220 connection_class: Type[AbstractConnection] 

1221 host: str 

1222 port: int 

1223 db: int 

1224 path: str 

1225 

1226 

1227def parse_url(url: str) -> ConnectKwargs: 

1228 parsed: ParseResult = urlparse(url) 

1229 kwargs: ConnectKwargs = {} 

1230 

1231 for name, value_list in parse_qs(parsed.query).items(): 

1232 if value_list and len(value_list) > 0: 

1233 value = unquote(value_list[0]) 

1234 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1235 if parser: 

1236 try: 

1237 kwargs[name] = parser(value) 

1238 except (TypeError, ValueError): 

1239 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1240 else: 

1241 kwargs[name] = value 

1242 

1243 if parsed.username: 

1244 kwargs["username"] = unquote(parsed.username) 

1245 if parsed.password: 

1246 kwargs["password"] = unquote(parsed.password) 

1247 

1248 # We only support redis://, rediss:// and unix:// schemes. 

1249 if parsed.scheme == "unix": 

1250 if parsed.path: 

1251 kwargs["path"] = unquote(parsed.path) 

1252 kwargs["connection_class"] = UnixDomainSocketConnection 

1253 

1254 elif parsed.scheme in ("redis", "rediss"): 

1255 if parsed.hostname: 

1256 kwargs["host"] = unquote(parsed.hostname) 

1257 if parsed.port: 

1258 kwargs["port"] = int(parsed.port) 

1259 

1260 # If there's a path argument, use it as the db argument if a 

1261 # querystring value wasn't specified 

1262 if parsed.path and "db" not in kwargs: 

1263 try: 

1264 kwargs["db"] = int(unquote(parsed.path).replace("/", "")) 

1265 except (AttributeError, ValueError): 

1266 pass 

1267 

1268 if parsed.scheme == "rediss": 

1269 kwargs["connection_class"] = SSLConnection 

1270 

1271 else: 

1272 valid_schemes = "redis://, rediss://, unix://" 

1273 raise ValueError( 

1274 f"Redis URL must specify one of the following schemes ({valid_schemes})" 

1275 ) 

1276 

1277 return kwargs 

1278 

1279 

1280_CP = TypeVar("_CP", bound="ConnectionPool") 

1281 

1282 

1283class ConnectionPoolInterface(ABC): 

1284 @abstractmethod 

1285 def get_protocol(self): 

1286 pass 

1287 

1288 @abstractmethod 

1289 def reset(self) -> None: 

1290 pass 

1291 

1292 @abstractmethod 

1293 @deprecated_args( 

1294 args_to_warn=["*"], 

1295 reason="Use get_connection() without args instead", 

1296 version="5.3.0", 

1297 ) 

1298 async def get_connection( 

1299 self, command_name: Optional[str] = None, *keys: Any, **options: Any 

1300 ) -> "AbstractConnection": 

1301 pass 

1302 

1303 @abstractmethod 

1304 def get_encoder(self) -> "Encoder": 

1305 pass 

1306 

1307 @abstractmethod 

1308 async def release(self, connection: "AbstractConnection") -> None: 

1309 pass 

1310 

1311 @abstractmethod 

1312 async def disconnect(self, inuse_connections: bool = True) -> None: 

1313 pass 

1314 

1315 @abstractmethod 

1316 async def aclose(self) -> None: 

1317 pass 

1318 

1319 @abstractmethod 

1320 def set_retry(self, retry: "Retry") -> None: 

1321 pass 

1322 

1323 @abstractmethod 

1324 async def re_auth_callback(self, token: TokenInterface) -> None: 

1325 pass 

1326 

1327 @abstractmethod 

1328 def get_connection_count(self) -> List[Tuple[int, dict]]: 

1329 """ 

1330 Returns a connection count (both idle and in use). 

1331 """ 

1332 pass 

1333 

1334 

1335class ConnectionPool(ConnectionPoolInterface): 

1336 """ 

1337 Create a connection pool. ``If max_connections`` is set, then this 

1338 object raises :py:class:`~redis.ConnectionError` when the pool's 

1339 limit is reached. 

1340 

1341 By default, TCP connections are created unless ``connection_class`` 

1342 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for 

1343 unix sockets. 

1344 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

1345 

1346 Any additional keyword arguments are passed to the constructor of 

1347 ``connection_class``. 

1348 """ 

1349 

1350 @classmethod 

1351 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

1352 """ 

1353 Return a connection pool configured from the given URL. 

1354 

1355 For example:: 

1356 

1357 redis://[[username]:[password]]@localhost:6379/0 

1358 rediss://[[username]:[password]]@localhost:6379/0 

1359 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1360 

1361 Three URL schemes are supported: 

1362 

1363 - `redis://` creates a TCP socket connection. See more at: 

1364 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1365 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1366 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1367 - ``unix://``: creates a Unix Domain Socket connection. 

1368 

1369 The username, password, hostname, path and all querystring values 

1370 are passed through urllib.parse.unquote in order to replace any 

1371 percent-encoded values with their corresponding characters. 

1372 

1373 There are several ways to specify a database number. The first value 

1374 found will be used: 

1375 

1376 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1377 

1378 2. If using the redis:// or rediss:// schemes, the path argument 

1379 of the url, e.g. redis://localhost/0 

1380 

1381 3. A ``db`` keyword argument to this function. 

1382 

1383 If none of these options are specified, the default db=0 is used. 

1384 

1385 All querystring options are cast to their appropriate Python types. 

1386 Boolean arguments can be specified with string values "True"/"False" 

1387 or "Yes"/"No". Values that cannot be properly cast cause a 

1388 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1389 and keyword arguments are passed to the ``ConnectionPool``'s 

1390 class initializer. In the case of conflicting arguments, querystring 

1391 arguments always win. 

1392 """ 

1393 url_options = parse_url(url) 

1394 kwargs.update(url_options) 

1395 return cls(**kwargs) 

1396 

1397 def __init__( 

1398 self, 

1399 connection_class: Type[AbstractConnection] = Connection, 

1400 max_connections: Optional[int] = None, 

1401 **connection_kwargs, 

1402 ): 

1403 max_connections = max_connections or 100 

1404 if not isinstance(max_connections, int) or max_connections < 0: 

1405 raise ValueError('"max_connections" must be a positive integer') 

1406 

1407 self.connection_class = connection_class 

1408 self.connection_kwargs = connection_kwargs 

1409 self.max_connections = max_connections 

1410 

1411 self._available_connections: List[AbstractConnection] = [] 

1412 self._in_use_connections: Set[AbstractConnection] = set() 

1413 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) 

1414 self._lock = asyncio.Lock() 

1415 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1416 if self._event_dispatcher is None: 

1417 self._event_dispatcher = EventDispatcher() 

1418 

1419 # Keys that should be redacted in __repr__ to avoid exposing sensitive information 

1420 SENSITIVE_REPR_KEYS = frozenset( 

1421 { 

1422 "password", 

1423 "username", 

1424 "ssl_password", 

1425 "credential_provider", 

1426 } 

1427 ) 

1428 

1429 def __repr__(self): 

1430 conn_kwargs = ",".join( 

1431 [ 

1432 f"{k}={'<REDACTED>' if k in self.SENSITIVE_REPR_KEYS else v}" 

1433 for k, v in self.connection_kwargs.items() 

1434 ] 

1435 ) 

1436 return ( 

1437 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

1438 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

1439 f"({conn_kwargs})>)>" 

1440 ) 

1441 

1442 def get_protocol(self): 

1443 """ 

1444 Returns: 

1445 The RESP protocol version, or ``None`` if the protocol is not specified, 

1446 in which case the server default will be used. 

1447 """ 

1448 return self.connection_kwargs.get("protocol", None) 

1449 

1450 def reset(self): 

1451 # Record metrics for connections being removed before clearing 

1452 # (only if attributes exist - they won't during __init__) 

1453 if hasattr(self, "_available_connections") and hasattr( 

1454 self, "_in_use_connections" 

1455 ): 

1456 idle_count = len(self._available_connections) 

1457 in_use_count = len(self._in_use_connections) 

1458 if idle_count > 0 or in_use_count > 0: 

1459 pool_name = get_pool_name(self) 

1460 # Note: Using sync version since reset() is sync 

1461 from redis.observability.recorder import ( 

1462 record_connection_count as sync_record_connection_count, 

1463 ) 

1464 

1465 if idle_count > 0: 

1466 sync_record_connection_count( 

1467 pool_name=pool_name, 

1468 connection_state=ConnectionState.IDLE, 

1469 counter=-idle_count, 

1470 ) 

1471 if in_use_count > 0: 

1472 sync_record_connection_count( 

1473 pool_name=pool_name, 

1474 connection_state=ConnectionState.USED, 

1475 counter=-in_use_count, 

1476 ) 

1477 

1478 self._available_connections = [] 

1479 self._in_use_connections = weakref.WeakSet() 

1480 

1481 def __del__(self) -> None: 

1482 """Clean up connection pool and record metrics when garbage collected.""" 

1483 try: 

1484 if not hasattr(self, "_available_connections") or not hasattr( 

1485 self, "_in_use_connections" 

1486 ): 

1487 return 

1488 idle_count = len(self._available_connections) 

1489 in_use_count = len(self._in_use_connections) 

1490 if idle_count > 0 or in_use_count > 0: 

1491 pool_name = get_pool_name(self) 

1492 # Note: Using sync version since __del__ is sync 

1493 from redis.observability.recorder import ( 

1494 record_connection_count as sync_record_connection_count, 

1495 ) 

1496 

1497 if idle_count > 0: 

1498 sync_record_connection_count( 

1499 pool_name=pool_name, 

1500 connection_state=ConnectionState.IDLE, 

1501 counter=-idle_count, 

1502 ) 

1503 if in_use_count > 0: 

1504 sync_record_connection_count( 

1505 pool_name=pool_name, 

1506 connection_state=ConnectionState.USED, 

1507 counter=-in_use_count, 

1508 ) 

1509 except Exception: 

1510 pass 

1511 

1512 def can_get_connection(self) -> bool: 

1513 """Return True if a connection can be retrieved from the pool.""" 

1514 return ( 

1515 self._available_connections 

1516 or len(self._in_use_connections) < self.max_connections 

1517 ) 

1518 

1519 @deprecated_args( 

1520 args_to_warn=["*"], 

1521 reason="Use get_connection() without args instead", 

1522 version="5.3.0", 

1523 ) 

1524 async def get_connection(self, command_name=None, *keys, **options): 

1525 """Get a connected connection from the pool""" 

1526 # Track connection count before to detect if a new connection is created 

1527 async with self._lock: 

1528 connections_before = len(self._available_connections) + len( 

1529 self._in_use_connections 

1530 ) 

1531 start_time_created = time.monotonic() 

1532 connection = self.get_available_connection() 

1533 connections_after = len(self._available_connections) + len( 

1534 self._in_use_connections 

1535 ) 

1536 is_created = connections_after > connections_before 

1537 

1538 # Record state transition for observability 

1539 # This ensures counters stay balanced if ensure_connection() fails and release() is called 

1540 pool_name = get_pool_name(self) 

1541 if is_created: 

1542 # New connection created and acquired: just USED +1 

1543 await record_connection_count( 

1544 pool_name=pool_name, 

1545 connection_state=ConnectionState.USED, 

1546 counter=1, 

1547 ) 

1548 else: 

1549 # Existing connection acquired from pool: IDLE -> USED 

1550 await record_connection_count( 

1551 pool_name=pool_name, 

1552 connection_state=ConnectionState.IDLE, 

1553 counter=-1, 

1554 ) 

1555 await record_connection_count( 

1556 pool_name=pool_name, 

1557 connection_state=ConnectionState.USED, 

1558 counter=1, 

1559 ) 

1560 

1561 # We now perform the connection check outside of the lock. 

1562 try: 

1563 await self.ensure_connection(connection) 

1564 

1565 if is_created: 

1566 await record_connection_create_time( 

1567 connection_pool=self, 

1568 duration_seconds=time.monotonic() - start_time_created, 

1569 ) 

1570 

1571 return connection 

1572 except BaseException: 

1573 await self.release(connection) 

1574 raise 

1575 

1576 def get_available_connection(self): 

1577 """Get a connection from the pool, without making sure it is connected""" 

1578 try: 

1579 connection = self._available_connections.pop() 

1580 except IndexError: 

1581 if len(self._in_use_connections) >= self.max_connections: 

1582 raise MaxConnectionsError("Too many connections") from None 

1583 connection = self.make_connection() 

1584 self._in_use_connections.add(connection) 

1585 return connection 

1586 

1587 def get_encoder(self): 

1588 """Return an encoder based on encoding settings""" 

1589 kwargs = self.connection_kwargs 

1590 return self.encoder_class( 

1591 encoding=kwargs.get("encoding", "utf-8"), 

1592 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1593 decode_responses=kwargs.get("decode_responses", False), 

1594 ) 

1595 

1596 def make_connection(self): 

1597 """Create a new connection. Can be overridden by child classes.""" 

1598 # Note: We don't record IDLE here because async uses a sync make_connection 

1599 # but async record_connection_count. The recording is handled in get_connection. 

1600 return self.connection_class(**self.connection_kwargs) 

1601 

1602 async def ensure_connection(self, connection: AbstractConnection): 

1603 """Ensure that the connection object is connected and valid""" 

1604 await connection.connect() 

1605 # connections that the pool provides should be ready to send 

1606 # a command. if not, the connection was either returned to the 

1607 # pool before all data has been read or the socket has been 

1608 # closed. either way, reconnect and verify everything is good. 

1609 try: 

1610 if await connection.can_read(): 

1611 raise ConnectionError("Connection has data") from None 

1612 except (ConnectionError, TimeoutError, OSError): 

1613 await connection.disconnect() 

1614 await connection.connect() 

1615 if await connection.can_read(): 

1616 raise ConnectionError("Connection not ready") from None 

1617 

1618 async def release(self, connection: AbstractConnection): 

1619 """Releases the connection back to the pool""" 

1620 # Connections should always be returned to the correct pool, 

1621 # not doing so is an error that will cause an exception here. 

1622 self._in_use_connections.remove(connection) 

1623 

1624 if connection.should_reconnect(): 

1625 await connection.disconnect() 

1626 

1627 self._available_connections.append(connection) 

1628 await self._event_dispatcher.dispatch_async( 

1629 AsyncAfterConnectionReleasedEvent(connection) 

1630 ) 

1631 

1632 # Record state transition: USED -> IDLE 

1633 pool_name = get_pool_name(self) 

1634 await record_connection_count( 

1635 pool_name=pool_name, 

1636 connection_state=ConnectionState.USED, 

1637 counter=-1, 

1638 ) 

1639 await record_connection_count( 

1640 pool_name=pool_name, 

1641 connection_state=ConnectionState.IDLE, 

1642 counter=1, 

1643 ) 

1644 

1645 async def disconnect(self, inuse_connections: bool = True): 

1646 """ 

1647 Disconnects connections in the pool 

1648 

1649 If ``inuse_connections`` is True, disconnect connections that are 

1650 current in use, potentially by other tasks. Otherwise only disconnect 

1651 connections that are idle in the pool. 

1652 """ 

1653 if inuse_connections: 

1654 connections: Iterable[AbstractConnection] = chain( 

1655 self._available_connections, self._in_use_connections 

1656 ) 

1657 else: 

1658 connections = self._available_connections 

1659 resp = await asyncio.gather( 

1660 *(connection.disconnect() for connection in connections), 

1661 return_exceptions=True, 

1662 ) 

1663 

1664 exc = next((r for r in resp if isinstance(r, BaseException)), None) 

1665 if exc: 

1666 raise exc 

1667 

1668 async def update_active_connections_for_reconnect(self): 

1669 """ 

1670 Mark all active connections for reconnect. 

1671 """ 

1672 async with self._lock: 

1673 for conn in self._in_use_connections: 

1674 conn.mark_for_reconnect() 

1675 

1676 async def aclose(self) -> None: 

1677 """Close the pool, disconnecting all connections""" 

1678 await self.disconnect() 

1679 

1680 def set_retry(self, retry: "Retry") -> None: 

1681 for conn in self._available_connections: 

1682 conn.retry = retry 

1683 for conn in self._in_use_connections: 

1684 conn.retry = retry 

1685 

1686 async def re_auth_callback(self, token: TokenInterface): 

1687 async with self._lock: 

1688 for conn in self._available_connections: 

1689 await conn.retry.call_with_retry( 

1690 lambda: conn.send_command( 

1691 "AUTH", token.try_get("oid"), token.get_value() 

1692 ), 

1693 lambda error: self._mock(error), 

1694 ) 

1695 await conn.retry.call_with_retry( 

1696 lambda: conn.read_response(), lambda error: self._mock(error) 

1697 ) 

1698 for conn in self._in_use_connections: 

1699 conn.set_re_auth_token(token) 

1700 

1701 async def _mock(self, error: RedisError): 

1702 """ 

1703 Dummy functions, needs to be passed as error callback to retry object. 

1704 :param error: 

1705 :return: 

1706 """ 

1707 pass 

1708 

1709 def get_connection_count(self) -> List[tuple[int, dict]]: 

1710 """ 

1711 Returns a connection count (both idle and in use). 

1712 """ 

1713 attributes = AttributeBuilder.build_base_attributes() 

1714 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self) 

1715 free_connections_attributes = attributes.copy() 

1716 in_use_connections_attributes = attributes.copy() 

1717 

1718 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ( 

1719 ConnectionState.IDLE.value 

1720 ) 

1721 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ( 

1722 ConnectionState.USED.value 

1723 ) 

1724 

1725 return [ 

1726 (len(self._available_connections), free_connections_attributes), 

1727 (len(self._in_use_connections), in_use_connections_attributes), 

1728 ] 

1729 

1730 

1731class BlockingConnectionPool(ConnectionPool): 

1732 """ 

1733 A blocking connection pool:: 

1734 

1735 >>> from redis.asyncio import Redis, BlockingConnectionPool 

1736 >>> client = Redis.from_pool(BlockingConnectionPool()) 

1737 

1738 It performs the same function as the default 

1739 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that, 

1740 it maintains a pool of reusable connections that can be shared by 

1741 multiple async redis clients. 

1742 

1743 The difference is that, in the event that a client tries to get a 

1744 connection from the pool when all of connections are in use, rather than 

1745 raising a :py:class:`~redis.ConnectionError` (as the default 

1746 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it 

1747 blocks the current `Task` for a specified number of seconds until 

1748 a connection becomes available. 

1749 

1750 Use ``max_connections`` to increase / decrease the pool size:: 

1751 

1752 >>> pool = BlockingConnectionPool(max_connections=10) 

1753 

1754 Use ``timeout`` to tell it either how many seconds to wait for a connection 

1755 to become available, or to block forever: 

1756 

1757 >>> # Block forever. 

1758 >>> pool = BlockingConnectionPool(timeout=None) 

1759 

1760 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

1761 >>> # not available. 

1762 >>> pool = BlockingConnectionPool(timeout=5) 

1763 """ 

1764 

1765 def __init__( 

1766 self, 

1767 max_connections: int = 50, 

1768 timeout: Optional[float] = 20, 

1769 connection_class: Type[AbstractConnection] = Connection, 

1770 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated 

1771 **connection_kwargs, 

1772 ): 

1773 super().__init__( 

1774 connection_class=connection_class, 

1775 max_connections=max_connections, 

1776 **connection_kwargs, 

1777 ) 

1778 self._condition = asyncio.Condition() 

1779 self.timeout = timeout 

1780 

1781 @deprecated_args( 

1782 args_to_warn=["*"], 

1783 reason="Use get_connection() without args instead", 

1784 version="5.3.0", 

1785 ) 

1786 async def get_connection(self, command_name=None, *keys, **options): 

1787 """Gets a connection from the pool, blocking until one is available""" 

1788 # Start timing for wait time observability 

1789 start_time_acquired = time.monotonic() 

1790 

1791 try: 

1792 async with self._condition: 

1793 async with async_timeout(self.timeout): 

1794 await self._condition.wait_for(self.can_get_connection) 

1795 # Track connection count before to detect if a new connection is created 

1796 connections_before = len(self._available_connections) + len( 

1797 self._in_use_connections 

1798 ) 

1799 start_time_created = time.monotonic() 

1800 connection = super().get_available_connection() 

1801 connections_after = len(self._available_connections) + len( 

1802 self._in_use_connections 

1803 ) 

1804 is_created = connections_after > connections_before 

1805 except asyncio.TimeoutError as err: 

1806 raise ConnectionError("No connection available.") from err 

1807 

1808 # We now perform the connection check outside of the lock. 

1809 try: 

1810 await self.ensure_connection(connection) 

1811 

1812 if is_created: 

1813 await record_connection_create_time( 

1814 connection_pool=self, 

1815 duration_seconds=time.monotonic() - start_time_created, 

1816 ) 

1817 

1818 await record_connection_wait_time( 

1819 pool_name=get_pool_name(self), 

1820 duration_seconds=time.monotonic() - start_time_acquired, 

1821 ) 

1822 

1823 return connection 

1824 except BaseException: 

1825 await self.release(connection) 

1826 raise 

1827 

1828 async def release(self, connection: AbstractConnection): 

1829 """Releases the connection back to the pool.""" 

1830 async with self._condition: 

1831 await super().release(connection) 

1832 self._condition.notify()