Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/connection.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

660 statements  

1import asyncio 

2import copy 

3import enum 

4import inspect 

5import socket 

6import sys 

7import warnings 

8import weakref 

9from abc import abstractmethod 

10from itertools import chain 

11from types import MappingProxyType 

12from typing import ( 

13 Any, 

14 Callable, 

15 Iterable, 

16 List, 

17 Mapping, 

18 Optional, 

19 Protocol, 

20 Set, 

21 Tuple, 

22 Type, 

23 TypedDict, 

24 TypeVar, 

25 Union, 

26) 

27from urllib.parse import ParseResult, parse_qs, unquote, urlparse 

28 

29from ..utils import SSL_AVAILABLE 

30 

31if SSL_AVAILABLE: 

32 import ssl 

33 from ssl import SSLContext, TLSVersion 

34else: 

35 ssl = None 

36 TLSVersion = None 

37 SSLContext = None 

38 

39from ..auth.token import TokenInterface 

40from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher 

41from ..utils import deprecated_args, format_error_message 

42 

43# the functionality is available in 3.11.x but has a major issue before 

44# 3.11.3. See https://github.com/redis/redis-py/issues/2633 

45if sys.version_info >= (3, 11, 3): 

46 from asyncio import timeout as async_timeout 

47else: 

48 from async_timeout import timeout as async_timeout 

49 

50from redis.asyncio.retry import Retry 

51from redis.backoff import NoBackoff 

52from redis.connection import DEFAULT_RESP_VERSION 

53from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider 

54from redis.exceptions import ( 

55 AuthenticationError, 

56 AuthenticationWrongNumberOfArgsError, 

57 ConnectionError, 

58 DataError, 

59 RedisError, 

60 ResponseError, 

61 TimeoutError, 

62) 

63from redis.typing import EncodableT 

64from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes 

65 

66from .._parsers import ( 

67 BaseParser, 

68 Encoder, 

69 _AsyncHiredisParser, 

70 _AsyncRESP2Parser, 

71 _AsyncRESP3Parser, 

72) 

73 

74SYM_STAR = b"*" 

75SYM_DOLLAR = b"$" 

76SYM_CRLF = b"\r\n" 

77SYM_LF = b"\n" 

78SYM_EMPTY = b"" 

79 

80 

81class _Sentinel(enum.Enum): 

82 sentinel = object() 

83 

84 

85SENTINEL = _Sentinel.sentinel 

86 

87 

88DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] 

89if HIREDIS_AVAILABLE: 

90 DefaultParser = _AsyncHiredisParser 

91else: 

92 DefaultParser = _AsyncRESP2Parser 

93 

94 

95class ConnectCallbackProtocol(Protocol): 

96 def __call__(self, connection: "AbstractConnection"): ... 

97 

98 

99class AsyncConnectCallbackProtocol(Protocol): 

100 async def __call__(self, connection: "AbstractConnection"): ... 

101 

102 

103ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] 

104 

105 

106class AbstractConnection: 

107 """Manages communication to and from a Redis server""" 

108 

109 __slots__ = ( 

110 "db", 

111 "username", 

112 "client_name", 

113 "lib_name", 

114 "lib_version", 

115 "credential_provider", 

116 "password", 

117 "socket_timeout", 

118 "socket_connect_timeout", 

119 "redis_connect_func", 

120 "retry_on_timeout", 

121 "retry_on_error", 

122 "health_check_interval", 

123 "next_health_check", 

124 "last_active_at", 

125 "encoder", 

126 "ssl_context", 

127 "protocol", 

128 "_reader", 

129 "_writer", 

130 "_parser", 

131 "_connect_callbacks", 

132 "_buffer_cutoff", 

133 "_lock", 

134 "_socket_read_size", 

135 "__dict__", 

136 ) 

137 

138 def __init__( 

139 self, 

140 *, 

141 db: Union[str, int] = 0, 

142 password: Optional[str] = None, 

143 socket_timeout: Optional[float] = None, 

144 socket_connect_timeout: Optional[float] = None, 

145 retry_on_timeout: bool = False, 

146 retry_on_error: Union[list, _Sentinel] = SENTINEL, 

147 encoding: str = "utf-8", 

148 encoding_errors: str = "strict", 

149 decode_responses: bool = False, 

150 parser_class: Type[BaseParser] = DefaultParser, 

151 socket_read_size: int = 65536, 

152 health_check_interval: float = 0, 

153 client_name: Optional[str] = None, 

154 lib_name: Optional[str] = "redis-py", 

155 lib_version: Optional[str] = get_lib_version(), 

156 username: Optional[str] = None, 

157 retry: Optional[Retry] = None, 

158 redis_connect_func: Optional[ConnectCallbackT] = None, 

159 encoder_class: Type[Encoder] = Encoder, 

160 credential_provider: Optional[CredentialProvider] = None, 

161 protocol: Optional[int] = 2, 

162 event_dispatcher: Optional[EventDispatcher] = None, 

163 ): 

164 if (username or password) and credential_provider is not None: 

165 raise DataError( 

166 "'username' and 'password' cannot be passed along with 'credential_" 

167 "provider'. Please provide only one of the following arguments: \n" 

168 "1. 'password' and (optional) 'username'\n" 

169 "2. 'credential_provider'" 

170 ) 

171 if event_dispatcher is None: 

172 self._event_dispatcher = EventDispatcher() 

173 else: 

174 self._event_dispatcher = event_dispatcher 

175 self.db = db 

176 self.client_name = client_name 

177 self.lib_name = lib_name 

178 self.lib_version = lib_version 

179 self.credential_provider = credential_provider 

180 self.password = password 

181 self.username = username 

182 self.socket_timeout = socket_timeout 

183 if socket_connect_timeout is None: 

