Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/connection.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

662 statements  

1import asyncio 

2import copy 

3import enum 

4import inspect 

5import socket 

6import sys 

7import warnings 

8import weakref 

9from abc import abstractmethod 

10from itertools import chain 

11from types import MappingProxyType 

12from typing import ( 

13 Any, 

14 Callable, 

15 Iterable, 

16 List, 

17 Mapping, 

18 Optional, 

19 Protocol, 

20 Set, 

21 Tuple, 

22 Type, 

23 TypedDict, 

24 TypeVar, 

25 Union, 

26) 

27from urllib.parse import ParseResult, parse_qs, unquote, urlparse 

28 

29from ..utils import SSL_AVAILABLE 

30 

31if SSL_AVAILABLE: 

32 import ssl 

33 from ssl import SSLContext, TLSVersion 

34else: 

35 ssl = None 

36 TLSVersion = None 

37 SSLContext = None 

38 

39from ..auth.token import TokenInterface 

40from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher 

41from ..utils import deprecated_args, format_error_message 

42 

43# the functionality is available in 3.11.x but has a major issue before 

44# 3.11.3. See https://github.com/redis/redis-py/issues/2633 

45if sys.version_info >= (3, 11, 3): 

46 from asyncio import timeout as async_timeout 

47else: 

48 from async_timeout import timeout as async_timeout 

49 

50from redis.asyncio.retry import Retry 

51from redis.backoff import NoBackoff 

52from redis.connection import DEFAULT_RESP_VERSION 

53from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider 

54from redis.exceptions import ( 

55 AuthenticationError, 

56 AuthenticationWrongNumberOfArgsError, 

57 ConnectionError, 

58 DataError, 

59 RedisError, 

60 ResponseError, 

61 TimeoutError, 

62) 

63from redis.typing import EncodableT 

64from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes 

65 

66from .._parsers import ( 

67 BaseParser, 

68 Encoder, 

69 _AsyncHiredisParser, 

70 _AsyncRESP2Parser, 

71 _AsyncRESP3Parser, 

72) 

73 

74SYM_STAR = b"*" 

75SYM_DOLLAR = b"$" 

76SYM_CRLF = b"\r\n" 

77SYM_LF = b"\n" 

78SYM_EMPTY = b"" 

79 

80 

81class _Sentinel(enum.Enum): 

82 sentinel = object() 

83 

84 

85SENTINEL = _Sentinel.sentinel 

86 

87 

88DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] 

89if HIREDIS_AVAILABLE: 

90 DefaultParser = _AsyncHiredisParser 

91else: 

92 DefaultParser = _AsyncRESP2Parser 

93 

94 

95class ConnectCallbackProtocol(Protocol): 

96 def __call__(self, connection: "AbstractConnection"): ... 

97 

98 

99class AsyncConnectCallbackProtocol(Protocol): 

100 async def __call__(self, connection: "AbstractConnection"): ... 

101 

102 

103ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] 

104 

105 

106class AbstractConnection: 

107 """Manages communication to and from a Redis server""" 

108 

109 __slots__ = ( 

110 "db", 

111 "username", 

112 "client_name", 

113 "lib_name", 

114 "lib_version", 

115 "credential_provider", 

116 "password", 

117 "socket_timeout", 

118 "socket_connect_timeout", 

119 "redis_connect_func", 

120 "retry_on_timeout", 

121 "retry_on_error", 

122 "health_check_interval", 

123 "next_health_check", 

124 "last_active_at", 

125 "encoder", 

126 "ssl_context", 

127 "protocol", 

128 "_reader", 

129 "_writer", 

130 "_parser", 

131 "_connect_callbacks", 

132 "_buffer_cutoff", 

133 "_lock", 

134 "_socket_read_size", 

135 "__dict__", 

136 ) 

137 

138 def __init__( 

139 self, 

140 *, 

141 db: Union[str, int] = 0, 

142 password: Optional[str] = None, 

143 socket_timeout: Optional[float] = None, 

144 socket_connect_timeout: Optional[float] = None, 

145 retry_on_timeout: bool = False, 

146 retry_on_error: Union[list, _Sentinel] = SENTINEL, 

147 encoding: str = "utf-8", 

148 encoding_errors: str = "strict", 

149 decode_responses: bool = False, 

150 parser_class: Type[BaseParser] = DefaultParser, 

151 socket_read_size: int = 65536, 

152 health_check_interval: float = 0, 

153 client_name: Optional[str] = None, 

154 lib_name: Optional[str] = "redis-py", 

155 lib_version: Optional[str] = get_lib_version(), 

156 username: Optional[str] = None, 

157 retry: Optional[Retry] = None, 

158 redis_connect_func: Optional[ConnectCallbackT] = None, 

159 encoder_class: Type[Encoder] = Encoder, 

160 credential_provider: Optional[CredentialProvider] = None, 

161 protocol: Optional[int] = 2, 

162 event_dispatcher: Optional[EventDispatcher] = None, 

163 ): 

164 if (username or password) and credential_provider is not None: 

165 raise DataError( 

166 "'username' and 'password' cannot be passed along with 'credential_" 

167 "provider'. Please provide only one of the following arguments: \n" 

168 "1. 'password' and (optional) 'username'\n" 

169 "2. 'credential_provider'" 

170 ) 

171 if event_dispatcher is None: 

172 self._event_dispatcher = EventDispatcher() 

173 else: 

174 self._event_dispatcher = event_dispatcher 

175 self.db = db 

176 self.client_name = client_name 

177 self.lib_name = lib_name 

178 self.lib_version = lib_version 

179 self.credential_provider = credential_provider 

180 self.password = password 

181 self.username = username 

182 self.socket_timeout = socket_timeout 

183 if socket_connect_timeout is None: 

