Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

905 statements  

1import copy 

2import os 

3import socket 

4import sys 

5import threading 

6import time 

7import weakref 

8from abc import abstractmethod 

9from itertools import chain 

10from queue import Empty, Full, LifoQueue 

11from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union 

12from urllib.parse import parse_qs, unquote, urlparse 

13 

14from redis.cache import ( 

15 CacheEntry, 

16 CacheEntryStatus, 

17 CacheFactory, 

18 CacheFactoryInterface, 

19 CacheInterface, 

20 CacheKey, 

21) 

22 

23from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser 

24from .auth.token import TokenInterface 

25from .backoff import NoBackoff 

26from .credentials import CredentialProvider, UsernamePasswordCredentialProvider 

27from .event import AfterConnectionReleasedEvent, EventDispatcher 

28from .exceptions import ( 

29 AuthenticationError, 

30 AuthenticationWrongNumberOfArgsError, 

31 ChildDeadlockedError, 

32 ConnectionError, 

33 DataError, 

34 RedisError, 

35 ResponseError, 

36 TimeoutError, 

37) 

38from .retry import Retry 

39from .utils import ( 

40 CRYPTOGRAPHY_AVAILABLE, 

41 HIREDIS_AVAILABLE, 

42 SSL_AVAILABLE, 

43 compare_versions, 

44 deprecated_args, 

45 ensure_string, 

46 format_error_message, 

47 get_lib_version, 

48 str_if_bytes, 

49) 

50 

51if SSL_AVAILABLE: 

52 import ssl 

53else: 

54 ssl = None 

55 

56if HIREDIS_AVAILABLE: 

57 import hiredis 

58 

59SYM_STAR = b"*" 

60SYM_DOLLAR = b"$" 

61SYM_CRLF = b"\r\n" 

62SYM_EMPTY = b"" 

63 

64DEFAULT_RESP_VERSION = 2 

65 

66SENTINEL = object() 

67 

68DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] 

69if HIREDIS_AVAILABLE: 

70 DefaultParser = _HiredisParser 

71else: 

72 DefaultParser = _RESP2Parser 

73 

74 

75class HiredisRespSerializer: 

76 def pack(self, *args: List): 

77 """Pack a series of arguments into the Redis protocol""" 

78 output = [] 

79 

80 if isinstance(args[0], str): 

81 args = tuple(args[0].encode().split()) + args[1:] 

82 elif b" " in args[0]: 

83 args = tuple(args[0].split()) + args[1:] 

84 try: 

85 output.append(hiredis.pack_command(args)) 

86 except TypeError: 

87 _, value, traceback = sys.exc_info() 

88 raise DataError(value).with_traceback(traceback) 

89 

90 return output 

91 

92 

93class PythonRespSerializer: 

94 def __init__(self, buffer_cutoff, encode) -> None: 

95 self._buffer_cutoff = buffer_cutoff 

96 self.encode = encode 

97 

98 def pack(self, *args): 

99 """Pack a series of arguments into the Redis protocol""" 

100 output = [] 

101 # the client might have included 1 or more literal arguments in 

102 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

103 # arguments to be sent separately, so split the first argument 

104 # manually. These arguments should be bytestrings so that they are 

105 # not encoded. 

106 if isinstance(args[0], str): 

107 args = tuple(args[0].encode().split()) + args[1:] 

108 elif b" " in args[0]: 

109 args = tuple(args[0].split()) + args[1:] 

110 

111 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

112 

113 buffer_cutoff = self._buffer_cutoff 

114 for arg in map(self.encode, args): 

115 # to avoid large string mallocs, chunk the command into the 

116 # output list if we're sending large values or memoryviews 

117 arg_length = len(arg) 

118 if ( 

119 len(buff) > buffer_cutoff 

120 or arg_length > buffer_cutoff 

121 or isinstance(arg, memoryview) 

122 ): 

123 buff = SYM_EMPTY.join( 

124 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

125 ) 

126 output.append(buff) 

127 output.append(arg) 

128 buff = SYM_CRLF 

129 else: 

130 buff = SYM_EMPTY.join( 

131 ( 

132 buff, 

133 SYM_DOLLAR, 

134 str(arg_length).encode(), 

135 SYM_CRLF, 

136 arg, 

137 SYM_CRLF, 

138 ) 

139 ) 

140 output.append(buff) 

141 return output 

142 

143 

144class ConnectionInterface: 

145 @abstractmethod 

146 def repr_pieces(self): 

147 pass 

148 

149 @abstractmethod 

150 def register_connect_callback(self, callback): 

151 pass 

152 

153 @abstractmethod 

154 def deregister_connect_callback(self, callback): 

155 pass 

156 

157 @abstractmethod 

158 def set_parser(self, parser_class): 

159 pass 

160 

161 @abstractmethod 

162 def get_protocol(self): 

163 pass 

164 

165 @abstractmethod 

166 def connect(self): 

167 pass 

168 

169 @abstractmethod 

170 def on_connect(self): 

171 pass 

172 

173 @abstractmethod 

174 def disconnect(self, *args): 

175 pass 

176 

177 @abstractmethod 

178 def check_health(self): 

179 pass 

180 

181 @abstractmethod 

182 def send_packed_command(self, command, check_health=True): 

183 pass 

184 

185 @abstractmethod 

186 def send_command(self, *args, **kwargs): 

187 pass 

188 

189 @abstractmethod 

190 def can_read(self, timeout=0): 

191 pass 

192 

193 @abstractmethod 

194 def read_response( 

195 self, 

196 disable_decoding=False, 

197 *, 

198 disconnect_on_error=True, 

199 push_request=False, 

200 ): 

201 pass 

202 

203 @abstractmethod 

204 def pack_command(self, *args): 

205 pass 

206 

207 @abstractmethod 

208 def pack_commands(self, commands): 

209 pass 

210 

211 @property 

212 @abstractmethod 

213 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

214 pass 

215 

216 @abstractmethod 

217 def set_re_auth_token(self, token: TokenInterface): 

218 pass 

219 

220 @abstractmethod 

221 def re_auth(self): 

222 pass 

223 

224 

225class AbstractConnection(ConnectionInterface): 

226 "Manages communication to and from a Redis server" 

227 

228 def __init__( 

229 self, 

230 db: int = 0, 

231 password: Optional[str] = None, 

232 socket_timeout: Optional[float] = None, 

233 socket_connect_timeout: Optional[float] = None, 

234 retry_on_timeout: bool = False, 

235 retry_on_error=SENTINEL, 

236 encoding: str = "utf-8", 

237 encoding_errors: str = "strict", 

238 decode_responses: bool = False, 

239 parser_class=DefaultParser, 

240 socket_read_size: int = 65536, 

241 health_check_interval: int = 0, 

242 client_name: Optional[str] = None, 

243 lib_name: Optional[str] = "redis-py", 

244 lib_version: Optional[str] = get_lib_version(), 

245 username: Optional[str] = None, 

246 retry: Union[Any, None] = None, 

247 redis_connect_func: Optional[Callable[[], None]] = None, 

248 credential_provider: Optional[CredentialProvider] = None, 

249 protocol: Optional[int] = 2, 

250 command_packer: Optional[Callable[[], None]] = None, 

251 event_dispatcher: Optional[EventDispatcher] = None, 

252 ): 

253 """ 

254 Initialize a new Connection. 

255 To specify a retry policy for specific errors, first set 

256 `retry_on_error` to a list of the error/s to retry on, then set 

257 `retry` to a valid `Retry` object. 

258 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

259 """ 

260 if (username or password) and credential_provider is not None: 