184 socket_connect_timeout = socket_timeout 

185 self.socket_connect_timeout = socket_connect_timeout 

186 self.retry_on_timeout = retry_on_timeout 

187 if retry_on_error is SENTINEL: 

188 retry_on_error = [] 

189 if retry_on_timeout: 

190 retry_on_error.append(TimeoutError) 

191 retry_on_error.append(socket.timeout) 

192 retry_on_error.append(asyncio.TimeoutError) 

193 self.retry_on_error = retry_on_error 

194 if retry or retry_on_error: 

195 if not retry: 

196 self.retry = Retry(NoBackoff(), 1) 

197 else: 

198 # deep-copy the Retry object as it is mutable 

199 self.retry = copy.deepcopy(retry) 

200 # Update the retry's supported errors with the specified errors 

201 self.retry.update_supported_errors(retry_on_error) 

202 else: 

203 self.retry = Retry(NoBackoff(), 0) 

204 self.health_check_interval = health_check_interval 

205 self.next_health_check: float = -1 

206 self.encoder = encoder_class(encoding, encoding_errors, decode_responses) 

207 self.redis_connect_func = redis_connect_func 

208 self._reader: Optional[asyncio.StreamReader] = None 

209 self._writer: Optional[asyncio.StreamWriter] = None 

210 self._socket_read_size = socket_read_size 

211 self.set_parser(parser_class) 

212 self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] 

213 self._buffer_cutoff = 6000 

214 self._re_auth_token: Optional[TokenInterface] = None 

215 

216 try: 

217 p = int(protocol) 

218 except TypeError: 

219 p = DEFAULT_RESP_VERSION 

220 except ValueError: 

221 raise ConnectionError("protocol must be an integer") 

222 finally: 

223 if p < 2 or p > 3: 

224 raise ConnectionError("protocol must be either 2 or 3") 

225 self.protocol = protocol 

226 

227 def __del__(self, _warnings: Any = warnings): 

228 # For some reason, the individual streams don't get properly garbage 

229 # collected and therefore produce no resource warnings. We add one 

230 # here, in the same style as those from the stdlib. 

231 if getattr(self, "_writer", None): 

232 _warnings.warn( 

233 f"unclosed Connection {self!r}", ResourceWarning, source=self 

234 ) 

235 

236 try: 

237 asyncio.get_running_loop() 

238 self._close() 

239 except RuntimeError: 

240 # No actions been taken if pool already closed. 

241 pass 

242 

243 def _close(self): 

244 """ 

245 Internal method to silently close the connection without waiting 

246 """ 

247 if self._writer: 

248 self._writer.close() 

249 self._writer = self._reader = None 

250 

251 def __repr__(self): 

252 repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) 

253 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

254 

255 @abstractmethod 

256 def repr_pieces(self): 

257 pass 

258 

259 @property 

260 def is_connected(self): 

261 return self._reader is not None and self._writer is not None 

262 

263 def register_connect_callback(self, callback): 

264 """ 

265 Register a callback to be called when the connection is established either 

266 initially or reconnected. This allows listeners to issue commands that 

267 are ephemeral to the connection, for example pub/sub subscription or 

268 key tracking. The callback must be a _method_ and will be kept as 

269 a weak reference. 

270 """ 

271 wm = weakref.WeakMethod(callback) 

272 if wm not in self._connect_callbacks: 

273 self._connect_callbacks.append(wm) 

274 

275 def deregister_connect_callback(self, callback): 

276 """ 

277 De-register a previously registered callback. It will no-longer receive 

278 notifications on connection events. Calling this is not required when the 

279 listener goes away, since the callbacks are kept as weak methods. 

280 """ 

281 try: 

282 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

283 except ValueError: 

284 pass 

285 

286 def set_parser(self, parser_class: Type[BaseParser]) -> None: 

287 """ 

288 Creates a new instance of parser_class with socket size: 

289 _socket_read_size and assigns it to the parser for the connection 

290 :param parser_class: The required parser class 

291 """ 

292 self._parser = parser_class(socket_read_size=self._socket_read_size) 

293 

294 async def connect(self): 

295 """Connects to the Redis server if not already connected""" 

296 await self.connect_check_health(check_health=True) 

297 

298 async def connect_check_health(self, check_health: bool = True): 

299 if self.is_connected: 

300 return 

301 try: 

302 await self.retry.call_with_retry( 

303 lambda: self._connect(), lambda error: self.disconnect() 

304 ) 

305 except asyncio.CancelledError: 

306 raise # in 3.7 and earlier, this is an Exception, not BaseException 

307 except (socket.timeout, asyncio.TimeoutError): 

308 raise TimeoutError("Timeout connecting to server") 

309 except OSError as e: 

310 raise ConnectionError(self._error_message(e)) 

311 except Exception as exc: 

312 raise ConnectionError(exc) from exc 

313 

314 try: 

315 if not self.redis_connect_func: 

316 # Use the default on_connect function 

317 await self.on_connect_check_health(check_health=check_health) 

318 else: 

319 # Use the passed function redis_connect_func 

320 ( 

321 await self.redis_connect_func(self) 

322 if asyncio.iscoroutinefunction(self.redis_connect_func) 

323 else self.redis_connect_func(self) 

324 ) 

325 except RedisError: 

326 # clean up after any error in on_connect 

327 await self.disconnect() 

328 raise 

329 

330 # run any user callbacks. right now the only internal callback 

331 # is for pubsub channel/pattern resubscription 

332 # first, remove any dead weakrefs 

333 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

334 for ref in self._connect_callbacks: 

335 callback = ref() 

336 task = callback(self) 

337 if task and inspect.isawaitable(task): 

338 await task 

339 

340 @abstractmethod 

341 async def _connect(self): 

