Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1187 statements  

1import copy 

2import os 

3import socket 

4import sys 

5import threading 

6import time 

7import weakref 

8from abc import abstractmethod 

9from itertools import chain 

10from queue import Empty, Full, LifoQueue 

11from typing import ( 

12 Any, 

13 Callable, 

14 Dict, 

15 Iterable, 

16 List, 

17 Literal, 

18 Optional, 

19 Type, 

20 TypeVar, 

21 Union, 

22) 

23from urllib.parse import parse_qs, unquote, urlparse 

24 

25from redis.cache import ( 

26 CacheEntry, 

27 CacheEntryStatus, 

28 CacheFactory, 

29 CacheFactoryInterface, 

30 CacheInterface, 

31 CacheKey, 

32) 

33 

34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser 

35from .auth.token import TokenInterface 

36from .backoff import NoBackoff 

37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider 

38from .event import AfterConnectionReleasedEvent, EventDispatcher 

39from .exceptions import ( 

40 AuthenticationError, 

41 AuthenticationWrongNumberOfArgsError, 

42 ChildDeadlockedError, 

43 ConnectionError, 

44 DataError, 

45 MaxConnectionsError, 

46 RedisError, 

47 ResponseError, 

48 TimeoutError, 

49) 

50from .maintenance_events import ( 

51 MaintenanceEventConnectionHandler, 

52 MaintenanceEventPoolHandler, 

53 MaintenanceEventsConfig, 

54 MaintenanceState, 

55) 

56from .retry import Retry 

57from .utils import ( 

58 CRYPTOGRAPHY_AVAILABLE, 

59 HIREDIS_AVAILABLE, 

60 SSL_AVAILABLE, 

61 compare_versions, 

62 deprecated_args, 

63 ensure_string, 

64 format_error_message, 

65 get_lib_version, 

66 str_if_bytes, 

67) 

68 

69if SSL_AVAILABLE: 

70 import ssl 

71else: 

72 ssl = None 

73 

74if HIREDIS_AVAILABLE: 

75 import hiredis 

76 

77SYM_STAR = b"*" 

78SYM_DOLLAR = b"$" 

79SYM_CRLF = b"\r\n" 

80SYM_EMPTY = b"" 

81 

82DEFAULT_RESP_VERSION = 2 

83 

84SENTINEL = object() 

85 

86DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] 

87if HIREDIS_AVAILABLE: 

88 DefaultParser = _HiredisParser 

89else: 

90 DefaultParser = _RESP2Parser 

91 

92 

93class HiredisRespSerializer: 

94 def pack(self, *args: List): 

95 """Pack a series of arguments into the Redis protocol""" 

96 output = [] 

97 

98 if isinstance(args[0], str): 

99 args = tuple(args[0].encode().split()) + args[1:] 

100 elif b" " in args[0]: 

101 args = tuple(args[0].split()) + args[1:] 

102 try: 

103 output.append(hiredis.pack_command(args)) 

104 except TypeError: 

105 _, value, traceback = sys.exc_info() 

106 raise DataError(value).with_traceback(traceback) 

107 

108 return output 

109 

110 

111class PythonRespSerializer: 

112 def __init__(self, buffer_cutoff, encode) -> None: 

113 self._buffer_cutoff = buffer_cutoff 

114 self.encode = encode 

115 

116 def pack(self, *args): 

117 """Pack a series of arguments into the Redis protocol""" 

118 output = [] 

119 # the client might have included 1 or more literal arguments in 

120 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

121 # arguments to be sent separately, so split the first argument 

122 # manually. These arguments should be bytestrings so that they are 

123 # not encoded. 

124 if isinstance(args[0], str): 

125 args = tuple(args[0].encode().split()) + args[1:] 

126 elif b" " in args[0]: 

127 args = tuple(args[0].split()) + args[1:] 

128 

129 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

130 

131 buffer_cutoff = self._buffer_cutoff 

132 for arg in map(self.encode, args): 

133 # to avoid large string mallocs, chunk the command into the 

134 # output list if we're sending large values or memoryviews 

135 arg_length = len(arg) 

136 if ( 

137 len(buff) > buffer_cutoff 

138 or arg_length > buffer_cutoff 

139 or isinstance(arg, memoryview) 

140 ): 

141 buff = SYM_EMPTY.join( 

142 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

143 ) 

144 output.append(buff) 

145 output.append(arg) 

146 buff = SYM_CRLF 

147 else: 

148 buff = SYM_EMPTY.join( 

149 ( 

150 buff, 

151 SYM_DOLLAR, 

152 str(arg_length).encode(), 

153 SYM_CRLF, 

154 arg, 

155 SYM_CRLF, 

156 ) 

157 ) 

158 output.append(buff) 

159 return output 

160 

161 

162class ConnectionInterface: 

163 @abstractmethod 

164 def repr_pieces(self): 

165 pass 

166 

167 @abstractmethod 

168 def register_connect_callback(self, callback): 

169 pass 

170 

171 @abstractmethod 

172 def deregister_connect_callback(self, callback): 

173 pass 

174 

175 @abstractmethod 

176 def set_parser(self, parser_class): 

177 pass 

178 

179 @abstractmethod 

180 def set_maintenance_event_pool_handler(self, maintenance_event_pool_handler): 

181 pass 

182 

183 @abstractmethod 

184 def get_protocol(self): 

185 pass 

186 

187 @abstractmethod 

188 def connect(self): 

189 pass 

190 

191 @abstractmethod 

192 def on_connect(self): 

193 pass 

194 

195 @abstractmethod 

196 def disconnect(self, *args): 

197 pass 

198 

199 @abstractmethod 

200 def check_health(self): 

201 pass 

202 

203 @abstractmethod 

204 def send_packed_command(self, command, check_health=True): 

205 pass 

206 

207 @abstractmethod 

208 def send_command(self, *args, **kwargs): 

209 pass 

210 

211 @abstractmethod 

212 def can_read(self, timeout=0): 

213 pass 

214 

215 @abstractmethod 

216 def read_response( 

217 self, 

218 disable_decoding=False, 

219 *, 

220 disconnect_on_error=True, 

221 push_request=False, 

222 ): 

223 pass 

224 

225 @abstractmethod 

226 def pack_command(self, *args): 

227 pass 

228 

229 @abstractmethod 

230 def pack_commands(self, commands): 

231 pass 

232 

233 @property 

234 @abstractmethod 

235 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

236 pass 

237 

238 @abstractmethod 

239 def set_re_auth_token(self, token: TokenInterface): 

240 pass 

241 

242 @abstractmethod 

243 def re_auth(self): 

244 pass 

245 

246 @property 

247 @abstractmethod 

248 def maintenance_state(self) -> MaintenanceState: 

249 """ 

250 Returns the current maintenance state of the connection. 

251 """ 

252 pass 

253 

254 @maintenance_state.setter 

255 @abstractmethod 

256 def maintenance_state(self, state: "MaintenanceState"): 

257 """ 

258 Sets the current maintenance state of the connection. 

259 """ 

260 pass 

261 

262 @abstractmethod 

263 def getpeername(self): 

264 """ 

265 Returns the peer name of the connection. 

266 """ 

267 pass 

268 

269 @abstractmethod 

270 def mark_for_reconnect(self): 

271 """ 

272 Mark the connection to be reconnected on the next command. 

273 This is useful when a connection is moved to a different node. 

274 """ 

275 pass 

276 

277 @abstractmethod 

278 def should_reconnect(self): 

279 """ 

280 Returns True if the connection should be reconnected. 

281 """ 

282 pass 

283 

284 @abstractmethod 

285 def get_resolved_ip(self): 

286 """ 

287 Get resolved ip address for the connection. 

288 """ 

289 pass 

290 

291 @abstractmethod 

292 def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): 

293 """ 

294 Update the timeout for the current socket. 

295 """ 

296 pass 

297 

298 @abstractmethod 

299 def set_tmp_settings( 

300 self, 

301 tmp_host_address: Optional[str] = None, 

302 tmp_relax_timeout: Optional[float] = None, 

303 ): 

304 """ 

305 Updates temporary host address and timeout settings for the connection. 

306 """ 

307 pass 

308 

309 @abstractmethod 

310 def reset_tmp_settings( 

311 self, 

312 reset_host_address: bool = False, 

313 reset_relax_timeout: bool = False, 

314 ): 

315 """ 

316 Resets temporary host address and timeout settings for the connection. 

317 """ 

318 pass 

319 

320 

321class AbstractConnection(ConnectionInterface): 

322 "Manages communication to and from a Redis server" 

323 

324 def __init__( 

325 self, 

326 db: int = 0, 

327 password: Optional[str] = None, 

328 socket_timeout: Optional[float] = None, 

329 socket_connect_timeout: Optional[float] = None, 

330 retry_on_timeout: bool = False, 

331 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL, 

332 encoding: str = "utf-8", 

333 encoding_errors: str = "strict", 

334 decode_responses: bool = False, 

335 parser_class=DefaultParser, 

336 socket_read_size: int = 65536, 

337 health_check_interval: int = 0, 

338 client_name: Optional[str] = None, 

339 lib_name: Optional[str] = "redis-py", 

340 lib_version: Optional[str] = get_lib_version(), 

341 username: Optional[str] = None, 

342 retry: Union[Any, None] = None, 

343 redis_connect_func: Optional[Callable[[], None]] = None, 

344 credential_provider: Optional[CredentialProvider] = None, 

345 protocol: Optional[int] = 2, 

346 command_packer: Optional[Callable[[], None]] = None, 

347 event_dispatcher: Optional[EventDispatcher] = None, 

348 maintenance_events_pool_handler: Optional[MaintenanceEventPoolHandler] = None, 

349 maintenance_events_config: Optional[MaintenanceEventsConfig] = None, 

350 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

351 maintenance_event_hash: Optional[int] = None, 

352 orig_host_address: Optional[str] = None, 

353 orig_socket_timeout: Optional[float] = None, 

354 orig_socket_connect_timeout: Optional[float] = None, 

355 ): 

356 """ 

357 Initialize a new Connection. 

358 To specify a retry policy for specific errors, first set 

359 `retry_on_error` to a list of the error/s to retry on, then set 

360 `retry` to a valid `Retry` object. 

361 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

362 """ 

363 if (username or password) and credential_provider is not None: 

364 raise DataError( 

365 "'username' and 'password' cannot be passed along with 'credential_" 

366 "provider'. Please provide only one of the following arguments: \n" 

367 "1. 'password' and (optional) 'username'\n" 

368 "2. 'credential_provider'" 

369 ) 

370 if event_dispatcher is None: 

371 self._event_dispatcher = EventDispatcher() 

372 else: 

373 self._event_dispatcher = event_dispatcher 

374 self.pid = os.getpid() 

375 self.db = db 

376 self.client_name = client_name 

377 self.lib_name = lib_name 

378 self.lib_version = lib_version 

379 self.credential_provider = credential_provider 

380 self.password = password 

381 self.username = username 

382 self.socket_timeout = socket_timeout 