184 socket_connect_timeout = socket_timeout 

185 self.socket_connect_timeout = socket_connect_timeout 

186 self.retry_on_timeout = retry_on_timeout 

187 if retry_on_error is SENTINEL: 

188 retry_on_error = [] 

189 if retry_on_timeout: 

190 retry_on_error.append(TimeoutError) 

191 retry_on_error.append(socket.timeout) 

192 retry_on_error.append(asyncio.TimeoutError) 

193 self.retry_on_error = retry_on_error 

194 if retry or retry_on_error: 

195 if not retry: 

196 self.retry = Retry(NoBackoff(), 1) 

197 else: 

198 # deep-copy the Retry object as it is mutable 

199 self.retry = copy.deepcopy(retry) 

200 # Update the retry's supported errors with the specified errors 

201 self.retry.update_supported_errors(retry_on_error) 

202 else: 

203 self.retry = Retry(NoBackoff(), 0) 

204 self.health_check_interval = health_check_interval 

205 self.next_health_check: float = -1 

206 self.encoder = encoder_class(encoding, encoding_errors, decode_responses) 

207 self.redis_connect_func = redis_connect_func 

208 self._reader: Optional[asyncio.StreamReader] = None 

209 self._writer: Optional[asyncio.StreamWriter] = None 

210 self._socket_read_size = socket_read_size 

211 self.set_parser(parser_class) 

212 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] 

213 self._buffer_cutoff = 6000 

214 self._re_auth_token: Optional[TokenInterface] = None 

215 

216 try: 

217 p = int(protocol) 

218 except TypeError: 

219 p = DEFAULT_RESP_VERSION 

220 except ValueError: 

221 raise ConnectionError("protocol must be an integer") 

222 finally: 

223 if p < 2 or p > 3: 

224 raise ConnectionError("protocol must be either 2 or 3") 

225 self.protocol = protocol 

226 

227 def __del__(self, _warnings: Any = warnings): 

228 # For some reason, the individual streams don't get properly garbage 

229 # collected and therefore produce no resource warnings. We add one 

230 # here, in the same style as those from the stdlib. 

231 if getattr(self, "_writer", None): 

232 _warnings.warn( 

233 f"unclosed Connection {self!r}", ResourceWarning, source=self 

234 ) 

235 

236 try: 

237 asyncio.get_running_loop() 

238 self._close() 

239 except RuntimeError: 

240 # No actions been taken if pool already closed. 

241 pass 

242 

243 def _close(self): 

244 """ 

245 Internal method to silently close the connection without waiting 

246 """ 

247 if self._writer: 

248 self._writer.close() 

249 self._writer = self._reader = None 

250 

251 def __repr__(self): 

252 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) 

253 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

254 

255 @abstractmethod 

256 def repr_pieces(self): 

257 pass 

258 

259 @property 

260 def is_connected(self): 

261 return self._reader is not None and self._writer is not None 

262 

263 def register_connect_callback(self, callback): 

264 """ 

265 Register a callback to be called when the connection is established either 

266 initially or reconnected. This allows listeners to issue commands that 

267 are ephemeral to the connection, for example pub/sub subscription or 

268 key tracking. The callback must be a _method_ and will be kept as 

269 a weak reference. 

270 """ 

271 wm = weakref.WeakMethod(callback) 

272 if wm not in self._connect_callbacks: 

273 self._connect_callbacks.append(wm) 

274 

275 def deregister_connect_callback(self, callback): 

276 """ 

277 De-register a previously registered callback. It will no-longer receive 

278 notifications on connection events. Calling this is not required when the 

279 listener goes away, since the callbacks are kept as weak methods. 

280 """ 

281 try: 

282 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

283 except ValueError: 

284 pass 

285 

286 def set_parser(self, parser_class: Type[BaseParser]) -> None: 

287 """ 

288 Creates a new instance of parser_class with socket size: 

289 _socket_read_size and assigns it to the parser for the connection 

290 :param parser_class: The required parser class 

291 """ 

292 self._parser = parser_class(socket_read_size=self._socket_read_size) 

293 

294 async def connect(self): 

295 """Connects to the Redis server if not already connected""" 

296 await self.connect_check_health(check_health=True) 

297 

298 async def connect_check_health( 

299 self, check_health: bool = True, retry_socket_connect: bool = True 

300 ): 

301 if self.is_connected: 

302 return 

303 try: 

304 if retry_socket_connect: 

305 await self.retry.call_with_retry( 

306 lambda: self._connect(), lambda error: self.disconnect() 

307 ) 

308 else: 

309 await self._connect() 

310 except asyncio.CancelledError: 

311 raise # in 3.7 and earlier, this is an Exception, not BaseException 

312 except (socket.timeout, asyncio.TimeoutError): 

313 raise TimeoutError("Timeout connecting to server") 

314 except OSError as e: 

315 raise ConnectionError(self._error_message(e)) 

316 except Exception as exc: 

317 raise ConnectionError(exc) from exc 

318 

319 try: 

320 if not self.redis_connect_func: 

321 # Use the default on_connect function 

322 await self.on_connect_check_health(check_health=check_health) 

323 else: 

324 # Use the passed function redis_connect_func 

325 ( 

326 await self.redis_connect_func(self) 

327 if asyncio.iscoroutinefunction(self.redis_connect_func) 

328 else self.redis_connect_func(self) 

329 ) 

330 except RedisError: 

331 # clean up after any error in on_connect 

332 await self.disconnect() 

333 raise 

334 

335 # run any user callbacks. right now the only internal callback 

336 # is for pubsub channel/pattern resubscription 

337 # first, remove any dead weakrefs 

338 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

339 for ref in self._connect_callbacks: 

340 callback = ref() 

341 task = callback(self) 