342 pass 

343 

344 @abstractmethod 

345 def _host_error(self) -> str: 

346 pass 

347 

348 def _error_message(self, exception: BaseException) -> str: 

349 return format_error_message(self._host_error(), exception) 

350 

351 def get_protocol(self): 

352 return self.protocol 

353 

354 async def on_connect(self) -> None: 

355 """Initialize the connection, authenticate and select a database""" 

356 await self.on_connect_check_health(check_health=True) 

357 

358 async def on_connect_check_health(self, check_health: bool = True) -> None: 

359 self._parser.on_connect(self) 

360 parser = self._parser 

361 

362 auth_args = None 

363 # if credential provider or username and/or password are set, authenticate 

364 if self.credential_provider or (self.username or self.password): 

365 cred_provider = ( 

366 self.credential_provider 

367 or UsernamePasswordCredentialProvider(self.username, self.password) 

368 ) 

369 auth_args = await cred_provider.get_credentials_async() 

370 

371 # if resp version is specified and we have auth args, 

372 # we need to send them via HELLO 

373 if auth_args and self.protocol not in [2, "2"]: 

374 if isinstance(self._parser, _AsyncRESP2Parser): 

375 self.set_parser(_AsyncRESP3Parser) 

376 # update cluster exception classes 

377 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

378 self._parser.on_connect(self) 

379 if len(auth_args) == 1: 

380 auth_args = ["default", auth_args[0]] 

381 # avoid checking health here -- PING will fail if we try 

382 # to check the health prior to the AUTH 

383 await self.send_command( 

384 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

385 ) 

386 response = await self.read_response() 

387 if response.get(b"proto") != int(self.protocol) and response.get( 

388 "proto" 

389 ) != int(self.protocol): 

390 raise ConnectionError("Invalid RESP version") 

391 # avoid checking health here -- PING will fail if we try 

392 # to check the health prior to the AUTH 

393 elif auth_args: 

394 await self.send_command("AUTH", *auth_args, check_health=False) 

395 

396 try: 

397 auth_response = await self.read_response() 

398 except AuthenticationWrongNumberOfArgsError: 

399 # a username and password were specified but the Redis 

400 # server seems to be < 6.0.0 which expects a single password 

401 # arg. retry auth with just the password. 

402 # https://github.com/andymccurdy/redis-py/issues/1274 

403 await self.send_command("AUTH", auth_args[-1], check_health=False) 

404 auth_response = await self.read_response() 

405 

406 if str_if_bytes(auth_response) != "OK": 

407 raise AuthenticationError("Invalid Username or Password") 

408 

409 # if resp version is specified, switch to it 

410 elif self.protocol not in [2, "2"]: 

411 if isinstance(self._parser, _AsyncRESP2Parser): 

412 self.set_parser(_AsyncRESP3Parser) 

413 # update cluster exception classes 

414 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

415 self._parser.on_connect(self) 

416 await self.send_command("HELLO", self.protocol, check_health=check_health) 

417 response = await self.read_response() 

418 # if response.get(b"proto") != self.protocol and response.get( 

419 # "proto" 

420 # ) != self.protocol: 

421 # raise ConnectionError("Invalid RESP version") 

422 

423 # if a client_name is given, set it 

424 if self.client_name: 

425 await self.send_command( 

426 "CLIENT", 

427 "SETNAME", 

428 self.client_name, 

429 check_health=check_health, 

430 ) 

431 if str_if_bytes(await self.read_response()) != "OK": 

432 raise ConnectionError("Error setting client name") 

433 

434 # set the library name and version, pipeline for lower startup latency 

435 if self.lib_name: 

436 await self.send_command( 

437 "CLIENT", 

438 "SETINFO", 

439 "LIB-NAME", 

440 self.lib_name, 

441 check_health=check_health, 

442 ) 

443 if self.lib_version: 

444 await self.send_command( 

445 "CLIENT", 

446 "SETINFO", 

447 "LIB-VER", 

448 self.lib_version, 

449 check_health=check_health, 

450 ) 

451 # if a database is specified, switch to it. Also pipeline this 

452 if self.db: 

453 await self.send_command("SELECT", self.db, check_health=check_health) 

454 

455 # read responses from pipeline 

456 for _ in (sent for sent in (self.lib_name, self.lib_version) if sent): 

457 try: 

458 await self.read_response() 

459 except ResponseError: 

460 pass 

461 

462 if self.db: 

463 if str_if_bytes(await self.read_response()) != "OK": 

464 raise ConnectionError("Invalid Database") 

465 

466 async def disconnect(self, nowait: bool = False) -> None: 

467 """Disconnects from the Redis server""" 

468 try: 

469 async with async_timeout(self.socket_connect_timeout): 

470 self._parser.on_disconnect() 

471 if not self.is_connected: 

472 return 

473 try: 

474 self._writer.close() # type: ignore[union-attr] 

475 # wait for close to finish, except when handling errors and 

476 # forcefully disconnecting. 

477 if not nowait: 

478 await self._writer.wait_closed() # type: ignore[union-attr] 

479 except OSError: 

480 pass 

481 finally: 

482 self._reader = None 

483 self._writer = None 

484 except asyncio.TimeoutError: 

485 raise TimeoutError( 

486 f"Timed out closing connection after {self.socket_connect_timeout}" 

487 ) from None 

488 

489 async def _send_ping(self): 

490 """Send PING, expect PONG in return""" 

491 await self.send_command("PING", check_health=False) 

492 if str_if_bytes(await self.read_response()) != "PONG": 

493 raise ConnectionError("Bad response from PING health check") 

494 

495 async def _ping_failed(self, error): 

496 """Function to call when PING fails""" 

497 await self.disconnect() 

498 

499 async def check_health(self): 