383 if socket_connect_timeout is None: 

384 socket_connect_timeout = socket_timeout 

385 self.socket_connect_timeout = socket_connect_timeout 

386 self.retry_on_timeout = retry_on_timeout 

387 if retry_on_error is SENTINEL: 

388 retry_on_errors_list = [] 

389 else: 

390 retry_on_errors_list = list(retry_on_error) 

391 if retry_on_timeout: 

392 # Add TimeoutError to the errors list to retry on 

393 retry_on_errors_list.append(TimeoutError) 

394 self.retry_on_error = retry_on_errors_list 

395 if retry or self.retry_on_error: 

396 if retry is None: 

397 self.retry = Retry(NoBackoff(), 1) 

398 else: 

399 # deep-copy the Retry object as it is mutable 

400 self.retry = copy.deepcopy(retry) 

401 if self.retry_on_error: 

402 # Update the retry's supported errors with the specified errors 

403 self.retry.update_supported_errors(self.retry_on_error) 

404 else: 

405 self.retry = Retry(NoBackoff(), 0) 

406 self.health_check_interval = health_check_interval 

407 self.next_health_check = 0 

408 self.redis_connect_func = redis_connect_func 

409 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

410 self.handshake_metadata = None 

411 self._sock = None 

412 self._socket_read_size = socket_read_size 

413 self._connect_callbacks = [] 

414 self._buffer_cutoff = 6000 

415 self._re_auth_token: Optional[TokenInterface] = None 

416 try: 

417 p = int(protocol) 

418 except TypeError: 

419 p = DEFAULT_RESP_VERSION 

420 except ValueError: 

421 raise ConnectionError("protocol must be an integer") 

422 finally: 

423 if p < 2 or p > 3: 

424 raise ConnectionError("protocol must be either 2 or 3") 

425 # p = DEFAULT_RESP_VERSION 

426 self.protocol = p 

427 if self.protocol == 3 and parser_class == DefaultParser: 

428 parser_class = _RESP3Parser 

429 self.set_parser(parser_class) 

430 

431 self.maintenance_events_config = maintenance_events_config 

432 

433 # Set up maintenance events if enabled 

434 self._configure_maintenance_events( 

435 maintenance_events_pool_handler, 

436 orig_host_address, 

437 orig_socket_timeout, 

438 orig_socket_connect_timeout, 

439 ) 

440 

441 self._should_reconnect = False 

442 self.maintenance_state = maintenance_state 

443 self.maintenance_event_hash = maintenance_event_hash 

444 

445 self._command_packer = self._construct_command_packer(command_packer) 

446 

447 def __repr__(self): 

448 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) 

449 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

450 

451 @abstractmethod 

452 def repr_pieces(self): 

453 pass 

454 

455 def __del__(self): 

456 try: 

457 self.disconnect() 

458 except Exception: 

459 pass 

460 

461 def _construct_command_packer(self, packer): 

462 if packer is not None: 

463 return packer 

464 elif HIREDIS_AVAILABLE: 

465 return HiredisRespSerializer() 

466 else: 

467 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) 

468 

469 def register_connect_callback(self, callback): 

470 """ 

471 Register a callback to be called when the connection is established either 

472 initially or reconnected. This allows listeners to issue commands that 

473 are ephemeral to the connection, for example pub/sub subscription or 

474 key tracking. The callback must be a _method_ and will be kept as 

475 a weak reference. 

476 """ 

477 wm = weakref.WeakMethod(callback) 

478 if wm not in self._connect_callbacks: 

479 self._connect_callbacks.append(wm) 

480 

481 def deregister_connect_callback(self, callback): 

482 """ 

483 De-register a previously registered callback. It will no-longer receive 

484 notifications on connection events. Calling this is not required when the 

485 listener goes away, since the callbacks are kept as weak methods. 

486 """ 

487 try: 

488 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

489 except ValueError: 

490 pass 

491 

492 def set_parser(self, parser_class): 

493 """ 

494 Creates a new instance of parser_class with socket size: 

495 _socket_read_size and assigns it to the parser for the connection 

496 :param parser_class: The required parser class 

497 """ 

498 self._parser = parser_class(socket_read_size=self._socket_read_size) 

499 

500 def _configure_maintenance_events( 

501 self, 

502 maintenance_events_pool_handler=None, 

503 orig_host_address=None, 

504 orig_socket_timeout=None, 

505 orig_socket_connect_timeout=None, 

506 ): 

507 """Enable maintenance events by setting up handlers and storing original connection parameters.""" 

508 if ( 

509 not self.maintenance_events_config 

510 or not self.maintenance_events_config.enabled 

511 ): 

512 self._maintenance_event_connection_handler = None 

513 return 

514 

515 # Set up pool handler if available 

516 if maintenance_events_pool_handler: 

517 self._parser.set_node_moving_push_handler( 

518 maintenance_events_pool_handler.handle_event 

519 ) 

520 

521 # Set up connection handler 

522 self._maintenance_event_connection_handler = MaintenanceEventConnectionHandler( 

523 self, self.maintenance_events_config 

524 ) 

525 self._parser.set_maintenance_push_handler( 

526 self._maintenance_event_connection_handler.handle_event 

527 ) 

528 

529 # Store original connection parameters 

530 self.orig_host_address = orig_host_address if orig_host_address else self.host 

531 self.orig_socket_timeout = ( 

532 orig_socket_timeout if orig_socket_timeout else self.socket_timeout 

533 ) 

534 self.orig_socket_connect_timeout = ( 

535 orig_socket_connect_timeout 

536 if orig_socket_connect_timeout 

537 else self.socket_connect_timeout 

538 ) 

539 

540 def set_maintenance_event_pool_handler( 

541 self, maintenance_event_pool_handler: MaintenanceEventPoolHandler 

542 ): 

543 maintenance_event_pool_handler.set_connection(self) 

544 self._parser.set_node_moving_push_handler( 

545 maintenance_event_pool_handler.handle_event 

546 ) 

547 

548 # Update maintenance event connection handler if it doesn't exist 

549 if not self._maintenance_event_connection_handler: 

550 self._maintenance_event_connection_handler = ( 

551 MaintenanceEventConnectionHandler( 

552 self, maintenance_event_pool_handler.config 

553 ) 

554 ) 

555 self._parser.set_maintenance_push_handler( 

556 self._maintenance_event_connection_handler.handle_event 

557 ) 

558 else: 

559 self._maintenance_event_connection_handler.config = ( 

560 maintenance_event_pool_handler.config 

561 ) 

562 

563 def connect(self): 

564 "Connects to the Redis server if not already connected" 

565 self.connect_check_health(check_health=True) 

566 

567 def connect_check_health( 

568 self, check_health: bool = True, retry_socket_connect: bool = True 

569 ): 

570 if self._sock: 

571 return 

572 try: 

573 if retry_socket_connect: 

574 sock = self.retry.call_with_retry( 

575 lambda: self._connect(), lambda error: self.disconnect(error) 

576 ) 

577 else: 

578 sock = self._connect() 

579 except socket.timeout: 

580 raise TimeoutError("Timeout connecting to server") 

581 except OSError as e: 

582 raise ConnectionError(self._error_message(e)) 

583 

584 self._sock = sock 

585 try: 

586 if self.redis_connect_func is None: 

587 # Use the default on_connect function 

588 self.on_connect_check_health(check_health=check_health) 

589 else: 

590 # Use the passed function redis_connect_func 

591 self.redis_connect_func(self) 

592 except RedisError: 

593 # clean up after any error in on_connect 

594 self.disconnect() 

595 raise 

596 

597 # run any user callbacks. right now the only internal callback 

598 # is for pubsub channel/pattern resubscription 

599 # first, remove any dead weakrefs 

600 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

601 for ref in self._connect_callbacks: 

602 callback = ref() 

603 if callback: 

604 callback(self) 

605 

606 @abstractmethod 

607 def _connect(self): 

608 pass 

609 

610 @abstractmethod 

611 def _host_error(self): 

612 pass 

613 

614 def _error_message(self, exception): 

615 return format_error_message(self._host_error(), exception) 

616 

617 def on_connect(self): 

618 self.on_connect_check_health(check_health=True) 

619 

620 def on_connect_check_health(self, check_health: bool = True): 

621 "Initialize the connection, authenticate and select a database" 

622 self._parser.on_connect(self) 

623 parser = self._parser 

624 

625 auth_args = None 

626 # if credential provider or username and/or password are set, authenticate 

627 if self.credential_provider or (self.username or self.password): 

628 cred_provider = ( 

629 self.credential_provider 

630 or UsernamePasswordCredentialProvider(self.username, self.password) 

631 ) 

632 auth_args = cred_provider.get_credentials() 

633 

634 # if resp version is specified and we have auth args, 

635 # we need to send them via HELLO 

636 if auth_args and self.protocol not in [2, "2"]: 

637 if isinstance(self._parser, _RESP2Parser): 

638 self.set_parser(_RESP3Parser) 

639 # update cluster exception classes 

640 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

641 self._parser.on_connect(self) 

642 if len(auth_args) == 1: 

643 auth_args = ["default", auth_args[0]] 

644 # avoid checking health here -- PING will fail if we try 

645 # to check the health prior to the AUTH 

646 self.send_command( 

647 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

648 ) 

649 self.handshake_metadata = self.read_response() 

650 # if response.get(b"proto") != self.protocol and response.get( 

651 # "proto" 

652 # ) != self.protocol: 

653 # raise ConnectionError("Invalid RESP version") 

654 elif auth_args: 

655 # avoid checking health here -- PING will fail if we try 

656 # to check the health prior to the AUTH 

657 self.send_command("AUTH", *auth_args, check_health=False) 

658 

659 try: 

660 auth_response = self.read_response() 

661 except AuthenticationWrongNumberOfArgsError: 

662 # a username and password were specified but the Redis 

663 # server seems to be < 6.0.0 which expects a single password 

664 # arg. retry auth with just the password. 

665 # https://github.com/andymccurdy/redis-py/issues/1274 

666 self.send_command("AUTH", auth_args[-1], check_health=False) 

667 auth_response = self.read_response() 

668 

669 if str_if_bytes(auth_response) != "OK": 

670 raise AuthenticationError("Invalid Username or Password") 

671 

672 # if resp version is specified, switch to it 

673 elif self.protocol not in [2, "2"]: 

674 if isinstance(self._parser, _RESP2Parser): 

675 self.set_parser(_RESP3Parser) 

676 # update cluster exception classes 

677 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

678 self._parser.on_connect(self) 

679 self.send_command("HELLO", self.protocol, check_health=check_health) 

680 self.handshake_metadata = self.read_response() 

681 if ( 

682 self.handshake_metadata.get(b"proto") != self.protocol 

683 and self.handshake_metadata.get("proto") != self.protocol 

684 ): 

685 raise ConnectionError("Invalid RESP version") 

686 

687 # Send maintenance notifications handshake if RESP3 is active and maintenance events are enabled 

688 # and we have a host to determine the endpoint type from 