342 if task and inspect.isawaitable(task): 

343 await task 

344 

345 @abstractmethod 

346 async def _connect(self): 

347 pass 

348 

349 @abstractmethod 

350 def _host_error(self) -> str: 

351 pass 

352 

353 def _error_message(self, exception: BaseException) -> str: 

354 return format_error_message(self._host_error(), exception) 

355 

356 def get_protocol(self): 

357 return self.protocol 

358 

359 async def on_connect(self) -> None: 

360 """Initialize the connection, authenticate and select a database""" 

361 await self.on_connect_check_health(check_health=True) 

362 

363 async def on_connect_check_health(self, check_health: bool = True) -> None: 

364 self._parser.on_connect(self) 

365 parser = self._parser 

366 

367 auth_args = None 

368 # if credential provider or username and/or password are set, authenticate 

369 if self.credential_provider or (self.username or self.password): 

370 cred_provider = ( 

371 self.credential_provider 

372 or UsernamePasswordCredentialProvider(self.username, self.password) 

373 ) 

374 auth_args = await cred_provider.get_credentials_async() 

375 

376 # if resp version is specified and we have auth args, 

377 # we need to send them via HELLO 

378 if auth_args and self.protocol not in [2, "2"]: 

379 if isinstance(self._parser, _AsyncRESP2Parser): 

380 self.set_parser(_AsyncRESP3Parser) 

381 # update cluster exception classes 

382 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

383 self._parser.on_connect(self) 

384 if len(auth_args) == 1: 

385 auth_args = ["default", auth_args[0]] 

386 # avoid checking health here -- PING will fail if we try 

387 # to check the health prior to the AUTH 

388 await self.send_command( 

389 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

390 ) 

391 response = await self.read_response() 

392 if response.get(b"proto") != int(self.protocol) and response.get( 

393 "proto" 

394 ) != int(self.protocol): 

395 raise ConnectionError("Invalid RESP version") 

396 # avoid checking health here -- PING will fail if we try 

397 # to check the health prior to the AUTH 

398 elif auth_args: 

399 await self.send_command("AUTH", *auth_args, check_health=False) 

400 

401 try: 

402 auth_response = await self.read_response() 

403 except AuthenticationWrongNumberOfArgsError: 

404 # a username and password were specified but the Redis 

405 # server seems to be < 6.0.0 which expects a single password 

406 # arg. retry auth with just the password. 

407 # https://github.com/andymccurdy/redis-py/issues/1274 

408 await self.send_command("AUTH", auth_args[-1], check_health=False) 

409 auth_response = await self.read_response() 

410 

411 if str_if_bytes(auth_response) != "OK": 

412 raise AuthenticationError("Invalid Username or Password") 

413 

414 # if resp version is specified, switch to it 

415 elif self.protocol not in [2, "2"]: 

416 if isinstance(self._parser, _AsyncRESP2Parser): 

417 self.set_parser(_AsyncRESP3Parser) 

418 # update cluster exception classes 

419 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

420 self._parser.on_connect(self) 

421 await self.send_command("HELLO", self.protocol, check_health=check_health) 

422 response = await self.read_response() 

423 # if response.get(b"proto") != self.protocol and response.get( 

424 # "proto" 

425 # ) != self.protocol: 

426 # raise ConnectionError("Invalid RESP version") 

427 

428 # if a client_name is given, set it 

429 if self.client_name: 

430 await self.send_command( 

431 "CLIENT", 

432 "SETNAME", 

433 self.client_name, 

434 check_health=check_health, 

435 ) 

436 if str_if_bytes(await self.read_response()) != "OK": 

437 raise ConnectionError("Error setting client name") 

438 

439 # set the library name and version, pipeline for lower startup latency 

440 if self.lib_name: 

441 await self.send_command( 

442 "CLIENT", 

443 "SETINFO", 

444 "LIB-NAME", 

445 self.lib_name, 

446 check_health=check_health, 

447 ) 

448 if self.lib_version: 

449 await self.send_command( 

450 "CLIENT", 

451 "SETINFO", 

452 "LIB-VER", 

453 self.lib_version, 

454 check_health=check_health, 

455 ) 

456 # if a database is specified, switch to it. Also pipeline this 

457 if self.db: 

458 await self.send_command("SELECT", self.db, check_health=check_health) 

459 

460 # read responses from pipeline 

461 for _ in (sent for sent in (self.lib_name, self.lib_version) if sent): 

462 try: 

463 await self.read_response() 

464 except ResponseError: 

465 pass 

466 

467 if self.db: 

468 if str_if_bytes(await self.read_response()) != "OK": 

469 raise ConnectionError("Invalid Database") 

470 

471 async def disconnect(self, nowait: bool = False) -> None: 

472 """Disconnects from the Redis server""" 

473 try: 

474 async with async_timeout(self.socket_connect_timeout): 

475 self._parser.on_disconnect() 

476 if not self.is_connected: 

477 return 

478 try: 

479 self._writer.close() # type: ignore[union-attr] 

480 # wait for close to finish, except when handling errors and 

481 # forcefully disconnecting. 

482 if not nowait: 

483 await self._writer.wait_closed() # type: ignore[union-attr] 

484 except OSError: 

485 pass 

486 finally: 

487 self._reader = None 

488 self._writer = None 

489 except asyncio.TimeoutError: 

490 raise TimeoutError( 

491 f"Timed out closing connection after {self.socket_connect_timeout}" 

492 ) from None 

493 

494 async def _send_ping(self): 

495 """Send PING, expect PONG in return""" 

496 await self.send_command("PING", check_health=False) 

497 if str_if_bytes(await self.read_response()) != "PONG": 

498 raise ConnectionError("Bad response from PING health check") 

499 

500 async def _ping_failed(self, error): 