500 """Check the health of the connection with a PING/PONG""" 

501 if ( 

502 self.health_check_interval 

503 and asyncio.get_running_loop().time() > self.next_health_check 

504 ): 

505 await self.retry.call_with_retry(self._send_ping, self._ping_failed) 

506 

507 async def _send_packed_command(self, command: Iterable[bytes]) -> None: 

508 self._writer.writelines(command) 

509 await self._writer.drain() 

510 

511 async def send_packed_command( 

512 self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True 

513 ) -> None: 

514 if not self.is_connected: 

515 await self.connect_check_health(check_health=False) 

516 if check_health: 

517 await self.check_health() 

518 

519 try: 

520 if isinstance(command, str): 

521 command = command.encode() 

522 if isinstance(command, bytes): 

523 command = [command] 

524 if self.socket_timeout: 

525 await asyncio.wait_for( 

526 self._send_packed_command(command), self.socket_timeout 

527 ) 

528 else: 

529 self._writer.writelines(command) 

530 await self._writer.drain() 

531 except asyncio.TimeoutError: 

532 await self.disconnect(nowait=True) 

533 raise TimeoutError("Timeout writing to socket") from None 

534 except OSError as e: 

535 await self.disconnect(nowait=True) 

536 if len(e.args) == 1: 

537 err_no, errmsg = "UNKNOWN", e.args[0] 

538 else: 

539 err_no = e.args[0] 

540 errmsg = e.args[1] 

541 raise ConnectionError( 

542 f"Error {err_no} while writing to socket. {errmsg}." 

543 ) from e 

544 except BaseException: 

545 # BaseExceptions can be raised when a socket send operation is not 

546 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

547 # to send un-sent data. However, the send_packed_command() API 

548 # does not support it so there is no point in keeping the connection open. 

549 await self.disconnect(nowait=True) 

550 raise 

551 

552 async def send_command(self, *args: Any, **kwargs: Any) -> None: 

553 """Pack and send a command to the Redis server""" 

554 await self.send_packed_command( 

555 self.pack_command(*args), check_health=kwargs.get("check_health", True) 

556 ) 

557 

558 async def can_read_destructive(self): 

559 """Poll the socket to see if there's data that can be read.""" 

560 try: 

561 return await self._parser.can_read_destructive() 

562 except OSError as e: 

563 await self.disconnect(nowait=True) 

564 host_error = self._host_error() 

565 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

566 

567 async def read_response( 

568 self, 

569 disable_decoding: bool = False, 

570 timeout: Optional[float] = None, 

571 *, 

572 disconnect_on_error: bool = True, 

573 push_request: Optional[bool] = False, 

574 ): 

575 """Read the response from a previously sent command""" 

576 read_timeout = timeout if timeout is not None else self.socket_timeout 

577 host_error = self._host_error() 

578 try: 

579 if read_timeout is not None and self.protocol in ["3", 3]: 

580 async with async_timeout(read_timeout): 

581 response = await self._parser.read_response( 

582 disable_decoding=disable_decoding, push_request=push_request 

583 ) 

584 elif read_timeout is not None: 

585 async with async_timeout(read_timeout): 

586 response = await self._parser.read_response( 

587 disable_decoding=disable_decoding 

588 ) 

589 elif self.protocol in ["3", 3]: 

590 response = await self._parser.read_response( 

591 disable_decoding=disable_decoding, push_request=push_request 

592 ) 

593 else: 

594 response = await self._parser.read_response( 

595 disable_decoding=disable_decoding 

596 ) 

597 except asyncio.TimeoutError: 

598 if timeout is not None: 

599 # user requested timeout, return None. Operation can be retried 

600 return None 

601 # it was a self.socket_timeout error. 

602 if disconnect_on_error: 

603 await self.disconnect(nowait=True) 

604 raise TimeoutError(f"Timeout reading from {host_error}") 

605 except OSError as e: 

606 if disconnect_on_error: 

607 await self.disconnect(nowait=True) 

608 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

609 except BaseException: 

610 # Also by default close in case of BaseException. A lot of code 

611 # relies on this behaviour when doing Command/Response pairs. 

612 # See #1128. 

613 if disconnect_on_error: 

614 await self.disconnect(nowait=True) 

615 raise 

616 

617 if self.health_check_interval: 

618 next_time = asyncio.get_running_loop().time() + self.health_check_interval 

619 self.next_health_check = next_time 

620 

621 if isinstance(response, ResponseError): 

622 raise response from None 

623 return response 

624 

625 def pack_command(self, *args: EncodableT) -> List[bytes]: 

626 """Pack a series of arguments into the Redis protocol""" 

627 output = [] 

628 # the client might have included 1 or more literal arguments in 

629 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

630 # arguments to be sent separately, so split the first argument 

631 # manually. These arguments should be bytestrings so that they are 

632 # not encoded. 

633 assert not isinstance(args[0], float) 

634 if isinstance(args[0], str): 

635 args = tuple(args[0].encode().split()) + args[1:] 

636 elif b" " in args[0]: 

637 args = tuple(args[0].split()) + args[1:] 

638 

639 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

640 

641 buffer_cutoff = self._buffer_cutoff 

642 for arg in map(self.encoder.encode, args): 

643 # to avoid large string mallocs, chunk the command into the 

644 # output list if we're sending large values or memoryviews 

645 arg_length = len(arg) 

646 if ( 

647 len(buff) > buffer_cutoff 

648 or arg_length > buffer_cutoff 

649 or isinstance(arg, memoryview) 

650 ): 

651 buff = SYM_EMPTY.join( 

652 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

653 ) 

654 output.append(buff) 

655 output.append(arg) 

656 buff = SYM_CRLF 

657 else: 