689 if ( 

690 self.protocol not in [2, "2"] 

691 and self.maintenance_events_config 

692 and self.maintenance_events_config.enabled 

693 and self._maintenance_event_connection_handler 

694 and hasattr(self, "host") 

695 ): 

696 try: 

697 endpoint_type = self.maintenance_events_config.get_endpoint_type( 

698 self.host, self 

699 ) 

700 self.send_command( 

701 "CLIENT", 

702 "MAINT_NOTIFICATIONS", 

703 "ON", 

704 "moving-endpoint-type", 

705 endpoint_type.value, 

706 check_health=check_health, 

707 ) 

708 response = self.read_response() 

709 if str_if_bytes(response) != "OK": 

710 raise ConnectionError( 

711 "The server doesn't support maintenance notifications" 

712 ) 

713 except Exception as e: 

714 # Log warning but don't fail the connection 

715 import logging 

716 

717 logger = logging.getLogger(__name__) 

718 logger.warning(f"Failed to enable maintenance notifications: {e}") 

719 

720 # if a client_name is given, set it 

721 if self.client_name: 

722 self.send_command( 

723 "CLIENT", 

724 "SETNAME", 

725 self.client_name, 

726 check_health=check_health, 

727 ) 

728 if str_if_bytes(self.read_response()) != "OK": 

729 raise ConnectionError("Error setting client name") 

730 

731 try: 

732 # set the library name and version 

733 if self.lib_name: 

734 self.send_command( 

735 "CLIENT", 

736 "SETINFO", 

737 "LIB-NAME", 

738 self.lib_name, 

739 check_health=check_health, 

740 ) 

741 self.read_response() 

742 except ResponseError: 

743 pass 

744 

745 try: 

746 if self.lib_version: 

747 self.send_command( 

748 "CLIENT", 

749 "SETINFO", 

750 "LIB-VER", 

751 self.lib_version, 

752 check_health=check_health, 

753 ) 

754 self.read_response() 

755 except ResponseError: 

756 pass 

757 

758 # if a database is specified, switch to it 

759 if self.db: 

760 self.send_command("SELECT", self.db, check_health=check_health) 

761 if str_if_bytes(self.read_response()) != "OK": 

762 raise ConnectionError("Invalid Database") 

763 

764 def disconnect(self, *args): 

765 "Disconnects from the Redis server" 

766 self._parser.on_disconnect() 

767 

768 conn_sock = self._sock 

769 self._sock = None 

770 # reset the reconnect flag 

771 self._should_reconnect = False 

772 if conn_sock is None: 

773 return 

774 

775 if os.getpid() == self.pid: 

776 try: 

777 conn_sock.shutdown(socket.SHUT_RDWR) 

778 except (OSError, TypeError): 

779 pass 

780 

781 try: 

782 conn_sock.close() 

783 except OSError: 

784 pass 

785 

786 def _send_ping(self): 

787 """Send PING, expect PONG in return""" 

788 self.send_command("PING", check_health=False) 

789 if str_if_bytes(self.read_response()) != "PONG": 

790 raise ConnectionError("Bad response from PING health check") 

791 

792 def _ping_failed(self, error): 

793 """Function to call when PING fails""" 

794 self.disconnect() 

795 

796 def check_health(self): 

797 """Check the health of the connection with a PING/PONG""" 

798 if self.health_check_interval and time.monotonic() > self.next_health_check: 

799 self.retry.call_with_retry(self._send_ping, self._ping_failed) 

800 

801 def send_packed_command(self, command, check_health=True): 

802 """Send an already packed command to the Redis server""" 

803 if not self._sock: 

804 self.connect_check_health(check_health=False) 

805 # guard against health check recursion 

806 if check_health: 

807 self.check_health() 

808 try: 

809 if isinstance(command, str): 

810 command = [command] 

811 for item in command: 

812 self._sock.sendall(item) 

813 except socket.timeout: 

814 self.disconnect() 

815 raise TimeoutError("Timeout writing to socket") 

816 except OSError as e: 

817 self.disconnect() 

818 if len(e.args) == 1: 

819 errno, errmsg = "UNKNOWN", e.args[0] 

820 else: 

821 errno = e.args[0] 

822 errmsg = e.args[1] 

823 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") 

824 except BaseException: 

825 # BaseExceptions can be raised when a socket send operation is not 

826 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

827 # to send un-sent data. However, the send_packed_command() API 

828 # does not support it so there is no point in keeping the connection open. 

829 self.disconnect() 

830 raise 

831 

832 def send_command(self, *args, **kwargs): 

833 """Pack and send a command to the Redis server""" 

834 self.send_packed_command( 

835 self._command_packer.pack(*args), 

836 check_health=kwargs.get("check_health", True), 

837 ) 

838 

839 def can_read(self, timeout=0): 

840 """Poll the socket to see if there's data that can be read.""" 

841 sock = self._sock 

842 if not sock: 

843 self.connect() 

844 

845 host_error = self._host_error() 

846 

847 try: 

848 return self._parser.can_read(timeout) 

849 

850 except OSError as e: 

851 self.disconnect() 

852 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

853 

854 def read_response( 

855 self, 

856 disable_decoding=False, 

857 *, 

858 disconnect_on_error=True, 

859 push_request=False, 

860 ): 

861 """Read the response from a previously sent command""" 

862 

863 host_error = self._host_error() 

864 

865 try: 

866 if self.protocol in ["3", 3]: 

867 response = self._parser.read_response( 

868 disable_decoding=disable_decoding, push_request=push_request 

869 ) 

870 else: 

871 response = self._parser.read_response(disable_decoding=disable_decoding) 

872 except socket.timeout: 

873 if disconnect_on_error: 

874 self.disconnect() 

875 raise TimeoutError(f"Timeout reading from {host_error}") 

876 except OSError as e: 

877 if disconnect_on_error: 

878 self.disconnect() 

879 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

880 except BaseException: 

881 # Also by default close in case of BaseException. A lot of code 

882 # relies on this behaviour when doing Command/Response pairs. 

883 # See #1128. 

884 if disconnect_on_error: 

885 self.disconnect() 

886 raise 

887 

888 if self.health_check_interval: 

889 self.next_health_check = time.monotonic() + self.health_check_interval 

890 

891 if isinstance(response, ResponseError): 

892 try: 

893 raise response 

894 finally: 

895 del response # avoid creating ref cycles 

896 return response 

897 

898 def pack_command(self, *args): 

899 """Pack a series of arguments into the Redis protocol""" 

900 return self._command_packer.pack(*args) 

901 

902 def pack_commands(self, commands): 

903 """Pack multiple commands into the Redis protocol""" 

904 output = [] 

905 pieces = [] 

906 buffer_length = 0 

907 buffer_cutoff = self._buffer_cutoff 

908 

909 for cmd in commands: 

910 for chunk in self._command_packer.pack(*cmd): 

911 chunklen = len(chunk) 

912 if ( 

913 buffer_length > buffer_cutoff 

914 or chunklen > buffer_cutoff 

915 or isinstance(chunk, memoryview) 

916 ): 

917 if pieces: 

918 output.append(SYM_EMPTY.join(pieces)) 

919 buffer_length = 0 

920 pieces = [] 

921 

922 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

923 output.append(chunk) 

924 else: 

925 pieces.append(chunk) 

926 buffer_length += chunklen 

927 

928 if pieces: 

929 output.append(SYM_EMPTY.join(pieces)) 

930 return output 

931 

932 def get_protocol(self) -> Union[int, str]: 

933 return self.protocol 

934 

935 @property 

936 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

937 return self._handshake_metadata 

938 

939 @handshake_metadata.setter 

940 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): 

941 self._handshake_metadata = value 

942 

943 def set_re_auth_token(self, token: TokenInterface): 

944 self._re_auth_token = token 

945 

946 def re_auth(self): 

947 if self._re_auth_token is not None: 

948 self.send_command( 

949 "AUTH", 

950 self._re_auth_token.try_get("oid"), 

951 self._re_auth_token.get_value(), 

952 ) 

953 self.read_response() 

954 self._re_auth_token = None 

955 

956 def get_resolved_ip(self) -> Optional[str]: 

957 """ 

958 Extract the resolved IP address from an 

959 established connection or resolve it from the host. 

960 

961 First tries to get the actual IP from the socket (most accurate), 

962 then falls back to DNS resolution if needed. 

963 

964 Args: 

965 connection: The connection object to extract the IP from 

966 

967 Returns: 

968 str: The resolved IP address, or None if it cannot be determined 

969 """ 

970 

971 # Method 1: Try to get the actual IP from the established socket connection 

972 # This is most accurate as it shows the exact IP being used 

973 try: 

974 if self._sock is not None: 

975 peer_addr = self._sock.getpeername() 

976 if peer_addr and len(peer_addr) >= 1: 

977 # For TCP sockets, peer_addr is typically (host, port) tuple 

978 # Return just the host part 

979 return peer_addr[0] 

980 except (AttributeError, OSError): 

981 # Socket might not be connected or getpeername() might fail 

982 pass 

983 

984 # Method 2: Fallback to DNS resolution of the host 

985 # This is less accurate but works when socket is not available 

986 try: 

987 host = getattr(self, "host", "localhost") 

988 port = getattr(self, "port", 6379) 

989 if host: 

990 # Use getaddrinfo to resolve the hostname to IP 

991 # This mimics what the connection would do during _connect() 

992 addr_info = socket.getaddrinfo( 

993 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM 

994 ) 

995 if addr_info: 

996 # Return the IP from the first result 

997 # addr_info[0] is (family, socktype, proto, canonname, sockaddr) 

998 # sockaddr[0] is the IP address 

999 return addr_info[0][4][0] 

1000 except (AttributeError, OSError, socket.gaierror): 

1001 # DNS resolution might fail 

1002 pass 

1003 

1004 return None 

1005 

1006 @property 

1007 def maintenance_state(self) -> MaintenanceState: 

1008 return self._maintenance_state 

1009 

1010 @maintenance_state.setter 

1011 def maintenance_state(self, state: "MaintenanceState"): 

1012 self._maintenance_state = state 

1013 

1014 def getpeername(self): 

1015 if not self._sock: 

1016 return None 

1017 return self._sock.getpeername()[0] 

1018 

1019 def mark_for_reconnect(self): 

1020 self._should_reconnect = True 

1021 

1022 def should_reconnect(self): 

1023 return self._should_reconnect 

1024 

1025 def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): 

1026 if self._sock: 

1027 timeout = relax_timeout if relax_timeout != -1 else self.socket_timeout 

1028 self._sock.settimeout(timeout) 

1029 self.update_parser_buffer_timeout(timeout) 

1030 

1031 def update_parser_buffer_timeout(self, timeout: Optional[float] = None): 

1032 if self._parser and self._parser._buffer: 

1033 self._parser._buffer.socket_timeout = timeout 

1034 

1035 def set_tmp_settings( 

1036 self, 

1037 tmp_host_address: Optional[Union[str, object]] = SENTINEL, 

1038 tmp_relax_timeout: Optional[float] = None, 

1039 ): 

