Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

907 statements  

1import copy 

2import os 

3import socket 

4import sys 

5import threading 

6import time 

7import weakref 

8from abc import abstractmethod 

9from itertools import chain 

10from queue import Empty, Full, LifoQueue 

11from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union 

12from urllib.parse import parse_qs, unquote, urlparse 

13 

14from redis.cache import ( 

15 CacheEntry, 

16 CacheEntryStatus, 

17 CacheFactory, 

18 CacheFactoryInterface, 

19 CacheInterface, 

20 CacheKey, 

21) 

22 

23from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser 

24from .auth.token import TokenInterface 

25from .backoff import NoBackoff 

26from .credentials import CredentialProvider, UsernamePasswordCredentialProvider 

27from .event import AfterConnectionReleasedEvent, EventDispatcher 

28from .exceptions import ( 

29 AuthenticationError, 

30 AuthenticationWrongNumberOfArgsError, 

31 ChildDeadlockedError, 

32 ConnectionError, 

33 DataError, 

34 MaxConnectionsError, 

35 RedisError, 

36 ResponseError, 

37 TimeoutError, 

38) 

39from .retry import Retry 

40from .utils import ( 

41 CRYPTOGRAPHY_AVAILABLE, 

42 HIREDIS_AVAILABLE, 

43 SSL_AVAILABLE, 

44 compare_versions, 

45 deprecated_args, 

46 ensure_string, 

47 format_error_message, 

48 get_lib_version, 

49 str_if_bytes, 

50) 

51 

52if SSL_AVAILABLE: 

53 import ssl 

54else: 

55 ssl = None 

56 

57if HIREDIS_AVAILABLE: 

58 import hiredis 

59 

60SYM_STAR = b"*" 

61SYM_DOLLAR = b"$" 

62SYM_CRLF = b"\r\n" 

63SYM_EMPTY = b"" 

64 

65DEFAULT_RESP_VERSION = 2 

66 

67SENTINEL = object() 

68 

69DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] 

70if HIREDIS_AVAILABLE: 

71 DefaultParser = _HiredisParser 

72else: 

73 DefaultParser = _RESP2Parser 

74 

75 

76class HiredisRespSerializer: 

77 def pack(self, *args: List): 

78 """Pack a series of arguments into the Redis protocol""" 

79 output = [] 

80 

81 if isinstance(args[0], str): 

82 args = tuple(args[0].encode().split()) + args[1:] 

83 elif b" " in args[0]: 

84 args = tuple(args[0].split()) + args[1:] 

85 try: 

86 output.append(hiredis.pack_command(args)) 

87 except TypeError: 

88 _, value, traceback = sys.exc_info() 

89 raise DataError(value).with_traceback(traceback) 

90 

91 return output 

92 

93 

94class PythonRespSerializer: 

95 def __init__(self, buffer_cutoff, encode) -> None: 

96 self._buffer_cutoff = buffer_cutoff 

97 self.encode = encode 

98 

99 def pack(self, *args): 

100 """Pack a series of arguments into the Redis protocol""" 

101 output = [] 

102 # the client might have included 1 or more literal arguments in 

103 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

104 # arguments to be sent separately, so split the first argument 

105 # manually. These arguments should be bytestrings so that they are 

106 # not encoded. 

107 if isinstance(args[0], str): 

108 args = tuple(args[0].encode().split()) + args[1:] 

109 elif b" " in args[0]: 

110 args = tuple(args[0].split()) + args[1:] 

111 

112 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

113 

114 buffer_cutoff = self._buffer_cutoff 

115 for arg in map(self.encode, args): 

116 # to avoid large string mallocs, chunk the command into the 

117 # output list if we're sending large values or memoryviews 

118 arg_length = len(arg) 

119 if ( 

120 len(buff) > buffer_cutoff 

121 or arg_length > buffer_cutoff 

122 or isinstance(arg, memoryview) 

123 ): 

124 buff = SYM_EMPTY.join( 

125 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

126 ) 

127 output.append(buff) 

128 output.append(arg) 

129 buff = SYM_CRLF 

130 else: 

131 buff = SYM_EMPTY.join( 

132 ( 

133 buff, 

134 SYM_DOLLAR, 

135 str(arg_length).encode(), 

136 SYM_CRLF, 

137 arg, 

138 SYM_CRLF, 

139 ) 

140 ) 

141 output.append(buff) 

142 return output 

143 

144 

145class ConnectionInterface: 

146 @abstractmethod 

147 def repr_pieces(self): 

148 pass 

149 

150 @abstractmethod 

151 def register_connect_callback(self, callback): 

152 pass 

153 

154 @abstractmethod 

155 def deregister_connect_callback(self, callback): 

156 pass 

157 

158 @abstractmethod 

159 def set_parser(self, parser_class): 

160 pass 

161 

162 @abstractmethod 

163 def get_protocol(self): 

164 pass 

165 

166 @abstractmethod 

167 def connect(self): 

168 pass 

169 

170 @abstractmethod 

171 def on_connect(self): 

172 pass 

173 

174 @abstractmethod 

175 def disconnect(self, *args): 

176 pass 

177 

178 @abstractmethod 

179 def check_health(self): 

180 pass 

181 

182 @abstractmethod 

183 def send_packed_command(self, command, check_health=True): 

184 pass 

185 

186 @abstractmethod 

187 def send_command(self, *args, **kwargs): 

188 pass 

189 

190 @abstractmethod 

191 def can_read(self, timeout=0): 

192 pass 

193 

194 @abstractmethod 

195 def read_response( 

196 self, 

197 disable_decoding=False, 

198 *, 

199 disconnect_on_error=True, 

200 push_request=False, 

201 ): 

202 pass 

203 

204 @abstractmethod 

205 def pack_command(self, *args): 

206 pass 

207 

208 @abstractmethod 

209 def pack_commands(self, commands): 

210 pass 

211 

212 @property 

213 @abstractmethod 

214 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

215 pass 

216 

217 @abstractmethod 

218 def set_re_auth_token(self, token: TokenInterface): 

219 pass 

220 

221 @abstractmethod 

222 def re_auth(self): 

223 pass 

224 

225 

226class AbstractConnection(ConnectionInterface): 

227 "Manages communication to and from a Redis server" 

228 

229 def __init__( 

230 self, 

231 db: int = 0, 

232 password: Optional[str] = None, 

233 socket_timeout: Optional[float] = None, 

234 socket_connect_timeout: Optional[float] = None, 

235 retry_on_timeout: bool = False, 

236 retry_on_error=SENTINEL, 

237 encoding: str = "utf-8", 

238 encoding_errors: str = "strict", 

239 decode_responses: bool = False, 

240 parser_class=DefaultParser, 

241 socket_read_size: int = 65536, 

242 health_check_interval: int = 0, 

243 client_name: Optional[str] = None, 

244 lib_name: Optional[str] = "redis-py", 

245 lib_version: Optional[str] = get_lib_version(), 

246 username: Optional[str] = None, 

247 retry: Union[Any, None] = None, 

248 redis_connect_func: Optional[Callable[[], None]] = None, 

249 credential_provider: Optional[CredentialProvider] = None, 

250 protocol: Optional[int] = 2, 

251 command_packer: Optional[Callable[[], None]] = None, 

252 event_dispatcher: Optional[EventDispatcher] = None, 

253 ): 

254 """ 

255 Initialize a new Connection. 

256 To specify a retry policy for specific errors, first set 

257 `retry_on_error` to a list of the error/s to retry on, then set 

258 `retry` to a valid `Retry` object. 

259 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

260 """ 

261 if (username or password) and credential_provider is not None: 