658 buff = SYM_EMPTY.join( 

659 ( 

660 buff, 

661 SYM_DOLLAR, 

662 str(arg_length).encode(), 

663 SYM_CRLF, 

664 arg, 

665 SYM_CRLF, 

666 ) 

667 ) 

668 output.append(buff) 

669 return output 

670 

671 def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: 

672 """Pack multiple commands into the Redis protocol""" 

673 output: List[bytes] = [] 

674 pieces: List[bytes] = [] 

675 buffer_length = 0 

676 buffer_cutoff = self._buffer_cutoff 

677 

678 for cmd in commands: 

679 for chunk in self.pack_command(*cmd): 

680 chunklen = len(chunk) 

681 if ( 

682 buffer_length > buffer_cutoff 

683 or chunklen > buffer_cutoff 

684 or isinstance(chunk, memoryview) 

685 ): 

686 if pieces: 

687 output.append(SYM_EMPTY.join(pieces)) 

688 buffer_length = 0 

689 pieces = [] 

690 

691 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

692 output.append(chunk) 

693 else: 

694 pieces.append(chunk) 

695 buffer_length += chunklen 

696 

697 if pieces: 

698 output.append(SYM_EMPTY.join(pieces)) 

699 return output 

700 

701 def _socket_is_empty(self): 

702 """Check if the socket is empty""" 

703 return len(self._reader._buffer) == 0 

704 

705 async def process_invalidation_messages(self): 

706 while not self._socket_is_empty(): 

707 await self.read_response(push_request=True) 

708 

709 def set_re_auth_token(self, token: TokenInterface): 

710 self._re_auth_token = token 

711 

712 async def re_auth(self): 

713 if self._re_auth_token is not None: 

714 await self.send_command( 

715 "AUTH", 

716 self._re_auth_token.try_get("oid"), 

717 self._re_auth_token.get_value(), 

718 ) 

719 await self.read_response() 

720 self._re_auth_token = None 

721 

722 

723class Connection(AbstractConnection): 

724 "Manages TCP communication to and from a Redis server" 

725 

726 def __init__( 

727 self, 

728 *, 

729 host: str = "localhost", 

730 port: Union[str, int] = 6379, 

731 socket_keepalive: bool = False, 

732 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

733 socket_type: int = 0, 

734 **kwargs, 

735 ): 

736 self.host = host 

737 self.port = int(port) 

738 self.socket_keepalive = socket_keepalive 

739 self.socket_keepalive_options = socket_keepalive_options or {} 

740 self.socket_type = socket_type 

741 super().__init__(**kwargs) 

742 

743 def repr_pieces(self): 

744 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

745 if self.client_name: 

746 pieces.append(("client_name", self.client_name)) 

747 return pieces 

748 

749 def _connection_arguments(self) -> Mapping: 

750 return {"host": self.host, "port": self.port} 

751 

752 async def _connect(self): 

753 """Create a TCP socket connection""" 

754 async with async_timeout(self.socket_connect_timeout): 

755 reader, writer = await asyncio.open_connection( 

756 **self._connection_arguments() 

757 ) 

758 self._reader = reader 

759 self._writer = writer 

760 sock = writer.transport.get_extra_info("socket") 

761 if sock: 

762 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

763 try: 

764 # TCP_KEEPALIVE 

765 if self.socket_keepalive: 

766 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

767 for k, v in self.socket_keepalive_options.items(): 

768 sock.setsockopt(socket.SOL_TCP, k, v) 

769 

770 except (OSError, TypeError): 

771 # `socket_keepalive_options` might contain invalid options 

772 # causing an error. Do not leave the connection open. 

773 writer.close() 

774 raise 

775 

776 def _host_error(self) -> str: 

777 return f"{self.host}:{self.port}" 

778 

779 

780class SSLConnection(Connection): 

781 """Manages SSL connections to and from the Redis server(s). 

782 This class extends the Connection class, adding SSL functionality, and making 

783 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

784 """ 

785 

786 def __init__( 

787 self, 

788 ssl_keyfile: Optional[str] = None, 

789 ssl_certfile: Optional[str] = None, 

790 ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", 

791 ssl_ca_certs: Optional[str] = None, 

792 ssl_ca_data: Optional[str] = None, 

793 ssl_check_hostname: bool = True, 

794 ssl_min_version: Optional[TLSVersion] = None, 

795 ssl_ciphers: Optional[str] = None, 

796 **kwargs, 

797 ): 

798 if not SSL_AVAILABLE: 

799 raise RedisError("Python wasn't built with SSL support") 

800 

801 self.ssl_context: RedisSSLContext = RedisSSLContext( 

802 keyfile=ssl_keyfile, 

803 certfile=ssl_certfile, 

804 cert_reqs=ssl_cert_reqs, 

805 ca_certs=ssl_ca_certs, 

806 ca_data=ssl_ca_data, 

807 check_hostname=ssl_check_hostname, 

808 min_version=ssl_min_version, 

809 ciphers=ssl_ciphers, 

810 ) 

811 super().__init__(**kwargs) 

812 

813 def _connection_arguments(self) -> Mapping: 

814 kwargs = super()._connection_arguments() 

815 kwargs["ssl"] = self.ssl_context.get() 

816 return kwargs 

817 

818 @property 

819 def keyfile(self): 

820 return self.ssl_context.keyfile 

821 

822 @property 

823 def certfile(self): 

824 return self.ssl_context.certfile 

825 

826 @property 

827 def cert_reqs(self): 

828 return self.ssl_context.cert_reqs 

829 

830 @property 

831 def ca_certs(self): 

832 return self.ssl_context.ca_certs 

833 

834 @property 

835 def ca_data(self): 

836 return self.ssl_context.ca_data 

837 

838 @property 

839 def check_hostname(self): 

840 return self.ssl_context.check_hostname 