261 raise DataError( 

262 "'username' and 'password' cannot be passed along with 'credential_" 

263 "provider'. Please provide only one of the following arguments: \n" 

264 "1. 'password' and (optional) 'username'\n" 

265 "2. 'credential_provider'" 

266 ) 

267 if event_dispatcher is None: 

268 self._event_dispatcher = EventDispatcher() 

269 else: 

270 self._event_dispatcher = event_dispatcher 

271 self.pid = os.getpid() 

272 self.db = db 

273 self.client_name = client_name 

274 self.lib_name = lib_name 

275 self.lib_version = lib_version 

276 self.credential_provider = credential_provider 

277 self.password = password 

278 self.username = username 

279 self.socket_timeout = socket_timeout 

280 if socket_connect_timeout is None: 

281 socket_connect_timeout = socket_timeout 

282 self.socket_connect_timeout = socket_connect_timeout 

283 self.retry_on_timeout = retry_on_timeout 

284 if retry_on_error is SENTINEL: 

285 retry_on_error = [] 

286 if retry_on_timeout: 

287 # Add TimeoutError to the errors list to retry on 

288 retry_on_error.append(TimeoutError) 

289 self.retry_on_error = retry_on_error 

290 if retry or retry_on_error: 

291 if retry is None: 

292 self.retry = Retry(NoBackoff(), 1) 

293 else: 

294 # deep-copy the Retry object as it is mutable 

295 self.retry = copy.deepcopy(retry) 

296 # Update the retry's supported errors with the specified errors 

297 self.retry.update_supported_errors(retry_on_error) 

298 else: 

299 self.retry = Retry(NoBackoff(), 0) 

300 self.health_check_interval = health_check_interval 

301 self.next_health_check = 0 

302 self.redis_connect_func = redis_connect_func 

303 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

304 self.handshake_metadata = None 

305 self._sock = None 

306 self._socket_read_size = socket_read_size 

307 self.set_parser(parser_class) 

308 self._connect_callbacks = [] 

309 self._buffer_cutoff = 6000 

310 self._re_auth_token: Optional[TokenInterface] = None 

311 try: 

312 p = int(protocol) 

313 except TypeError: 

314 p = DEFAULT_RESP_VERSION 

315 except ValueError: 

316 raise ConnectionError("protocol must be an integer") 

317 finally: 

318 if p < 2 or p > 3: 

319 raise ConnectionError("protocol must be either 2 or 3") 

320 # p = DEFAULT_RESP_VERSION 

321 self.protocol = p 

322 self._command_packer = self._construct_command_packer(command_packer) 

323 

324 def __repr__(self): 

325 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) 

326 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

327 

328 @abstractmethod 

329 def repr_pieces(self): 

330 pass 

331 

332 def __del__(self): 

333 try: 

334 self.disconnect() 

335 except Exception: 

336 pass 

337 

338 def _construct_command_packer(self, packer): 

339 if packer is not None: 

340 return packer 

341 elif HIREDIS_AVAILABLE: 

342 return HiredisRespSerializer() 

343 else: 

344 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) 

345 

346 def register_connect_callback(self, callback): 

347 """ 

348 Register a callback to be called when the connection is established either 

349 initially or reconnected. This allows listeners to issue commands that 

350 are ephemeral to the connection, for example pub/sub subscription or 

351 key tracking. The callback must be a _method_ and will be kept as 

352 a weak reference. 

353 """ 

354 wm = weakref.WeakMethod(callback) 

355 if wm not in self._connect_callbacks: 

356 self._connect_callbacks.append(wm) 

357 

358 def deregister_connect_callback(self, callback): 

359 """ 

360 De-register a previously registered callback. It will no-longer receive 

361 notifications on connection events. Calling this is not required when the 

362 listener goes away, since the callbacks are kept as weak methods. 

363 """ 

364 try: 

365 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

366 except ValueError: 

367 pass 

368 

369 def set_parser(self, parser_class): 

370 """ 

371 Creates a new instance of parser_class with socket size: 

372 _socket_read_size and assigns it to the parser for the connection 

373 :param parser_class: The required parser class 

374 """ 

375 self._parser = parser_class(socket_read_size=self._socket_read_size) 

376 

377 def connect(self): 

378 "Connects to the Redis server if not already connected" 

379 self.connect_check_health(check_health=True) 

380 

381 def connect_check_health(self, check_health: bool = True): 

382 if self._sock: 

383 return 

384 try: 

385 sock = self.retry.call_with_retry( 

386 lambda: self._connect(), lambda error: self.disconnect(error) 

387 ) 

388 except socket.timeout: 

389 raise TimeoutError("Timeout connecting to server") 

390 except OSError as e: 

391 raise ConnectionError(self._error_message(e)) 

392 

393 self._sock = sock 

394 try: 

395 if self.redis_connect_func is None: 

396 # Use the default on_connect function 

397 self.on_connect_check_health(check_health=check_health) 

398 else: 

399 # Use the passed function redis_connect_func 

400 self.redis_connect_func(self) 

401 except RedisError: 

402 # clean up after any error in on_connect 

403 self.disconnect() 

404 raise 

405 

406 # run any user callbacks. right now the only internal callback 

407 # is for pubsub channel/pattern resubscription 

408 # first, remove any dead weakrefs 

409 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

410 for ref in self._connect_callbacks: 

411 callback = ref() 

412 if callback: 

413 callback(self) 

414 

415 @abstractmethod 

416 def _connect(self): 

417 pass 

418 

419 @abstractmethod 

420 def _host_error(self): 

421 pass 

422 

423 def _error_message(self, exception): 

424 return format_error_message(self._host_error(), exception) 

425 

426 def on_connect(self): 

427 self.on_connect_check_health(check_health=True) 

428 

429 def on_connect_check_health(self, check_health: bool = True): 

430 "Initialize the connection, authenticate and select a database" 

431 self._parser.on_connect(self) 

432 parser = self._parser 

433 

434 auth_args = None 

435 # if credential provider or username and/or password are set, authenticate 

436 if self.credential_provider or (self.username or self.password): 

437 cred_provider = ( 

438 self.credential_provider 

439 or UsernamePasswordCredentialProvider(self.username, self.password) 

440 ) 

441 auth_args = cred_provider.get_credentials() 

442 

443 # if resp version is specified and we have auth args, 

444 # we need to send them via HELLO 

445 if auth_args and self.protocol not in [2, "2"]: 

446 if isinstance(self._parser, _RESP2Parser): 

447 self.set_parser(_RESP3Parser) 

448 # update cluster exception classes 

449 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

450 self._parser.on_connect(self) 

451 if len(auth_args) == 1: 

452 auth_args = ["default", auth_args[0]] 

453 # avoid checking health here -- PING will fail if we try 

454 # to check the health prior to the AUTH 

455 self.send_command( 

456 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

457 ) 

458 self.handshake_metadata = self.read_response() 

459 # if response.get(b"proto") != self.protocol and response.get( 

460 # "proto" 

461 # ) != self.protocol: 

462 # raise ConnectionError("Invalid RESP version") 

463 elif auth_args: 

464 # avoid checking health here -- PING will fail if we try 

465 # to check the health prior to the AUTH 

466 self.send_command("AUTH", *auth_args, check_health=False) 

467 

468 try: 

469 auth_response = self.read_response() 

470 except AuthenticationWrongNumberOfArgsError: 

471 # a username and password were specified but the Redis 

472 # server seems to be < 6.0.0 which expects a single password 

473 # arg. retry auth with just the password. 

474 # https://github.com/andymccurdy/redis-py/issues/1274 

475 self.send_command("AUTH", auth_args[-1], check_health=False) 

476 auth_response = self.read_response() 

477 

478 if str_if_bytes(auth_response) != "OK": 