262 raise DataError( 

263 "'username' and 'password' cannot be passed along with 'credential_" 

264 "provider'. Please provide only one of the following arguments: \n" 

265 "1. 'password' and (optional) 'username'\n" 

266 "2. 'credential_provider'" 

267 ) 

268 if event_dispatcher is None: 

269 self._event_dispatcher = EventDispatcher() 

270 else: 

271 self._event_dispatcher = event_dispatcher 

272 self.pid = os.getpid() 

273 self.db = db 

274 self.client_name = client_name 

275 self.lib_name = lib_name 

276 self.lib_version = lib_version 

277 self.credential_provider = credential_provider 

278 self.password = password 

279 self.username = username 

280 self.socket_timeout = socket_timeout 

281 if socket_connect_timeout is None: 

282 socket_connect_timeout = socket_timeout 

283 self.socket_connect_timeout = socket_connect_timeout 

284 self.retry_on_timeout = retry_on_timeout 

285 if retry_on_error is SENTINEL: 

286 retry_on_error = [] 

287 if retry_on_timeout: 

288 # Add TimeoutError to the errors list to retry on 

289 retry_on_error.append(TimeoutError) 

290 self.retry_on_error = retry_on_error 

291 if retry or retry_on_error: 

292 if retry is None: 

293 self.retry = Retry(NoBackoff(), 1) 

294 else: 

295 # deep-copy the Retry object as it is mutable 

296 self.retry = copy.deepcopy(retry) 

297 # Update the retry's supported errors with the specified errors 

298 self.retry.update_supported_errors(retry_on_error) 

299 else: 

300 self.retry = Retry(NoBackoff(), 0) 

301 self.health_check_interval = health_check_interval 

302 self.next_health_check = 0 

303 self.redis_connect_func = redis_connect_func 

304 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

305 self.handshake_metadata = None 

306 self._sock = None 

307 self._socket_read_size = socket_read_size 

308 self.set_parser(parser_class) 

309 self._connect_callbacks = [] 

310 self._buffer_cutoff = 6000 

311 self._re_auth_token: Optional[TokenInterface] = None 

312 try: 

313 p = int(protocol) 

314 except TypeError: 

315 p = DEFAULT_RESP_VERSION 

316 except ValueError: 

317 raise ConnectionError("protocol must be an integer") 

318 finally: 

319 if p < 2 or p > 3: 

320 raise ConnectionError("protocol must be either 2 or 3") 

321 # p = DEFAULT_RESP_VERSION 

322 self.protocol = p 

323 self._command_packer = self._construct_command_packer(command_packer) 

324 

325 def __repr__(self): 

326 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) 

327 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

328 

329 @abstractmethod 

330 def repr_pieces(self): 

331 pass 

332 

333 def __del__(self): 

334 try: 

335 self.disconnect() 

336 except Exception: 

337 pass 

338 

339 def _construct_command_packer(self, packer): 

340 if packer is not None: 

341 return packer 

342 elif HIREDIS_AVAILABLE: 

343 return HiredisRespSerializer() 

344 else: 

345 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) 

346 

347 def register_connect_callback(self, callback): 

348 """ 

349 Register a callback to be called when the connection is established either 

350 initially or reconnected. This allows listeners to issue commands that 

351 are ephemeral to the connection, for example pub/sub subscription or 

352 key tracking. The callback must be a _method_ and will be kept as 

353 a weak reference. 

354 """ 

355 wm = weakref.WeakMethod(callback) 

356 if wm not in self._connect_callbacks: 

357 self._connect_callbacks.append(wm) 

358 

359 def deregister_connect_callback(self, callback): 

360 """ 

361 De-register a previously registered callback. It will no-longer receive 

362 notifications on connection events. Calling this is not required when the 

363 listener goes away, since the callbacks are kept as weak methods. 

364 """ 

365 try: 

366 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

367 except ValueError: 

368 pass 

369 

370 def set_parser(self, parser_class): 

371 """ 

372 Creates a new instance of parser_class with socket size: 

373 _socket_read_size and assigns it to the parser for the connection 

374 :param parser_class: The required parser class 

375 """ 

376 self._parser = parser_class(socket_read_size=self._socket_read_size) 

377 

378 def connect(self): 

379 "Connects to the Redis server if not already connected" 

380 self.connect_check_health(check_health=True) 

381 

382 def connect_check_health( 

383 self, check_health: bool = True, retry_socket_connect: bool = True 

384 ): 

385 if self._sock: 

386 return 

387 try: 

388 if retry_socket_connect: 

389 sock = self.retry.call_with_retry( 

390 lambda: self._connect(), lambda error: self.disconnect(error) 

391 ) 

392 else: 

393 sock = self._connect() 

394 except socket.timeout: 

395 raise TimeoutError("Timeout connecting to server") 

396 except OSError as e: 

397 raise ConnectionError(self._error_message(e)) 

398 

399 self._sock = sock 

400 try: 

401 if self.redis_connect_func is None: 

402 # Use the default on_connect function 

403 self.on_connect_check_health(check_health=check_health) 

404 else: 

405 # Use the passed function redis_connect_func 

406 self.redis_connect_func(self) 

407 except RedisError: 

408 # clean up after any error in on_connect 

409 self.disconnect() 

410 raise 

411 

412 # run any user callbacks. right now the only internal callback 

413 # is for pubsub channel/pattern resubscription 

414 # first, remove any dead weakrefs 

415 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

416 for ref in self._connect_callbacks: 

417 callback = ref() 

418 if callback: 

419 callback(self) 

420 

421 @abstractmethod 

422 def _connect(self): 

423 pass 

424 

425 @abstractmethod 

426 def _host_error(self): 

427 pass 

428 

429 def _error_message(self, exception): 

430 return format_error_message(self._host_error(), exception) 

431 

432 def on_connect(self): 

433 self.on_connect_check_health(check_health=True) 

434 

435 def on_connect_check_health(self, check_health: bool = True): 

436 "Initialize the connection, authenticate and select a database" 

437 self._parser.on_connect(self) 

438 parser = self._parser 

439 

440 auth_args = None 

441 # if credential provider or username and/or password are set, authenticate 

442 if self.credential_provider or (self.username or self.password): 

443 cred_provider = ( 

444 self.credential_provider 

445 or UsernamePasswordCredentialProvider(self.username, self.password) 

446 ) 

447 auth_args = cred_provider.get_credentials() 

448 

449 # if resp version is specified and we have auth args, 

450 # we need to send them via HELLO 

451 if auth_args and self.protocol not in [2, "2"]: 

452 if isinstance(self._parser, _RESP2Parser): 

453 self.set_parser(_RESP3Parser) 

454 # update cluster exception classes 

455 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

456 self._parser.on_connect(self) 

457 if len(auth_args) == 1: 

458 auth_args = ["default", auth_args[0]] 

459 # avoid checking health here -- PING will fail if we try 

460 # to check the health prior to the AUTH 

461 self.send_command( 

462 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

463 ) 

464 self.handshake_metadata = self.read_response() 

465 # if response.get(b"proto") != self.protocol and response.get( 

466 # "proto" 

467 # ) != self.protocol: 

468 # raise ConnectionError("Invalid RESP version") 

469 elif auth_args: 

470 # avoid checking health here -- PING will fail if we try 

471 # to check the health prior to the AUTH 

472 self.send_command("AUTH", *auth_args, check_health=False) 

473 

474 try: 

475 auth_response = self.read_response() 

476 except AuthenticationWrongNumberOfArgsError: 

477 # a username and password were specified but the Redis 

478 # server seems to be < 6.0.0 which expects a single password 

479 # arg. retry auth with just the password. 

480 # https://github.com/andymccurdy/redis-py/issues/1274 

481 self.send_command("AUTH", auth_args[-1], check_health=False) 