841 

842 @property 

843 def min_version(self): 

844 return self.ssl_context.min_version 

845 

846 

847class RedisSSLContext: 

848 __slots__ = ( 

849 "keyfile", 

850 "certfile", 

851 "cert_reqs", 

852 "ca_certs", 

853 "ca_data", 

854 "context", 

855 "check_hostname", 

856 "min_version", 

857 "ciphers", 

858 ) 

859 

860 def __init__( 

861 self, 

862 keyfile: Optional[str] = None, 

863 certfile: Optional[str] = None, 

864 cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, 

865 ca_certs: Optional[str] = None, 

866 ca_data: Optional[str] = None, 

867 check_hostname: bool = False, 

868 min_version: Optional[TLSVersion] = None, 

869 ciphers: Optional[str] = None, 

870 ): 

871 if not SSL_AVAILABLE: 

872 raise RedisError("Python wasn't built with SSL support") 

873 

874 self.keyfile = keyfile 

875 self.certfile = certfile 

876 if cert_reqs is None: 

877 cert_reqs = ssl.CERT_NONE 

878 elif isinstance(cert_reqs, str): 

879 CERT_REQS = { # noqa: N806 

880 "none": ssl.CERT_NONE, 

881 "optional": ssl.CERT_OPTIONAL, 

882 "required": ssl.CERT_REQUIRED, 

883 } 

884 if cert_reqs not in CERT_REQS: 

885 raise RedisError( 

886 f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" 

887 ) 

888 cert_reqs = CERT_REQS[cert_reqs] 

889 self.cert_reqs = cert_reqs 

890 self.ca_certs = ca_certs 

891 self.ca_data = ca_data 

892 self.check_hostname = ( 

893 check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

894 ) 

895 self.min_version = min_version 

896 self.ciphers = ciphers 

897 self.context: Optional[SSLContext] = None 

898 

899 def get(self) -> SSLContext: 

900 if not self.context: 

901 context = ssl.create_default_context() 

902 context.check_hostname = self.check_hostname 

903 context.verify_mode = self.cert_reqs 

904 if self.certfile and self.keyfile: 

905 context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) 

906 if self.ca_certs or self.ca_data: 

907 context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data) 

908 if self.min_version is not None: 

909 context.minimum_version = self.min_version 

910 if self.ciphers is not None: 

911 context.set_ciphers(self.ciphers) 

912 self.context = context 

913 return self.context 

914 

915 

916class UnixDomainSocketConnection(AbstractConnection): 

917 "Manages UDS communication to and from a Redis server" 

918 

919 def __init__(self, *, path: str = "", **kwargs): 

920 self.path = path 

921 super().__init__(**kwargs) 

922 

923 def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: 

924 pieces = [("path", self.path), ("db", self.db)] 

925 if self.client_name: 

926 pieces.append(("client_name", self.client_name)) 

927 return pieces 

928 

929 async def _connect(self): 

930 async with async_timeout(self.socket_connect_timeout): 

931 reader, writer = await asyncio.open_unix_connection(path=self.path) 

932 self._reader = reader 

933 self._writer = writer 

934 await self.on_connect() 

935 

936 def _host_error(self) -> str: 

937 return self.path 

938 

939 

940FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

941 

942 

943def to_bool(value) -> Optional[bool]: 

944 if value is None or value == "": 

945 return None 

946 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

947 return False 

948 return bool(value) 

949 

950 

951URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( 

952 { 

953 "db": int, 

954 "socket_timeout": float, 

955 "socket_connect_timeout": float, 

956 "socket_keepalive": to_bool, 

957 "retry_on_timeout": to_bool, 

958 "max_connections": int, 

959 "health_check_interval": int, 

960 "ssl_check_hostname": to_bool, 

961 "timeout": float, 

962 } 

963) 

964 

965 

966class ConnectKwargs(TypedDict, total=False): 

967 username: str 

968 password: str 

969 connection_class: Type[AbstractConnection] 

970 host: str 

971 port: int 

972 db: int 

973 path: str 

974 

975 

976def parse_url(url: str) -> ConnectKwargs: 

977 parsed: ParseResult = urlparse(url) 

978 kwargs: ConnectKwargs = {} 

979 

980 for name, value_list in parse_qs(parsed.query).items(): 

981 if value_list and len(value_list) > 0: 

982 value = unquote(value_list[0]) 

983 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

984 if parser: 

985 try: 

986 kwargs[name] = parser(value) 

987 except (TypeError, ValueError): 

988 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

989 else: 

990 kwargs[name] = value 

991 

992 if parsed.username: 

993 kwargs["username"] = unquote(parsed.username) 

994 if parsed.password: 

995 kwargs["password"] = unquote(parsed.password) 

996 

997 # We only support redis://, rediss:// and unix:// schemes. 

998 if parsed.scheme == "unix": 

999 if parsed.path: 

1000 kwargs["path"] = unquote(parsed.path) 

1001 kwargs["connection_class"] = UnixDomainSocketConnection 

1002 

1003 elif parsed.scheme in ("redis", "rediss"): 

1004 if parsed.hostname: 

1005 kwargs["host"] = unquote(parsed.hostname) 

1006 if parsed.port: 

1007 kwargs["port"] = int(parsed.port) 

1008 

1009 # If there's a path argument, use it as the db argument if a 

1010 # querystring value wasn't specified 

1011 if parsed.path and "db" not in kwargs: 

1012 try: 

1013 kwargs["db"] = int(unquote(parsed.path).replace("/", "")) 

1014 except (AttributeError, ValueError): 

1015 pass 

1016 

1017 if parsed.scheme == "rediss": 

1018 kwargs["connection_class"] = SSLConnection 

1019 else: 