479 raise AuthenticationError("Invalid Username or Password") 

480 

481 # if resp version is specified, switch to it 

482 elif self.protocol not in [2, "2"]: 

483 if isinstance(self._parser, _RESP2Parser): 

484 self.set_parser(_RESP3Parser) 

485 # update cluster exception classes 

486 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

487 self._parser.on_connect(self) 

488 self.send_command("HELLO", self.protocol, check_health=check_health) 

489 self.handshake_metadata = self.read_response() 

490 if ( 

491 self.handshake_metadata.get(b"proto") != self.protocol 

492 and self.handshake_metadata.get("proto") != self.protocol 

493 ): 

494 raise ConnectionError("Invalid RESP version") 

495 

496 # if a client_name is given, set it 

497 if self.client_name: 

498 self.send_command( 

499 "CLIENT", 

500 "SETNAME", 

501 self.client_name, 

502 check_health=check_health, 

503 ) 

504 if str_if_bytes(self.read_response()) != "OK": 

505 raise ConnectionError("Error setting client name") 

506 

507 try: 

508 # set the library name and version 

509 if self.lib_name: 

510 self.send_command( 

511 "CLIENT", 

512 "SETINFO", 

513 "LIB-NAME", 

514 self.lib_name, 

515 check_health=check_health, 

516 ) 

517 self.read_response() 

518 except ResponseError: 

519 pass 

520 

521 try: 

522 if self.lib_version: 

523 self.send_command( 

524 "CLIENT", 

525 "SETINFO", 

526 "LIB-VER", 

527 self.lib_version, 

528 check_health=check_health, 

529 ) 

530 self.read_response() 

531 except ResponseError: 

532 pass 

533 

534 # if a database is specified, switch to it 

535 if self.db: 

536 self.send_command("SELECT", self.db, check_health=check_health) 

537 if str_if_bytes(self.read_response()) != "OK": 

538 raise ConnectionError("Invalid Database") 

539 

540 def disconnect(self, *args): 

541 "Disconnects from the Redis server" 

542 self._parser.on_disconnect() 

543 

544 conn_sock = self._sock 

545 self._sock = None 

546 if conn_sock is None: 

547 return 

548 

549 if os.getpid() == self.pid: 

550 try: 

551 conn_sock.shutdown(socket.SHUT_RDWR) 

552 except (OSError, TypeError): 

553 pass 

554 

555 try: 

556 conn_sock.close() 

557 except OSError: 

558 pass 

559 

560 def _send_ping(self): 

561 """Send PING, expect PONG in return""" 

562 self.send_command("PING", check_health=False) 

563 if str_if_bytes(self.read_response()) != "PONG": 

564 raise ConnectionError("Bad response from PING health check") 

565 

566 def _ping_failed(self, error): 

567 """Function to call when PING fails""" 

568 self.disconnect() 

569 

570 def check_health(self): 

571 """Check the health of the connection with a PING/PONG""" 

572 if self.health_check_interval and time.monotonic() > self.next_health_check: 

573 self.retry.call_with_retry(self._send_ping, self._ping_failed) 

574 

575 def send_packed_command(self, command, check_health=True): 

576 """Send an already packed command to the Redis server""" 

577 if not self._sock: 

578 self.connect_check_health(check_health=False) 

579 # guard against health check recursion 

580 if check_health: 

581 self.check_health() 

582 try: 

583 if isinstance(command, str): 

584 command = [command] 

585 for item in command: 

586 self._sock.sendall(item) 

587 except socket.timeout: 

588 self.disconnect() 

589 raise TimeoutError("Timeout writing to socket") 

590 except OSError as e: 

591 self.disconnect() 

592 if len(e.args) == 1: 

593 errno, errmsg = "UNKNOWN", e.args[0] 

594 else: 

595 errno = e.args[0] 

596 errmsg = e.args[1] 

597 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") 

598 except BaseException: 

599 # BaseExceptions can be raised when a socket send operation is not 

600 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

601 # to send un-sent data. However, the send_packed_command() API 

602 # does not support it so there is no point in keeping the connection open. 

603 self.disconnect() 

604 raise 

605 

606 def send_command(self, *args, **kwargs): 

607 """Pack and send a command to the Redis server""" 

608 self.send_packed_command( 

609 self._command_packer.pack(*args), 

610 check_health=kwargs.get("check_health", True), 

611 ) 

612 

613 def can_read(self, timeout=0): 

614 """Poll the socket to see if there's data that can be read.""" 

615 sock = self._sock 

616 if not sock: 

617 self.connect() 

618 

619 host_error = self._host_error() 

620 

621 try: 

622 return self._parser.can_read(timeout) 

623 except OSError as e: 

624 self.disconnect() 

625 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

626 

627 def read_response( 

628 self, 

629 disable_decoding=False, 

630 *, 

631 disconnect_on_error=True, 

632 push_request=False, 

633 ): 

634 """Read the response from a previously sent command""" 

635 

636 host_error = self._host_error() 

637 

638 try: 

639 if self.protocol in ["3", 3]: 

640 response = self._parser.read_response( 

641 disable_decoding=disable_decoding, push_request=push_request 

642 ) 

643 else: 

644 response = self._parser.read_response(disable_decoding=disable_decoding) 

645 except socket.timeout: 

646 if disconnect_on_error: 

647 self.disconnect() 

648 raise TimeoutError(f"Timeout reading from {host_error}") 

649 except OSError as e: 

650 if disconnect_on_error: 

651 self.disconnect() 

652 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

653 except BaseException: 

654 # Also by default close in case of BaseException. A lot of code 

655 # relies on this behaviour when doing Command/Response pairs. 

656 # See #1128. 

657 if disconnect_on_error: 

658 self.disconnect() 

659 raise 

660 

661 if self.health_check_interval: 

662 self.next_health_check = time.monotonic() + self.health_check_interval 

663 

664 if isinstance(response, ResponseError): 

665 try: 

666 raise response 

667 finally: 

668 del response # avoid creating ref cycles 

669 return response 

670 

671 def pack_command(self, *args): 

672 """Pack a series of arguments into the Redis protocol""" 

673 return self._command_packer.pack(*args) 

674 

675 def pack_commands(self, commands): 

676 """Pack multiple commands into the Redis protocol""" 

677 output = [] 

678 pieces = [] 

679 buffer_length = 0 

680 buffer_cutoff = self._buffer_cutoff 

681 

682 for cmd in commands: 

683 for chunk in self._command_packer.pack(*cmd): 

684 chunklen = len(chunk) 

685 if ( 

686 buffer_length > buffer_cutoff 

687 or chunklen > buffer_cutoff 

688 or isinstance(chunk, memoryview) 

689 ): 

690 if pieces: 

691 output.append(SYM_EMPTY.join(pieces)) 

692 buffer_length = 0 

693 pieces = [] 

694 

695 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

696 output.append(chunk) 

697 else: 

698 pieces.append(chunk) 

699 buffer_length += chunklen 

700 

701 if pieces: 

702 output.append(SYM_EMPTY.join(pieces)) 

703 return output 

704 

705 def get_protocol(self) -> Union[int, str]: 

706 return self.protocol 

707 

708 @property 

709 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

710 return self._handshake_metadata 

711 

712 @handshake_metadata.setter 

713 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): 

714 self._handshake_metadata = value 

715 

716 def set_re_auth_token(self, token: TokenInterface): 

717 self._re_auth_token = token 

718 

719 def re_auth(self): 

720 if self._re_auth_token is not None: 

721 self.send_command( 

722 "AUTH", 

723 self._re_auth_token.try_get("oid"), 

724 self._re_auth_token.get_value(), 

725 ) 

726 self.read_response() 