482 auth_response = self.read_response() 

483 

484 if str_if_bytes(auth_response) != "OK": 

485 raise AuthenticationError("Invalid Username or Password") 

486 

487 # if resp version is specified, switch to it 

488 elif self.protocol not in [2, "2"]: 

489 if isinstance(self._parser, _RESP2Parser): 

490 self.set_parser(_RESP3Parser) 

491 # update cluster exception classes 

492 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

493 self._parser.on_connect(self) 

494 self.send_command("HELLO", self.protocol, check_health=check_health) 

495 self.handshake_metadata = self.read_response() 

496 if ( 

497 self.handshake_metadata.get(b"proto") != self.protocol 

498 and self.handshake_metadata.get("proto") != self.protocol 

499 ): 

500 raise ConnectionError("Invalid RESP version") 

501 

502 # if a client_name is given, set it 

503 if self.client_name: 

504 self.send_command( 

505 "CLIENT", 

506 "SETNAME", 

507 self.client_name, 

508 check_health=check_health, 

509 ) 

510 if str_if_bytes(self.read_response()) != "OK": 

511 raise ConnectionError("Error setting client name") 

512 

513 try: 

514 # set the library name and version 

515 if self.lib_name: 

516 self.send_command( 

517 "CLIENT", 

518 "SETINFO", 

519 "LIB-NAME", 

520 self.lib_name, 

521 check_health=check_health, 

522 ) 

523 self.read_response() 

524 except ResponseError: 

525 pass 

526 

527 try: 

528 if self.lib_version: 

529 self.send_command( 

530 "CLIENT", 

531 "SETINFO", 

532 "LIB-VER", 

533 self.lib_version, 

534 check_health=check_health, 

535 ) 

536 self.read_response() 

537 except ResponseError: 

538 pass 

539 

540 # if a database is specified, switch to it 

541 if self.db: 

542 self.send_command("SELECT", self.db, check_health=check_health) 

543 if str_if_bytes(self.read_response()) != "OK": 

544 raise ConnectionError("Invalid Database") 

545 

546 def disconnect(self, *args): 

547 "Disconnects from the Redis server" 

548 self._parser.on_disconnect() 

549 

550 conn_sock = self._sock 

551 self._sock = None 

552 if conn_sock is None: 

553 return 

554 

555 if os.getpid() == self.pid: 

556 try: 

557 conn_sock.shutdown(socket.SHUT_RDWR) 

558 except (OSError, TypeError): 

559 pass 

560 

561 try: 

562 conn_sock.close() 

563 except OSError: 

564 pass 

565 

566 def _send_ping(self): 

567 """Send PING, expect PONG in return""" 

568 self.send_command("PING", check_health=False) 

569 if str_if_bytes(self.read_response()) != "PONG": 

570 raise ConnectionError("Bad response from PING health check") 

571 

572 def _ping_failed(self, error): 

573 """Function to call when PING fails""" 

574 self.disconnect() 

575 

576 def check_health(self): 

577 """Check the health of the connection with a PING/PONG""" 

578 if self.health_check_interval and time.monotonic() > self.next_health_check: 

579 self.retry.call_with_retry(self._send_ping, self._ping_failed) 

580 

581 def send_packed_command(self, command, check_health=True): 

582 """Send an already packed command to the Redis server""" 

583 if not self._sock: 

584 self.connect_check_health(check_health=False) 

585 # guard against health check recursion 

586 if check_health: 

587 self.check_health() 

588 try: 

589 if isinstance(command, str): 

590 command = [command] 

591 for item in command: 

592 self._sock.sendall(item) 

593 except socket.timeout: 

594 self.disconnect() 

595 raise TimeoutError("Timeout writing to socket") 

596 except OSError as e: 

597 self.disconnect() 

598 if len(e.args) == 1: 

599 errno, errmsg = "UNKNOWN", e.args[0] 

600 else: 

601 errno = e.args[0] 

602 errmsg = e.args[1] 

603 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") 

604 except BaseException: 

605 # BaseExceptions can be raised when a socket send operation is not 

606 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

607 # to send un-sent data. However, the send_packed_command() API 

608 # does not support it so there is no point in keeping the connection open. 

609 self.disconnect() 

610 raise 

611 

612 def send_command(self, *args, **kwargs): 

613 """Pack and send a command to the Redis server""" 

614 self.send_packed_command( 

615 self._command_packer.pack(*args), 

616 check_health=kwargs.get("check_health", True), 

617 ) 

618 

619 def can_read(self, timeout=0): 

620 """Poll the socket to see if there's data that can be read.""" 

621 sock = self._sock 

622 if not sock: 

623 self.connect() 

624 

625 host_error = self._host_error() 

626 

627 try: 

628 return self._parser.can_read(timeout) 

629 except OSError as e: 

630 self.disconnect() 

631 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

632 

633 def read_response( 

634 self, 

635 disable_decoding=False, 

636 *, 

637 disconnect_on_error=True, 

638 push_request=False, 

639 ): 

640 """Read the response from a previously sent command""" 

641 

642 host_error = self._host_error() 

643 

644 try: 

645 if self.protocol in ["3", 3]: 

646 response = self._parser.read_response( 

647 disable_decoding=disable_decoding, push_request=push_request 

648 ) 

649 else: 

650 response = self._parser.read_response(disable_decoding=disable_decoding) 

651 except socket.timeout: 

652 if disconnect_on_error: 

653 self.disconnect() 

654 raise TimeoutError(f"Timeout reading from {host_error}") 

655 except OSError as e: 

656 if disconnect_on_error: 

657 self.disconnect() 

658 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

659 except BaseException: 

660 # Also by default close in case of BaseException. A lot of code 

661 # relies on this behaviour when doing Command/Response pairs. 

662 # See #1128. 

663 if disconnect_on_error: 

664 self.disconnect() 

665 raise 

666 

667 if self.health_check_interval: 

668 self.next_health_check = time.monotonic() + self.health_check_interval 

669 

670 if isinstance(response, ResponseError): 

671 try: 

672 raise response 

673 finally: 

674 del response # avoid creating ref cycles 

675 return response 

676 

677 def pack_command(self, *args): 

678 """Pack a series of arguments into the Redis protocol""" 

679 return self._command_packer.pack(*args) 

680 

681 def pack_commands(self, commands): 

682 """Pack multiple commands into the Redis protocol""" 

683 output = [] 

684 pieces = [] 

685 buffer_length = 0 

686 buffer_cutoff = self._buffer_cutoff 

687 

688 for cmd in commands: 

689 for chunk in self._command_packer.pack(*cmd): 

690 chunklen = len(chunk) 

691 if ( 

692 buffer_length > buffer_cutoff 

693 or chunklen > buffer_cutoff 

694 or isinstance(chunk, memoryview) 

695 ): 

696 if pieces: 

697 output.append(SYM_EMPTY.join(pieces)) 

698 buffer_length = 0 

699 pieces = [] 

700 

701 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

702 output.append(chunk) 

703 else: 

704 pieces.append(chunk) 

705 buffer_length += chunklen 

706 

707 if pieces: 

708 output.append(SYM_EMPTY.join(pieces)) 

709 return output 

710 

711 def get_protocol(self) -> Union[int, str]: 

712 return self.protocol 

713 

714 @property 

715 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

716 return self._handshake_metadata 

717 

718 @handshake_metadata.setter 

719 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): 

720 self._handshake_metadata = value 

721 

722 def set_re_auth_token(self, token: TokenInterface): 

723 self._re_auth_token = token 

724 

725 def re_auth(self): 

726 if self._re_auth_token is not None: 

727 self.send_command( 

728 "AUTH", 

729 self._re_auth_token.try_get("oid"), 

730 self._re_auth_token.get_value(), 

731 ) 