501 """Function to call when PING fails""" 

502 await self.disconnect() 

503 

504 async def check_health(self): 

505 """Check the health of the connection with a PING/PONG""" 

506 if ( 

507 self.health_check_interval 

508 and asyncio.get_running_loop().time() > self.next_health_check 

509 ): 

510 await self.retry.call_with_retry(self._send_ping, self._ping_failed) 

511 

512 async def _send_packed_command(self, command: Iterable[bytes]) -> None: 

513 self._writer.writelines(command) 

514 await self._writer.drain() 

515 

516 async def send_packed_command( 

517 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True 

518 ) -> None: 

519 if not self.is_connected: 

520 await self.connect_check_health(check_health=False) 

521 if check_health: 

522 await self.check_health() 

523 

524 try: 

525 if isinstance(command, str): 

526 command = command.encode() 

527 if isinstance(command, bytes): 

528 command = [command] 

529 if self.socket_timeout: 

530 await asyncio.wait_for( 

531 self._send_packed_command(command), self.socket_timeout 

532 ) 

533 else: 

534 self._writer.writelines(command) 

535 await self._writer.drain() 

536 except asyncio.TimeoutError: 

537 await self.disconnect(nowait=True) 

538 raise TimeoutError("Timeout writing to socket") from None 

539 except OSError as e: 

540 await self.disconnect(nowait=True) 

541 if len(e.args) == 1: 

542 err_no, errmsg = "UNKNOWN", e.args[0] 

543 else: 

544 err_no = e.args[0] 

545 errmsg = e.args[1] 

546 raise ConnectionError( 

547 f"Error {err_no} while writing to socket. {errmsg}." 

548 ) from e 

549 except BaseException: 

550 # BaseExceptions can be raised when a socket send operation is not 

551 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

552 # to send un-sent data. However, the send_packed_command() API 

553 # does not support it so there is no point in keeping the connection open. 

554 await self.disconnect(nowait=True) 

555 raise 

556 

557 async def send_command(self, *args: Any, **kwargs: Any) -> None: 

558 """Pack and send a command to the Redis server""" 

559 await self.send_packed_command( 

560 self.pack_command(*args), check_health=kwargs.get("check_health", True) 

561 ) 

562 

563 async def can_read_destructive(self): 

564 """Poll the socket to see if there's data that can be read.""" 

565 try: 

566 return await self._parser.can_read_destructive() 

567 except OSError as e: 

568 await self.disconnect(nowait=True) 

569 host_error = self._host_error() 

570 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

571 

572 async def read_response( 

573 self, 

574 disable_decoding: bool = False, 

575 timeout: Optional[float] = None, 

576 *, 

577 disconnect_on_error: bool = True, 

578 push_request: Optional[bool] = False, 

579 ): 

580 """Read the response from a previously sent command""" 

581 read_timeout = timeout if timeout is not None else self.socket_timeout 

582 host_error = self._host_error() 

583 try: 

584 if read_timeout is not None and self.protocol in ["3", 3]: 

585 async with async_timeout(read_timeout): 

586 response = await self._parser.read_response( 

587 disable_decoding=disable_decoding, push_request=push_request 

588 ) 

589 elif read_timeout is not None: 

590 async with async_timeout(read_timeout): 

591 response = await self._parser.read_response( 

592 disable_decoding=disable_decoding 

593 ) 

594 elif self.protocol in ["3", 3]: 

595 response = await self._parser.read_response( 

596 disable_decoding=disable_decoding, push_request=push_request 

597 ) 

598 else: 

599 response = await self._parser.read_response( 

600 disable_decoding=disable_decoding 

601 ) 

602 except asyncio.TimeoutError: 

603 if timeout is not None: 

604 # user requested timeout, return None. Operation can be retried 

605 return None 

606 # it was a self.socket_timeout error. 

607 if disconnect_on_error: 

608 await self.disconnect(nowait=True) 

609 raise TimeoutError(f"Timeout reading from {host_error}") 

610 except OSError as e: 

611 if disconnect_on_error: 

612 await self.disconnect(nowait=True) 

613 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

614 except BaseException: 

615 # Also by default close in case of BaseException. A lot of code 

616 # relies on this behaviour when doing Command/Response pairs. 

617 # See #1128. 

618 if disconnect_on_error: 

619 await self.disconnect(nowait=True) 

620 raise 

621 

622 if self.health_check_interval: 

623 next_time = asyncio.get_running_loop().time() + self.health_check_interval 

624 self.next_health_check = next_time 

625 

626 if isinstance(response, ResponseError): 

627 raise response from None 

628 return response 

629 

630 def pack_command(self, *args: EncodableT) -> List[bytes]: 

631 """Pack a series of arguments into the Redis protocol""" 

632 output = [] 

633 # the client might have included 1 or more literal arguments in 

634 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

635 # arguments to be sent separately, so split the first argument 

636 # manually. These arguments should be bytestrings so that they are 

637 # not encoded. 

638 assert not isinstance(args[0], float) 

639 if isinstance(args[0], str): 

640 args = tuple(args[0].encode().split()) + args[1:] 

641 elif b" " in args[0]: 

642 args = tuple(args[0].split()) + args[1:] 

643 

644 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

645 

646 buffer_cutoff = self._buffer_cutoff 

647 for arg in map(self.encoder.encode, args): 

648 # to avoid large string mallocs, chunk the command into the 

649 # output list if we're sending large values or memoryviews 

650 arg_length = len(arg) 

651 if ( 

652 len(buff) > buffer_cutoff 

653 or arg_length > buffer_cutoff 

654 or isinstance(arg, memoryview) 

655 ): 

656 buff = SYM_EMPTY.join( 

657 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

658 ) 

659 output.append(buff) 