1040 """ 

1041 The value of SENTINEL is used to indicate that the property should not be updated. 

1042 """ 

1043 if tmp_host_address is not SENTINEL: 

1044 self.host = tmp_host_address 

1045 if tmp_relax_timeout != -1: 

1046 self.socket_timeout = tmp_relax_timeout 

1047 self.socket_connect_timeout = tmp_relax_timeout 

1048 

1049 def reset_tmp_settings( 

1050 self, 

1051 reset_host_address: bool = False, 

1052 reset_relax_timeout: bool = False, 

1053 ): 

1054 if reset_host_address: 

1055 self.host = self.orig_host_address 

1056 if reset_relax_timeout: 

1057 self.socket_timeout = self.orig_socket_timeout 

1058 self.socket_connect_timeout = self.orig_socket_connect_timeout 

1059 

1060 

1061class Connection(AbstractConnection): 

1062 "Manages TCP communication to and from a Redis server" 

1063 

1064 def __init__( 

1065 self, 

1066 host="localhost", 

1067 port=6379, 

1068 socket_keepalive=False, 

1069 socket_keepalive_options=None, 

1070 socket_type=0, 

1071 **kwargs, 

1072 ): 

1073 self.host = host 

1074 self.port = int(port) 

1075 self.socket_keepalive = socket_keepalive 

1076 self.socket_keepalive_options = socket_keepalive_options or {} 

1077 self.socket_type = socket_type 

1078 super().__init__(**kwargs) 

1079 

1080 def repr_pieces(self): 

1081 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

1082 if self.client_name: 

1083 pieces.append(("client_name", self.client_name)) 

1084 return pieces 

1085 

1086 def _connect(self): 

1087 "Create a TCP socket connection" 

1088 # we want to mimic what socket.create_connection does to support 

1089 # ipv4/ipv6, but we want to set options prior to calling 

1090 # socket.connect() 

1091 err = None 

1092 

1093 for res in socket.getaddrinfo( 

1094 self.host, self.port, self.socket_type, socket.SOCK_STREAM 

1095 ): 

1096 family, socktype, proto, canonname, socket_address = res 

1097 sock = None 

1098 try: 

1099 sock = socket.socket(family, socktype, proto) 

1100 # TCP_NODELAY 

1101 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

1102 

1103 # TCP_KEEPALIVE 

1104 if self.socket_keepalive: 

1105 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

1106 for k, v in self.socket_keepalive_options.items(): 

1107 sock.setsockopt(socket.IPPROTO_TCP, k, v) 

1108 

1109 # set the socket_connect_timeout before we connect 

1110 sock.settimeout(self.socket_connect_timeout) 

1111 

1112 # connect 

1113 sock.connect(socket_address) 

1114 

1115 # set the socket_timeout now that we're connected 

1116 sock.settimeout(self.socket_timeout) 

1117 return sock 

1118 

1119 except OSError as _: 

1120 err = _ 

1121 if sock is not None: 

1122 try: 

1123 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1124 except OSError: 

1125 pass 

1126 sock.close() 

1127 

1128 if err is not None: 

1129 raise err 

1130 raise OSError("socket.getaddrinfo returned an empty list") 

1131 

1132 def _host_error(self): 

1133 return f"{self.host}:{self.port}" 

1134 

1135 

1136class CacheProxyConnection(ConnectionInterface): 

1137 DUMMY_CACHE_VALUE = b"foo" 

1138 MIN_ALLOWED_VERSION = "7.4.0" 

1139 DEFAULT_SERVER_NAME = "redis" 

1140 

1141 def __init__( 

1142 self, 

1143 conn: ConnectionInterface, 

1144 cache: CacheInterface, 

1145 pool_lock: threading.RLock, 

1146 ): 

1147 self.pid = os.getpid() 

1148 self._conn = conn 

1149 self.retry = self._conn.retry 

1150 self.host = self._conn.host 

1151 self.port = self._conn.port 

1152 self.credential_provider = conn.credential_provider 

1153 self._pool_lock = pool_lock 

1154 self._cache = cache 

1155 self._cache_lock = threading.RLock() 

1156 self._current_command_cache_key = None 

1157 self._current_options = None 

1158 self.register_connect_callback(self._enable_tracking_callback) 

1159 

1160 def repr_pieces(self): 

1161 return self._conn.repr_pieces() 

1162 

1163 def register_connect_callback(self, callback): 

1164 self._conn.register_connect_callback(callback) 

1165 

1166 def deregister_connect_callback(self, callback): 

1167 self._conn.deregister_connect_callback(callback) 

1168 

1169 def set_parser(self, parser_class): 

1170 self._conn.set_parser(parser_class) 

1171 

1172 def connect(self): 

1173 self._conn.connect() 

1174 

1175 server_name = self._conn.handshake_metadata.get(b"server", None) 

1176 if server_name is None: 

1177 server_name = self._conn.handshake_metadata.get("server", None) 

1178 server_ver = self._conn.handshake_metadata.get(b"version", None) 

1179 if server_ver is None: 

1180 server_ver = self._conn.handshake_metadata.get("version", None) 

1181 if server_ver is None or server_ver is None: 

1182 raise ConnectionError("Cannot retrieve information about server version") 

1183 

1184 server_ver = ensure_string(server_ver) 

1185 server_name = ensure_string(server_name) 

1186 

1187 if ( 

1188 server_name != self.DEFAULT_SERVER_NAME 

1189 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 

1190 ): 

1191 raise ConnectionError( 

1192 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 

1193 ) 

1194 

1195 def on_connect(self): 

1196 self._conn.on_connect() 

1197 

1198 def disconnect(self, *args): 

1199 with self._cache_lock: 

1200 self._cache.flush() 

1201 self._conn.disconnect(*args) 

1202 

1203 def check_health(self): 

1204 self._conn.check_health() 

1205 

1206 def send_packed_command(self, command, check_health=True): 

1207 # TODO: Investigate if it's possible to unpack command 

1208 # or extract keys from packed command 

1209 self._conn.send_packed_command(command) 

1210 

1211 def send_command(self, *args, **kwargs): 

1212 self._process_pending_invalidations() 

1213 

1214 with self._cache_lock: 

1215 # Command is write command or not allowed 

1216 # to be cached. 

1217 if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())): 

1218 self._current_command_cache_key = None 

1219 self._conn.send_command(*args, **kwargs) 

1220 return 

1221 

1222 if kwargs.get("keys") is None: 

1223 raise ValueError("Cannot create cache key.") 

1224 

1225 # Creates cache key. 

1226 self._current_command_cache_key = CacheKey( 

1227 command=args[0], redis_keys=tuple(kwargs.get("keys")) 

1228 ) 

1229 

1230 with self._cache_lock: 

1231 # We have to trigger invalidation processing in case if 

1232 # it was cached by another connection to avoid 

1233 # queueing invalidations in stale connections. 

1234 if self._cache.get(self._current_command_cache_key): 

1235 entry = self._cache.get(self._current_command_cache_key) 

1236 

1237 if entry.connection_ref != self._conn: 

1238 with self._pool_lock: 

1239 while entry.connection_ref.can_read(): 

1240 entry.connection_ref.read_response(push_request=True) 

1241 

1242 return 

1243 

1244 # Set temporary entry value to prevent 

1245 # race condition from another connection. 

1246 self._cache.set( 

1247 CacheEntry( 

1248 cache_key=self._current_command_cache_key, 

1249 cache_value=self.DUMMY_CACHE_VALUE, 

1250 status=CacheEntryStatus.IN_PROGRESS, 

1251 connection_ref=self._conn, 

1252 ) 

1253 ) 

1254 

1255 # Send command over socket only if it's allowed 

1256 # read-only command that not yet cached. 

1257 self._conn.send_command(*args, **kwargs) 

1258 

1259 def can_read(self, timeout=0): 

1260 return self._conn.can_read(timeout) 

1261 

1262 def read_response( 

1263 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False 

1264 ): 

1265 with self._cache_lock: 

1266 # Check if command response exists in a cache and it's not in progress. 

1267 if ( 

1268 self._current_command_cache_key is not None 

1269 and self._cache.get(self._current_command_cache_key) is not None 

1270 and self._cache.get(self._current_command_cache_key).status 

1271 != CacheEntryStatus.IN_PROGRESS 

1272 ): 

1273 res = copy.deepcopy( 

1274 self._cache.get(self._current_command_cache_key).cache_value 

1275 ) 

1276 self._current_command_cache_key = None 

1277 return res 

1278 

1279 response = self._conn.read_response( 

1280 disable_decoding=disable_decoding, 

1281 disconnect_on_error=disconnect_on_error, 

1282 push_request=push_request, 

1283 ) 

1284 

1285 with self._cache_lock: 

1286 # Prevent not-allowed command from caching. 

1287 if self._current_command_cache_key is None: 

1288 return response 

1289 # If response is None prevent from caching. 

1290 if response is None: 

1291 self._cache.delete_by_cache_keys([self._current_command_cache_key]) 

1292 return response 

1293 

1294 cache_entry = self._cache.get(self._current_command_cache_key) 

1295 

1296 # Cache only responses that still valid 

1297 # and wasn't invalidated by another connection in meantime. 

1298 if cache_entry is not None: 

1299 cache_entry.status = CacheEntryStatus.VALID 

1300 cache_entry.cache_value = response 

1301 self._cache.set(cache_entry) 

1302 

1303 self._current_command_cache_key = None 

1304 

1305 return response 

1306 

1307 def pack_command(self, *args): 

1308 return self._conn.pack_command(*args) 

1309 

1310 def pack_commands(self, commands): 

1311 return self._conn.pack_commands(commands) 

1312 

1313 @property 

1314 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1315 return self._conn.handshake_metadata 

1316 

1317 def _connect(self): 

1318 self._conn._connect() 

1319 

1320 def _host_error(self): 

1321 self._conn._host_error() 

1322 

1323 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: 

1324 conn.send_command("CLIENT", "TRACKING", "ON") 

1325 conn.read_response() 

1326 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) 

1327 

1328 def _process_pending_invalidations(self): 

1329 while self.can_read(): 

1330 self._conn.read_response(push_request=True) 

1331 

1332 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): 

1333 with self._cache_lock: 

1334 # Flush cache when DB flushed on server-side 

1335 if data[1] is None: 

1336 self._cache.flush() 

1337 else: 

1338 self._cache.delete_by_redis_keys(data[1]) 

1339 

1340 def get_protocol(self): 

1341 return self._conn.get_protocol() 

1342 

1343 def set_re_auth_token(self, token: TokenInterface): 

1344 self._conn.set_re_auth_token(token) 

1345 

1346 def re_auth(self): 

1347 self._conn.re_auth() 

1348 

1349 

1350class SSLConnection(Connection): 

1351 """Manages SSL connections to and from the Redis server(s). 

1352 This class extends the Connection class, adding SSL functionality, and making 

1353 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

1354 """ # noqa 