732 self.read_response() 

733 self._re_auth_token = None 

734 

735 

736class Connection(AbstractConnection): 

737 "Manages TCP communication to and from a Redis server" 

738 

739 def __init__( 

740 self, 

741 host="localhost", 

742 port=6379, 

743 socket_keepalive=False, 

744 socket_keepalive_options=None, 

745 socket_type=0, 

746 **kwargs, 

747 ): 

748 self.host = host 

749 self.port = int(port) 

750 self.socket_keepalive = socket_keepalive 

751 self.socket_keepalive_options = socket_keepalive_options or {} 

752 self.socket_type = socket_type 

753 super().__init__(**kwargs) 

754 

755 def repr_pieces(self): 

756 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

757 if self.client_name: 

758 pieces.append(("client_name", self.client_name)) 

759 return pieces 

760 

761 def _connect(self): 

762 "Create a TCP socket connection" 

763 # we want to mimic what socket.create_connection does to support 

764 # ipv4/ipv6, but we want to set options prior to calling 

765 # socket.connect() 

766 err = None 

767 for res in socket.getaddrinfo( 

768 self.host, self.port, self.socket_type, socket.SOCK_STREAM 

769 ): 

770 family, socktype, proto, canonname, socket_address = res 

771 sock = None 

772 try: 

773 sock = socket.socket(family, socktype, proto) 

774 # TCP_NODELAY 

775 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

776 

777 # TCP_KEEPALIVE 

778 if self.socket_keepalive: 

779 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

780 for k, v in self.socket_keepalive_options.items(): 

781 sock.setsockopt(socket.IPPROTO_TCP, k, v) 

782 

783 # set the socket_connect_timeout before we connect 

784 sock.settimeout(self.socket_connect_timeout) 

785 

786 # connect 

787 sock.connect(socket_address) 

788 

789 # set the socket_timeout now that we're connected 

790 sock.settimeout(self.socket_timeout) 

791 return sock 

792 

793 except OSError as _: 

794 err = _ 

795 if sock is not None: 

796 try: 

797 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

798 except OSError: 

799 pass 

800 sock.close() 

801 

802 if err is not None: 

803 raise err 

804 raise OSError("socket.getaddrinfo returned an empty list") 

805 

806 def _host_error(self): 

807 return f"{self.host}:{self.port}" 

808 

809 

810class CacheProxyConnection(ConnectionInterface): 

811 DUMMY_CACHE_VALUE = b"foo" 

812 MIN_ALLOWED_VERSION = "7.4.0" 

813 DEFAULT_SERVER_NAME = "redis" 

814 

815 def __init__( 

816 self, 

817 conn: ConnectionInterface, 

818 cache: CacheInterface, 

819 pool_lock: threading.RLock, 

820 ): 

821 self.pid = os.getpid() 

822 self._conn = conn 

823 self.retry = self._conn.retry 

824 self.host = self._conn.host 

825 self.port = self._conn.port 

826 self.credential_provider = conn.credential_provider 

827 self._pool_lock = pool_lock 

828 self._cache = cache 

829 self._cache_lock = threading.RLock() 

830 self._current_command_cache_key = None 

831 self._current_options = None 

832 self.register_connect_callback(self._enable_tracking_callback) 

833 

834 def repr_pieces(self): 

835 return self._conn.repr_pieces() 

836 

837 def register_connect_callback(self, callback): 

838 self._conn.register_connect_callback(callback) 

839 

840 def deregister_connect_callback(self, callback): 

841 self._conn.deregister_connect_callback(callback) 

842 

843 def set_parser(self, parser_class): 

844 self._conn.set_parser(parser_class) 

845 

846 def connect(self): 

847 self._conn.connect() 

848 

849 server_name = self._conn.handshake_metadata.get(b"server", None) 

850 if server_name is None: 

851 server_name = self._conn.handshake_metadata.get("server", None) 

852 server_ver = self._conn.handshake_metadata.get(b"version", None) 

853 if server_ver is None: 

854 server_ver = self._conn.handshake_metadata.get("version", None) 

855 if server_ver is None or server_ver is None: 

856 raise ConnectionError("Cannot retrieve information about server version") 

857 

858 server_ver = ensure_string(server_ver) 

859 server_name = ensure_string(server_name) 

860 

861 if ( 

862 server_name != self.DEFAULT_SERVER_NAME 

863 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 

864 ): 

865 raise ConnectionError( 

866 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 

867 ) 

868 

869 def on_connect(self): 

870 self._conn.on_connect() 

871 

872 def disconnect(self, *args): 

873 with self._cache_lock: 

874 self._cache.flush() 

875 self._conn.disconnect(*args) 

876 

877 def check_health(self): 

878 self._conn.check_health() 

879 

880 def send_packed_command(self, command, check_health=True): 

881 # TODO: Investigate if it's possible to unpack command 

882 # or extract keys from packed command 

883 self._conn.send_packed_command(command) 

884 

885 def send_command(self, *args, **kwargs): 

886 self._process_pending_invalidations() 

887 

888 with self._cache_lock: 

889 # Command is write command or not allowed 

890 # to be cached. 

891 if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())): 

892 self._current_command_cache_key = None 

893 self._conn.send_command(*args, **kwargs) 

894 return 

895 

896 if kwargs.get("keys") is None: 

897 raise ValueError("Cannot create cache key.") 

898 

899 # Creates cache key. 

900 self._current_command_cache_key = CacheKey( 

901 command=args[0], redis_keys=tuple(kwargs.get("keys")) 

902 ) 

903 

904 with self._cache_lock: 

905 # We have to trigger invalidation processing in case if 

906 # it was cached by another connection to avoid 

907 # queueing invalidations in stale connections. 

908 if self._cache.get(self._current_command_cache_key): 

909 entry = self._cache.get(self._current_command_cache_key) 

910 

911 if entry.connection_ref != self._conn: 

912 with self._pool_lock: 

913 while entry.connection_ref.can_read(): 

914 entry.connection_ref.read_response(push_request=True) 

915 

916 return 

917 

918 # Set temporary entry value to prevent 

919 # race condition from another connection. 

920 self._cache.set( 

921 CacheEntry( 

922 cache_key=self._current_command_cache_key, 

923 cache_value=self.DUMMY_CACHE_VALUE, 

924 status=CacheEntryStatus.IN_PROGRESS, 

925 connection_ref=self._conn, 

926 ) 

927 ) 

928 

929 # Send command over socket only if it's allowed 

930 # read-only command that not yet cached. 

931 self._conn.send_command(*args, **kwargs) 

932 

933 def can_read(self, timeout=0): 

934 return self._conn.can_read(timeout) 

935 

936 def read_response( 

937 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False 

938 ): 

939 with self._cache_lock: 

940 # Check if command response exists in a cache and it's not in progress. 

941 if ( 

942 self._current_command_cache_key is not None 

943 and self._cache.get(self._current_command_cache_key) is not None 

944 and self._cache.get(self._current_command_cache_key).status 

945 != CacheEntryStatus.IN_PROGRESS 

946 ): 

947 res = copy.deepcopy( 

948 self._cache.get(self._current_command_cache_key).cache_value 

949 ) 

950 self._current_command_cache_key = None 

951 return res 

952 

953 response = self._conn.read_response( 

954 disable_decoding=disable_decoding, 

955 disconnect_on_error=disconnect_on_error, 

956 push_request=push_request, 

957 ) 

958 

959 with self._cache_lock: 

960 # Prevent not-allowed command from caching. 

961 if self._current_command_cache_key is None: 

962 return response 

963 # If response is None prevent from caching. 

964 if response is None: 