660 output.append(arg) 

661 buff = SYM_CRLF 

662 else: 

663 buff = SYM_EMPTY.join( 

664 ( 

665 buff, 

666 SYM_DOLLAR, 

667 str(arg_length).encode(), 

668 SYM_CRLF, 

669 arg, 

670 SYM_CRLF, 

671 ) 

672 ) 

673 output.append(buff) 

674 return output 

675 

676 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: 

677 """Pack multiple commands into the Redis protocol""" 

678 output: List[bytes] = [] 

679 pieces: List[bytes] = [] 

680 buffer_length = 0 

681 buffer_cutoff = self._buffer_cutoff 

682 

683 for cmd in commands: 

684 for chunk in self.pack_command(*cmd): 

685 chunklen = len(chunk) 

686 if ( 

687 buffer_length > buffer_cutoff 

688 or chunklen > buffer_cutoff 

689 or isinstance(chunk, memoryview) 

690 ): 

691 if pieces: 

692 output.append(SYM_EMPTY.join(pieces)) 

693 buffer_length = 0 

694 pieces = [] 

695 

696 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

697 output.append(chunk) 

698 else: 

699 pieces.append(chunk) 

700 buffer_length += chunklen 

701 

702 if pieces: 

703 output.append(SYM_EMPTY.join(pieces)) 

704 return output 

705 

706 def _socket_is_empty(self): 

707 """Check if the socket is empty""" 

708 return len(self._reader._buffer) == 0 

709 

710 async def process_invalidation_messages(self): 

711 while not self._socket_is_empty(): 

712 await self.read_response(push_request=True) 

713 

714 def set_re_auth_token(self, token: TokenInterface): 

715 self._re_auth_token = token 

716 

717 async def re_auth(self): 

718 if self._re_auth_token is not None: 

719 await self.send_command( 

720 "AUTH", 

721 self._re_auth_token.try_get("oid"), 

722 self._re_auth_token.get_value(), 

723 ) 

724 await self.read_response() 

725 self._re_auth_token = None 

726 

727 

728class Connection(AbstractConnection): 

729 "Manages TCP communication to and from a Redis server" 

730 

731 def __init__( 

732 self, 

733 *, 

734 host: str = "localhost", 

735 port: Union[str, int] = 6379, 

736 socket_keepalive: bool = False, 

737 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

738 socket_type: int = 0, 

739 **kwargs, 

740 ): 

741 self.host = host 

742 self.port = int(port) 

743 self.socket_keepalive = socket_keepalive 

744 self.socket_keepalive_options = socket_keepalive_options or {} 

745 self.socket_type = socket_type 

746 super().__init__(**kwargs) 

747 

748 def repr_pieces(self): 

749 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

750 if self.client_name: 

751 pieces.append(("client_name", self.client_name)) 

752 return pieces 

753 

754 def _connection_arguments(self) -> Mapping: 

755 return {"host": self.host, "port": self.port} 

756 

757 async def _connect(self): 

758 """Create a TCP socket connection""" 

759 async with async_timeout(self.socket_connect_timeout): 

760 reader, writer = await asyncio.open_connection( 

761 **self._connection_arguments() 

762 ) 

763 self._reader = reader 

764 self._writer = writer 

765 sock = writer.transport.get_extra_info("socket") 

766 if sock: 

767 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

768 try: 

769 # TCP_KEEPALIVE 

770 if self.socket_keepalive: 

771 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

772 for k, v in self.socket_keepalive_options.items(): 

773 sock.setsockopt(socket.SOL_TCP, k, v) 

774 

775 except (OSError, TypeError): 

776 # `socket_keepalive_options` might contain invalid options 

777 # causing an error. Do not leave the connection open. 

778 writer.close() 

779 raise 

780 

781 def _host_error(self) -> str: 

782 return f"{self.host}:{self.port}" 

783 

784 

785class SSLConnection(Connection): 

786 """Manages SSL connections to and from the Redis server(s). 

787 This class extends the Connection class, adding SSL functionality, and making 

788 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

789 """ 

790 

791 def __init__( 

792 self, 

793 ssl_keyfile: Optional[str] = None, 

794 ssl_certfile: Optional[str] = None, 

795 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", 

796 ssl_ca_certs: Optional[str] = None, 

797 ssl_ca_data: Optional[str] = None, 

798 ssl_check_hostname: bool = True, 

799 ssl_min_version: Optional[TLSVersion] = None, 

800 ssl_ciphers: Optional[str] = None, 

801 **kwargs, 

802 ): 

803 if not SSL_AVAILABLE: 

804 raise RedisError("Python wasn't built with SSL support") 

805 

806 self.ssl_context: RedisSSLContext = RedisSSLContext( 

807 keyfile=ssl_keyfile, 

808 certfile=ssl_certfile, 

809 cert_reqs=ssl_cert_reqs, 

810 ca_certs=ssl_ca_certs, 

811 ca_data=ssl_ca_data, 

812 check_hostname=ssl_check_hostname, 

813 min_version=ssl_min_version, 

814 ciphers=ssl_ciphers, 

815 ) 

816 super().__init__(**kwargs) 

817 

818 def _connection_arguments(self) -> Mapping: 

819 kwargs = super()._connection_arguments() 

820 kwargs["ssl"] = self.ssl_context.get() 

821 return kwargs 

822 

823 @property 

824 def keyfile(self): 

825 return self.ssl_context.keyfile 

826 

827 @property 

828 def certfile(self): 

829 return self.ssl_context.certfile 

830 

831 @property 

832 def cert_reqs(self): 

833 return self.ssl_context.cert_reqs 

834 

835 @property 

836 def ca_certs(self): 

837 return self.ssl_context.ca_certs 

838 

839 @property 

840 def ca_data(self): 