727 self._re_auth_token = None 

728 

729 

730class Connection(AbstractConnection): 

731 "Manages TCP communication to and from a Redis server" 

732 

733 def __init__( 

734 self, 

735 host="localhost", 

736 port=6379, 

737 socket_keepalive=False, 

738 socket_keepalive_options=None, 

739 socket_type=0, 

740 **kwargs, 

741 ): 

742 self.host = host 

743 self.port = int(port) 

744 self.socket_keepalive = socket_keepalive 

745 self.socket_keepalive_options = socket_keepalive_options or {} 

746 self.socket_type = socket_type 

747 super().__init__(**kwargs) 

748 

749 def repr_pieces(self): 

750 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

751 if self.client_name: 

752 pieces.append(("client_name", self.client_name)) 

753 return pieces 

754 

755 def _connect(self): 

756 "Create a TCP socket connection" 

757 # we want to mimic what socket.create_connection does to support 

758 # ipv4/ipv6, but we want to set options prior to calling 

759 # socket.connect() 

760 err = None 

761 for res in socket.getaddrinfo( 

762 self.host, self.port, self.socket_type, socket.SOCK_STREAM 

763 ): 

764 family, socktype, proto, canonname, socket_address = res 

765 sock = None 

766 try: 

767 sock = socket.socket(family, socktype, proto) 

768 # TCP_NODELAY 

769 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

770 

771 # TCP_KEEPALIVE 

772 if self.socket_keepalive: 

773 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

774 for k, v in self.socket_keepalive_options.items(): 

775 sock.setsockopt(socket.IPPROTO_TCP, k, v) 

776 

777 # set the socket_connect_timeout before we connect 

778 sock.settimeout(self.socket_connect_timeout) 

779 

780 # connect 

781 sock.connect(socket_address) 

782 

783 # set the socket_timeout now that we're connected 

784 sock.settimeout(self.socket_timeout) 

785 return sock 

786 

787 except OSError as _: 

788 err = _ 

789 if sock is not None: 

790 try: 

791 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

792 except OSError: 

793 pass 

794 sock.close() 

795 

796 if err is not None: 

797 raise err 

798 raise OSError("socket.getaddrinfo returned an empty list") 

799 

800 def _host_error(self): 

801 return f"{self.host}:{self.port}" 

802 

803 

804class CacheProxyConnection(ConnectionInterface): 

805 DUMMY_CACHE_VALUE = b"foo" 

806 MIN_ALLOWED_VERSION = "7.4.0" 

807 DEFAULT_SERVER_NAME = "redis" 

808 

809 def __init__( 

810 self, 

811 conn: ConnectionInterface, 

812 cache: CacheInterface, 

813 pool_lock: threading.RLock, 

814 ): 

815 self.pid = os.getpid() 

816 self._conn = conn 

817 self.retry = self._conn.retry 

818 self.host = self._conn.host 

819 self.port = self._conn.port 

820 self.credential_provider = conn.credential_provider 

821 self._pool_lock = pool_lock 

822 self._cache = cache 

823 self._cache_lock = threading.RLock() 

824 self._current_command_cache_key = None 

825 self._current_options = None 

826 self.register_connect_callback(self._enable_tracking_callback) 

827 

828 def repr_pieces(self): 

829 return self._conn.repr_pieces() 

830 

831 def register_connect_callback(self, callback): 

832 self._conn.register_connect_callback(callback) 

833 

834 def deregister_connect_callback(self, callback): 

835 self._conn.deregister_connect_callback(callback) 

836 

837 def set_parser(self, parser_class): 

838 self._conn.set_parser(parser_class) 

839 

840 def connect(self): 

841 self._conn.connect() 

842 

843 server_name = self._conn.handshake_metadata.get(b"server", None) 

844 if server_name is None: 

845 server_name = self._conn.handshake_metadata.get("server", None) 

846 server_ver = self._conn.handshake_metadata.get(b"version", None) 

847 if server_ver is None: 

848 server_ver = self._conn.handshake_metadata.get("version", None) 

849 if server_ver is None or server_ver is None: 

850 raise ConnectionError("Cannot retrieve information about server version") 

851 

852 server_ver = ensure_string(server_ver) 

853 server_name = ensure_string(server_name) 

854 

855 if ( 

856 server_name != self.DEFAULT_SERVER_NAME 

857 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 

858 ): 

859 raise ConnectionError( 

860 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 

861 ) 

862 

863 def on_connect(self): 

864 self._conn.on_connect() 

865 

866 def disconnect(self, *args): 

867 with self._cache_lock: 

868 self._cache.flush() 

869 self._conn.disconnect(*args) 

870 

871 def check_health(self): 

872 self._conn.check_health() 

873 

874 def send_packed_command(self, command, check_health=True): 

875 # TODO: Investigate if it's possible to unpack command 

876 # or extract keys from packed command 

877 self._conn.send_packed_command(command) 

878 

879 def send_command(self, *args, **kwargs): 

880 self._process_pending_invalidations() 

881 

882 with self._cache_lock: 

883 # Command is write command or not allowed 

884 # to be cached. 

885 if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())): 

886 self._current_command_cache_key = None 

887 self._conn.send_command(*args, **kwargs) 

888 return 

889 

890 if kwargs.get("keys") is None: 

891 raise ValueError("Cannot create cache key.") 

892 

893 # Creates cache key. 

894 self._current_command_cache_key = CacheKey( 

895 command=args[0], redis_keys=tuple(kwargs.get("keys")) 

896 ) 

897 

898 with self._cache_lock: 

899 # We have to trigger invalidation processing in case if 

900 # it was cached by another connection to avoid 

901 # queueing invalidations in stale connections. 

902 if self._cache.get(self._current_command_cache_key): 

903 entry = self._cache.get(self._current_command_cache_key) 

904 

905 if entry.connection_ref != self._conn: 

906 with self._pool_lock: 

907 while entry.connection_ref.can_read(): 

908 entry.connection_ref.read_response(push_request=True) 

909 

910 return 

911 

912 # Set temporary entry value to prevent 

913 # race condition from another connection. 

914 self._cache.set( 

915 CacheEntry( 

916 cache_key=self._current_command_cache_key, 

917 cache_value=self.DUMMY_CACHE_VALUE, 

918 status=CacheEntryStatus.IN_PROGRESS, 

919 connection_ref=self._conn, 

920 ) 

921 ) 

922 

923 # Send command over socket only if it's allowed 

924 # read-only command that not yet cached. 

925 self._conn.send_command(*args, **kwargs) 

926 

927 def can_read(self, timeout=0): 

928 return self._conn.can_read(timeout) 

929 

930 def read_response( 

931 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False 

932 ): 

933 with self._cache_lock: 

934 # Check if command response exists in a cache and it's not in progress. 

935 if ( 

936 self._current_command_cache_key is not None 

937 and self._cache.get(self._current_command_cache_key) is not None 

938 and self._cache.get(self._current_command_cache_key).status 

939 != CacheEntryStatus.IN_PROGRESS 

940 ): 

941 res = copy.deepcopy( 

942 self._cache.get(self._current_command_cache_key).cache_value 

943 ) 

944 self._current_command_cache_key = None 

945 return res 

946 

947 response = self._conn.read_response( 

948 disable_decoding=disable_decoding, 

949 disconnect_on_error=disconnect_on_error, 

950 push_request=push_request, 

951 ) 

952 

953 with self._cache_lock: 

954 # Prevent not-allowed command from caching. 

955 if self._current_command_cache_key is None: 

956 return response 

957 # If response is None prevent from caching. 

958 if response is None: 

959 self._cache.delete_by_cache_keys([self._current_command_cache_key]) 