965 self._cache.delete_by_cache_keys([self._current_command_cache_key]) 

966 return response 

967 

968 cache_entry = self._cache.get(self._current_command_cache_key) 

969 

970 # Cache only responses that still valid 

971 # and wasn't invalidated by another connection in meantime. 

972 if cache_entry is not None: 

973 cache_entry.status = CacheEntryStatus.VALID 

974 cache_entry.cache_value = response 

975 self._cache.set(cache_entry) 

976 

977 self._current_command_cache_key = None 

978 

979 return response 

980 

981 def pack_command(self, *args): 

982 return self._conn.pack_command(*args) 

983 

984 def pack_commands(self, commands): 

985 return self._conn.pack_commands(commands) 

986 

987 @property 

988 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

989 return self._conn.handshake_metadata 

990 

991 def _connect(self): 

992 self._conn._connect() 

993 

994 def _host_error(self): 

995 self._conn._host_error() 

996 

997 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: 

998 conn.send_command("CLIENT", "TRACKING", "ON") 

999 conn.read_response() 

1000 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) 

1001 

1002 def _process_pending_invalidations(self): 

1003 while self.can_read(): 

1004 self._conn.read_response(push_request=True) 

1005 

1006 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): 

1007 with self._cache_lock: 

1008 # Flush cache when DB flushed on server-side 

1009 if data[1] is None: 

1010 self._cache.flush() 

1011 else: 

1012 self._cache.delete_by_redis_keys(data[1]) 

1013 

1014 def get_protocol(self): 

1015 return self._conn.get_protocol() 

1016 

1017 def set_re_auth_token(self, token: TokenInterface): 

1018 self._conn.set_re_auth_token(token) 

1019 

1020 def re_auth(self): 

1021 self._conn.re_auth() 

1022 

1023 

1024class SSLConnection(Connection): 

1025 """Manages SSL connections to and from the Redis server(s). 

1026 This class extends the Connection class, adding SSL functionality, and making 

1027 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

1028 """ # noqa 

1029 

1030 def __init__( 

1031 self, 

1032 ssl_keyfile=None, 

1033 ssl_certfile=None, 

1034 ssl_cert_reqs="required", 

1035 ssl_ca_certs=None, 

1036 ssl_ca_data=None, 

1037 ssl_check_hostname=True, 

1038 ssl_ca_path=None, 

1039 ssl_password=None, 

1040 ssl_validate_ocsp=False, 

1041 ssl_validate_ocsp_stapled=False, 

1042 ssl_ocsp_context=None, 

1043 ssl_ocsp_expected_cert=None, 

1044 ssl_min_version=None, 

1045 ssl_ciphers=None, 

1046 **kwargs, 

1047 ): 

1048 """Constructor 

1049 

1050 Args: 

1051 ssl_keyfile: Path to an ssl private key. Defaults to None. 

1052 ssl_certfile: Path to an ssl certificate. Defaults to None. 

1053 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required". 

1054 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. 

1055 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. 

1056 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False. 

1057 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. 

1058 ssl_password: Password for unlocking an encrypted private key. Defaults to None. 

1059 

1060 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) 

1061 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response 

1062 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert 

1063 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. 

1064 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. 

1065 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information. 

1066 

1067 Raises: 

1068 RedisError 

1069 """ # noqa 

1070 if not SSL_AVAILABLE: 

1071 raise RedisError("Python wasn't built with SSL support") 

1072 

1073 self.keyfile = ssl_keyfile 

1074 self.certfile = ssl_certfile 

1075 if ssl_cert_reqs is None: 

1076 ssl_cert_reqs = ssl.CERT_NONE 

1077 elif isinstance(ssl_cert_reqs, str): 

1078 CERT_REQS = { # noqa: N806 

1079 "none": ssl.CERT_NONE, 

1080 "optional": ssl.CERT_OPTIONAL, 

1081 "required": ssl.CERT_REQUIRED, 

1082 } 

1083 if ssl_cert_reqs not in CERT_REQS: 

1084 raise RedisError( 

1085 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" 

1086 ) 

1087 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 

1088 self.cert_reqs = ssl_cert_reqs 

1089 self.ca_certs = ssl_ca_certs 

1090 self.ca_data = ssl_ca_data 

1091 self.ca_path = ssl_ca_path 

1092 self.check_hostname = ( 

1093 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1094 ) 

1095 self.certificate_password = ssl_password 

1096 self.ssl_validate_ocsp = ssl_validate_ocsp 

1097 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled 

1098 self.ssl_ocsp_context = ssl_ocsp_context 

1099 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert 

1100 self.ssl_min_version = ssl_min_version 

1101 self.ssl_ciphers = ssl_ciphers 

1102 super().__init__(**kwargs) 

1103 

1104 def _connect(self): 

1105 """ 

1106 Wrap the socket with SSL support, handling potential errors. 

1107 """ 

1108 sock = super()._connect() 

1109 try: 

1110 return self._wrap_socket_with_ssl(sock) 

1111 except (OSError, RedisError): 

1112 sock.close() 

1113 raise 

1114 

1115 def _wrap_socket_with_ssl(self, sock): 

1116 """ 

1117 Wraps the socket with SSL support. 

1118 

1119 Args: 

1120 sock: The plain socket to wrap with SSL. 

1121 

1122 Returns: 

1123 An SSL wrapped socket. 

1124 """ 

1125 context = ssl.create_default_context() 

1126 context.check_hostname = self.check_hostname 

1127 context.verify_mode = self.cert_reqs 

1128 if self.certfile or self.keyfile: 

1129 context.load_cert_chain( 

1130 certfile=self.certfile, 

1131 keyfile=self.keyfile, 

1132 password=self.certificate_password, 

1133 ) 

1134 if ( 

1135 self.ca_certs is not None 

1136 or self.ca_path is not None 

1137 or self.ca_data is not None 

1138 ): 

1139 context.load_verify_locations( 

1140 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1141 ) 

1142 if self.ssl_min_version is not None: 

1143 context.minimum_version = self.ssl_min_version 

1144 if self.ssl_ciphers: 

1145 context.set_ciphers(self.ssl_ciphers) 

1146 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: 

1147 raise RedisError("cryptography is not installed.") 

1148 

1149 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: 

1150 raise RedisError( 

1151 "Either an OCSP staple or pure OCSP connection must be validated " 

1152 "- not both." 

1153 ) 

1154 

1155 sslsock = context.wrap_socket(sock, server_hostname=self.host) 

1156 

1157 # validation for the stapled case 

1158 if self.ssl_validate_ocsp_stapled: 

1159 import OpenSSL 

1160 

1161 from .ocsp import ocsp_staple_verifier 

1162 

1163 # if a context is provided use it - otherwise, a basic context 

1164 if self.ssl_ocsp_context is None: 

1165 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) 

1166 staple_ctx.use_certificate_file(self.certfile) 

1167 staple_ctx.use_privatekey_file(self.keyfile) 

1168 else: 

1169 staple_ctx = self.ssl_ocsp_context 

1170 

1171 staple_ctx.set_ocsp_client_callback( 

1172 ocsp_staple_verifier, self.ssl_ocsp_expected_cert 

1173 ) 

1174 

1175 # need another socket 

1176 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) 

1177 con.request_ocsp() 

1178 con.connect((self.host, self.port)) 

1179 con.do_handshake() 

1180 con.shutdown() 

1181 return sslsock 

1182 

1183 # pure ocsp validation 

1184 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: 

1185 from .ocsp import OCSPVerifier 

1186 

1187 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) 

1188 if o.is_valid(): 

1189 return sslsock 

1190 else: 

1191 raise ConnectionError("ocsp validation error") 

1192 return sslsock 

1193 

1194 