1355 

1356 def __init__( 

1357 self, 

1358 ssl_keyfile=None, 

1359 ssl_certfile=None, 

1360 ssl_cert_reqs="required", 

1361 ssl_ca_certs=None, 

1362 ssl_ca_data=None, 

1363 ssl_check_hostname=True, 

1364 ssl_ca_path=None, 

1365 ssl_password=None, 

1366 ssl_validate_ocsp=False, 

1367 ssl_validate_ocsp_stapled=False, 

1368 ssl_ocsp_context=None, 

1369 ssl_ocsp_expected_cert=None, 

1370 ssl_min_version=None, 

1371 ssl_ciphers=None, 

1372 **kwargs, 

1373 ): 

1374 """Constructor 

1375 

1376 Args: 

1377 ssl_keyfile: Path to an ssl private key. Defaults to None. 

1378 ssl_certfile: Path to an ssl certificate. Defaults to None. 

1379 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required". 

1380 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. 

1381 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. 

1382 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False. 

1383 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. 

1384 ssl_password: Password for unlocking an encrypted private key. Defaults to None. 

1385 

1386 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) 

1387 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response 

1388 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert 

1389 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. 

1390 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. 

1391 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information. 

1392 

1393 Raises: 

1394 RedisError 

1395 """ # noqa 

1396 if not SSL_AVAILABLE: 

1397 raise RedisError("Python wasn't built with SSL support") 

1398 

1399 self.keyfile = ssl_keyfile 

1400 self.certfile = ssl_certfile 

1401 if ssl_cert_reqs is None: 

1402 ssl_cert_reqs = ssl.CERT_NONE 

1403 elif isinstance(ssl_cert_reqs, str): 

1404 CERT_REQS = { # noqa: N806 

1405 "none": ssl.CERT_NONE, 

1406 "optional": ssl.CERT_OPTIONAL, 

1407 "required": ssl.CERT_REQUIRED, 

1408 } 

1409 if ssl_cert_reqs not in CERT_REQS: 

1410 raise RedisError( 

1411 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" 

1412 ) 

1413 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 

1414 self.cert_reqs = ssl_cert_reqs 

1415 self.ca_certs = ssl_ca_certs 

1416 self.ca_data = ssl_ca_data 

1417 self.ca_path = ssl_ca_path 

1418 self.check_hostname = ( 

1419 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1420 ) 

1421 self.certificate_password = ssl_password 

1422 self.ssl_validate_ocsp = ssl_validate_ocsp 

1423 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled 

1424 self.ssl_ocsp_context = ssl_ocsp_context 

1425 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert 

1426 self.ssl_min_version = ssl_min_version 

1427 self.ssl_ciphers = ssl_ciphers 

1428 super().__init__(**kwargs) 

1429 

1430 def _connect(self): 

1431 """ 

1432 Wrap the socket with SSL support, handling potential errors. 

1433 """ 

1434 sock = super()._connect() 

1435 try: 

1436 return self._wrap_socket_with_ssl(sock) 

1437 except (OSError, RedisError): 

1438 sock.close() 

1439 raise 

1440 

1441 def _wrap_socket_with_ssl(self, sock): 

1442 """ 

1443 Wraps the socket with SSL support. 

1444 

1445 Args: 

1446 sock: The plain socket to wrap with SSL. 

1447 

1448 Returns: 

1449 An SSL wrapped socket. 

1450 """ 

1451 context = ssl.create_default_context() 

1452 context.check_hostname = self.check_hostname 

1453 context.verify_mode = self.cert_reqs 

1454 if self.certfile or self.keyfile: 

1455 context.load_cert_chain( 

1456 certfile=self.certfile, 

1457 keyfile=self.keyfile, 

1458 password=self.certificate_password, 

1459 ) 

1460 if ( 

1461 self.ca_certs is not None 

1462 or self.ca_path is not None 

1463 or self.ca_data is not None 

1464 ): 

1465 context.load_verify_locations( 

1466 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1467 ) 

1468 if self.ssl_min_version is not None: 

1469 context.minimum_version = self.ssl_min_version 

1470 if self.ssl_ciphers: 

1471 context.set_ciphers(self.ssl_ciphers) 

1472 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: 

1473 raise RedisError("cryptography is not installed.") 

1474 

1475 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: 

1476 raise RedisError( 

1477 "Either an OCSP staple or pure OCSP connection must be validated " 

1478 "- not both." 

1479 ) 

1480 

1481 sslsock = context.wrap_socket(sock, server_hostname=self.host) 

1482 

1483 # validation for the stapled case 

1484 if self.ssl_validate_ocsp_stapled: 

1485 import OpenSSL 

1486 

1487 from .ocsp import ocsp_staple_verifier 

1488 

1489 # if a context is provided use it - otherwise, a basic context 

1490 if self.ssl_ocsp_context is None: 

1491 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) 

1492 staple_ctx.use_certificate_file(self.certfile) 

1493 staple_ctx.use_privatekey_file(self.keyfile) 

1494 else: 

1495 staple_ctx = self.ssl_ocsp_context 

1496 

1497 staple_ctx.set_ocsp_client_callback( 

1498 ocsp_staple_verifier, self.ssl_ocsp_expected_cert 

1499 ) 

1500 

1501 # need another socket 

1502 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) 

1503 con.request_ocsp() 

1504 con.connect((self.host, self.port)) 

1505 con.do_handshake() 

1506 con.shutdown() 

1507 return sslsock 

1508 

1509 # pure ocsp validation 

1510 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: 

1511 from .ocsp import OCSPVerifier 

1512 

1513 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) 

1514 if o.is_valid(): 

1515 return sslsock 

1516 else: 

1517 raise ConnectionError("ocsp validation error") 

1518 return sslsock 

1519 

1520 

1521class UnixDomainSocketConnection(AbstractConnection): 

1522 "Manages UDS communication to and from a Redis server" 

1523 

1524 def __init__(self, path="", socket_timeout=None, **kwargs): 

1525 super().__init__(**kwargs) 

1526 self.path = path 

1527 self.socket_timeout = socket_timeout 

1528 

1529 def repr_pieces(self): 

1530 pieces = [("path", self.path), ("db", self.db)] 

1531 if self.client_name: 

1532 pieces.append(("client_name", self.client_name)) 

1533 return pieces 

1534 

1535 def _connect(self): 

1536 "Create a Unix domain socket connection" 

1537 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

1538 sock.settimeout(self.socket_connect_timeout) 

1539 try: 

1540 sock.connect(self.path) 

1541 except OSError: 

1542 # Prevent ResourceWarnings for unclosed sockets. 

1543 try: 

1544 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1545 except OSError: 

1546 pass 

1547 sock.close() 

1548 raise 

1549 sock.settimeout(self.socket_timeout) 

1550 return sock 

1551 

1552 def _host_error(self): 

1553 return self.path 

1554 

1555 

1556FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1557 

1558 

1559def to_bool(value): 

1560 if value is None or value == "": 

1561 return None 

1562 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1563 return False 

1564 return bool(value) 

1565 

1566 

1567URL_QUERY_ARGUMENT_PARSERS = { 

1568 "db": int, 

1569 "socket_timeout": float, 

1570 "socket_connect_timeout": float, 

1571 "socket_keepalive": to_bool, 

1572 "retry_on_timeout": to_bool, 

1573 "retry_on_error": list, 

1574 "max_connections": int, 

1575 "health_check_interval": int, 

1576 "ssl_check_hostname": to_bool, 

1577 "timeout": float, 

1578} 

1579 

1580 

1581def parse_url(url): 

1582 if not ( 

1583 url.startswith("redis://") 

1584 or url.startswith("rediss://") 

1585 or url.startswith("unix://") 

1586 ): 

1587 raise ValueError( 

1588 "Redis URL must specify one of the following " 

1589 "schemes (redis://, rediss://, unix://)" 

1590 ) 

1591 

1592 url = urlparse(url) 

1593 kwargs = {} 

1594 

1595 for name, value in parse_qs(url.query).items(): 

1596 if value and len(value) > 0: 

1597 value = unquote(value[0]) 

1598 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1599 if parser: 

1600 try: 

1601 kwargs[name] = parser(value) 

1602 except (TypeError, ValueError): 

1603 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1604 else: 

1605 kwargs[name] = value 

1606 

1607 if url.username: 

1608 kwargs["username"] = unquote(url.username) 

1609 if url.password: 

1610 kwargs["password"] = unquote(url.password) 

1611 

1612 # We only support redis://, rediss:// and unix:// schemes. 

1613 if url.scheme == "unix": 

1614 if url.path: 

1615 kwargs["path"] = unquote(url.path) 

1616 kwargs["connection_class"] = UnixDomainSocketConnection 

1617 

1618 else: # implied: url.scheme in ("redis", "rediss"): 

1619 if url.hostname: 

1620 kwargs["host"] = unquote(url.hostname) 

1621 if url.port: 

1622 kwargs["port"] = int(url.port) 

1623 

1624 # If there's a path argument, use it as the db argument if a 

1625 # querystring value wasn't specified 

1626 if url.path and "db" not in kwargs: 

1627 try: 

1628 kwargs["db"] = int(unquote(url.path).replace("/", "")) 

1629 except (AttributeError, ValueError): 

1630 pass 

1631 

1632 if url.scheme == "rediss": 

1633 kwargs["connection_class"] = SSLConnection 

1634 

1635 return kwargs 

1636 

1637 

1638_CP = TypeVar("_CP", bound="ConnectionPool") 

1639 

1640 

1641class ConnectionPool: 

1642 """ 

1643 Create a connection pool. ``If max_connections`` is set, then this 

1644 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's 

1645 limit is reached. 

1646 

1647 By default, TCP connections are created unless ``connection_class`` 

1648 is specified. Use class:`.UnixDomainSocketConnection` for 

1649 unix sockets. 

1650 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

1651 

1652 Any additional keyword arguments are passed to the constructor of 

1653 ``connection_class``. 

1654 """ 

1655 

1656 @classmethod 

1657 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

1658 """ 

1659 Return a connection pool configured from the given URL. 

1660 

1661 For example:: 

1662 

1663 redis://[[username]:[password]]@localhost:6379/0 

1664 rediss://[[username]:[password]]@localhost:6379/0 

1665 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1666 

1667 Three URL schemes are supported: 

1668 

1669 - `redis://` creates a TCP socket connection. See more at: 

1670 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1671 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1672 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1673 - ``unix://``: creates a Unix Domain Socket connection. 

1674 

1675 The username, password, hostname, path and all querystring values 

1676 are passed through urllib.parse.unquote in order to replace any 

1677 percent-encoded values with their corresponding characters. 

1678 

1679 There are several ways to specify a database number. The first value 

1680 found will be used: 

1681 

1682 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1683 2. If using the redis:// or rediss:// schemes, the path argument 

1684 of the url, e.g. redis://localhost/0 

1685 3. A ``db`` keyword argument to this function. 

1686 

1687 If none of these options are specified, the default db=0 is used. 

1688 

1689 All querystring options are cast to their appropriate Python types. 

1690 Boolean arguments can be specified with string values "True"/"False" 

1691 or "Yes"/"No". Values that cannot be properly cast cause a 

1692 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1693 and keyword arguments are passed to the ``ConnectionPool``'s 

1694 class initializer. In the case of conflicting arguments, querystring 

1695 arguments always win. 

1696 """ 