841 return self.ssl_context.ca_data 

842 

843 @property 

844 def check_hostname(self): 

845 return self.ssl_context.check_hostname 

846 

847 @property 

848 def min_version(self): 

849 return self.ssl_context.min_version 

850 

851 

852class RedisSSLContext: 

853 __slots__ = ( 

854 "keyfile", 

855 "certfile", 

856 "cert_reqs", 

857 "ca_certs", 

858 "ca_data", 

859 "context", 

860 "check_hostname", 

861 "min_version", 

862 "ciphers", 

863 ) 

864 

865 def __init__( 

866 self, 

867 keyfile: Optional[str] = None, 

868 certfile: Optional[str] = None, 

869 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, 

870 ca_certs: Optional[str] = None, 

871 ca_data: Optional[str] = None, 

872 check_hostname: bool = False, 

873 min_version: Optional[TLSVersion] = None, 

874 ciphers: Optional[str] = None, 

875 ): 

876 if not SSL_AVAILABLE: 

877 raise RedisError("Python wasn't built with SSL support") 

878 

879 self.keyfile = keyfile 

880 self.certfile = certfile 

881 if cert_reqs is None: 

882 cert_reqs = ssl.CERT_NONE 

883 elif isinstance(cert_reqs, str): 

884 CERT_REQS = { # noqa: N806 

885 "none": ssl.CERT_NONE, 

886 "optional": ssl.CERT_OPTIONAL, 

887 "required": ssl.CERT_REQUIRED, 

888 } 

889 if cert_reqs not in CERT_REQS: 

890 raise RedisError( 

891 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" 

892 ) 

893 cert_reqs = CERT_REQS[cert_reqs] 

894 self.cert_reqs = cert_reqs 

895 self.ca_certs = ca_certs 

896 self.ca_data = ca_data 

897 self.check_hostname = ( 

898 check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

899 ) 

900 self.min_version = min_version 

901 self.ciphers = ciphers 

902 self.context: Optional[SSLContext] = None 

903 

904 def get(self) -> SSLContext: 

905 if not self.context: 

906 context = ssl.create_default_context() 

907 context.check_hostname = self.check_hostname 

908 context.verify_mode = self.cert_reqs 

909 if self.certfile and self.keyfile: 

910 context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) 

911 if self.ca_certs or self.ca_data: 

912 context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data) 

913 if self.min_version is not None: 

914 context.minimum_version = self.min_version 

915 if self.ciphers is not None: 

916 context.set_ciphers(self.ciphers) 

917 self.context = context 

918 return self.context 

919 

920 

921class UnixDomainSocketConnection(AbstractConnection): 

922 "Manages UDS communication to and from a Redis server" 

923 

924 def __init__(self, *, path: str = "", **kwargs): 

925 self.path = path 

926 super().__init__(**kwargs) 

927 

928 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: 

929 pieces = [("path", self.path), ("db", self.db)] 

930 if self.client_name: 

931 pieces.append(("client_name", self.client_name)) 

932 return pieces 

933 

934 async def _connect(self): 

935 async with async_timeout(self.socket_connect_timeout): 

936 reader, writer = await asyncio.open_unix_connection(path=self.path) 

937 self._reader = reader 

938 self._writer = writer 

939 await self.on_connect() 

940 

941 def _host_error(self) -> str: 

942 return self.path 

943 

944 

945FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

946 

947 

948def to_bool(value) -> Optional[bool]: 

949 if value is None or value == "": 

950 return None 

951 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

952 return False 

953 return bool(value) 

954 

955 

956URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( 

957 { 

958 "db": int, 

959 "socket_timeout": float, 

960 "socket_connect_timeout": float, 

961 "socket_keepalive": to_bool, 

962 "retry_on_timeout": to_bool, 

963 "max_connections": int, 

964 "health_check_interval": int, 

965 "ssl_check_hostname": to_bool, 

966 "timeout": float, 

967 } 

968) 

969 

970 

971class ConnectKwargs(TypedDict, total=False): 

972 username: str 

973 password: str 

974 connection_class: Type[AbstractConnection] 

975 host: str 

976 port: int 

977 db: int 

978 path: str 

979 

980 

981def parse_url(url: str) -> ConnectKwargs: 

982 parsed: ParseResult = urlparse(url) 

983 kwargs: ConnectKwargs = {} 

984 

985 for name, value_list in parse_qs(parsed.query).items(): 

986 if value_list and len(value_list) > 0: 

987 value = unquote(value_list[0]) 

988 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

989 if parser: 

990 try: 

991 kwargs[name] = parser(value) 

992 except (TypeError, ValueError): 

993 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

994 else: 

995 kwargs[name] = value 

996 

997 if parsed.username: 

998 kwargs["username"] = unquote(parsed.username) 

999 if parsed.password: 

1000 kwargs["password"] = unquote(parsed.password) 

1001 

1002 # We only support redis://, rediss:// and unix:// schemes. 

1003 if parsed.scheme == "unix": 

1004 if parsed.path: 

1005 kwargs["path"] = unquote(parsed.path) 

1006 kwargs["connection_class"] = UnixDomainSocketConnection 

1007 

1008 elif parsed.scheme in ("redis", "rediss"): 

1009 if parsed.hostname: 

1010 kwargs["host"] = unquote(parsed.hostname) 

1011 if parsed.port: 

1012 kwargs["port"] = int(parsed.port) 

1013 

1014 # If there's a path argument, use it as the db argument if a 

1015 # querystring value wasn't specified 

1016 if parsed.path and "db" not in kwargs: 

1017 try: 

1018 kwargs["db"] = int(unquote(parsed.path).replace("/", "")) 

1019 except (AttributeError, ValueError): 

1020 pass 

1021 