1195class UnixDomainSocketConnection(AbstractConnection): 

1196 "Manages UDS communication to and from a Redis server" 

1197 

1198 def __init__(self, path="", socket_timeout=None, **kwargs): 

1199 super().__init__(**kwargs) 

1200 self.path = path 

1201 self.socket_timeout = socket_timeout 

1202 

1203 def repr_pieces(self): 

1204 pieces = [("path", self.path), ("db", self.db)] 

1205 if self.client_name: 

1206 pieces.append(("client_name", self.client_name)) 

1207 return pieces 

1208 

1209 def _connect(self): 

1210 "Create a Unix domain socket connection" 

1211 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

1212 sock.settimeout(self.socket_connect_timeout) 

1213 try: 

1214 sock.connect(self.path) 

1215 except OSError: 

1216 # Prevent ResourceWarnings for unclosed sockets. 

1217 try: 

1218 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1219 except OSError: 

1220 pass 

1221 sock.close() 

1222 raise 

1223 sock.settimeout(self.socket_timeout) 

1224 return sock 

1225 

1226 def _host_error(self): 

1227 return self.path 

1228 

1229 

1230FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1231 

1232 

1233def to_bool(value): 

1234 if value is None or value == "": 

1235 return None 

1236 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1237 return False 

1238 return bool(value) 

1239 

1240 

1241URL_QUERY_ARGUMENT_PARSERS = { 

1242 "db": int, 

1243 "socket_timeout": float, 

1244 "socket_connect_timeout": float, 

1245 "socket_keepalive": to_bool, 

1246 "retry_on_timeout": to_bool, 

1247 "retry_on_error": list, 

1248 "max_connections": int, 

1249 "health_check_interval": int, 

1250 "ssl_check_hostname": to_bool, 

1251 "timeout": float, 

1252} 

1253 

1254 

1255def parse_url(url): 

1256 if not ( 

1257 url.startswith("redis://") 

1258 or url.startswith("rediss://") 

1259 or url.startswith("unix://") 

1260 ): 

1261 raise ValueError( 

1262 "Redis URL must specify one of the following " 

1263 "schemes (redis://, rediss://, unix://)" 

1264 ) 

1265 

1266 url = urlparse(url) 

1267 kwargs = {} 

1268 

1269 for name, value in parse_qs(url.query).items(): 

1270 if value and len(value) > 0: 

1271 value = unquote(value[0]) 

1272 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1273 if parser: 

1274 try: 

1275 kwargs[name] = parser(value) 

1276 except (TypeError, ValueError): 

1277 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1278 else: 

1279 kwargs[name] = value 

1280 

1281 if url.username: 

1282 kwargs["username"] = unquote(url.username) 

1283 if url.password: 

1284 kwargs["password"] = unquote(url.password) 

1285 

1286 # We only support redis://, rediss:// and unix:// schemes. 

1287 if url.scheme == "unix": 

1288 if url.path: 

1289 kwargs["path"] = unquote(url.path) 

1290 kwargs["connection_class"] = UnixDomainSocketConnection 

1291 

1292 else: # implied: url.scheme in ("redis", "rediss"): 

1293 if url.hostname: 

1294 kwargs["host"] = unquote(url.hostname) 

1295 if url.port: 

1296 kwargs["port"] = int(url.port) 

1297 

1298 # If there's a path argument, use it as the db argument if a 

1299 # querystring value wasn't specified 

1300 if url.path and "db" not in kwargs: 

1301 try: 

1302 kwargs["db"] = int(unquote(url.path).replace("/", "")) 

1303 except (AttributeError, ValueError): 

1304 pass 

1305 

1306 if url.scheme == "rediss": 

1307 kwargs["connection_class"] = SSLConnection 

1308 

1309 return kwargs 

1310 

1311 

1312_CP = TypeVar("_CP", bound="ConnectionPool") 

1313 

1314 

1315class ConnectionPool: 

1316 """ 

1317 Create a connection pool. ``If max_connections`` is set, then this 

1318 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's 

1319 limit is reached. 

1320 

1321 By default, TCP connections are created unless ``connection_class`` 

1322 is specified. Use class:`.UnixDomainSocketConnection` for 

1323 unix sockets. 

1324 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

1325 

1326 Any additional keyword arguments are passed to the constructor of 

1327 ``connection_class``. 

1328 """ 

1329 

1330 @classmethod 

1331 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

1332 """ 

1333 Return a connection pool configured from the given URL. 

1334 

1335 For example:: 

1336 

1337 redis://[[username]:[password]]@localhost:6379/0 

1338 rediss://[[username]:[password]]@localhost:6379/0 

1339 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1340 

1341 Three URL schemes are supported: 

1342 

1343 - `redis://` creates a TCP socket connection. See more at: 

1344 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1345 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1346 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1347 - ``unix://``: creates a Unix Domain Socket connection. 

1348 

1349 The username, password, hostname, path and all querystring values 

1350 are passed through urllib.parse.unquote in order to replace any 

1351 percent-encoded values with their corresponding characters. 

1352 

1353 There are several ways to specify a database number. The first value 

1354 found will be used: 

1355 

1356 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1357 2. If using the redis:// or rediss:// schemes, the path argument 

1358 of the url, e.g. redis://localhost/0 

1359 3. A ``db`` keyword argument to this function. 

1360 

1361 If none of these options are specified, the default db=0 is used. 

1362 

1363 All querystring options are cast to their appropriate Python types. 

1364 Boolean arguments can be specified with string values "True"/"False" 

1365 or "Yes"/"No". Values that cannot be properly cast cause a 

1366 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1367 and keyword arguments are passed to the ``ConnectionPool``'s 

1368 class initializer. In the case of conflicting arguments, querystring 

1369 arguments always win. 

1370 """ 

1371 url_options = parse_url(url) 

1372 

1373 if "connection_class" in kwargs: 

1374 url_options["connection_class"] = kwargs["connection_class"] 

1375 

1376 kwargs.update(url_options) 

1377 return cls(**kwargs) 

1378 

1379 def __init__( 

1380 self, 

1381 connection_class=Connection, 

1382 max_connections: Optional[int] = None, 

1383 cache_factory: Optional[CacheFactoryInterface] = None, 

1384 **connection_kwargs, 

1385 ): 

1386 max_connections = max_connections or 2**31 

1387 if not isinstance(max_connections, int) or max_connections < 0: 

1388 raise ValueError('"max_connections" must be a positive integer') 

1389 

1390 self.connection_class = connection_class 

1391 self.connection_kwargs = connection_kwargs 

1392 self.max_connections = max_connections 

1393 self.cache = None 

1394 self._cache_factory = cache_factory 

1395 

1396 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): 

1397 if connection_kwargs.get("protocol") not in [3, "3"]: 

1398 raise RedisError("Client caching is only supported with RESP version 3") 

1399 

1400 cache = self.connection_kwargs.get("cache") 

1401 

1402 if cache is not None: 

1403 if not isinstance(cache, CacheInterface): 

1404 raise ValueError("Cache must implement CacheInterface") 

1405 

1406 self.cache = cache 

1407 else: 

1408 if self._cache_factory is not None: 

1409 self.cache = self._cache_factory.get_cache() 

1410 else: 

1411 self.cache = CacheFactory( 

1412 self.connection_kwargs.get("cache_config") 

1413 ).get_cache() 

1414 

1415 connection_kwargs.pop("cache", None) 

1416 connection_kwargs.pop("cache_config", None) 

1417 

1418 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1419 if self._event_dispatcher is None: 

1420 self._event_dispatcher = EventDispatcher() 

1421 

1422 # a lock to protect the critical section in _checkpid(). 

1423 # this lock is acquired when the process id changes, such as 