960 return response 

961 

962 cache_entry = self._cache.get(self._current_command_cache_key) 

963 

964 # Cache only responses that still valid 

965 # and wasn't invalidated by another connection in meantime. 

966 if cache_entry is not None: 

967 cache_entry.status = CacheEntryStatus.VALID 

968 cache_entry.cache_value = response 

969 self._cache.set(cache_entry) 

970 

971 self._current_command_cache_key = None 

972 

973 return response 

974 

975 def pack_command(self, *args): 

976 return self._conn.pack_command(*args) 

977 

978 def pack_commands(self, commands): 

979 return self._conn.pack_commands(commands) 

980 

981 @property 

982 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

983 return self._conn.handshake_metadata 

984 

985 def _connect(self): 

986 self._conn._connect() 

987 

988 def _host_error(self): 

989 self._conn._host_error() 

990 

991 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: 

992 conn.send_command("CLIENT", "TRACKING", "ON") 

993 conn.read_response() 

994 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) 

995 

996 def _process_pending_invalidations(self): 

997 while self.can_read(): 

998 self._conn.read_response(push_request=True) 

999 

1000 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): 

1001 with self._cache_lock: 

1002 # Flush cache when DB flushed on server-side 

1003 if data[1] is None: 

1004 self._cache.flush() 

1005 else: 

1006 self._cache.delete_by_redis_keys(data[1]) 

1007 

1008 def get_protocol(self): 

1009 return self._conn.get_protocol() 

1010 

1011 def set_re_auth_token(self, token: TokenInterface): 

1012 self._conn.set_re_auth_token(token) 

1013 

1014 def re_auth(self): 

1015 self._conn.re_auth() 

1016 

1017 

1018class SSLConnection(Connection): 

1019 """Manages SSL connections to and from the Redis server(s). 

1020 This class extends the Connection class, adding SSL functionality, and making 

1021 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

1022 """ # noqa 

1023 

1024 def __init__( 

1025 self, 

1026 ssl_keyfile=None, 

1027 ssl_certfile=None, 

1028 ssl_cert_reqs="required", 

1029 ssl_ca_certs=None, 

1030 ssl_ca_data=None, 

1031 ssl_check_hostname=True, 

1032 ssl_ca_path=None, 

1033 ssl_password=None, 

1034 ssl_validate_ocsp=False, 

1035 ssl_validate_ocsp_stapled=False, 

1036 ssl_ocsp_context=None, 

1037 ssl_ocsp_expected_cert=None, 

1038 ssl_min_version=None, 

1039 ssl_ciphers=None, 

1040 **kwargs, 

1041 ): 

1042 """Constructor 

1043 

1044 Args: 

1045 ssl_keyfile: Path to an ssl private key. Defaults to None. 

1046 ssl_certfile: Path to an ssl certificate. Defaults to None. 

1047 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required". 

1048 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. 

1049 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. 

1050 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False. 

1051 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. 

1052 ssl_password: Password for unlocking an encrypted private key. Defaults to None. 

1053 

1054 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) 

1055 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response 

1056 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert 

1057 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. 

1058 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. 

1059 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information. 

1060 

1061 Raises: 

1062 RedisError 

1063 """ # noqa 

1064 if not SSL_AVAILABLE: 

1065 raise RedisError("Python wasn't built with SSL support") 

1066 

1067 self.keyfile = ssl_keyfile 

1068 self.certfile = ssl_certfile 

1069 if ssl_cert_reqs is None: 

1070 ssl_cert_reqs = ssl.CERT_NONE 

1071 elif isinstance(ssl_cert_reqs, str): 

1072 CERT_REQS = { # noqa: N806 

1073 "none": ssl.CERT_NONE, 

1074 "optional": ssl.CERT_OPTIONAL, 

1075 "required": ssl.CERT_REQUIRED, 

1076 } 

1077 if ssl_cert_reqs not in CERT_REQS: 

1078 raise RedisError( 

1079 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" 

1080 ) 

1081 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 

1082 self.cert_reqs = ssl_cert_reqs 

1083 self.ca_certs = ssl_ca_certs 

1084 self.ca_data = ssl_ca_data 

1085 self.ca_path = ssl_ca_path 

1086 self.check_hostname = ( 

1087 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1088 ) 

1089 self.certificate_password = ssl_password 

1090 self.ssl_validate_ocsp = ssl_validate_ocsp 

1091 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled 

1092 self.ssl_ocsp_context = ssl_ocsp_context 

1093 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert 

1094 self.ssl_min_version = ssl_min_version 

1095 self.ssl_ciphers = ssl_ciphers 

1096 super().__init__(**kwargs) 

1097 

1098 def _connect(self): 

1099 """ 

1100 Wrap the socket with SSL support, handling potential errors. 

1101 """ 

1102 sock = super()._connect() 

1103 try: 

1104 return self._wrap_socket_with_ssl(sock) 

1105 except (OSError, RedisError): 

1106 sock.close() 

1107 raise 

1108 

1109 def _wrap_socket_with_ssl(self, sock): 

1110 """ 

1111 Wraps the socket with SSL support. 

1112 

1113 Args: 

1114 sock: The plain socket to wrap with SSL. 

1115 

1116 Returns: 

1117 An SSL wrapped socket. 

1118 """ 

1119 context = ssl.create_default_context() 

1120 context.check_hostname = self.check_hostname 

1121 context.verify_mode = self.cert_reqs 

1122 if self.certfile or self.keyfile: 

1123 context.load_cert_chain( 

1124 certfile=self.certfile, 

1125 keyfile=self.keyfile, 

1126 password=self.certificate_password, 

1127 ) 

1128 if ( 

1129 self.ca_certs is not None 

1130 or self.ca_path is not None 

1131 or self.ca_data is not None 

1132 ): 

1133 context.load_verify_locations( 

1134 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1135 ) 

1136 if self.ssl_min_version is not None: 

1137 context.minimum_version = self.ssl_min_version 

1138 if self.ssl_ciphers: 

1139 context.set_ciphers(self.ssl_ciphers) 

1140 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: 

1141 raise RedisError("cryptography is not installed.") 

1142 

1143 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: 

1144 raise RedisError( 

1145 "Either an OCSP staple or pure OCSP connection must be validated " 

1146 "- not both." 

1147 ) 

1148 

1149 sslsock = context.wrap_socket(sock, server_hostname=self.host) 

1150 

1151 # validation for the stapled case 

1152 if self.ssl_validate_ocsp_stapled: 

1153 import OpenSSL 

1154 

1155 from .ocsp import ocsp_staple_verifier 

1156 

1157 # if a context is provided use it - otherwise, a basic context 

1158 if self.ssl_ocsp_context is None: 

1159 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) 

1160 staple_ctx.use_certificate_file(self.certfile) 

1161 staple_ctx.use_privatekey_file(self.keyfile) 

1162 else: 

1163 staple_ctx = self.ssl_ocsp_context 

1164 

1165 staple_ctx.set_ocsp_client_callback( 

1166 ocsp_staple_verifier, self.ssl_ocsp_expected_cert 

1167 ) 

1168 

1169 # need another socket 

1170 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) 

1171 con.request_ocsp() 

1172 con.connect((self.host, self.port)) 

1173 con.do_handshake() 

1174 con.shutdown() 

1175 return sslsock 

1176 

1177 # pure ocsp validation 

1178 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: 

1179 from .ocsp import OCSPVerifier 

1180 

1181 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) 

1182 if o.is_valid(): 

1183 return sslsock 

1184 else: 

1185 raise ConnectionError("ocsp validation error") 

1186 return sslsock 

1187 

1188 

1189class UnixDomainSocketConnection(AbstractConnection): 