1022 if parsed.scheme == "rediss": 

1023 kwargs["connection_class"] = SSLConnection 

1024 else: 

1025 valid_schemes = "redis://, rediss://, unix://" 

1026 raise ValueError( 

1027 f"Redis URL must specify one of the following schemes ({valid_schemes})" 

1028 ) 

1029 

1030 return kwargs 

1031 

1032 

1033_CP = TypeVar("_CP", bound="ConnectionPool") 

1034 

1035 

1036class ConnectionPool: 

1037 """ 

1038 Create a connection pool. ``If max_connections`` is set, then this 

1039 object raises :py:class:`~redis.ConnectionError` when the pool's 

1040 limit is reached. 

1041 

1042 By default, TCP connections are created unless ``connection_class`` 

1043 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for 

1044 unix sockets. 

1045 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

1046 

1047 Any additional keyword arguments are passed to the constructor of 

1048 ``connection_class``. 

1049 """ 

1050 

1051 @classmethod 

1052 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

1053 """ 

1054 Return a connection pool configured from the given URL. 

1055 

1056 For example:: 

1057 

1058 redis://[[username]:[password]]@localhost:6379/0 

1059 rediss://[[username]:[password]]@localhost:6379/0 

1060 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1061 

1062 Three URL schemes are supported: 

1063 

1064 - `redis://` creates a TCP socket connection. See more at: 

1065 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1066 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1067 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1068 - ``unix://``: creates a Unix Domain Socket connection. 

1069 

1070 The username, password, hostname, path and all querystring values 

1071 are passed through urllib.parse.unquote in order to replace any 

1072 percent-encoded values with their corresponding characters. 

1073 

1074 There are several ways to specify a database number. The first value 

1075 found will be used: 

1076 

1077 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1078 

1079 2. If using the redis:// or rediss:// schemes, the path argument 

1080 of the url, e.g. redis://localhost/0 

1081 

1082 3. A ``db`` keyword argument to this function. 

1083 

1084 If none of these options are specified, the default db=0 is used. 

1085 

1086 All querystring options are cast to their appropriate Python types. 

1087 Boolean arguments can be specified with string values "True"/"False" 

1088 or "Yes"/"No". Values that cannot be properly cast cause a 

1089 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1090 and keyword arguments are passed to the ``ConnectionPool``'s 

1091 class initializer. In the case of conflicting arguments, querystring 

1092 arguments always win. 

1093 """ 

1094 url_options = parse_url(url) 

1095 kwargs.update(url_options) 

1096 return cls(**kwargs) 

1097 

1098 def __init__( 

1099 self, 

1100 connection_class: Type[AbstractConnection] = Connection, 

1101 max_connections: Optional[int] = None, 

1102 **connection_kwargs, 

1103 ): 

1104 max_connections = max_connections or 2**31 

1105 if not isinstance(max_connections, int) or max_connections < 0: 

1106 raise ValueError('"max_connections" must be a positive integer') 

1107 

1108 self.connection_class = connection_class 

1109 self.connection_kwargs = connection_kwargs 

1110 self.max_connections = max_connections 

1111 

1112 self._available_connections: List[AbstractConnection] = [] 

1113 self._in_use_connections: Set[AbstractConnection] = set() 

1114 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) 

1115 self._lock = asyncio.Lock() 

1116 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1117 if self._event_dispatcher is None: 

1118 self._event_dispatcher = EventDispatcher() 

1119 

1120 def __repr__(self): 

1121 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

1122 return ( 

1123 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

1124 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

1125 f"({conn_kwargs})>)>" 

1126 ) 

1127 

1128 def reset(self): 

1129 self._available_connections = [] 

1130 self._in_use_connections = weakref.WeakSet() 

1131 

1132 def can_get_connection(self) -> bool: 

1133 """Return True if a connection can be retrieved from the pool.""" 

1134 return ( 

1135 self._available_connections 

1136 or len(self._in_use_connections) < self.max_connections 

1137 ) 

1138 

1139 @deprecated_args( 

1140 args_to_warn=["*"], 

1141 reason="Use get_connection() without args instead", 

1142 version="5.3.0", 

1143 ) 

1144 async def get_connection(self, command_name=None, *keys, **options): 

1145 async with self._lock: 

1146 """Get a connected connection from the pool""" 

1147 connection = self.get_available_connection() 

1148 try: 

1149 await self.ensure_connection(connection) 

1150 except BaseException: 

1151 await self.release(connection) 

1152 raise 

1153 

1154 return connection 

1155 

1156 def get_available_connection(self): 

1157 """Get a connection from the pool, without making sure it is connected""" 

1158 try: 

1159 connection = self._available_connections.pop() 

1160 except IndexError: 

1161 if len(self._in_use_connections) >= self.max_connections: 

1162 raise ConnectionError("Too many connections") from None 

1163 connection = self.make_connection() 

1164 self._in_use_connections.add(connection) 

1165 return connection 

1166 

1167 def get_encoder(self): 

1168 """Return an encoder based on encoding settings""" 

1169 kwargs = self.connection_kwargs 

1170 return self.encoder_class( 

1171 encoding=kwargs.get("encoding", "utf-8"), 

1172 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1173 decode_responses=kwargs.get("decode_responses", False), 

1174 ) 

1175 

1176 def make_connection(self): 

1177 """Create a new connection. Can be overridden by child classes.""" 

1178 return self.connection_class(**self.connection_kwargs) 

1179 

1180 async def ensure_connection(self, connection: AbstractConnection): 

1181 """Ensure that the connection object is connected and valid""" 

1182 await connection.connect() 

1183 # connections that the pool provides should be ready to send 

1184 # a command. if not, the connection was either returned to the 

1185 # pool before all data has been read or the socket has been 