1424 # after a fork. during this time, multiple threads in the child 

1425 # process could attempt to acquire this lock. the first thread 

1426 # to acquire the lock will reset the data structures and lock 

1427 # object of this pool. subsequent threads acquiring this lock 

1428 # will notice the first thread already did the work and simply 

1429 # release the lock. 

1430 

1431 self._fork_lock = threading.RLock() 

1432 self._lock = threading.RLock() 

1433 

1434 self.reset() 

1435 

1436 def __repr__(self) -> str: 

1437 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

1438 return ( 

1439 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

1440 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

1441 f"({conn_kwargs})>)>" 

1442 ) 

1443 

1444 def get_protocol(self): 

1445 """ 

1446 Returns: 

1447 The RESP protocol version, or ``None`` if the protocol is not specified, 

1448 in which case the server default will be used. 

1449 """ 

1450 return self.connection_kwargs.get("protocol", None) 

1451 

1452 def reset(self) -> None: 

1453 self._created_connections = 0 

1454 self._available_connections = [] 

1455 self._in_use_connections = set() 

1456 

1457 # this must be the last operation in this method. while reset() is 

1458 # called when holding _fork_lock, other threads in this process 

1459 # can call _checkpid() which compares self.pid and os.getpid() without 

1460 # holding any lock (for performance reasons). keeping this assignment 

1461 # as the last operation ensures that those other threads will also 

1462 # notice a pid difference and block waiting for the first thread to 

1463 # release _fork_lock. when each of these threads eventually acquire 

1464 # _fork_lock, they will notice that another thread already called 

1465 # reset() and they will immediately release _fork_lock and continue on. 

1466 self.pid = os.getpid() 

1467 

1468 def _checkpid(self) -> None: 

1469 # _checkpid() attempts to keep ConnectionPool fork-safe on modern 

1470 # systems. this is called by all ConnectionPool methods that 

1471 # manipulate the pool's state such as get_connection() and release(). 

1472 # 

1473 # _checkpid() determines whether the process has forked by comparing 

1474 # the current process id to the process id saved on the ConnectionPool 

1475 # instance. if these values are the same, _checkpid() simply returns. 

1476 # 

1477 # when the process ids differ, _checkpid() assumes that the process 

1478 # has forked and that we're now running in the child process. the child 

1479 # process cannot use the parent's file descriptors (e.g., sockets). 

1480 # therefore, when _checkpid() sees the process id change, it calls 

1481 # reset() in order to reinitialize the child's ConnectionPool. this 

1482 # will cause the child to make all new connection objects. 

1483 # 

1484 # _checkpid() is protected by self._fork_lock to ensure that multiple 

1485 # threads in the child process do not call reset() multiple times. 

1486 # 

1487 # there is an extremely small chance this could fail in the following 

1488 # scenario: 

1489 # 1. process A calls _checkpid() for the first time and acquires 

1490 # self._fork_lock. 

1491 # 2. while holding self._fork_lock, process A forks (the fork() 

1492 # could happen in a different thread owned by process A) 

1493 # 3. process B (the forked child process) inherits the 

1494 # ConnectionPool's state from the parent. that state includes 

1495 # a locked _fork_lock. process B will not be notified when 

1496 # process A releases the _fork_lock and will thus never be 

1497 # able to acquire the _fork_lock. 

1498 # 

1499 # to mitigate this possible deadlock, _checkpid() will only wait 5 

1500 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in 

1501 # that time it is assumed that the child is deadlocked and a 

1502 # redis.ChildDeadlockedError error is raised. 

1503 if self.pid != os.getpid(): 

1504 acquired = self._fork_lock.acquire(timeout=5) 

1505 if not acquired: 

1506 raise ChildDeadlockedError 

1507 # reset() the instance for the new process if another thread 

1508 # hasn't already done so 

1509 try: 

1510 if self.pid != os.getpid(): 

1511 self.reset() 

1512 finally: 

1513 self._fork_lock.release() 

1514 

1515 @deprecated_args( 

1516 args_to_warn=["*"], 

1517 reason="Use get_connection() without args instead", 

1518 version="5.3.0", 

1519 ) 

1520 def get_connection(self, command_name=None, *keys, **options) -> "Connection": 

1521 "Get a connection from the pool" 

1522 

1523 self._checkpid() 

1524 with self._lock: 

1525 try: 

1526 connection = self._available_connections.pop() 

1527 except IndexError: 

1528 connection = self.make_connection() 

1529 self._in_use_connections.add(connection) 

1530 

1531 try: 

1532 # ensure this connection is connected to Redis 

1533 connection.connect() 

1534 # connections that the pool provides should be ready to send 

1535 # a command. if not, the connection was either returned to the 

1536 # pool before all data has been read or the socket has been 

1537 # closed. either way, reconnect and verify everything is good. 

1538 try: 

1539 if connection.can_read() and self.cache is None: 

1540 raise ConnectionError("Connection has data") 

1541 except (ConnectionError, TimeoutError, OSError): 

1542 connection.disconnect() 

1543 connection.connect() 

1544 if connection.can_read(): 

1545 raise ConnectionError("Connection not ready") 

1546 except BaseException: 

1547 # release the connection back to the pool so that we don't 

1548 # leak it 

1549 self.release(connection) 

1550 raise 

1551 

1552 return connection 

1553 

1554 def get_encoder(self) -> Encoder: 

1555 "Return an encoder based on encoding settings" 

1556 kwargs = self.connection_kwargs 

1557 return Encoder( 

1558 encoding=kwargs.get("encoding", "utf-8"), 

1559 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1560 decode_responses=kwargs.get("decode_responses", False), 

1561 ) 

1562 

1563 def make_connection(self) -> "ConnectionInterface": 

1564 "Create a new connection" 

1565 if self._created_connections >= self.max_connections: 

1566 raise MaxConnectionsError("Too many connections") 

1567 self._created_connections += 1 

1568 

1569 if self.cache is not None: 

1570 return CacheProxyConnection( 

1571 self.connection_class(**self.connection_kwargs), self.cache, self._lock 

1572 ) 

1573 

1574 return self.connection_class(**self.connection_kwargs) 

1575 

1576 def release(self, connection: "Connection") -> None: 

1577 "Releases the connection back to the pool" 

1578 self._checkpid() 

1579 with self._lock: 

1580 try: 

1581 self._in_use_connections.remove(connection) 

1582 except KeyError: 

1583 # Gracefully fail when a connection is returned to this pool 

1584 # that the pool doesn't actually own 

1585 return 

1586 

1587 if self.owns_connection(connection): 

1588 self._available_connections.append(connection) 

1589 self._event_dispatcher.dispatch( 

1590 AfterConnectionReleasedEvent(connection) 

1591 ) 

1592 else: 

1593 # Pool doesn't own this connection, do not add it back 

1594 # to the pool. 

1595 # The created connections count should not be changed, 

1596 # because the connection was not created by the pool. 

1597 connection.disconnect() 

1598 return 

1599 

1600 def owns_connection(self, connection: "Connection") -> int: 

1601 return connection.pid == self.pid 

1602 

1603 def disconnect(self, inuse_connections: bool = True) -> None: 

1604 """ 

1605 Disconnects connections in the pool 

1606 

1607 If ``inuse_connections`` is True, disconnect connections that are 

1608 current in use, potentially by other threads. Otherwise only disconnect 

1609 connections that are idle in the pool. 

1610 """ 

1611 self._checkpid() 

1612 with self._lock: 

1613 if inuse_connections: 

1614 connections = chain( 

1615 self._available_connections, self._in_use_connections 

1616 ) 

1617 else: 

1618 connections = self._available_connections 

1619 

1620 for connection in connections: 