1190 "Manages UDS communication to and from a Redis server" 

1191 

1192 def __init__(self, path="", socket_timeout=None, **kwargs): 

1193 super().__init__(**kwargs) 

1194 self.path = path 

1195 self.socket_timeout = socket_timeout 

1196 

1197 def repr_pieces(self): 

1198 pieces = [("path", self.path), ("db", self.db)] 

1199 if self.client_name: 

1200 pieces.append(("client_name", self.client_name)) 

1201 return pieces 

1202 

1203 def _connect(self): 

1204 "Create a Unix domain socket connection" 

1205 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

1206 sock.settimeout(self.socket_connect_timeout) 

1207 try: 

1208 sock.connect(self.path) 

1209 except OSError: 

1210 # Prevent ResourceWarnings for unclosed sockets. 

1211 try: 

1212 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1213 except OSError: 

1214 pass 

1215 sock.close() 

1216 raise 

1217 sock.settimeout(self.socket_timeout) 

1218 return sock 

1219 

1220 def _host_error(self): 

1221 return self.path 

1222 

1223 

1224FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1225 

1226 

1227def to_bool(value): 

1228 if value is None or value == "": 

1229 return None 

1230 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1231 return False 

1232 return bool(value) 

1233 

1234 

1235URL_QUERY_ARGUMENT_PARSERS = { 

1236 "db": int, 

1237 "socket_timeout": float, 

1238 "socket_connect_timeout": float, 

1239 "socket_keepalive": to_bool, 

1240 "retry_on_timeout": to_bool, 

1241 "retry_on_error": list, 

1242 "max_connections": int, 

1243 "health_check_interval": int, 

1244 "ssl_check_hostname": to_bool, 

1245 "timeout": float, 

1246} 

1247 

1248 

1249def parse_url(url): 

1250 if not ( 

1251 url.startswith("redis://") 

1252 or url.startswith("rediss://") 

1253 or url.startswith("unix://") 

1254 ): 

1255 raise ValueError( 

1256 "Redis URL must specify one of the following " 

1257 "schemes (redis://, rediss://, unix://)" 

1258 ) 

1259 

1260 url = urlparse(url) 

1261 kwargs = {} 

1262 

1263 for name, value in parse_qs(url.query).items(): 

1264 if value and len(value) > 0: 

1265 value = unquote(value[0]) 

1266 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1267 if parser: 

1268 try: 

1269 kwargs[name] = parser(value) 

1270 except (TypeError, ValueError): 

1271 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1272 else: 

1273 kwargs[name] = value 

1274 

1275 if url.username: 

1276 kwargs["username"] = unquote(url.username) 

1277 if url.password: 

1278 kwargs["password"] = unquote(url.password) 

1279 

1280 # We only support redis://, rediss:// and unix:// schemes. 

1281 if url.scheme == "unix": 

1282 if url.path: 

1283 kwargs["path"] = unquote(url.path) 

1284 kwargs["connection_class"] = UnixDomainSocketConnection 

1285 

1286 else: # implied: url.scheme in ("redis", "rediss"): 

1287 if url.hostname: 

1288 kwargs["host"] = unquote(url.hostname) 

1289 if url.port: 

1290 kwargs["port"] = int(url.port) 

1291 

1292 # If there's a path argument, use it as the db argument if a 

1293 # querystring value wasn't specified 

1294 if url.path and "db" not in kwargs: 

1295 try: 

1296 kwargs["db"] = int(unquote(url.path).replace("/", "")) 

1297 except (AttributeError, ValueError): 

1298 pass 

1299 

1300 if url.scheme == "rediss": 

1301 kwargs["connection_class"] = SSLConnection 

1302 

1303 return kwargs 

1304 

1305 

1306_CP = TypeVar("_CP", bound="ConnectionPool") 

1307 

1308 

1309class ConnectionPool: 

1310 """ 

1311 Create a connection pool. ``If max_connections`` is set, then this 

1312 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's 

1313 limit is reached. 

1314 

1315 By default, TCP connections are created unless ``connection_class`` 

1316 is specified. Use class:`.UnixDomainSocketConnection` for 

1317 unix sockets. 

1318 

1319 Any additional keyword arguments are passed to the constructor of 

1320 ``connection_class``. 

1321 """ 

1322 

1323 @classmethod 

1324 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

1325 """ 

1326 Return a connection pool configured from the given URL. 

1327 

1328 For example:: 

1329 

1330 redis://[[username]:[password]]@localhost:6379/0 

1331 rediss://[[username]:[password]]@localhost:6379/0 

1332 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1333 

1334 Three URL schemes are supported: 

1335 

1336 - `redis://` creates a TCP socket connection. See more at: 

1337 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1338 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1339 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1340 - ``unix://``: creates a Unix Domain Socket connection. 

1341 

1342 The username, password, hostname, path and all querystring values 

1343 are passed through urllib.parse.unquote in order to replace any 

1344 percent-encoded values with their corresponding characters. 

1345 

1346 There are several ways to specify a database number. The first value 

1347 found will be used: 

1348 

1349 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1350 2. If using the redis:// or rediss:// schemes, the path argument 

1351 of the url, e.g. redis://localhost/0 

1352 3. A ``db`` keyword argument to this function. 

1353 

1354 If none of these options are specified, the default db=0 is used. 

1355 

1356 All querystring options are cast to their appropriate Python types. 

1357 Boolean arguments can be specified with string values "True"/"False" 

1358 or "Yes"/"No". Values that cannot be properly cast cause a 

1359 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1360 and keyword arguments are passed to the ``ConnectionPool``'s 

1361 class initializer. In the case of conflicting arguments, querystring 

1362 arguments always win. 

1363 """ 

1364 url_options = parse_url(url) 

1365 

1366 if "connection_class" in kwargs: 

1367 url_options["connection_class"] = kwargs["connection_class"] 

1368 

1369 kwargs.update(url_options) 

1370 return cls(**kwargs) 

1371 

1372 def __init__( 

1373 self, 

1374 connection_class=Connection, 

1375 max_connections: Optional[int] = None, 

1376 cache_factory: Optional[CacheFactoryInterface] = None, 

1377 **connection_kwargs, 

1378 ): 

1379 max_connections = max_connections or 2**31 

1380 if not isinstance(max_connections, int) or max_connections < 0: 

1381 raise ValueError('"max_connections" must be a positive integer') 

1382 

1383 self.connection_class = connection_class 

1384 self.connection_kwargs = connection_kwargs 

1385 self.max_connections = max_connections 

1386 self.cache = None 

1387 self._cache_factory = cache_factory 

1388 

1389 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): 

1390 if connection_kwargs.get("protocol") not in [3, "3"]: 

1391 raise RedisError("Client caching is only supported with RESP version 3") 

1392 

1393 cache = self.connection_kwargs.get("cache") 

1394 

1395 if cache is not None: 

1396 if not isinstance(cache, CacheInterface): 

1397 raise ValueError("Cache must implement CacheInterface") 

1398 

1399 self.cache = cache 

1400 else: 

1401 if self._cache_factory is not None: 

1402 self.cache = self._cache_factory.get_cache() 

1403 else: 

1404 self.cache = CacheFactory( 

1405 self.connection_kwargs.get("cache_config") 

1406 ).get_cache() 

1407 

1408 connection_kwargs.pop("cache", None) 

1409 connection_kwargs.pop("cache_config", None) 

1410 

1411 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1412 if self._event_dispatcher is None: 

1413 self._event_dispatcher = EventDispatcher() 

1414 

1415 # a lock to protect the critical section in _checkpid(). 

1416 # this lock is acquired when the process id changes, such as 

1417 # after a fork. during this time, multiple threads in the child 