1020 valid_schemes = "redis://, rediss://, unix://" 

1021 raise ValueError( 

1022 f"Redis URL must specify one of the following schemes ({valid_schemes})" 

1023 ) 

1024 

1025 return kwargs 

1026 

1027 

1028_CP = TypeVar("_CP", bound="ConnectionPool") 

1029 

1030 

1031class ConnectionPool: 

1032 """ 

1033 Create a connection pool. ``If max_connections`` is set, then this 

1034 object raises :py:class:`~redis.ConnectionError` when the pool's 

1035 limit is reached. 

1036 

1037 By default, TCP connections are created unless ``connection_class`` 

1038 is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for 

1039 unix sockets. 

1040 

1041 Any additional keyword arguments are passed to the constructor of 

1042 ``connection_class``. 

1043 """ 

1044 

1045 @classmethod 

1046 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

1047 """ 

1048 Return a connection pool configured from the given URL. 

1049 

1050 For example:: 

1051 

1052 redis://[[username]:[password]]@localhost:6379/0 

1053 rediss://[[username]:[password]]@localhost:6379/0 

1054 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1055 

1056 Three URL schemes are supported: 

1057 

1058 - `redis://` creates a TCP socket connection. See more at: 

1059 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1060 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1061 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1062 - ``unix://``: creates a Unix Domain Socket connection. 

1063 

1064 The username, password, hostname, path and all querystring values 

1065 are passed through urllib.parse.unquote in order to replace any 

1066 percent-encoded values with their corresponding characters. 

1067 

1068 There are several ways to specify a database number. The first value 

1069 found will be used: 

1070 

1071 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1072 

1073 2. If using the redis:// or rediss:// schemes, the path argument 

1074 of the url, e.g. redis://localhost/0 

1075 

1076 3. A ``db`` keyword argument to this function. 

1077 

1078 If none of these options are specified, the default db=0 is used. 

1079 

1080 All querystring options are cast to their appropriate Python types. 

1081 Boolean arguments can be specified with string values "True"/"False" 

1082 or "Yes"/"No". Values that cannot be properly cast cause a 

1083 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1084 and keyword arguments are passed to the ``ConnectionPool``'s 

1085 class initializer. In the case of conflicting arguments, querystring 

1086 arguments always win. 

1087 """ 

1088 url_options = parse_url(url) 

1089 kwargs.update(url_options) 

1090 return cls(**kwargs) 

1091 

1092 def __init__( 

1093 self, 

1094 connection_class: Type[AbstractConnection] = Connection, 

1095 max_connections: Optional[int] = None, 

1096 **connection_kwargs, 

1097 ): 

1098 max_connections = max_connections or 2**31 

1099 if not isinstance(max_connections, int) or max_connections < 0: 

1100 raise ValueError('"max_connections" must be a positive integer') 

1101 

1102 self.connection_class = connection_class 

1103 self.connection_kwargs = connection_kwargs 

1104 self.max_connections = max_connections 

1105 

1106 self._available_connections: List[AbstractConnection] = [] 

1107 self._in_use_connections: Set[AbstractConnection] = set() 

1108 self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) 

1109 self._lock = asyncio.Lock() 

1110 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1111 if self._event_dispatcher is None: 

1112 self._event_dispatcher = EventDispatcher() 

1113 

1114 def __repr__(self): 

1115 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

1116 return ( 

1117 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

1118 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

1119 f"({conn_kwargs})>)>" 

1120 ) 

1121 

1122 def reset(self): 

1123 self._available_connections = [] 

1124 self._in_use_connections = weakref.WeakSet() 

1125 

1126 def can_get_connection(self) -> bool: 

1127 """Return True if a connection can be retrieved from the pool.""" 

1128 return ( 

1129 self._available_connections 

1130 or len(self._in_use_connections) < self.max_connections 

1131 ) 

1132 

1133 @deprecated_args( 

1134 args_to_warn=["*"], 

1135 reason="Use get_connection() without args instead", 

1136 version="5.3.0", 

1137 ) 

1138 async def get_connection(self, command_name=None, *keys, **options): 

1139 async with self._lock: 

1140 """Get a connected connection from the pool""" 

1141 connection = self.get_available_connection() 

1142 try: 

1143 await self.ensure_connection(connection) 

1144 except BaseException: 

1145 await self.release(connection) 

1146 raise 

1147 

1148 return connection 

1149 

1150 def get_available_connection(self): 

1151 """Get a connection from the pool, without making sure it is connected""" 

1152 try: 

1153 connection = self._available_connections.pop() 

1154 except IndexError: 

1155 if len(self._in_use_connections) >= self.max_connections: 

1156 raise ConnectionError("Too many connections") from None 

1157 connection = self.make_connection() 

1158 self._in_use_connections.add(connection) 

1159 return connection 

1160 

1161 def get_encoder(self): 

1162 """Return an encoder based on encoding settings""" 

1163 kwargs = self.connection_kwargs 

1164 return self.encoder_class( 

1165 encoding=kwargs.get("encoding", "utf-8"), 

1166 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1167 decode_responses=kwargs.get("decode_responses", False), 

1168 ) 

1169 

1170 def make_connection(self): 

1171 """Create a new connection. Can be overridden by child classes.""" 

1172 return self.connection_class(**self.connection_kwargs) 

1173 

1174 async def ensure_connection(self, connection: AbstractConnection): 

1175 """Ensure that the connection object is connected and valid""" 

1176 await connection.connect() 

1177 # connections that the pool provides should be ready to send 

1178 # a command. if not, the connection was either returned to the 

1179 # pool before all data has been read or the socket has been 

1180 # closed. either way, reconnect and verify everything is good. 

1181 try: 

1182 if await connection.can_read_destructive(): 

1183 raise ConnectionError("Connection has data") from None 