1621 connection.disconnect() 

1622 

1623 def close(self) -> None: 

1624 """Close the pool, disconnecting all connections""" 

1625 self.disconnect() 

1626 

1627 def set_retry(self, retry: Retry) -> None: 

1628 self.connection_kwargs.update({"retry": retry}) 

1629 for conn in self._available_connections: 

1630 conn.retry = retry 

1631 for conn in self._in_use_connections: 

1632 conn.retry = retry 

1633 

1634 def re_auth_callback(self, token: TokenInterface): 

1635 with self._lock: 

1636 for conn in self._available_connections: 

1637 conn.retry.call_with_retry( 

1638 lambda: conn.send_command( 

1639 "AUTH", token.try_get("oid"), token.get_value() 

1640 ), 

1641 lambda error: self._mock(error), 

1642 ) 

1643 conn.retry.call_with_retry( 

1644 lambda: conn.read_response(), lambda error: self._mock(error) 

1645 ) 

1646 for conn in self._in_use_connections: 

1647 conn.set_re_auth_token(token) 

1648 

1649 async def _mock(self, error: RedisError): 

1650 """ 

1651 Dummy functions, needs to be passed as error callback to retry object. 

1652 :param error: 

1653 :return: 

1654 """ 

1655 pass 

1656 

1657 

1658class BlockingConnectionPool(ConnectionPool): 

1659 """ 

1660 Thread-safe blocking connection pool:: 

1661 

1662 >>> from redis.client import Redis 

1663 >>> client = Redis(connection_pool=BlockingConnectionPool()) 

1664 

1665 It performs the same function as the default 

1666 :py:class:`~redis.ConnectionPool` implementation, in that, 

1667 it maintains a pool of reusable connections that can be shared by 

1668 multiple redis clients (safely across threads if required). 

1669 

1670 The difference is that, in the event that a client tries to get a 

1671 connection from the pool when all of connections are in use, rather than 

1672 raising a :py:class:`~redis.ConnectionError` (as the default 

1673 :py:class:`~redis.ConnectionPool` implementation does), it 

1674 makes the client wait ("blocks") for a specified number of seconds until 

1675 a connection becomes available. 

1676 

1677 Use ``max_connections`` to increase / decrease the pool size:: 

1678 

1679 >>> pool = BlockingConnectionPool(max_connections=10) 

1680 

1681 Use ``timeout`` to tell it either how many seconds to wait for a connection 

1682 to become available, or to block forever: 

1683 

1684 >>> # Block forever. 

1685 >>> pool = BlockingConnectionPool(timeout=None) 

1686 

1687 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

1688 >>> # not available. 

1689 >>> pool = BlockingConnectionPool(timeout=5) 

1690 """ 

1691 

1692 def __init__( 

1693 self, 

1694 max_connections=50, 

1695 timeout=20, 

1696 connection_class=Connection, 

1697 queue_class=LifoQueue, 

1698 **connection_kwargs, 

1699 ): 

1700 self.queue_class = queue_class 

1701 self.timeout = timeout 

1702 super().__init__( 

1703 connection_class=connection_class, 

1704 max_connections=max_connections, 

1705 **connection_kwargs, 

1706 ) 

1707 

1708 def reset(self): 

1709 # Create and fill up a thread safe queue with ``None`` values. 

1710 self.pool = self.queue_class(self.max_connections) 

1711 while True: 

1712 try: 

1713 self.pool.put_nowait(None) 

1714 except Full: 

1715 break 

1716 

1717 # Keep a list of actual connection instances so that we can 

1718 # disconnect them later. 

1719 self._connections = [] 

1720 

1721 # this must be the last operation in this method. while reset() is 

1722 # called when holding _fork_lock, other threads in this process 

1723 # can call _checkpid() which compares self.pid and os.getpid() without 

1724 # holding any lock (for performance reasons). keeping this assignment 

1725 # as the last operation ensures that those other threads will also 

1726 # notice a pid difference and block waiting for the first thread to 

1727 # release _fork_lock. when each of these threads eventually acquire 

1728 # _fork_lock, they will notice that another thread already called 

1729 # reset() and they will immediately release _fork_lock and continue on. 

1730 self.pid = os.getpid() 

1731 

1732 def make_connection(self): 

1733 "Make a fresh connection." 

1734 if self.cache is not None: 

1735 connection = CacheProxyConnection( 

1736 self.connection_class(**self.connection_kwargs), self.cache, self._lock 

1737 ) 

1738 else: 

1739 connection = self.connection_class(**self.connection_kwargs) 

1740 self._connections.append(connection) 

1741 return connection 

1742 

1743 @deprecated_args( 

1744 args_to_warn=["*"], 

1745 reason="Use get_connection() without args instead", 

1746 version="5.3.0", 

1747 ) 

1748 def get_connection(self, command_name=None, *keys, **options): 

1749 """ 

1750 Get a connection, blocking for ``self.timeout`` until a connection 

1751 is available from the pool. 

1752 

1753 If the connection returned is ``None`` then creates a new connection. 

1754 Because we use a last-in first-out queue, the existing connections 

1755 (having been returned to the pool after the initial ``None`` values 

1756 were added) will be returned before ``None`` values. This means we only 

1757 create new connections when we need to, i.e.: the actual number of 

1758 connections will only increase in response to demand. 

1759 """ 

1760 # Make sure we haven't changed process. 

1761 self._checkpid() 

1762 

1763 # Try and get a connection from the pool. If one isn't available within 

1764 # self.timeout then raise a ``ConnectionError``. 

1765 connection = None 

1766 try: 

1767 connection = self.pool.get(block=True, timeout=self.timeout) 

1768 except Empty: 

1769 # Note that this is not caught by the redis client and will be 

1770 # raised unless handled by application code. If you want never to 

1771 raise ConnectionError("No connection available.") 

1772 

1773 # If the ``connection`` is actually ``None`` then that's a cue to make 

1774 # a new connection to add to the pool. 

1775 if connection is None: 

1776 connection = self.make_connection() 

1777 

1778 try: 

1779 # ensure this connection is connected to Redis 

1780 connection.connect() 

1781 # connections that the pool provides should be ready to send 

1782 # a command. if not, the connection was either returned to the 

1783 # pool before all data has been read or the socket has been 

1784 # closed. either way, reconnect and verify everything is good. 

1785 try: 

1786 if connection.can_read(): 

1787 raise ConnectionError("Connection has data") 

1788 except (ConnectionError, TimeoutError, OSError): 

1789 connection.disconnect() 

1790 connection.connect() 

1791 if connection.can_read(): 

1792 raise ConnectionError("Connection not ready") 

1793 except BaseException: 

1794 # release the connection back to the pool so that we don't leak it 

1795 self.release(connection) 

1796 raise 

1797 

1798 return connection 

1799 

1800 def release(self, connection): 

1801 "Releases the connection back to the pool." 

1802 # Make sure we haven't changed process. 

1803 self._checkpid() 

1804 if not self.owns_connection(connection): 

1805 # pool doesn't own this connection. do not add it back 

1806 # to the pool. instead add a None value which is a placeholder 

1807 # that will cause the pool to recreate the connection if 

1808 # its needed. 

1809 connection.disconnect() 

1810 self.pool.put_nowait(None) 

1811 return 

1812 

1813 # Put the connection back into the pool. 

1814 try: 

1815 self.pool.put_nowait(connection) 

1816 except Full: 

1817 # perhaps the pool has been reset() after a fork? regardless, 

1818 # we don't want this connection 

1819 pass 

1820 

1821 def disconnect(self): 

1822 "Disconnects all connections in the pool." 

1823 self._checkpid() 

1824 for connection in self._connections: 

1825 connection.disconnect()