1697 url_options = parse_url(url) 

1698 

1699 if "connection_class" in kwargs: 

1700 url_options["connection_class"] = kwargs["connection_class"] 

1701 

1702 kwargs.update(url_options) 

1703 return cls(**kwargs) 

1704 

1705 def __init__( 

1706 self, 

1707 connection_class=Connection, 

1708 max_connections: Optional[int] = None, 

1709 cache_factory: Optional[CacheFactoryInterface] = None, 

1710 **connection_kwargs, 

1711 ): 

1712 max_connections = max_connections or 2**31 

1713 if not isinstance(max_connections, int) or max_connections < 0: 

1714 raise ValueError('"max_connections" must be a positive integer') 

1715 

1716 self.connection_class = connection_class 

1717 self.connection_kwargs = connection_kwargs 

1718 self.max_connections = max_connections 

1719 self.cache = None 

1720 self._cache_factory = cache_factory 

1721 

1722 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): 

1723 if connection_kwargs.get("protocol") not in [3, "3"]: 

1724 raise RedisError("Client caching is only supported with RESP version 3") 

1725 

1726 cache = self.connection_kwargs.get("cache") 

1727 

1728 if cache is not None: 

1729 if not isinstance(cache, CacheInterface): 

1730 raise ValueError("Cache must implement CacheInterface") 

1731 

1732 self.cache = cache 

1733 else: 

1734 if self._cache_factory is not None: 

1735 self.cache = self._cache_factory.get_cache() 

1736 else: 

1737 self.cache = CacheFactory( 

1738 self.connection_kwargs.get("cache_config") 

1739 ).get_cache() 

1740 

1741 connection_kwargs.pop("cache", None) 

1742 connection_kwargs.pop("cache_config", None) 

1743 

1744 if connection_kwargs.get( 

1745 "maintenance_events_pool_handler" 

1746 ) or connection_kwargs.get("maintenance_events_config"): 

1747 if connection_kwargs.get("protocol") not in [3, "3"]: 

1748 raise RedisError( 

1749 "Push handlers on connection are only supported with RESP version 3" 

1750 ) 

1751 config = connection_kwargs.get("maintenance_events_config", None) or ( 

1752 connection_kwargs.get("maintenance_events_pool_handler").config 

1753 if connection_kwargs.get("maintenance_events_pool_handler") 

1754 else None 

1755 ) 

1756 

1757 if config and config.enabled: 

1758 connection_kwargs.update( 

1759 { 

1760 "orig_host_address": connection_kwargs.get("host"), 

1761 "orig_socket_timeout": connection_kwargs.get( 

1762 "socket_timeout", None 

1763 ), 

1764 "orig_socket_connect_timeout": connection_kwargs.get( 

1765 "socket_connect_timeout", None 

1766 ), 

1767 } 

1768 ) 

1769 

1770 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1771 if self._event_dispatcher is None: 

1772 self._event_dispatcher = EventDispatcher() 

1773 

1774 # a lock to protect the critical section in _checkpid(). 

1775 # this lock is acquired when the process id changes, such as 

1776 # after a fork. during this time, multiple threads in the child 

1777 # process could attempt to acquire this lock. the first thread 

1778 # to acquire the lock will reset the data structures and lock 

1779 # object of this pool. subsequent threads acquiring this lock 

1780 # will notice the first thread already did the work and simply 

1781 # release the lock. 

1782 

1783 self._fork_lock = threading.RLock() 

1784 self._lock = threading.RLock() 

1785 

1786 self.reset() 

1787 

1788 def __repr__(self) -> str: 

1789 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

1790 return ( 

1791 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

1792 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

1793 f"({conn_kwargs})>)>" 

1794 ) 

1795 

1796 def get_protocol(self): 

1797 """ 

1798 Returns: 

1799 The RESP protocol version, or ``None`` if the protocol is not specified, 

1800 in which case the server default will be used. 

1801 """ 

1802 return self.connection_kwargs.get("protocol", None) 

1803 

1804 def maintenance_events_pool_handler_enabled(self): 

1805 """ 

1806 Returns: 

1807 True if the maintenance events pool handler is enabled, False otherwise. 

1808 """ 

1809 maintenance_events_config = self.connection_kwargs.get( 

1810 "maintenance_events_config", None 

1811 ) 

1812 

1813 return maintenance_events_config and maintenance_events_config.enabled 

1814 

1815 def set_maintenance_events_pool_handler( 

1816 self, maintenance_events_pool_handler: MaintenanceEventPoolHandler 

1817 ): 

1818 self.connection_kwargs.update( 

1819 { 

1820 "maintenance_events_pool_handler": maintenance_events_pool_handler, 

1821 "maintenance_events_config": maintenance_events_pool_handler.config, 

1822 } 

1823 ) 

1824 

1825 self._update_maintenance_events_configs_for_connections( 

1826 maintenance_events_pool_handler 

1827 ) 

1828 

1829 def _update_maintenance_events_configs_for_connections( 

1830 self, maintenance_events_pool_handler 

1831 ): 

1832 """Update the maintenance events config for all connections in the pool.""" 

1833 with self._lock: 

1834 for conn in self._available_connections: 

1835 conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) 

1836 conn.maintenance_events_config = maintenance_events_pool_handler.config 

1837 for conn in self._in_use_connections: 

1838 conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) 

1839 conn.maintenance_events_config = maintenance_events_pool_handler.config 

1840 

1841 def reset(self) -> None: 

1842 self._created_connections = 0 

1843 self._available_connections = [] 

1844 self._in_use_connections = set() 

1845 

1846 # this must be the last operation in this method. while reset() is 

1847 # called when holding _fork_lock, other threads in this process 

1848 # can call _checkpid() which compares self.pid and os.getpid() without 

1849 # holding any lock (for performance reasons). keeping this assignment 

1850 # as the last operation ensures that those other threads will also 

1851 # notice a pid difference and block waiting for the first thread to 

1852 # release _fork_lock. when each of these threads eventually acquire 

1853 # _fork_lock, they will notice that another thread already called 

1854 # reset() and they will immediately release _fork_lock and continue on. 

1855 self.pid = os.getpid() 

1856 

1857 def _checkpid(self) -> None: 

1858 # _checkpid() attempts to keep ConnectionPool fork-safe on modern 

1859 # systems. this is called by all ConnectionPool methods that 

1860 # manipulate the pool's state such as get_connection() and release(). 

1861 # 

1862 # _checkpid() determines whether the process has forked by comparing 

1863 # the current process id to the process id saved on the ConnectionPool 

1864 # instance. if these values are the same, _checkpid() simply returns. 

1865 # 

1866 # when the process ids differ, _checkpid() assumes that the process 

1867 # has forked and that we're now running in the child process. the child 

1868 # process cannot use the parent's file descriptors (e.g., sockets). 

1869 # therefore, when _checkpid() sees the process id change, it calls 

1870 # reset() in order to reinitialize the child's ConnectionPool. this 

1871 # will cause the child to make all new connection objects. 

1872 # 

1873 # _checkpid() is protected by self._fork_lock to ensure that multiple 

1874 # threads in the child process do not call reset() multiple times. 

1875 # 

1876 # there is an extremely small chance this could fail in the following 

1877 # scenario: 

1878 # 1. process A calls _checkpid() for the first time and acquires 

1879 # self._fork_lock. 

1880 # 2. while holding self._fork_lock, process A forks (the fork() 

1881 # could happen in a different thread owned by process A) 

1882 # 3. process B (the forked child process) inherits the 

1883 # ConnectionPool's state from the parent. that state includes 

1884 # a locked _fork_lock. process B will not be notified when 

1885 # process A releases the _fork_lock and will thus never be 

1886 # able to acquire the _fork_lock. 

1887 # 

1888 # to mitigate this possible deadlock, _checkpid() will only wait 5 

1889 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in 

1890 # that time it is assumed that the child is deadlocked and a 

1891 # redis.ChildDeadlockedError error is raised. 

1892 if self.pid != os.getpid(): 

1893 acquired = self._fork_lock.acquire(timeout=5) 

1894 if not acquired: 

1895 raise ChildDeadlockedError 

1896 # reset() the instance for the new process if another thread 

1897 # hasn't already done so 

1898 try: 

1899 if self.pid != os.getpid(): 

1900 self.reset() 

1901 finally: 

1902 self._fork_lock.release() 

1903 

1904 @deprecated_args( 

1905 args_to_warn=["*"], 

1906 reason="Use get_connection() without args instead", 

1907 version="5.3.0", 

1908 ) 

1909 def get_connection(self, command_name=None, *keys, **options) -> "Connection": 

1910 "Get a connection from the pool" 

1911 

1912 self._checkpid() 

1913 with self._lock: 

1914 try: 

1915 connection = self._available_connections.pop() 

1916 except IndexError: 

1917 connection = self.make_connection() 

1918 self._in_use_connections.add(connection) 

1919 

1920 try: 

1921 # ensure this connection is connected to Redis 

1922 connection.connect() 

1923 # connections that the pool provides should be ready to send 

1924 # a command. if not, the connection was either returned to the 

1925 # pool before all data has been read or the socket has been 

1926 # closed. either way, reconnect and verify everything is good. 

1927 try: 

1928 if ( 

1929 connection.can_read() 

1930 and self.cache is None 

1931 and not self.maintenance_events_pool_handler_enabled() 

1932 ): 

1933 raise ConnectionError("Connection has data") 

1934 except (ConnectionError, TimeoutError, OSError): 

1935 connection.disconnect() 

1936 connection.connect() 

1937 if connection.can_read(): 

1938 raise ConnectionError("Connection not ready") 

1939 except BaseException: 

1940 # release the connection back to the pool so that we don't 

1941 # leak it 

1942 self.release(connection) 

1943 raise 

1944 return connection 

1945 

1946 def get_encoder(self) -> Encoder: 

1947 "Return an encoder based on encoding settings" 

1948 kwargs = self.connection_kwargs 

1949 return Encoder( 

1950 encoding=kwargs.get("encoding", "utf-8"), 

1951 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1952 decode_responses=kwargs.get("decode_responses", False), 

1953 ) 

1954 

1955 def make_connection(self) -> "ConnectionInterface": 

1956 "Create a new connection" 

1957 if self._created_connections >= self.max_connections: 

1958 raise MaxConnectionsError("Too many connections") 

1959 self._created_connections += 1 

1960 

1961 kwargs = dict(self.connection_kwargs) 

1962 

1963 if self.cache is not None: 

1964 return CacheProxyConnection( 

1965 self.connection_class(**kwargs), self.cache, self._lock 

1966 ) 