1186 # closed. either way, reconnect and verify everything is good. 

1187 try: 

1188 if await connection.can_read_destructive(): 

1189 raise ConnectionError("Connection has data") from None 

1190 except (ConnectionError, TimeoutError, OSError): 

1191 await connection.disconnect() 

1192 await connection.connect() 

1193 if await connection.can_read_destructive(): 

1194 raise ConnectionError("Connection not ready") from None 

1195 

1196 async def release(self, connection: AbstractConnection): 

1197 """Releases the connection back to the pool""" 

1198 # Connections should always be returned to the correct pool, 

1199 # not doing so is an error that will cause an exception here. 

1200 self._in_use_connections.remove(connection) 

1201 self._available_connections.append(connection) 

1202 await self._event_dispatcher.dispatch_async( 

1203 AsyncAfterConnectionReleasedEvent(connection) 

1204 ) 

1205 

1206 async def disconnect(self, inuse_connections: bool = True): 

1207 """ 

1208 Disconnects connections in the pool 

1209 

1210 If ``inuse_connections`` is True, disconnect connections that are 

1211 current in use, potentially by other tasks. Otherwise only disconnect 

1212 connections that are idle in the pool. 

1213 """ 

1214 if inuse_connections: 

1215 connections: Iterable[AbstractConnection] = chain( 

1216 self._available_connections, self._in_use_connections 

1217 ) 

1218 else: 

1219 connections = self._available_connections 

1220 resp = await asyncio.gather( 

1221 *(connection.disconnect() for connection in connections), 

1222 return_exceptions=True, 

1223 ) 

1224 exc = next((r for r in resp if isinstance(r, BaseException)), None) 

1225 if exc: 

1226 raise exc 

1227 

1228 async def aclose(self) -> None: 

1229 """Close the pool, disconnecting all connections""" 

1230 await self.disconnect() 

1231 

1232 def set_retry(self, retry: "Retry") -> None: 

1233 for conn in self._available_connections: 

1234 conn.retry = retry 

1235 for conn in self._in_use_connections: 

1236 conn.retry = retry 

1237 

1238 async def re_auth_callback(self, token: TokenInterface): 

1239 async with self._lock: 

1240 for conn in self._available_connections: 

1241 await conn.retry.call_with_retry( 

1242 lambda: conn.send_command( 

1243 "AUTH", token.try_get("oid"), token.get_value() 

1244 ), 

1245 lambda error: self._mock(error), 

1246 ) 

1247 await conn.retry.call_with_retry( 

1248 lambda: conn.read_response(), lambda error: self._mock(error) 

1249 ) 

1250 for conn in self._in_use_connections: 

1251 conn.set_re_auth_token(token) 

1252 

1253 async def _mock(self, error: RedisError): 

1254 """ 

1255 Dummy functions, needs to be passed as error callback to retry object. 

1256 :param error: 

1257 :return: 

1258 """ 

1259 pass 

1260 

1261 

1262class BlockingConnectionPool(ConnectionPool): 

1263 """ 

1264 A blocking connection pool:: 

1265 

1266 >>> from redis.asyncio import Redis, BlockingConnectionPool 

1267 >>> client = Redis.from_pool(BlockingConnectionPool()) 

1268 

1269 It performs the same function as the default 

1270 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that, 

1271 it maintains a pool of reusable connections that can be shared by 

1272 multiple async redis clients. 

1273 

1274 The difference is that, in the event that a client tries to get a 

1275 connection from the pool when all of connections are in use, rather than 

1276 raising a :py:class:`~redis.ConnectionError` (as the default 

1277 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it 

1278 blocks the current `Task` for a specified number of seconds until 

1279 a connection becomes available. 

1280 

1281 Use ``max_connections`` to increase / decrease the pool size:: 

1282 

1283 >>> pool = BlockingConnectionPool(max_connections=10) 

1284 

1285 Use ``timeout`` to tell it either how many seconds to wait for a connection 

1286 to become available, or to block forever: 

1287 

1288 >>> # Block forever. 

1289 >>> pool = BlockingConnectionPool(timeout=None) 

1290 

1291 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

1292 >>> # not available. 

1293 >>> pool = BlockingConnectionPool(timeout=5) 

1294 """ 

1295 

1296 def __init__( 

1297 self, 

1298 max_connections: int = 50, 

1299 timeout: Optional[int] = 20, 

1300 connection_class: Type[AbstractConnection] = Connection, 

1301 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated 

1302 **connection_kwargs, 

1303 ): 

1304 super().__init__( 

1305 connection_class=connection_class, 

1306 max_connections=max_connections, 

1307 **connection_kwargs, 

1308 ) 

1309 self._condition = asyncio.Condition() 

1310 self.timeout = timeout 

1311 

1312 @deprecated_args( 

1313 args_to_warn=["*"], 

1314 reason="Use get_connection() without args instead", 

1315 version="5.3.0", 

1316 ) 

1317 async def get_connection(self, command_name=None, *keys, **options): 

1318 """Gets a connection from the pool, blocking until one is available""" 

1319 try: 

1320 async with self._condition: 

1321 async with async_timeout(self.timeout): 

1322 await self._condition.wait_for(self.can_get_connection) 

1323 connection = super().get_available_connection() 

1324 except asyncio.TimeoutError as err: 

1325 raise ConnectionError("No connection available.") from err 

1326 

1327 # We now perform the connection check outside of the lock. 

1328 try: 

1329 await self.ensure_connection(connection) 

1330 return connection 

1331 except BaseException: 

1332 await self.release(connection) 

1333 raise 

1334 

1335 async def release(self, connection: AbstractConnection): 

1336 """Releases the connection back to the pool.""" 

1337 async with self._condition: 

1338 await super().release(connection) 

1339 self._condition.notify()