1418 # process could attempt to acquire this lock. the first thread 

1419 # to acquire the lock will reset the data structures and lock 

1420 # object of this pool. subsequent threads acquiring this lock 

1421 # will notice the first thread already did the work and simply 

1422 # release the lock. 

1423 

1424 self._fork_lock = threading.RLock() 

1425 self._lock = threading.RLock() 

1426 

1427 self.reset() 

1428 

1429 def __repr__(self) -> str: 

1430 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

1431 return ( 

1432 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

1433 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

1434 f"({conn_kwargs})>)>" 

1435 ) 

1436 

1437 def get_protocol(self): 

1438 """ 

1439 Returns: 

1440 The RESP protocol version, or ``None`` if the protocol is not specified, 

1441 in which case the server default will be used. 

1442 """ 

1443 return self.connection_kwargs.get("protocol", None) 

1444 

1445 def reset(self) -> None: 

1446 self._created_connections = 0 

1447 self._available_connections = [] 

1448 self._in_use_connections = set() 

1449 

1450 # this must be the last operation in this method. while reset() is 

1451 # called when holding _fork_lock, other threads in this process 

1452 # can call _checkpid() which compares self.pid and os.getpid() without 

1453 # holding any lock (for performance reasons). keeping this assignment 

1454 # as the last operation ensures that those other threads will also 

1455 # notice a pid difference and block waiting for the first thread to 

1456 # release _fork_lock. when each of these threads eventually acquire 

1457 # _fork_lock, they will notice that another thread already called 

1458 # reset() and they will immediately release _fork_lock and continue on. 

1459 self.pid = os.getpid() 

1460 

1461 def _checkpid(self) -> None: 

1462 # _checkpid() attempts to keep ConnectionPool fork-safe on modern 

1463 # systems. this is called by all ConnectionPool methods that 

1464 # manipulate the pool's state such as get_connection() and release(). 

1465 # 

1466 # _checkpid() determines whether the process has forked by comparing 

1467 # the current process id to the process id saved on the ConnectionPool 

1468 # instance. if these values are the same, _checkpid() simply returns. 

1469 # 

1470 # when the process ids differ, _checkpid() assumes that the process 

1471 # has forked and that we're now running in the child process. the child 

1472 # process cannot use the parent's file descriptors (e.g., sockets). 

1473 # therefore, when _checkpid() sees the process id change, it calls 

1474 # reset() in order to reinitialize the child's ConnectionPool. this 

1475 # will cause the child to make all new connection objects. 

1476 # 

1477 # _checkpid() is protected by self._fork_lock to ensure that multiple 

1478 # threads in the child process do not call reset() multiple times. 

1479 # 

1480 # there is an extremely small chance this could fail in the following 

1481 # scenario: 

1482 # 1. process A calls _checkpid() for the first time and acquires 

1483 # self._fork_lock. 

1484 # 2. while holding self._fork_lock, process A forks (the fork() 

1485 # could happen in a different thread owned by process A) 

1486 # 3. process B (the forked child process) inherits the 

1487 # ConnectionPool's state from the parent. that state includes 

1488 # a locked _fork_lock. process B will not be notified when 

1489 # process A releases the _fork_lock and will thus never be 

1490 # able to acquire the _fork_lock. 

1491 # 

1492 # to mitigate this possible deadlock, _checkpid() will only wait 5 

1493 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in 

1494 # that time it is assumed that the child is deadlocked and a 

1495 # redis.ChildDeadlockedError error is raised. 

1496 if self.pid != os.getpid(): 

1497 acquired = self._fork_lock.acquire(timeout=5) 

1498 if not acquired: 

1499 raise ChildDeadlockedError 

1500 # reset() the instance for the new process if another thread 

1501 # hasn't already done so 

1502 try: 

1503 if self.pid != os.getpid(): 

1504 self.reset() 

1505 finally: 

1506 self._fork_lock.release() 

1507 

1508 @deprecated_args( 

1509 args_to_warn=["*"], 

1510 reason="Use get_connection() without args instead", 

1511 version="5.3.0", 

1512 ) 

1513 def get_connection(self, command_name=None, *keys, **options) -> "Connection": 

1514 "Get a connection from the pool" 

1515 

1516 self._checkpid() 

1517 with self._lock: 

1518 try: 

1519 connection = self._available_connections.pop() 

1520 except IndexError: 

1521 connection = self.make_connection() 

1522 self._in_use_connections.add(connection) 

1523 

1524 try: 

1525 # ensure this connection is connected to Redis 

1526 connection.connect() 

1527 # connections that the pool provides should be ready to send 

1528 # a command. if not, the connection was either returned to the 

1529 # pool before all data has been read or the socket has been 

1530 # closed. either way, reconnect and verify everything is good. 

1531 try: 

1532 if connection.can_read() and self.cache is None: 

1533 raise ConnectionError("Connection has data") 

1534 except (ConnectionError, TimeoutError, OSError): 

1535 connection.disconnect() 

1536 connection.connect() 

1537 if connection.can_read(): 

1538 raise ConnectionError("Connection not ready") 

1539 except BaseException: 

1540 # release the connection back to the pool so that we don't 

1541 # leak it 

1542 self.release(connection) 

1543 raise 

1544 

1545 return connection 

1546 

1547 def get_encoder(self) -> Encoder: 

1548 "Return an encoder based on encoding settings" 

1549 kwargs = self.connection_kwargs 

1550 return Encoder( 

1551 encoding=kwargs.get("encoding", "utf-8"), 

1552 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1553 decode_responses=kwargs.get("decode_responses", False), 

1554 ) 

1555 

1556 def make_connection(self) -> "ConnectionInterface": 

1557 "Create a new connection" 

1558 if self._created_connections >= self.max_connections: 

1559 raise ConnectionError("Too many connections") 

1560 self._created_connections += 1 

1561 

1562 if self.cache is not None: 

1563 return CacheProxyConnection( 

1564 self.connection_class(**self.connection_kwargs), self.cache, self._lock 

1565 ) 

1566 

1567 return self.connection_class(**self.connection_kwargs) 

1568 

1569 def release(self, connection: "Connection") -> None: 

1570 "Releases the connection back to the pool" 

1571 self._checkpid() 

1572 with self._lock: 

1573 try: 

1574 self._in_use_connections.remove(connection) 

1575 except KeyError: 

1576 # Gracefully fail when a connection is returned to this pool 

1577 # that the pool doesn't actually own 

1578 return 

1579 

1580 if self.owns_connection(connection): 

1581 self._available_connections.append(connection) 

1582 self._event_dispatcher.dispatch( 

1583 AfterConnectionReleasedEvent(connection) 

1584 ) 

1585 else: 

1586 # Pool doesn't own this connection, do not add it back 

1587 # to the pool. 

1588 # The created connections count should not be changed, 

1589 # because the connection was not created by the pool. 

1590 connection.disconnect() 

1591 return 

1592 

1593 def owns_connection(self, connection: "Connection") -> int: 

1594 return connection.pid == self.pid 

1595 

1596 def disconnect(self, inuse_connections: bool = True) -> None: 

1597 """ 

1598 Disconnects connections in the pool 

1599 

1600 If ``inuse_connections`` is True, disconnect connections that are 

1601 current in use, potentially by other threads. Otherwise only disconnect 

1602 connections that are idle in the pool. 

1603 """ 

1604 self._checkpid() 

1605 with self._lock: 

1606 if inuse_connections: 

1607 connections = chain( 

1608 self._available_connections, self._in_use_connections 

1609 ) 

1610 else: 

1611 connections = self._available_connections 

1612 

1613 for connection in connections: 

1614 connection.disconnect() 