1967 return self.connection_class(**kwargs) 

1968 

1969 def release(self, connection: "Connection") -> None: 

1970 "Releases the connection back to the pool" 

1971 self._checkpid() 

1972 with self._lock: 

1973 try: 

1974 self._in_use_connections.remove(connection) 

1975 except KeyError: 

1976 # Gracefully fail when a connection is returned to this pool 

1977 # that the pool doesn't actually own 

1978 return 

1979 

1980 if self.owns_connection(connection): 

1981 if connection.should_reconnect(): 

1982 connection.disconnect() 

1983 self._available_connections.append(connection) 

1984 self._event_dispatcher.dispatch( 

1985 AfterConnectionReleasedEvent(connection) 

1986 ) 

1987 else: 

1988 # Pool doesn't own this connection, do not add it back 

1989 # to the pool. 

1990 # The created connections count should not be changed, 

1991 # because the connection was not created by the pool. 

1992 connection.disconnect() 

1993 return 

1994 

1995 def owns_connection(self, connection: "Connection") -> int: 

1996 return connection.pid == self.pid 

1997 

1998 def disconnect(self, inuse_connections: bool = True) -> None: 

1999 """ 

2000 Disconnects connections in the pool 

2001 

2002 If ``inuse_connections`` is True, disconnect connections that are 

2003 current in use, potentially by other threads. Otherwise only disconnect 

2004 connections that are idle in the pool. 

2005 """ 

2006 self._checkpid() 

2007 with self._lock: 

2008 if inuse_connections: 

2009 connections = chain( 

2010 self._available_connections, self._in_use_connections 

2011 ) 

2012 else: 

2013 connections = self._available_connections 

2014 

2015 for connection in connections: 

2016 connection.disconnect() 

2017 

2018 def close(self) -> None: 

2019 """Close the pool, disconnecting all connections""" 

2020 self.disconnect() 

2021 

2022 def set_retry(self, retry: Retry) -> None: 

2023 self.connection_kwargs.update({"retry": retry}) 

2024 for conn in self._available_connections: 

2025 conn.retry = retry 

2026 for conn in self._in_use_connections: 

2027 conn.retry = retry 

2028 

2029 def re_auth_callback(self, token: TokenInterface): 

2030 with self._lock: 

2031 for conn in self._available_connections: 

2032 conn.retry.call_with_retry( 

2033 lambda: conn.send_command( 

2034 "AUTH", token.try_get("oid"), token.get_value() 

2035 ), 

2036 lambda error: self._mock(error), 

2037 ) 

2038 conn.retry.call_with_retry( 

2039 lambda: conn.read_response(), lambda error: self._mock(error) 

2040 ) 

2041 for conn in self._in_use_connections: 

2042 conn.set_re_auth_token(token) 

2043 

2044 def _should_update_connection( 

2045 self, 

2046 conn: "Connection", 

2047 matching_pattern: Literal[ 

2048 "connected_address", "configured_address", "event_hash" 

2049 ] = "connected_address", 

2050 matching_address: Optional[str] = None, 

2051 matching_event_hash: Optional[int] = None, 

2052 ) -> bool: 

2053 """ 

2054 Check if the connection should be updated based on the matching address. 

2055 """ 

2056 if matching_pattern == "connected_address": 

2057 if matching_address and conn.getpeername() != matching_address: 

2058 return False 

2059 elif matching_pattern == "configured_address": 

2060 if matching_address and conn.host != matching_address: 

2061 return False 

2062 elif matching_pattern == "event_hash": 

2063 if ( 

2064 matching_event_hash 

2065 and conn.maintenance_event_hash != matching_event_hash 

2066 ): 

2067 return False 

2068 return True 

2069 

2070 def update_connection_settings( 

2071 self, 

2072 conn: "Connection", 

2073 state: Optional["MaintenanceState"] = None, 

2074 maintenance_event_hash: Optional[int] = None, 

2075 host_address: Optional[str] = None, 

2076 relax_timeout: Optional[float] = None, 

2077 update_event_hash: bool = False, 

2078 reset_host_address: bool = False, 

2079 reset_relax_timeout: bool = False, 

2080 ): 

2081 """ 

2082 Update the settings for a single connection. 

2083 """ 

2084 if state: 

2085 conn.maintenance_state = state 

2086 

2087 if update_event_hash: 

2088 # update the event hash only if requested 

2089 conn.maintenance_event_hash = maintenance_event_hash 

2090 

2091 if host_address is not None: 

2092 conn.set_tmp_settings(tmp_host_address=host_address) 

2093 

2094 if relax_timeout is not None: 

2095 conn.set_tmp_settings(tmp_relax_timeout=relax_timeout) 

2096 

2097 if reset_relax_timeout or reset_host_address: 

2098 conn.reset_tmp_settings( 

2099 reset_host_address=reset_host_address, 

2100 reset_relax_timeout=reset_relax_timeout, 

2101 ) 

2102 

2103 conn.update_current_socket_timeout(relax_timeout) 

2104 

2105 def update_connections_settings( 

2106 self, 

2107 state: Optional["MaintenanceState"] = None, 

2108 maintenance_event_hash: Optional[int] = None, 

2109 host_address: Optional[str] = None, 

2110 relax_timeout: Optional[float] = None, 

2111 matching_address: Optional[str] = None, 

2112 matching_event_hash: Optional[int] = None, 

2113 matching_pattern: Literal[ 

2114 "connected_address", "configured_address", "event_hash" 

2115 ] = "connected_address", 

2116 update_event_hash: bool = False, 

2117 reset_host_address: bool = False, 

2118 reset_relax_timeout: bool = False, 

2119 include_free_connections: bool = True, 

2120 ): 

2121 """ 

2122 Update the settings for all matching connections in the pool. 

2123 

2124 This method does not create new connections. 

2125 This method does not affect the connection kwargs. 

2126 

2127 :param state: The maintenance state to set for the connection. 

2128 :param maintenance_event_hash: The hash of the maintenance event 

2129 to set for the connection. 

2130 :param host_address: The host address to set for the connection. 

2131 :param relax_timeout: The relax timeout to set for the connection. 

2132 :param matching_address: The address to match for the connection. 

2133 :param matching_event_hash: The event hash to match for the connection. 

2134 :param matching_pattern: The pattern to match for the connection. 

2135 :param update_event_hash: Whether to update the event hash for the connection. 

2136 :param reset_host_address: Whether to reset the host address to the original address. 

2137 :param reset_relax_timeout: Whether to reset the relax timeout to the original timeout. 

2138 :param include_free_connections: Whether to include free/available connections. 

2139 """ 

2140 with self._lock: 

2141 for conn in self._in_use_connections: 

2142 if self._should_update_connection( 

2143 conn, 

2144 matching_pattern, 

2145 matching_address, 

2146 matching_event_hash, 

2147 ): 

2148 self.update_connection_settings( 

2149 conn, 

2150 state=state, 

2151 maintenance_event_hash=maintenance_event_hash, 

2152 host_address=host_address, 

2153 relax_timeout=relax_timeout, 

2154 update_event_hash=update_event_hash, 

2155 reset_host_address=reset_host_address, 

2156 reset_relax_timeout=reset_relax_timeout, 

2157 ) 

2158 

2159 if include_free_connections: 

2160 for conn in self._available_connections: 

2161 if self._should_update_connection( 

2162 conn, 

2163 matching_pattern, 

2164 matching_address, 

2165 matching_event_hash, 

2166 ): 

2167 self.update_connection_settings( 

2168 conn, 

2169 state=state, 

2170 maintenance_event_hash=maintenance_event_hash, 

2171 host_address=host_address, 

2172 relax_timeout=relax_timeout, 

2173 update_event_hash=update_event_hash, 

2174 reset_host_address=reset_host_address, 

2175 reset_relax_timeout=reset_relax_timeout, 

2176 ) 

2177 

2178 def update_connection_kwargs( 

2179 self, 

2180 **kwargs, 

2181 ): 

2182 """ 

2183 Update the connection kwargs for all future connections. 

2184 

2185 This method updates the connection kwargs for all future connections created by the pool. 

2186 Existing connections are not affected. 

2187 """ 

2188 self.connection_kwargs.update(kwargs) 

2189 

2190 def update_active_connections_for_reconnect( 

2191 self, 

2192 moving_address_src: Optional[str] = None, 

2193 ): 

2194 """ 

2195 Mark all active connections for reconnect. 

2196 This is used when a cluster node is migrated to a different address. 

2197 

2198 :param moving_address_src: The address of the node that is being moved. 

2199 """ 

2200 with self._lock: 

2201 for conn in self._in_use_connections: 

2202 if self._should_update_connection( 

2203 conn, "connected_address", moving_address_src 

2204 ): 

2205 conn.mark_for_reconnect() 

2206 

2207 def disconnect_free_connections( 

2208 self, 

2209 moving_address_src: Optional[str] = None, 

2210 ): 

2211 """ 

2212 Disconnect all free/available connections. 

2213 This is used when a cluster node is migrated to a different address. 

2214 

2215 :param moving_address_src: The address of the node that is being moved. 

2216 """ 

2217 with self._lock: 

2218 for conn in self._available_connections: 

2219 if self._should_update_connection( 

2220 conn, "connected_address", moving_address_src 

2221 ): 

2222 conn.disconnect() 

2223 

2224 async def _mock(self, error: RedisError): 

2225 """ 

2226 Dummy functions, needs to be passed as error callback to retry object. 

2227 :param error: 

2228 :return: 

2229 """ 

2230 pass 

2231 

2232 

2233class BlockingConnectionPool(ConnectionPool): 

2234 """ 

2235 Thread-safe blocking connection pool:: 

2236 

2237 >>> from redis.client import Redis 

2238 >>> client = Redis(connection_pool=BlockingConnectionPool()) 

2239 

2240 It performs the same function as the default 

2241 :py:class:`~redis.ConnectionPool` implementation, in that, 

2242 it maintains a pool of reusable connections that can be shared by 

2243 multiple redis clients (safely across threads if required). 

2244 

2245 The difference is that, in the event that a client tries to get a 

2246 connection from the pool when all of connections are in use, rather than 

2247 raising a :py:class:`~redis.ConnectionError` (as the default 

2248 :py:class:`~redis.ConnectionPool` implementation does), it 

2249 makes the client wait ("blocks") for a specified number of seconds until 

2250 a connection becomes available. 

2251 

2252 Use ``max_connections`` to increase / decrease the pool size:: 

2253 

2254 >>> pool = BlockingConnectionPool(max_connections=10) 

2255 

2256 Use ``timeout`` to tell it either how many seconds to wait for a connection 

2257 to become available, or to block forever: 

2258 

2259 >>> # Block forever. 

2260 >>> pool = BlockingConnectionPool(timeout=None) 

2261 

2262 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

2263 >>> # not available. 

2264 >>> pool = BlockingConnectionPool(timeout=5) 

2265 """ 

2266 