1184 except (ConnectionError, TimeoutError, OSError): 

1185 await connection.disconnect() 

1186 await connection.connect() 

1187 if await connection.can_read_destructive(): 

1188 raise ConnectionError("Connection not ready") from None 

1189 

1190 async def release(self, connection: AbstractConnection): 

1191 """Releases the connection back to the pool""" 

1192 # Connections should always be returned to the correct pool, 

1193 # not doing so is an error that will cause an exception here. 

1194 self._in_use_connections.remove(connection) 

1195 self._available_connections.append(connection) 

1196 await self._event_dispatcher.dispatch_async( 

1197 AsyncAfterConnectionReleasedEvent(connection) 

1198 ) 

1199 

1200 async def disconnect(self, inuse_connections: bool = True): 

1201 """ 

1202 Disconnects connections in the pool 

1203 

1204 If ``inuse_connections`` is True, disconnect connections that are 

1205 current in use, potentially by other tasks. Otherwise only disconnect 

1206 connections that are idle in the pool. 

1207 """ 

1208 if inuse_connections: 

1209 connections: Iterable[AbstractConnection] = chain( 

1210 self._available_connections, self._in_use_connections 

1211 ) 

1212 else: 

1213 connections = self._available_connections 

1214 resp = await asyncio.gather( 

1215 *(connection.disconnect() for connection in connections), 

1216 return_exceptions=True, 

1217 ) 

1218 exc = next((r for r in resp if isinstance(r, BaseException)), None) 

1219 if exc: 

1220 raise exc 

1221 

1222 async def aclose(self) -> None: 

1223 """Close the pool, disconnecting all connections""" 

1224 await self.disconnect() 

1225 

1226 def set_retry(self, retry: "Retry") -> None: 

1227 for conn in self._available_connections: 

1228 conn.retry = retry 

1229 for conn in self._in_use_connections: 

1230 conn.retry = retry 

1231 

1232 async def re_auth_callback(self, token: TokenInterface): 

1233 async with self._lock: 

1234 for conn in self._available_connections: 

1235 await conn.retry.call_with_retry( 

1236 lambda: conn.send_command( 

1237 "AUTH", token.try_get("oid"), token.get_value() 

1238 ), 

1239 lambda error: self._mock(error), 

1240 ) 

1241 await conn.retry.call_with_retry( 

1242 lambda: conn.read_response(), lambda error: self._mock(error) 

1243 ) 

1244 for conn in self._in_use_connections: 

1245 conn.set_re_auth_token(token) 

1246 

1247 async def _mock(self, error: RedisError): 

1248 """ 

1249 Dummy functions, needs to be passed as error callback to retry object. 

1250 :param error: 

1251 :return: 

1252 """ 

1253 pass 

1254 

1255 

1256class BlockingConnectionPool(ConnectionPool): 

1257 """ 

1258 A blocking connection pool:: 

1259 

1260 >>> from redis.asyncio import Redis, BlockingConnectionPool 

1261 >>> client = Redis.from_pool(BlockingConnectionPool()) 

1262 

1263 It performs the same function as the default 

1264 :py:class:`~redis.asyncio.ConnectionPool` implementation, in that, 

1265 it maintains a pool of reusable connections that can be shared by 

1266 multiple async redis clients. 

1267 

1268 The difference is that, in the event that a client tries to get a 

1269 connection from the pool when all of connections are in use, rather than 

1270 raising a :py:class:`~redis.ConnectionError` (as the default 

1271 :py:class:`~redis.asyncio.ConnectionPool` implementation does), it 

1272 blocks the current `Task` for a specified number of seconds until 

1273 a connection becomes available. 

1274 

1275 Use ``max_connections`` to increase / decrease the pool size:: 

1276 

1277 >>> pool = BlockingConnectionPool(max_connections=10) 

1278 

1279 Use ``timeout`` to tell it either how many seconds to wait for a connection 

1280 to become available, or to block forever: 

1281 

1282 >>> # Block forever. 

1283 >>> pool = BlockingConnectionPool(timeout=None) 

1284 

1285 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

1286 >>> # not available. 

1287 >>> pool = BlockingConnectionPool(timeout=5) 

1288 """ 

1289 

1290 def __init__( 

1291 self, 

1292 max_connections: int = 50, 

1293 timeout: Optional[int] = 20, 

1294 connection_class: Type[AbstractConnection] = Connection, 

1295 queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated 

1296 **connection_kwargs, 

1297 ): 

1298 super().__init__( 

1299 connection_class=connection_class, 

1300 max_connections=max_connections, 

1301 **connection_kwargs, 

1302 ) 

1303 self._condition = asyncio.Condition() 

1304 self.timeout = timeout 

1305 

1306 @deprecated_args( 

1307 args_to_warn=["*"], 

1308 reason="Use get_connection() without args instead", 

1309 version="5.3.0", 

1310 ) 

1311 async def get_connection(self, command_name=None, *keys, **options): 

1312 """Gets a connection from the pool, blocking until one is available""" 

1313 try: 

1314 async with self._condition: 

1315 async with async_timeout(self.timeout): 

1316 await self._condition.wait_for(self.can_get_connection) 

1317 connection = super().get_available_connection() 

1318 except asyncio.TimeoutError as err: 

1319 raise ConnectionError("No connection available.") from err 

1320 

1321 # We now perform the connection check outside of the lock. 

1322 try: 

1323 await self.ensure_connection(connection) 

1324 return connection 

1325 except BaseException: 

1326 await self.release(connection) 

1327 raise 

1328 

1329 async def release(self, connection: AbstractConnection): 

1330 """Releases the connection back to the pool.""" 

1331 async with self._condition: 

1332 await super().release(connection) 

1333 self._condition.notify()