1615 

1616 def close(self) -> None: 

1617 """Close the pool, disconnecting all connections""" 

1618 self.disconnect() 

1619 

1620 def set_retry(self, retry: Retry) -> None: 

1621 self.connection_kwargs.update({"retry": retry}) 

1622 for conn in self._available_connections: 

1623 conn.retry = retry 

1624 for conn in self._in_use_connections: 

1625 conn.retry = retry 

1626 

1627 def re_auth_callback(self, token: TokenInterface): 

1628 with self._lock: 

1629 for conn in self._available_connections: 

1630 conn.retry.call_with_retry( 

1631 lambda: conn.send_command( 

1632 "AUTH", token.try_get("oid"), token.get_value() 

1633 ), 

1634 lambda error: self._mock(error), 

1635 ) 

1636 conn.retry.call_with_retry( 

1637 lambda: conn.read_response(), lambda error: self._mock(error) 

1638 ) 

1639 for conn in self._in_use_connections: 

1640 conn.set_re_auth_token(token) 

1641 

1642 async def _mock(self, error: RedisError): 

1643 """ 

1644 Dummy functions, needs to be passed as error callback to retry object. 

1645 :param error: 

1646 :return: 

1647 """ 

1648 pass 

1649 

1650 

1651class BlockingConnectionPool(ConnectionPool): 

1652 """ 

1653 Thread-safe blocking connection pool:: 

1654 

1655 >>> from redis.client import Redis 

1656 >>> client = Redis(connection_pool=BlockingConnectionPool()) 

1657 

1658 It performs the same function as the default 

1659 :py:class:`~redis.ConnectionPool` implementation, in that, 

1660 it maintains a pool of reusable connections that can be shared by 

1661 multiple redis clients (safely across threads if required). 

1662 

1663 The difference is that, in the event that a client tries to get a 

1664 connection from the pool when all of connections are in use, rather than 

1665 raising a :py:class:`~redis.ConnectionError` (as the default 

1666 :py:class:`~redis.ConnectionPool` implementation does), it 

1667 makes the client wait ("blocks") for a specified number of seconds until 

1668 a connection becomes available. 

1669 

1670 Use ``max_connections`` to increase / decrease the pool size:: 

1671 

1672 >>> pool = BlockingConnectionPool(max_connections=10) 

1673 

1674 Use ``timeout`` to tell it either how many seconds to wait for a connection 

1675 to become available, or to block forever: 

1676 

1677 >>> # Block forever. 

1678 >>> pool = BlockingConnectionPool(timeout=None) 

1679 

1680 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

1681 >>> # not available. 

1682 >>> pool = BlockingConnectionPool(timeout=5) 

1683 """ 

1684 

1685 def __init__( 

1686 self, 

1687 max_connections=50, 

1688 timeout=20, 

1689 connection_class=Connection, 

1690 queue_class=LifoQueue, 

1691 **connection_kwargs, 

1692 ): 

1693 self.queue_class = queue_class 

1694 self.timeout = timeout 

1695 super().__init__( 

1696 connection_class=connection_class, 

1697 max_connections=max_connections, 

1698 **connection_kwargs, 

1699 ) 

1700 

1701 def reset(self): 

1702 # Create and fill up a thread safe queue with ``None`` values. 

1703 self.pool = self.queue_class(self.max_connections) 

1704 while True: 

1705 try: 

1706 self.pool.put_nowait(None) 

1707 except Full: 

1708 break 

1709 

1710 # Keep a list of actual connection instances so that we can 

1711 # disconnect them later. 

1712 self._connections = [] 

1713 

1714 # this must be the last operation in this method. while reset() is 

1715 # called when holding _fork_lock, other threads in this process 

1716 # can call _checkpid() which compares self.pid and os.getpid() without 

1717 # holding any lock (for performance reasons). keeping this assignment 

1718 # as the last operation ensures that those other threads will also 

1719 # notice a pid difference and block waiting for the first thread to 

1720 # release _fork_lock. when each of these threads eventually acquire 

1721 # _fork_lock, they will notice that another thread already called 

1722 # reset() and they will immediately release _fork_lock and continue on. 

1723 self.pid = os.getpid() 

1724 

1725 def make_connection(self): 

1726 "Make a fresh connection." 

1727 if self.cache is not None: 

1728 connection = CacheProxyConnection( 

1729 self.connection_class(**self.connection_kwargs), self.cache, self._lock 

1730 ) 

1731 else: 

1732 connection = self.connection_class(**self.connection_kwargs) 

1733 self._connections.append(connection) 

1734 return connection 

1735 

1736 @deprecated_args( 

1737 args_to_warn=["*"], 

1738 reason="Use get_connection() without args instead", 

1739 version="5.3.0", 

1740 ) 

1741 def get_connection(self, command_name=None, *keys, **options): 

1742 """ 

1743 Get a connection, blocking for ``self.timeout`` until a connection 

1744 is available from the pool. 

1745 

1746 If the connection returned is ``None`` then creates a new connection. 

1747 Because we use a last-in first-out queue, the existing connections 

1748 (having been returned to the pool after the initial ``None`` values 

1749 were added) will be returned before ``None`` values. This means we only 

1750 create new connections when we need to, i.e.: the actual number of 

1751 connections will only increase in response to demand. 

1752 """ 

1753 # Make sure we haven't changed process. 

1754 self._checkpid() 

1755 

1756 # Try and get a connection from the pool. If one isn't available within 

1757 # self.timeout then raise a ``ConnectionError``. 

1758 connection = None 

1759 try: 

1760 connection = self.pool.get(block=True, timeout=self.timeout) 

1761 except Empty: 

1762 # Note that this is not caught by the redis client and will be 

1763 # raised unless handled by application code. If you want never to 

1764 raise ConnectionError("No connection available.") 

1765 

1766 # If the ``connection`` is actually ``None`` then that's a cue to make 

1767 # a new connection to add to the pool. 

1768 if connection is None: 

1769 connection = self.make_connection() 

1770 

1771 try: 

1772 # ensure this connection is connected to Redis 

1773 connection.connect() 

1774 # connections that the pool provides should be ready to send 

1775 # a command. if not, the connection was either returned to the 

1776 # pool before all data has been read or the socket has been 

1777 # closed. either way, reconnect and verify everything is good. 

1778 try: 

1779 if connection.can_read(): 

1780 raise ConnectionError("Connection has data") 

1781 except (ConnectionError, TimeoutError, OSError): 

1782 connection.disconnect() 

1783 connection.connect() 

1784 if connection.can_read(): 

1785 raise ConnectionError("Connection not ready") 

1786 except BaseException: 

1787 # release the connection back to the pool so that we don't leak it 

1788 self.release(connection) 

1789 raise 

1790 

1791 return connection 

1792 

1793 def release(self, connection): 

1794 "Releases the connection back to the pool." 

1795 # Make sure we haven't changed process. 

1796 self._checkpid() 

1797 if not self.owns_connection(connection): 

1798 # pool doesn't own this connection. do not add it back 

1799 # to the pool. instead add a None value which is a placeholder 

1800 # that will cause the pool to recreate the connection if 

1801 # its needed. 

1802 connection.disconnect() 

1803 self.pool.put_nowait(None) 

1804 return 

1805 

1806 # Put the connection back into the pool. 

1807 try: 

1808 self.pool.put_nowait(connection) 

1809 except Full: 

1810 # perhaps the pool has been reset() after a fork? regardless, 

1811 # we don't want this connection 

1812 pass 

1813 

1814 def disconnect(self): 

1815 "Disconnects all connections in the pool." 

1816 self._checkpid() 

1817 for connection in self._connections: 

1818 connection.disconnect()