2267 def __init__( 

2268 self, 

2269 max_connections=50, 

2270 timeout=20, 

2271 connection_class=Connection, 

2272 queue_class=LifoQueue, 

2273 **connection_kwargs, 

2274 ): 

2275 self.queue_class = queue_class 

2276 self.timeout = timeout 

2277 self._in_maintenance = False 

2278 self._locked = False 

2279 super().__init__( 

2280 connection_class=connection_class, 

2281 max_connections=max_connections, 

2282 **connection_kwargs, 

2283 ) 

2284 

2285 def reset(self): 

2286 # Create and fill up a thread safe queue with ``None`` values. 

2287 try: 

2288 if self._in_maintenance: 

2289 self._lock.acquire() 

2290 self._locked = True 

2291 self.pool = self.queue_class(self.max_connections) 

2292 while True: 

2293 try: 

2294 self.pool.put_nowait(None) 

2295 except Full: 

2296 break 

2297 

2298 # Keep a list of actual connection instances so that we can 

2299 # disconnect them later. 

2300 self._connections = [] 

2301 finally: 

2302 if self._locked: 

2303 try: 

2304 self._lock.release() 

2305 except Exception: 

2306 pass 

2307 self._locked = False 

2308 

2309 # this must be the last operation in this method. while reset() is 

2310 # called when holding _fork_lock, other threads in this process 

2311 # can call _checkpid() which compares self.pid and os.getpid() without 

2312 # holding any lock (for performance reasons). keeping this assignment 

2313 # as the last operation ensures that those other threads will also 

2314 # notice a pid difference and block waiting for the first thread to 

2315 # release _fork_lock. when each of these threads eventually acquire 

2316 # _fork_lock, they will notice that another thread already called 

2317 # reset() and they will immediately release _fork_lock and continue on. 

2318 self.pid = os.getpid() 

2319 

2320 def make_connection(self): 

2321 "Make a fresh connection." 

2322 try: 

2323 if self._in_maintenance: 

2324 self._lock.acquire() 

2325 self._locked = True 

2326 

2327 if self.cache is not None: 

2328 connection = CacheProxyConnection( 

2329 self.connection_class(**self.connection_kwargs), 

2330 self.cache, 

2331 self._lock, 

2332 ) 

2333 else: 

2334 connection = self.connection_class(**self.connection_kwargs) 

2335 self._connections.append(connection) 

2336 return connection 

2337 finally: 

2338 if self._locked: 

2339 try: 

2340 self._lock.release() 

2341 except Exception: 

2342 pass 

2343 self._locked = False 

2344 

2345 @deprecated_args( 

2346 args_to_warn=["*"], 

2347 reason="Use get_connection() without args instead", 

2348 version="5.3.0", 

2349 ) 

2350 def get_connection(self, command_name=None, *keys, **options): 

2351 """ 

2352 Get a connection, blocking for ``self.timeout`` until a connection 

2353 is available from the pool. 

2354 

2355 If the connection returned is ``None`` then creates a new connection. 

2356 Because we use a last-in first-out queue, the existing connections 

2357 (having been returned to the pool after the initial ``None`` values 

2358 were added) will be returned before ``None`` values. This means we only 

2359 create new connections when we need to, i.e.: the actual number of 

2360 connections will only increase in response to demand. 

2361 """ 

2362 # Make sure we haven't changed process. 

2363 self._checkpid() 

2364 

2365 # Try and get a connection from the pool. If one isn't available within 

2366 # self.timeout then raise a ``ConnectionError``. 

2367 connection = None 

2368 try: 

2369 if self._in_maintenance: 

2370 self._lock.acquire() 

2371 self._locked = True 

2372 try: 

2373 connection = self.pool.get(block=True, timeout=self.timeout) 

2374 except Empty: 

2375 # Note that this is not caught by the redis client and will be 

2376 # raised unless handled by application code. If you want never to 

2377 raise ConnectionError("No connection available.") 

2378 

2379 # If the ``connection`` is actually ``None`` then that's a cue to make 

2380 # a new connection to add to the pool. 

2381 if connection is None: 

2382 connection = self.make_connection() 

2383 finally: 

2384 if self._locked: 

2385 try: 

2386 self._lock.release() 

2387 except Exception: 

2388 pass 

2389 self._locked = False 

2390 

2391 try: 

2392 # ensure this connection is connected to Redis 

2393 connection.connect() 

2394 # connections that the pool provides should be ready to send 

2395 # a command. if not, the connection was either returned to the 

2396 # pool before all data has been read or the socket has been 

2397 # closed. either way, reconnect and verify everything is good. 

2398 try: 

2399 if connection.can_read(): 

2400 raise ConnectionError("Connection has data") 

2401 except (ConnectionError, TimeoutError, OSError): 

2402 connection.disconnect() 

2403 connection.connect() 

2404 if connection.can_read(): 

2405 raise ConnectionError("Connection not ready") 

2406 except BaseException: 

2407 # release the connection back to the pool so that we don't leak it 

2408 self.release(connection) 

2409 raise 

2410 

2411 return connection 

2412 

2413 def release(self, connection): 

2414 "Releases the connection back to the pool." 

2415 # Make sure we haven't changed process. 

2416 self._checkpid() 

2417 

2418 try: 

2419 if self._in_maintenance: 

2420 self._lock.acquire() 

2421 self._locked = True 

2422 if not self.owns_connection(connection): 

2423 # pool doesn't own this connection. do not add it back 

2424 # to the pool. instead add a None value which is a placeholder 

2425 # that will cause the pool to recreate the connection if 

2426 # its needed. 

2427 connection.disconnect() 

2428 self.pool.put_nowait(None) 

2429 return 

2430 if connection.should_reconnect(): 

2431 connection.disconnect() 

2432 # Put the connection back into the pool. 

2433 try: 

2434 self.pool.put_nowait(connection) 

2435 except Full: 

2436 # perhaps the pool has been reset() after a fork? regardless, 

2437 # we don't want this connection 

2438 pass 

2439 finally: 

2440 if self._locked: 

2441 try: 

2442 self._lock.release() 

2443 except Exception: 

2444 pass 

2445 self._locked = False 

2446 

2447 def disconnect(self): 

2448 "Disconnects all connections in the pool." 

2449 self._checkpid() 

2450 try: 

2451 if self._in_maintenance: 

2452 self._lock.acquire() 

2453 self._locked = True 

2454 for connection in self._connections: 

2455 connection.disconnect() 

2456 finally: 

2457 if self._locked: 

2458 try: 

2459 self._lock.release() 

2460 except Exception: 

2461 pass 

2462 self._locked = False 

2463 

2464 def update_connections_settings( 

2465 self, 

2466 state: Optional["MaintenanceState"] = None, 

2467 maintenance_event_hash: Optional[int] = None, 

2468 relax_timeout: Optional[float] = None, 

2469 host_address: Optional[str] = None, 

2470 matching_address: Optional[str] = None, 

2471 matching_event_hash: Optional[int] = None, 

2472 matching_pattern: Literal[ 

2473 "connected_address", "configured_address", "event_hash" 

2474 ] = "connected_address", 

2475 update_event_hash: bool = False, 

2476 reset_host_address: bool = False, 

2477 reset_relax_timeout: bool = False, 

2478 include_free_connections: bool = True, 

2479 ): 

2480 """ 

2481 Override base class method to work with BlockingConnectionPool's structure. 

2482 """ 

2483 with self._lock: 

2484 if include_free_connections: 

2485 for conn in tuple(self._connections): 

2486 if self._should_update_connection( 

2487 conn, 

2488 matching_pattern, 

2489 matching_address, 

2490 matching_event_hash, 

2491 ): 

2492 self.update_connection_settings( 

2493 conn, 

2494 state=state, 

2495 maintenance_event_hash=maintenance_event_hash, 

2496 host_address=host_address, 

2497 relax_timeout=relax_timeout, 

2498 update_event_hash=update_event_hash, 

2499 reset_host_address=reset_host_address, 

2500 reset_relax_timeout=reset_relax_timeout, 

2501 ) 

2502 else: 

2503 connections_in_queue = {conn for conn in self.pool.queue if conn} 

2504 for conn in self._connections: 

2505 if conn not in connections_in_queue: 

2506 if self._should_update_connection( 

2507 conn, 

2508 matching_pattern, 

2509 matching_address, 

2510 matching_event_hash, 

2511 ): 

2512 self.update_connection_settings( 

2513 conn, 

2514 state=state, 

2515 maintenance_event_hash=maintenance_event_hash, 

2516 host_address=host_address, 

2517 relax_timeout=relax_timeout, 

2518 update_event_hash=update_event_hash, 

2519 reset_host_address=reset_host_address, 

2520 reset_relax_timeout=reset_relax_timeout, 

2521 ) 

2522 

2523 def update_active_connections_for_reconnect( 

2524 self, 

2525 moving_address_src: Optional[str] = None, 

2526 ): 

2527 """ 

2528 Mark all active connections for reconnect. 

2529 This is used when a cluster node is migrated to a different address. 

2530 

2531 :param moving_address_src: The address of the node that is being moved. 

2532 """ 

2533 with self._lock: 

2534 connections_in_queue = {conn for conn in self.pool.queue if conn} 

2535 for conn in self._connections: 

2536 if conn not in connections_in_queue: 

2537 if self._should_update_connection( 

2538 conn, 

2539 matching_pattern="connected_address", 

2540 matching_address=moving_address_src, 

2541 ): 

2542 conn.mark_for_reconnect() 

2543 

2544 def disconnect_free_connections( 

2545 self, 

2546 moving_address_src: Optional[str] = None, 

2547 ): 

2548 """ 

2549 Disconnect all free/available connections. 

2550 This is used when a cluster node is migrated to a different address. 

2551 

2552 :param moving_address_src: The address of the node that is being moved. 

2553 """ 

2554 with self._lock: 

2555 existing_connections = self.pool.queue 

2556 

2557 for conn in existing_connections: 

2558 if conn: 

2559 if self._should_update_connection( 

2560 conn, "connected_address", moving_address_src 

2561 ): 

2562 conn.disconnect() 

2563 

2564 def _update_maintenance_events_config_for_connections( 

2565 self, maintenance_events_config 

2566 ): 

2567 for conn in tuple(self._connections): 

2568 conn.maintenance_events_config = maintenance_events_config 

2569 

2570 def _update_maintenance_events_configs_for_connections( 

2571 self, maintenance_events_pool_handler 

2572 ): 

2573 """Update the maintenance events config for all connections in the pool.""" 

2574 with self._lock: 

2575 for conn in tuple(self._connections): 

2576 conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) 

2577 conn.maintenance_events_config = maintenance_events_pool_handler.config 

2578 

2579 def set_in_maintenance(self, in_maintenance: bool): 

2580 """ 

2581 Sets a flag that this Blocking ConnectionPool is in maintenance mode. 

2582 

2583 This is used to prevent new connections from being created while we are in maintenance mode. 

2584 The pool will be in maintenance mode only when we are processing a MOVING event. 

2585 """ 

2586 self._in_maintenance = in_maintenance