Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1212 statements  

1import copy 

2import os 

3import socket 

4import sys 

5import threading 

6import time 

7import weakref 

8from abc import abstractmethod 

9from itertools import chain 

10from queue import Empty, Full, LifoQueue 

11from typing import ( 

12 Any, 

13 Callable, 

14 Dict, 

15 Iterable, 

16 List, 

17 Literal, 

18 Optional, 

19 Type, 

20 TypeVar, 

21 Union, 

22) 

23from urllib.parse import parse_qs, unquote, urlparse 

24 

25from redis.cache import ( 

26 CacheEntry, 

27 CacheEntryStatus, 

28 CacheFactory, 

29 CacheFactoryInterface, 

30 CacheInterface, 

31 CacheKey, 

32) 

33 

34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser 

35from .auth.token import TokenInterface 

36from .backoff import NoBackoff 

37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider 

38from .event import AfterConnectionReleasedEvent, EventDispatcher 

39from .exceptions import ( 

40 AuthenticationError, 

41 AuthenticationWrongNumberOfArgsError, 

42 ChildDeadlockedError, 

43 ConnectionError, 

44 DataError, 

45 MaxConnectionsError, 

46 RedisError, 

47 ResponseError, 

48 TimeoutError, 

49) 

50from .maint_notifications import ( 

51 MaintenanceState, 

52 MaintNotificationsConfig, 

53 MaintNotificationsConnectionHandler, 

54 MaintNotificationsPoolHandler, 

55) 

56from .retry import Retry 

57from .utils import ( 

58 CRYPTOGRAPHY_AVAILABLE, 

59 HIREDIS_AVAILABLE, 

60 SSL_AVAILABLE, 

61 compare_versions, 

62 deprecated_args, 

63 ensure_string, 

64 format_error_message, 

65 get_lib_version, 

66 str_if_bytes, 

67) 

68 

69if SSL_AVAILABLE: 

70 import ssl 

71 from ssl import VerifyFlags 

72else: 

73 ssl = None 

74 VerifyFlags = None 

75 

76if HIREDIS_AVAILABLE: 

77 import hiredis 

78 

79SYM_STAR = b"*" 

80SYM_DOLLAR = b"$" 

81SYM_CRLF = b"\r\n" 

82SYM_EMPTY = b"" 

83 

84DEFAULT_RESP_VERSION = 2 

85 

86SENTINEL = object() 

87 

88DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] 

89if HIREDIS_AVAILABLE: 

90 DefaultParser = _HiredisParser 

91else: 

92 DefaultParser = _RESP2Parser 

93 

94 

95class HiredisRespSerializer: 

96 def pack(self, *args: List): 

97 """Pack a series of arguments into the Redis protocol""" 

98 output = [] 

99 

100 if isinstance(args[0], str): 

101 args = tuple(args[0].encode().split()) + args[1:] 

102 elif b" " in args[0]: 

103 args = tuple(args[0].split()) + args[1:] 

104 try: 

105 output.append(hiredis.pack_command(args)) 

106 except TypeError: 

107 _, value, traceback = sys.exc_info() 

108 raise DataError(value).with_traceback(traceback) 

109 

110 return output 

111 

112 

113class PythonRespSerializer: 

114 def __init__(self, buffer_cutoff, encode) -> None: 

115 self._buffer_cutoff = buffer_cutoff 

116 self.encode = encode 

117 

118 def pack(self, *args): 

119 """Pack a series of arguments into the Redis protocol""" 

120 output = [] 

121 # the client might have included 1 or more literal arguments in 

122 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

123 # arguments to be sent separately, so split the first argument 

124 # manually. These arguments should be bytestrings so that they are 

125 # not encoded. 

126 if isinstance(args[0], str): 

127 args = tuple(args[0].encode().split()) + args[1:] 

128 elif b" " in args[0]: 

129 args = tuple(args[0].split()) + args[1:] 

130 

131 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

132 

133 buffer_cutoff = self._buffer_cutoff 

134 for arg in map(self.encode, args): 

135 # to avoid large string mallocs, chunk the command into the 

136 # output list if we're sending large values or memoryviews 

137 arg_length = len(arg) 

138 if ( 

139 len(buff) > buffer_cutoff 

140 or arg_length > buffer_cutoff 

141 or isinstance(arg, memoryview) 

142 ): 

143 buff = SYM_EMPTY.join( 

144 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

145 ) 

146 output.append(buff) 

147 output.append(arg) 

148 buff = SYM_CRLF 

149 else: 

150 buff = SYM_EMPTY.join( 

151 ( 

152 buff, 

153 SYM_DOLLAR, 

154 str(arg_length).encode(), 

155 SYM_CRLF, 

156 arg, 

157 SYM_CRLF, 

158 ) 

159 ) 

160 output.append(buff) 

161 return output 

162 

163 

164class ConnectionInterface: 

165 @abstractmethod 

166 def repr_pieces(self): 

167 pass 

168 

169 @abstractmethod 

170 def register_connect_callback(self, callback): 

171 pass 

172 

173 @abstractmethod 

174 def deregister_connect_callback(self, callback): 

175 pass 

176 

177 @abstractmethod 

178 def set_parser(self, parser_class): 

179 pass 

180 

181 @abstractmethod 

182 def set_maint_notifications_pool_handler(self, maint_notifications_pool_handler): 

183 pass 

184 

185 @abstractmethod 

186 def get_protocol(self): 

187 pass 

188 

189 @abstractmethod 

190 def connect(self): 

191 pass 

192 

193 @abstractmethod 

194 def on_connect(self): 

195 pass 

196 

197 @abstractmethod 

198 def disconnect(self, *args): 

199 pass 

200 

201 @abstractmethod 

202 def check_health(self): 

203 pass 

204 

205 @abstractmethod 

206 def send_packed_command(self, command, check_health=True): 

207 pass 

208 

209 @abstractmethod 

210 def send_command(self, *args, **kwargs): 

211 pass 

212 

213 @abstractmethod 

214 def can_read(self, timeout=0): 

215 pass 

216 

217 @abstractmethod 

218 def read_response( 

219 self, 

220 disable_decoding=False, 

221 *, 

222 disconnect_on_error=True, 

223 push_request=False, 

224 ): 

225 pass 

226 

227 @abstractmethod 

228 def pack_command(self, *args): 

229 pass 

230 

231 @abstractmethod 

232 def pack_commands(self, commands): 

233 pass 

234 

235 @property 

236 @abstractmethod 

237 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

238 pass 

239 

240 @abstractmethod 

241 def set_re_auth_token(self, token: TokenInterface): 

242 pass 

243 

244 @abstractmethod 

245 def re_auth(self): 

246 pass 

247 

248 @property 

249 @abstractmethod 

250 def maintenance_state(self) -> MaintenanceState: 

251 """ 

252 Returns the current maintenance state of the connection. 

253 """ 

254 pass 

255 

256 @maintenance_state.setter 

257 @abstractmethod 

258 def maintenance_state(self, state: "MaintenanceState"): 

259 """ 

260 Sets the current maintenance state of the connection. 

261 """ 

262 pass 

263 

264 @abstractmethod 

265 def getpeername(self): 

266 """ 

267 Returns the peer name of the connection. 

268 """ 

269 pass 

270 

271 @abstractmethod 

272 def mark_for_reconnect(self): 

273 """ 

274 Mark the connection to be reconnected on the next command. 

275 This is useful when a connection is moved to a different node. 

276 """ 

277 pass 

278 

279 @abstractmethod 

280 def should_reconnect(self): 

281 """ 

282 Returns True if the connection should be reconnected. 

283 """ 

284 pass 

285 

286 @abstractmethod 

287 def get_resolved_ip(self): 

288 """ 

289 Get resolved ip address for the connection. 

290 """ 

291 pass 

292 

293 @abstractmethod 

294 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

295 """ 

296 Update the timeout for the current socket. 

297 """ 

298 pass 

299 

300 @abstractmethod 

301 def set_tmp_settings( 

302 self, 

303 tmp_host_address: Optional[str] = None, 

304 tmp_relaxed_timeout: Optional[float] = None, 

305 ): 

306 """ 

307 Updates temporary host address and timeout settings for the connection. 

308 """ 

309 pass 

310 

311 @abstractmethod 

312 def reset_tmp_settings( 

313 self, 

314 reset_host_address: bool = False, 

315 reset_relaxed_timeout: bool = False, 

316 ): 

317 """ 

318 Resets temporary host address and timeout settings for the connection. 

319 """ 

320 pass 

321 

322 

323class AbstractConnection(ConnectionInterface): 

324 "Manages communication to and from a Redis server" 

325 

326 def __init__( 

327 self, 

328 db: int = 0, 

329 password: Optional[str] = None, 

330 socket_timeout: Optional[float] = None, 

331 socket_connect_timeout: Optional[float] = None, 

332 retry_on_timeout: bool = False, 

333 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL, 

334 encoding: str = "utf-8", 

335 encoding_errors: str = "strict", 

336 decode_responses: bool = False, 

337 parser_class=DefaultParser, 

338 socket_read_size: int = 65536, 

339 health_check_interval: int = 0, 

340 client_name: Optional[str] = None, 

341 lib_name: Optional[str] = "redis-py", 

342 lib_version: Optional[str] = get_lib_version(), 

343 username: Optional[str] = None, 

344 retry: Union[Any, None] = None, 

345 redis_connect_func: Optional[Callable[[], None]] = None, 

346 credential_provider: Optional[CredentialProvider] = None, 

347 protocol: Optional[int] = 2, 

348 command_packer: Optional[Callable[[], None]] = None, 

349 event_dispatcher: Optional[EventDispatcher] = None, 

350 maint_notifications_pool_handler: Optional[ 

351 MaintNotificationsPoolHandler 

352 ] = None, 

353 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

354 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

355 maintenance_notification_hash: Optional[int] = None, 

356 orig_host_address: Optional[str] = None, 

357 orig_socket_timeout: Optional[float] = None, 

358 orig_socket_connect_timeout: Optional[float] = None, 

359 ): 

360 """ 

361 Initialize a new Connection. 

362 To specify a retry policy for specific errors, first set 

363 `retry_on_error` to a list of the error/s to retry on, then set 

364 `retry` to a valid `Retry` object. 

365 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

366 """ 

367 if (username or password) and credential_provider is not None: 

368 raise DataError( 

369 "'username' and 'password' cannot be passed along with 'credential_" 

370 "provider'. Please provide only one of the following arguments: \n" 

371 "1. 'password' and (optional) 'username'\n" 

372 "2. 'credential_provider'" 

373 ) 

374 if event_dispatcher is None: 

375 self._event_dispatcher = EventDispatcher() 

376 else: 

377 self._event_dispatcher = event_dispatcher 

378 self.pid = os.getpid() 

379 self.db = db 

380 self.client_name = client_name 

381 self.lib_name = lib_name 

382 self.lib_version = lib_version 

383 self.credential_provider = credential_provider 

384 self.password = password 

385 self.username = username 

386 self.socket_timeout = socket_timeout 

387 if socket_connect_timeout is None: 

388 socket_connect_timeout = socket_timeout 

389 self.socket_connect_timeout = socket_connect_timeout 

390 self.retry_on_timeout = retry_on_timeout 

391 if retry_on_error is SENTINEL: 

392 retry_on_errors_list = [] 

393 else: 

394 retry_on_errors_list = list(retry_on_error) 

395 if retry_on_timeout: 

396 # Add TimeoutError to the errors list to retry on 

397 retry_on_errors_list.append(TimeoutError) 

398 self.retry_on_error = retry_on_errors_list 

399 if retry or self.retry_on_error: 

400 if retry is None: 

401 self.retry = Retry(NoBackoff(), 1) 

402 else: 

403 # deep-copy the Retry object as it is mutable 

404 self.retry = copy.deepcopy(retry) 

405 if self.retry_on_error: 

406 # Update the retry's supported errors with the specified errors 

407 self.retry.update_supported_errors(self.retry_on_error) 

408 else: 

409 self.retry = Retry(NoBackoff(), 0) 

410 self.health_check_interval = health_check_interval 

411 self.next_health_check = 0 

412 self.redis_connect_func = redis_connect_func 

413 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

414 self.handshake_metadata = None 

415 self._sock = None 

416 self._socket_read_size = socket_read_size 

417 self._connect_callbacks = [] 

418 self._buffer_cutoff = 6000 

419 self._re_auth_token: Optional[TokenInterface] = None 

420 try: 

421 p = int(protocol) 

422 except TypeError: 

423 p = DEFAULT_RESP_VERSION 

424 except ValueError: 

425 raise ConnectionError("protocol must be an integer") 

426 finally: 

427 if p < 2 or p > 3: 

428 raise ConnectionError("protocol must be either 2 or 3") 

429 # p = DEFAULT_RESP_VERSION 

430 self.protocol = p 

431 if self.protocol == 3 and parser_class == DefaultParser: 

432 parser_class = _RESP3Parser 

433 self.set_parser(parser_class) 

434 

435 self.maint_notifications_config = maint_notifications_config 

436 

437 # Set up maintenance notifications if enabled 

438 self._configure_maintenance_notifications( 

439 maint_notifications_pool_handler, 

440 orig_host_address, 

441 orig_socket_timeout, 

442 orig_socket_connect_timeout, 

443 ) 

444 

445 self._should_reconnect = False 

446 self.maintenance_state = maintenance_state 

447 self.maintenance_notification_hash = maintenance_notification_hash 

448 

449 self._command_packer = self._construct_command_packer(command_packer) 

450 

451 def __repr__(self): 

452 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) 

453 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

454 

455 @abstractmethod 

456 def repr_pieces(self): 

457 pass 

458 

459 def __del__(self): 

460 try: 

461 self.disconnect() 

462 except Exception: 

463 pass 

464 

465 def _construct_command_packer(self, packer): 

466 if packer is not None: 

467 return packer 

468 elif HIREDIS_AVAILABLE: 

469 return HiredisRespSerializer() 

470 else: 

471 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) 

472 

473 def register_connect_callback(self, callback): 

474 """ 

475 Register a callback to be called when the connection is established either 

476 initially or reconnected. This allows listeners to issue commands that 

477 are ephemeral to the connection, for example pub/sub subscription or 

478 key tracking. The callback must be a _method_ and will be kept as 

479 a weak reference. 

480 """ 

481 wm = weakref.WeakMethod(callback) 

482 if wm not in self._connect_callbacks: 

483 self._connect_callbacks.append(wm) 

484 

485 def deregister_connect_callback(self, callback): 

486 """ 

487 De-register a previously registered callback. It will no-longer receive 

488 notifications on connection events. Calling this is not required when the 

489 listener goes away, since the callbacks are kept as weak methods. 

490 """ 

491 try: 

492 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

493 except ValueError: 

494 pass 

495 

496 def set_parser(self, parser_class): 

497 """ 

498 Creates a new instance of parser_class with socket size: 

499 _socket_read_size and assigns it to the parser for the connection 

500 :param parser_class: The required parser class 

501 """ 

502 self._parser = parser_class(socket_read_size=self._socket_read_size) 

503 

504 def _configure_maintenance_notifications( 

505 self, 

506 maint_notifications_pool_handler=None, 

507 orig_host_address=None, 

508 orig_socket_timeout=None, 

509 orig_socket_connect_timeout=None, 

510 ): 

511 """Enable maintenance notifications by setting up handlers and storing original connection parameters.""" 

512 if ( 

513 not self.maint_notifications_config 

514 or not self.maint_notifications_config.enabled 

515 ): 

516 self._maint_notifications_connection_handler = None 

517 return 

518 

519 # Set up pool handler if available 

520 if maint_notifications_pool_handler: 

521 self._parser.set_node_moving_push_handler( 

522 maint_notifications_pool_handler.handle_notification 

523 ) 

524 

525 # Set up connection handler 

526 self._maint_notifications_connection_handler = ( 

527 MaintNotificationsConnectionHandler(self, self.maint_notifications_config) 

528 ) 

529 self._parser.set_maintenance_push_handler( 

530 self._maint_notifications_connection_handler.handle_notification 

531 ) 

532 

533 # Store original connection parameters 

534 self.orig_host_address = orig_host_address if orig_host_address else self.host 

535 self.orig_socket_timeout = ( 

536 orig_socket_timeout if orig_socket_timeout else self.socket_timeout 

537 ) 

538 self.orig_socket_connect_timeout = ( 

539 orig_socket_connect_timeout 

540 if orig_socket_connect_timeout 

541 else self.socket_connect_timeout 

542 ) 

543 

544 def set_maint_notifications_pool_handler( 

545 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

546 ): 

547 maint_notifications_pool_handler.set_connection(self) 

548 self._parser.set_node_moving_push_handler( 

549 maint_notifications_pool_handler.handle_notification 

550 ) 

551 

552 # Update maintenance notification connection handler if it doesn't exist 

553 if not self._maint_notifications_connection_handler: 

554 self._maint_notifications_connection_handler = ( 

555 MaintNotificationsConnectionHandler( 

556 self, maint_notifications_pool_handler.config 

557 ) 

558 ) 

559 self._parser.set_maintenance_push_handler( 

560 self._maint_notifications_connection_handler.handle_notification 

561 ) 

562 else: 

563 self._maint_notifications_connection_handler.config = ( 

564 maint_notifications_pool_handler.config 

565 ) 

566 

567 def connect(self): 

568 "Connects to the Redis server if not already connected" 

569 self.connect_check_health(check_health=True) 

570 

571 def connect_check_health( 

572 self, check_health: bool = True, retry_socket_connect: bool = True 

573 ): 

574 if self._sock: 

575 return 

576 try: 

577 if retry_socket_connect: 

578 sock = self.retry.call_with_retry( 

579 lambda: self._connect(), lambda error: self.disconnect(error) 

580 ) 

581 else: 

582 sock = self._connect() 

583 except socket.timeout: 

584 raise TimeoutError("Timeout connecting to server") 

585 except OSError as e: 

586 raise ConnectionError(self._error_message(e)) 

587 

588 self._sock = sock 

589 try: 

590 if self.redis_connect_func is None: 

591 # Use the default on_connect function 

592 self.on_connect_check_health(check_health=check_health) 

593 else: 

594 # Use the passed function redis_connect_func 

595 self.redis_connect_func(self) 

596 except RedisError: 

597 # clean up after any error in on_connect 

598 self.disconnect() 

599 raise 

600 

601 # run any user callbacks. right now the only internal callback 

602 # is for pubsub channel/pattern resubscription 

603 # first, remove any dead weakrefs 

604 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

605 for ref in self._connect_callbacks: 

606 callback = ref() 

607 if callback: 

608 callback(self) 

609 

610 @abstractmethod 

611 def _connect(self): 

612 pass 

613 

614 @abstractmethod 

615 def _host_error(self): 

616 pass 

617 

618 def _error_message(self, exception): 

619 return format_error_message(self._host_error(), exception) 

620 

621 def on_connect(self): 

622 self.on_connect_check_health(check_health=True) 

623 

624 def on_connect_check_health(self, check_health: bool = True): 

625 "Initialize the connection, authenticate and select a database" 

626 self._parser.on_connect(self) 

627 parser = self._parser 

628 

629 auth_args = None 

630 # if credential provider or username and/or password are set, authenticate 

631 if self.credential_provider or (self.username or self.password): 

632 cred_provider = ( 

633 self.credential_provider 

634 or UsernamePasswordCredentialProvider(self.username, self.password) 

635 ) 

636 auth_args = cred_provider.get_credentials() 

637 

638 # if resp version is specified and we have auth args, 

639 # we need to send them via HELLO 

640 if auth_args and self.protocol not in [2, "2"]: 

641 if isinstance(self._parser, _RESP2Parser): 

642 self.set_parser(_RESP3Parser) 

643 # update cluster exception classes 

644 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

645 self._parser.on_connect(self) 

646 if len(auth_args) == 1: 

647 auth_args = ["default", auth_args[0]] 

648 # avoid checking health here -- PING will fail if we try 

649 # to check the health prior to the AUTH 

650 self.send_command( 

651 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

652 ) 

653 self.handshake_metadata = self.read_response() 

654 # if response.get(b"proto") != self.protocol and response.get( 

655 # "proto" 

656 # ) != self.protocol: 

657 # raise ConnectionError("Invalid RESP version") 

658 elif auth_args: 

659 # avoid checking health here -- PING will fail if we try 

660 # to check the health prior to the AUTH 

661 self.send_command("AUTH", *auth_args, check_health=False) 

662 

663 try: 

664 auth_response = self.read_response() 

665 except AuthenticationWrongNumberOfArgsError: 

666 # a username and password were specified but the Redis 

667 # server seems to be < 6.0.0 which expects a single password 

668 # arg. retry auth with just the password. 

669 # https://github.com/andymccurdy/redis-py/issues/1274 

670 self.send_command("AUTH", auth_args[-1], check_health=False) 

671 auth_response = self.read_response() 

672 

673 if str_if_bytes(auth_response) != "OK": 

674 raise AuthenticationError("Invalid Username or Password") 

675 

676 # if resp version is specified, switch to it 

677 elif self.protocol not in [2, "2"]: 

678 if isinstance(self._parser, _RESP2Parser): 

679 self.set_parser(_RESP3Parser) 

680 # update cluster exception classes 

681 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

682 self._parser.on_connect(self) 

683 self.send_command("HELLO", self.protocol, check_health=check_health) 

684 self.handshake_metadata = self.read_response() 

685 if ( 

686 self.handshake_metadata.get(b"proto") != self.protocol 

687 and self.handshake_metadata.get("proto") != self.protocol 

688 ): 

689 raise ConnectionError("Invalid RESP version") 

690 

691 # Send maintenance notifications handshake if RESP3 is active 

692 # and maintenance notifications are enabled 

693 # and we have a host to determine the endpoint type from 

694 # When the maint_notifications_config enabled mode is "auto", 

695 # we just log a warning if the handshake fails 

696 # When the mode is enabled=True, we raise an exception in case of failure 

697 if ( 

698 self.protocol not in [2, "2"] 

699 and self.maint_notifications_config 

700 and self.maint_notifications_config.enabled 

701 and self._maint_notifications_connection_handler 

702 and hasattr(self, "host") 

703 ): 

704 try: 

705 endpoint_type = self.maint_notifications_config.get_endpoint_type( 

706 self.host, self 

707 ) 

708 self.send_command( 

709 "CLIENT", 

710 "MAINT_NOTIFICATIONS", 

711 "ON", 

712 "moving-endpoint-type", 

713 endpoint_type.value, 

714 check_health=check_health, 

715 ) 

716 response = self.read_response() 

717 if str_if_bytes(response) != "OK": 

718 raise ResponseError( 

719 "The server doesn't support maintenance notifications" 

720 ) 

721 except Exception as e: 

722 if ( 

723 isinstance(e, ResponseError) 

724 and self.maint_notifications_config.enabled == "auto" 

725 ): 

726 # Log warning but don't fail the connection 

727 import logging 

728 

729 logger = logging.getLogger(__name__) 

730 logger.warning(f"Failed to enable maintenance notifications: {e}") 

731 else: 

732 raise 

733 

734 # if a client_name is given, set it 

735 if self.client_name: 

736 self.send_command( 

737 "CLIENT", 

738 "SETNAME", 

739 self.client_name, 

740 check_health=check_health, 

741 ) 

742 if str_if_bytes(self.read_response()) != "OK": 

743 raise ConnectionError("Error setting client name") 

744 

745 try: 

746 # set the library name and version 

747 if self.lib_name: 

748 self.send_command( 

749 "CLIENT", 

750 "SETINFO", 

751 "LIB-NAME", 

752 self.lib_name, 

753 check_health=check_health, 

754 ) 

755 self.read_response() 

756 except ResponseError: 

757 pass 

758 

759 try: 

760 if self.lib_version: 

761 self.send_command( 

762 "CLIENT", 

763 "SETINFO", 

764 "LIB-VER", 

765 self.lib_version, 

766 check_health=check_health, 

767 ) 

768 self.read_response() 

769 except ResponseError: 

770 pass 

771 

772 # if a database is specified, switch to it 

773 if self.db: 

774 self.send_command("SELECT", self.db, check_health=check_health) 

775 if str_if_bytes(self.read_response()) != "OK": 

776 raise ConnectionError("Invalid Database") 

777 

778 def disconnect(self, *args): 

779 "Disconnects from the Redis server" 

780 self._parser.on_disconnect() 

781 

782 conn_sock = self._sock 

783 self._sock = None 

784 # reset the reconnect flag 

785 self._should_reconnect = False 

786 if conn_sock is None: 

787 return 

788 

789 if os.getpid() == self.pid: 

790 try: 

791 conn_sock.shutdown(socket.SHUT_RDWR) 

792 except (OSError, TypeError): 

793 pass 

794 

795 try: 

796 conn_sock.close() 

797 except OSError: 

798 pass 

799 

800 def _send_ping(self): 

801 """Send PING, expect PONG in return""" 

802 self.send_command("PING", check_health=False) 

803 if str_if_bytes(self.read_response()) != "PONG": 

804 raise ConnectionError("Bad response from PING health check") 

805 

806 def _ping_failed(self, error): 

807 """Function to call when PING fails""" 

808 self.disconnect() 

809 

810 def check_health(self): 

811 """Check the health of the connection with a PING/PONG""" 

812 if self.health_check_interval and time.monotonic() > self.next_health_check: 

813 self.retry.call_with_retry(self._send_ping, self._ping_failed) 

814 

815 def send_packed_command(self, command, check_health=True): 

816 """Send an already packed command to the Redis server""" 

817 if not self._sock: 

818 self.connect_check_health(check_health=False) 

819 # guard against health check recursion 

820 if check_health: 

821 self.check_health() 

822 try: 

823 if isinstance(command, str): 

824 command = [command] 

825 for item in command: 

826 self._sock.sendall(item) 

827 except socket.timeout: 

828 self.disconnect() 

829 raise TimeoutError("Timeout writing to socket") 

830 except OSError as e: 

831 self.disconnect() 

832 if len(e.args) == 1: 

833 errno, errmsg = "UNKNOWN", e.args[0] 

834 else: 

835 errno = e.args[0] 

836 errmsg = e.args[1] 

837 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") 

838 except BaseException: 

839 # BaseExceptions can be raised when a socket send operation is not 

840 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

841 # to send un-sent data. However, the send_packed_command() API 

842 # does not support it so there is no point in keeping the connection open. 

843 self.disconnect() 

844 raise 

845 

846 def send_command(self, *args, **kwargs): 

847 """Pack and send a command to the Redis server""" 

848 self.send_packed_command( 

849 self._command_packer.pack(*args), 

850 check_health=kwargs.get("check_health", True), 

851 ) 

852 

853 def can_read(self, timeout=0): 

854 """Poll the socket to see if there's data that can be read.""" 

855 sock = self._sock 

856 if not sock: 

857 self.connect() 

858 

859 host_error = self._host_error() 

860 

861 try: 

862 return self._parser.can_read(timeout) 

863 

864 except OSError as e: 

865 self.disconnect() 

866 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

867 

868 def read_response( 

869 self, 

870 disable_decoding=False, 

871 *, 

872 disconnect_on_error=True, 

873 push_request=False, 

874 ): 

875 """Read the response from a previously sent command""" 

876 

877 host_error = self._host_error() 

878 

879 try: 

880 if self.protocol in ["3", 3]: 

881 response = self._parser.read_response( 

882 disable_decoding=disable_decoding, push_request=push_request 

883 ) 

884 else: 

885 response = self._parser.read_response(disable_decoding=disable_decoding) 

886 except socket.timeout: 

887 if disconnect_on_error: 

888 self.disconnect() 

889 raise TimeoutError(f"Timeout reading from {host_error}") 

890 except OSError as e: 

891 if disconnect_on_error: 

892 self.disconnect() 

893 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

894 except BaseException: 

895 # Also by default close in case of BaseException. A lot of code 

896 # relies on this behaviour when doing Command/Response pairs. 

897 # See #1128. 

898 if disconnect_on_error: 

899 self.disconnect() 

900 raise 

901 

902 if self.health_check_interval: 

903 self.next_health_check = time.monotonic() + self.health_check_interval 

904 

905 if isinstance(response, ResponseError): 

906 try: 

907 raise response 

908 finally: 

909 del response # avoid creating ref cycles 

910 return response 

911 

912 def pack_command(self, *args): 

913 """Pack a series of arguments into the Redis protocol""" 

914 return self._command_packer.pack(*args) 

915 

916 def pack_commands(self, commands): 

917 """Pack multiple commands into the Redis protocol""" 

918 output = [] 

919 pieces = [] 

920 buffer_length = 0 

921 buffer_cutoff = self._buffer_cutoff 

922 

923 for cmd in commands: 

924 for chunk in self._command_packer.pack(*cmd): 

925 chunklen = len(chunk) 

926 if ( 

927 buffer_length > buffer_cutoff 

928 or chunklen > buffer_cutoff 

929 or isinstance(chunk, memoryview) 

930 ): 

931 if pieces: 

932 output.append(SYM_EMPTY.join(pieces)) 

933 buffer_length = 0 

934 pieces = [] 

935 

936 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

937 output.append(chunk) 

938 else: 

939 pieces.append(chunk) 

940 buffer_length += chunklen 

941 

942 if pieces: 

943 output.append(SYM_EMPTY.join(pieces)) 

944 return output 

945 

946 def get_protocol(self) -> Union[int, str]: 

947 return self.protocol 

948 

949 @property 

950 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

951 return self._handshake_metadata 

952 

953 @handshake_metadata.setter 

954 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): 

955 self._handshake_metadata = value 

956 

957 def set_re_auth_token(self, token: TokenInterface): 

958 self._re_auth_token = token 

959 

960 def re_auth(self): 

961 if self._re_auth_token is not None: 

962 self.send_command( 

963 "AUTH", 

964 self._re_auth_token.try_get("oid"), 

965 self._re_auth_token.get_value(), 

966 ) 

967 self.read_response() 

968 self._re_auth_token = None 

969 

970 def get_resolved_ip(self) -> Optional[str]: 

971 """ 

972 Extract the resolved IP address from an 

973 established connection or resolve it from the host. 

974 

975 First tries to get the actual IP from the socket (most accurate), 

976 then falls back to DNS resolution if needed. 

977 

978 Args: 

979 connection: The connection object to extract the IP from 

980 

981 Returns: 

982 str: The resolved IP address, or None if it cannot be determined 

983 """ 

984 

985 # Method 1: Try to get the actual IP from the established socket connection 

986 # This is most accurate as it shows the exact IP being used 

987 try: 

988 if self._sock is not None: 

989 peer_addr = self._sock.getpeername() 

990 if peer_addr and len(peer_addr) >= 1: 

991 # For TCP sockets, peer_addr is typically (host, port) tuple 

992 # Return just the host part 

993 return peer_addr[0] 

994 except (AttributeError, OSError): 

995 # Socket might not be connected or getpeername() might fail 

996 pass 

997 

998 # Method 2: Fallback to DNS resolution of the host 

999 # This is less accurate but works when socket is not available 

1000 try: 

1001 host = getattr(self, "host", "localhost") 

1002 port = getattr(self, "port", 6379) 

1003 if host: 

1004 # Use getaddrinfo to resolve the hostname to IP 

1005 # This mimics what the connection would do during _connect() 

1006 addr_info = socket.getaddrinfo( 

1007 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM 

1008 ) 

1009 if addr_info: 

1010 # Return the IP from the first result 

1011 # addr_info[0] is (family, socktype, proto, canonname, sockaddr) 

1012 # sockaddr[0] is the IP address 

1013 return addr_info[0][4][0] 

1014 except (AttributeError, OSError, socket.gaierror): 

1015 # DNS resolution might fail 

1016 pass 

1017 

1018 return None 

1019 

1020 @property 

1021 def maintenance_state(self) -> MaintenanceState: 

1022 return self._maintenance_state 

1023 

1024 @maintenance_state.setter 

1025 def maintenance_state(self, state: "MaintenanceState"): 

1026 self._maintenance_state = state 

1027 

1028 def getpeername(self): 

1029 if not self._sock: 

1030 return None 

1031 return self._sock.getpeername()[0] 

1032 

1033 def mark_for_reconnect(self): 

1034 self._should_reconnect = True 

1035 

1036 def should_reconnect(self): 

1037 return self._should_reconnect 

1038 

1039 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

1040 if self._sock: 

1041 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout 

1042 self._sock.settimeout(timeout) 

1043 self.update_parser_buffer_timeout(timeout) 

1044 

1045 def update_parser_buffer_timeout(self, timeout: Optional[float] = None): 

1046 if self._parser and self._parser._buffer: 

1047 self._parser._buffer.socket_timeout = timeout 

1048 

1049 def set_tmp_settings( 

1050 self, 

1051 tmp_host_address: Optional[Union[str, object]] = SENTINEL, 

1052 tmp_relaxed_timeout: Optional[float] = None, 

1053 ): 

1054 """ 

1055 The value of SENTINEL is used to indicate that the property should not be updated. 

1056 """ 

1057 if tmp_host_address is not SENTINEL: 

1058 self.host = tmp_host_address 

1059 if tmp_relaxed_timeout != -1: 

1060 self.socket_timeout = tmp_relaxed_timeout 

1061 self.socket_connect_timeout = tmp_relaxed_timeout 

1062 

1063 def reset_tmp_settings( 

1064 self, 

1065 reset_host_address: bool = False, 

1066 reset_relaxed_timeout: bool = False, 

1067 ): 

1068 if reset_host_address: 

1069 self.host = self.orig_host_address 

1070 if reset_relaxed_timeout: 

1071 self.socket_timeout = self.orig_socket_timeout 

1072 self.socket_connect_timeout = self.orig_socket_connect_timeout 

1073 

1074 

1075class Connection(AbstractConnection): 

1076 "Manages TCP communication to and from a Redis server" 

1077 

1078 def __init__( 

1079 self, 

1080 host="localhost", 

1081 port=6379, 

1082 socket_keepalive=False, 

1083 socket_keepalive_options=None, 

1084 socket_type=0, 

1085 **kwargs, 

1086 ): 

1087 self.host = host 

1088 self.port = int(port) 

1089 self.socket_keepalive = socket_keepalive 

1090 self.socket_keepalive_options = socket_keepalive_options or {} 

1091 self.socket_type = socket_type 

1092 super().__init__(**kwargs) 

1093 

1094 def repr_pieces(self): 

1095 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

1096 if self.client_name: 

1097 pieces.append(("client_name", self.client_name)) 

1098 return pieces 

1099 

1100 def _connect(self): 

1101 "Create a TCP socket connection" 

1102 # we want to mimic what socket.create_connection does to support 

1103 # ipv4/ipv6, but we want to set options prior to calling 

1104 # socket.connect() 

1105 err = None 

1106 

1107 for res in socket.getaddrinfo( 

1108 self.host, self.port, self.socket_type, socket.SOCK_STREAM 

1109 ): 

1110 family, socktype, proto, canonname, socket_address = res 

1111 sock = None 

1112 try: 

1113 sock = socket.socket(family, socktype, proto) 

1114 # TCP_NODELAY 

1115 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

1116 

1117 # TCP_KEEPALIVE 

1118 if self.socket_keepalive: 

1119 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

1120 for k, v in self.socket_keepalive_options.items(): 

1121 sock.setsockopt(socket.IPPROTO_TCP, k, v) 

1122 

1123 # set the socket_connect_timeout before we connect 

1124 sock.settimeout(self.socket_connect_timeout) 

1125 

1126 # connect 

1127 sock.connect(socket_address) 

1128 

1129 # set the socket_timeout now that we're connected 

1130 sock.settimeout(self.socket_timeout) 

1131 return sock 

1132 

1133 except OSError as _: 

1134 err = _ 

1135 if sock is not None: 

1136 try: 

1137 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1138 except OSError: 

1139 pass 

1140 sock.close() 

1141 

1142 if err is not None: 

1143 raise err 

1144 raise OSError("socket.getaddrinfo returned an empty list") 

1145 

1146 def _host_error(self): 

1147 return f"{self.host}:{self.port}" 

1148 

1149 

1150class CacheProxyConnection(ConnectionInterface): 

1151 DUMMY_CACHE_VALUE = b"foo" 

1152 MIN_ALLOWED_VERSION = "7.4.0" 

1153 DEFAULT_SERVER_NAME = "redis" 

1154 

1155 def __init__( 

1156 self, 

1157 conn: ConnectionInterface, 

1158 cache: CacheInterface, 

1159 pool_lock: threading.RLock, 

1160 ): 

1161 self.pid = os.getpid() 

1162 self._conn = conn 

1163 self.retry = self._conn.retry 

1164 self.host = self._conn.host 

1165 self.port = self._conn.port 

1166 self.credential_provider = conn.credential_provider 

1167 self._pool_lock = pool_lock 

1168 self._cache = cache 

1169 self._cache_lock = threading.RLock() 

1170 self._current_command_cache_key = None 

1171 self._current_options = None 

1172 self.register_connect_callback(self._enable_tracking_callback) 

1173 

1174 def repr_pieces(self): 

1175 return self._conn.repr_pieces() 

1176 

1177 def register_connect_callback(self, callback): 

1178 self._conn.register_connect_callback(callback) 

1179 

1180 def deregister_connect_callback(self, callback): 

1181 self._conn.deregister_connect_callback(callback) 

1182 

1183 def set_parser(self, parser_class): 

1184 self._conn.set_parser(parser_class) 

1185 

1186 def connect(self): 

1187 self._conn.connect() 

1188 

1189 server_name = self._conn.handshake_metadata.get(b"server", None) 

1190 if server_name is None: 

1191 server_name = self._conn.handshake_metadata.get("server", None) 

1192 server_ver = self._conn.handshake_metadata.get(b"version", None) 

1193 if server_ver is None: 

1194 server_ver = self._conn.handshake_metadata.get("version", None) 

1195 if server_ver is None or server_ver is None: 

1196 raise ConnectionError("Cannot retrieve information about server version") 

1197 

1198 server_ver = ensure_string(server_ver) 

1199 server_name = ensure_string(server_name) 

1200 

1201 if ( 

1202 server_name != self.DEFAULT_SERVER_NAME 

1203 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 

1204 ): 

1205 raise ConnectionError( 

1206 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 

1207 ) 

1208 

1209 def on_connect(self): 

1210 self._conn.on_connect() 

1211 

1212 def disconnect(self, *args): 

1213 with self._cache_lock: 

1214 self._cache.flush() 

1215 self._conn.disconnect(*args) 

1216 

1217 def check_health(self): 

1218 self._conn.check_health() 

1219 

1220 def send_packed_command(self, command, check_health=True): 

1221 # TODO: Investigate if it's possible to unpack command 

1222 # or extract keys from packed command 

1223 self._conn.send_packed_command(command) 

1224 

1225 def send_command(self, *args, **kwargs): 

1226 self._process_pending_invalidations() 

1227 

1228 with self._cache_lock: 

1229 # Command is write command or not allowed 

1230 # to be cached. 

1231 if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())): 

1232 self._current_command_cache_key = None 

1233 self._conn.send_command(*args, **kwargs) 

1234 return 

1235 

1236 if kwargs.get("keys") is None: 

1237 raise ValueError("Cannot create cache key.") 

1238 

1239 # Creates cache key. 

1240 self._current_command_cache_key = CacheKey( 

1241 command=args[0], redis_keys=tuple(kwargs.get("keys")) 

1242 ) 

1243 

1244 with self._cache_lock: 

1245 # We have to trigger invalidation processing in case if 

1246 # it was cached by another connection to avoid 

1247 # queueing invalidations in stale connections. 

1248 if self._cache.get(self._current_command_cache_key): 

1249 entry = self._cache.get(self._current_command_cache_key) 

1250 

1251 if entry.connection_ref != self._conn: 

1252 with self._pool_lock: 

1253 while entry.connection_ref.can_read(): 

1254 entry.connection_ref.read_response(push_request=True) 

1255 

1256 return 

1257 

1258 # Set temporary entry value to prevent 

1259 # race condition from another connection. 

1260 self._cache.set( 

1261 CacheEntry( 

1262 cache_key=self._current_command_cache_key, 

1263 cache_value=self.DUMMY_CACHE_VALUE, 

1264 status=CacheEntryStatus.IN_PROGRESS, 

1265 connection_ref=self._conn, 

1266 ) 

1267 ) 

1268 

1269 # Send command over socket only if it's allowed 

1270 # read-only command that not yet cached. 

1271 self._conn.send_command(*args, **kwargs) 

1272 

1273 def can_read(self, timeout=0): 

1274 return self._conn.can_read(timeout) 

1275 

1276 def read_response( 

1277 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False 

1278 ): 

1279 with self._cache_lock: 

1280 # Check if command response exists in a cache and it's not in progress. 

1281 if ( 

1282 self._current_command_cache_key is not None 

1283 and self._cache.get(self._current_command_cache_key) is not None 

1284 and self._cache.get(self._current_command_cache_key).status 

1285 != CacheEntryStatus.IN_PROGRESS 

1286 ): 

1287 res = copy.deepcopy( 

1288 self._cache.get(self._current_command_cache_key).cache_value 

1289 ) 

1290 self._current_command_cache_key = None 

1291 return res 

1292 

1293 response = self._conn.read_response( 

1294 disable_decoding=disable_decoding, 

1295 disconnect_on_error=disconnect_on_error, 

1296 push_request=push_request, 

1297 ) 

1298 

1299 with self._cache_lock: 

1300 # Prevent not-allowed command from caching. 

1301 if self._current_command_cache_key is None: 

1302 return response 

1303 # If response is None prevent from caching. 

1304 if response is None: 

1305 self._cache.delete_by_cache_keys([self._current_command_cache_key]) 

1306 return response 

1307 

1308 cache_entry = self._cache.get(self._current_command_cache_key) 

1309 

1310 # Cache only responses that still valid 

1311 # and wasn't invalidated by another connection in meantime. 

1312 if cache_entry is not None: 

1313 cache_entry.status = CacheEntryStatus.VALID 

1314 cache_entry.cache_value = response 

1315 self._cache.set(cache_entry) 

1316 

1317 self._current_command_cache_key = None 

1318 

1319 return response 

1320 

1321 def pack_command(self, *args): 

1322 return self._conn.pack_command(*args) 

1323 

1324 def pack_commands(self, commands): 

1325 return self._conn.pack_commands(commands) 

1326 

1327 @property 

1328 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1329 return self._conn.handshake_metadata 

1330 

1331 def _connect(self): 

1332 self._conn._connect() 

1333 

1334 def _host_error(self): 

1335 self._conn._host_error() 

1336 

1337 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: 

1338 conn.send_command("CLIENT", "TRACKING", "ON") 

1339 conn.read_response() 

1340 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) 

1341 

1342 def _process_pending_invalidations(self): 

1343 while self.can_read(): 

1344 self._conn.read_response(push_request=True) 

1345 

1346 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): 

1347 with self._cache_lock: 

1348 # Flush cache when DB flushed on server-side 

1349 if data[1] is None: 

1350 self._cache.flush() 

1351 else: 

1352 self._cache.delete_by_redis_keys(data[1]) 

1353 

1354 def get_protocol(self): 

1355 return self._conn.get_protocol() 

1356 

1357 def set_re_auth_token(self, token: TokenInterface): 

1358 self._conn.set_re_auth_token(token) 

1359 

1360 def re_auth(self): 

1361 self._conn.re_auth() 

1362 

1363 

1364class SSLConnection(Connection): 

1365 """Manages SSL connections to and from the Redis server(s). 

1366 This class extends the Connection class, adding SSL functionality, and making 

1367 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

1368 """ # noqa 

1369 

1370 def __init__( 

1371 self, 

1372 ssl_keyfile=None, 

1373 ssl_certfile=None, 

1374 ssl_cert_reqs="required", 

1375 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None, 

1376 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None, 

1377 ssl_ca_certs=None, 

1378 ssl_ca_data=None, 

1379 ssl_check_hostname=True, 

1380 ssl_ca_path=None, 

1381 ssl_password=None, 

1382 ssl_validate_ocsp=False, 

1383 ssl_validate_ocsp_stapled=False, 

1384 ssl_ocsp_context=None, 

1385 ssl_ocsp_expected_cert=None, 

1386 ssl_min_version=None, 

1387 ssl_ciphers=None, 

1388 **kwargs, 

1389 ): 

1390 """Constructor 

1391 

1392 Args: 

1393 ssl_keyfile: Path to an ssl private key. Defaults to None. 

1394 ssl_certfile: Path to an ssl certificate. Defaults to None. 

1395 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), 

1396 or an ssl.VerifyMode. Defaults to "required". 

1397 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None. 

1398 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None. 

1399 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. 

1400 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. 

1401 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True. 

1402 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. 

1403 ssl_password: Password for unlocking an encrypted private key. Defaults to None. 

1404 

1405 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) 

1406 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response 

1407 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert 

1408 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. 

1409 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. 

1410 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information. 

1411 

1412 Raises: 

1413 RedisError 

1414 """ # noqa 

1415 if not SSL_AVAILABLE: 

1416 raise RedisError("Python wasn't built with SSL support") 

1417 

1418 self.keyfile = ssl_keyfile 

1419 self.certfile = ssl_certfile 

1420 if ssl_cert_reqs is None: 

1421 ssl_cert_reqs = ssl.CERT_NONE 

1422 elif isinstance(ssl_cert_reqs, str): 

1423 CERT_REQS = { # noqa: N806 

1424 "none": ssl.CERT_NONE, 

1425 "optional": ssl.CERT_OPTIONAL, 

1426 "required": ssl.CERT_REQUIRED, 

1427 } 

1428 if ssl_cert_reqs not in CERT_REQS: 

1429 raise RedisError( 

1430 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" 

1431 ) 

1432 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 

1433 self.cert_reqs = ssl_cert_reqs 

1434 self.ssl_include_verify_flags = ssl_include_verify_flags 

1435 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags 

1436 self.ca_certs = ssl_ca_certs 

1437 self.ca_data = ssl_ca_data 

1438 self.ca_path = ssl_ca_path 

1439 self.check_hostname = ( 

1440 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1441 ) 

1442 self.certificate_password = ssl_password 

1443 self.ssl_validate_ocsp = ssl_validate_ocsp 

1444 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled 

1445 self.ssl_ocsp_context = ssl_ocsp_context 

1446 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert 

1447 self.ssl_min_version = ssl_min_version 

1448 self.ssl_ciphers = ssl_ciphers 

1449 super().__init__(**kwargs) 

1450 

1451 def _connect(self): 

1452 """ 

1453 Wrap the socket with SSL support, handling potential errors. 

1454 """ 

1455 sock = super()._connect() 

1456 try: 

1457 return self._wrap_socket_with_ssl(sock) 

1458 except (OSError, RedisError): 

1459 sock.close() 

1460 raise 

1461 

1462 def _wrap_socket_with_ssl(self, sock): 

1463 """ 

1464 Wraps the socket with SSL support. 

1465 

1466 Args: 

1467 sock: The plain socket to wrap with SSL. 

1468 

1469 Returns: 

1470 An SSL wrapped socket. 

1471 """ 

1472 context = ssl.create_default_context() 

1473 context.check_hostname = self.check_hostname 

1474 context.verify_mode = self.cert_reqs 

1475 if self.ssl_include_verify_flags: 

1476 for flag in self.ssl_include_verify_flags: 

1477 context.verify_flags |= flag 

1478 if self.ssl_exclude_verify_flags: 

1479 for flag in self.ssl_exclude_verify_flags: 

1480 context.verify_flags &= ~flag 

1481 if self.certfile or self.keyfile: 

1482 context.load_cert_chain( 

1483 certfile=self.certfile, 

1484 keyfile=self.keyfile, 

1485 password=self.certificate_password, 

1486 ) 

1487 if ( 

1488 self.ca_certs is not None 

1489 or self.ca_path is not None 

1490 or self.ca_data is not None 

1491 ): 

1492 context.load_verify_locations( 

1493 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1494 ) 

1495 if self.ssl_min_version is not None: 

1496 context.minimum_version = self.ssl_min_version 

1497 if self.ssl_ciphers: 

1498 context.set_ciphers(self.ssl_ciphers) 

1499 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: 

1500 raise RedisError("cryptography is not installed.") 

1501 

1502 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: 

1503 raise RedisError( 

1504 "Either an OCSP staple or pure OCSP connection must be validated " 

1505 "- not both." 

1506 ) 

1507 

1508 sslsock = context.wrap_socket(sock, server_hostname=self.host) 

1509 

1510 # validation for the stapled case 

1511 if self.ssl_validate_ocsp_stapled: 

1512 import OpenSSL 

1513 

1514 from .ocsp import ocsp_staple_verifier 

1515 

1516 # if a context is provided use it - otherwise, a basic context 

1517 if self.ssl_ocsp_context is None: 

1518 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) 

1519 staple_ctx.use_certificate_file(self.certfile) 

1520 staple_ctx.use_privatekey_file(self.keyfile) 

1521 else: 

1522 staple_ctx = self.ssl_ocsp_context 

1523 

1524 staple_ctx.set_ocsp_client_callback( 

1525 ocsp_staple_verifier, self.ssl_ocsp_expected_cert 

1526 ) 

1527 

1528 # need another socket 

1529 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) 

1530 con.request_ocsp() 

1531 con.connect((self.host, self.port)) 

1532 con.do_handshake() 

1533 con.shutdown() 

1534 return sslsock 

1535 

1536 # pure ocsp validation 

1537 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: 

1538 from .ocsp import OCSPVerifier 

1539 

1540 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) 

1541 if o.is_valid(): 

1542 return sslsock 

1543 else: 

1544 raise ConnectionError("ocsp validation error") 

1545 return sslsock 

1546 

1547 

1548class UnixDomainSocketConnection(AbstractConnection): 

1549 "Manages UDS communication to and from a Redis server" 

1550 

1551 def __init__(self, path="", socket_timeout=None, **kwargs): 

1552 super().__init__(**kwargs) 

1553 self.path = path 

1554 self.socket_timeout = socket_timeout 

1555 

1556 def repr_pieces(self): 

1557 pieces = [("path", self.path), ("db", self.db)] 

1558 if self.client_name: 

1559 pieces.append(("client_name", self.client_name)) 

1560 return pieces 

1561 

1562 def _connect(self): 

1563 "Create a Unix domain socket connection" 

1564 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

1565 sock.settimeout(self.socket_connect_timeout) 

1566 try: 

1567 sock.connect(self.path) 

1568 except OSError: 

1569 # Prevent ResourceWarnings for unclosed sockets. 

1570 try: 

1571 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1572 except OSError: 

1573 pass 

1574 sock.close() 

1575 raise 

1576 sock.settimeout(self.socket_timeout) 

1577 return sock 

1578 

1579 def _host_error(self): 

1580 return self.path 

1581 

1582 

1583FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1584 

1585 

1586def to_bool(value): 

1587 if value is None or value == "": 

1588 return None 

1589 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1590 return False 

1591 return bool(value) 

1592 

1593 

1594def parse_ssl_verify_flags(value): 

1595 # flags are passed in as a string representation of a list, 

1596 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

1597 verify_flags_str = value.replace("[", "").replace("]", "") 

1598 

1599 verify_flags = [] 

1600 for flag in verify_flags_str.split(","): 

1601 flag = flag.strip() 

1602 if not hasattr(VerifyFlags, flag): 

1603 raise ValueError(f"Invalid ssl verify flag: {flag}") 

1604 verify_flags.append(getattr(VerifyFlags, flag)) 

1605 return verify_flags 

1606 

1607 

1608URL_QUERY_ARGUMENT_PARSERS = { 

1609 "db": int, 

1610 "socket_timeout": float, 

1611 "socket_connect_timeout": float, 

1612 "socket_keepalive": to_bool, 

1613 "retry_on_timeout": to_bool, 

1614 "retry_on_error": list, 

1615 "max_connections": int, 

1616 "health_check_interval": int, 

1617 "ssl_check_hostname": to_bool, 

1618 "ssl_include_verify_flags": parse_ssl_verify_flags, 

1619 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

1620 "timeout": float, 

1621} 

1622 

1623 

1624def parse_url(url): 

1625 if not ( 

1626 url.startswith("redis://") 

1627 or url.startswith("rediss://") 

1628 or url.startswith("unix://") 

1629 ): 

1630 raise ValueError( 

1631 "Redis URL must specify one of the following " 

1632 "schemes (redis://, rediss://, unix://)" 

1633 ) 

1634 

1635 url = urlparse(url) 

1636 kwargs = {} 

1637 

1638 for name, value in parse_qs(url.query).items(): 

1639 if value and len(value) > 0: 

1640 value = unquote(value[0]) 

1641 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1642 if parser: 

1643 try: 

1644 kwargs[name] = parser(value) 

1645 except (TypeError, ValueError): 

1646 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1647 else: 

1648 kwargs[name] = value 

1649 

1650 if url.username: 

1651 kwargs["username"] = unquote(url.username) 

1652 if url.password: 

1653 kwargs["password"] = unquote(url.password) 

1654 

1655 # We only support redis://, rediss:// and unix:// schemes. 

1656 if url.scheme == "unix": 

1657 if url.path: 

1658 kwargs["path"] = unquote(url.path) 

1659 kwargs["connection_class"] = UnixDomainSocketConnection 

1660 

1661 else: # implied: url.scheme in ("redis", "rediss"): 

1662 if url.hostname: 

1663 kwargs["host"] = unquote(url.hostname) 

1664 if url.port: 

1665 kwargs["port"] = int(url.port) 

1666 

1667 # If there's a path argument, use it as the db argument if a 

1668 # querystring value wasn't specified 

1669 if url.path and "db" not in kwargs: 

1670 try: 

1671 kwargs["db"] = int(unquote(url.path).replace("/", "")) 

1672 except (AttributeError, ValueError): 

1673 pass 

1674 

1675 if url.scheme == "rediss": 

1676 kwargs["connection_class"] = SSLConnection 

1677 

1678 return kwargs 

1679 

1680 

1681_CP = TypeVar("_CP", bound="ConnectionPool") 

1682 

1683 

1684class ConnectionPool: 

1685 """ 

1686 Create a connection pool. ``If max_connections`` is set, then this 

1687 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's 

1688 limit is reached. 

1689 

1690 By default, TCP connections are created unless ``connection_class`` 

1691 is specified. Use class:`.UnixDomainSocketConnection` for 

1692 unix sockets. 

1693 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

1694 

1695 Any additional keyword arguments are passed to the constructor of 

1696 ``connection_class``. 

1697 """ 

1698 

1699 @classmethod 

1700 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

1701 """ 

1702 Return a connection pool configured from the given URL. 

1703 

1704 For example:: 

1705 

1706 redis://[[username]:[password]]@localhost:6379/0 

1707 rediss://[[username]:[password]]@localhost:6379/0 

1708 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1709 

1710 Three URL schemes are supported: 

1711 

1712 - `redis://` creates a TCP socket connection. See more at: 

1713 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1714 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1715 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1716 - ``unix://``: creates a Unix Domain Socket connection. 

1717 

1718 The username, password, hostname, path and all querystring values 

1719 are passed through urllib.parse.unquote in order to replace any 

1720 percent-encoded values with their corresponding characters. 

1721 

1722 There are several ways to specify a database number. The first value 

1723 found will be used: 

1724 

1725 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1726 2. If using the redis:// or rediss:// schemes, the path argument 

1727 of the url, e.g. redis://localhost/0 

1728 3. A ``db`` keyword argument to this function. 

1729 

1730 If none of these options are specified, the default db=0 is used. 

1731 

1732 All querystring options are cast to their appropriate Python types. 

1733 Boolean arguments can be specified with string values "True"/"False" 

1734 or "Yes"/"No". Values that cannot be properly cast cause a 

1735 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1736 and keyword arguments are passed to the ``ConnectionPool``'s 

1737 class initializer. In the case of conflicting arguments, querystring 

1738 arguments always win. 

1739 """ 

1740 url_options = parse_url(url) 

1741 

1742 if "connection_class" in kwargs: 

1743 url_options["connection_class"] = kwargs["connection_class"] 

1744 

1745 kwargs.update(url_options) 

1746 return cls(**kwargs) 

1747 

1748 def __init__( 

1749 self, 

1750 connection_class=Connection, 

1751 max_connections: Optional[int] = None, 

1752 cache_factory: Optional[CacheFactoryInterface] = None, 

1753 **connection_kwargs, 

1754 ): 

1755 max_connections = max_connections or 2**31 

1756 if not isinstance(max_connections, int) or max_connections < 0: 

1757 raise ValueError('"max_connections" must be a positive integer') 

1758 

1759 self.connection_class = connection_class 

1760 self.connection_kwargs = connection_kwargs 

1761 self.max_connections = max_connections 

1762 self.cache = None 

1763 self._cache_factory = cache_factory 

1764 

1765 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): 

1766 if self.connection_kwargs.get("protocol") not in [3, "3"]: 

1767 raise RedisError("Client caching is only supported with RESP version 3") 

1768 

1769 cache = self.connection_kwargs.get("cache") 

1770 

1771 if cache is not None: 

1772 if not isinstance(cache, CacheInterface): 

1773 raise ValueError("Cache must implement CacheInterface") 

1774 

1775 self.cache = cache 

1776 else: 

1777 if self._cache_factory is not None: 

1778 self.cache = self._cache_factory.get_cache() 

1779 else: 

1780 self.cache = CacheFactory( 

1781 self.connection_kwargs.get("cache_config") 

1782 ).get_cache() 

1783 

1784 connection_kwargs.pop("cache", None) 

1785 connection_kwargs.pop("cache_config", None) 

1786 

1787 if self.connection_kwargs.get( 

1788 "maint_notifications_pool_handler" 

1789 ) or self.connection_kwargs.get("maint_notifications_config"): 

1790 if self.connection_kwargs.get("protocol") not in [3, "3"]: 

1791 raise RedisError( 

1792 "Push handlers on connection are only supported with RESP version 3" 

1793 ) 

1794 config = self.connection_kwargs.get("maint_notifications_config", None) or ( 

1795 self.connection_kwargs.get("maint_notifications_pool_handler").config 

1796 if self.connection_kwargs.get("maint_notifications_pool_handler") 

1797 else None 

1798 ) 

1799 

1800 if config and config.enabled: 

1801 self._update_connection_kwargs_for_maint_notifications() 

1802 

1803 self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) 

1804 if self._event_dispatcher is None: 

1805 self._event_dispatcher = EventDispatcher() 

1806 

1807 # a lock to protect the critical section in _checkpid(). 

1808 # this lock is acquired when the process id changes, such as 

1809 # after a fork. during this time, multiple threads in the child 

1810 # process could attempt to acquire this lock. the first thread 

1811 # to acquire the lock will reset the data structures and lock 

1812 # object of this pool. subsequent threads acquiring this lock 

1813 # will notice the first thread already did the work and simply 

1814 # release the lock. 

1815 

1816 self._fork_lock = threading.RLock() 

1817 self._lock = threading.RLock() 

1818 

1819 self.reset() 

1820 

1821 def __repr__(self) -> str: 

1822 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

1823 return ( 

1824 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

1825 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

1826 f"({conn_kwargs})>)>" 

1827 ) 

1828 

1829 def get_protocol(self): 

1830 """ 

1831 Returns: 

1832 The RESP protocol version, or ``None`` if the protocol is not specified, 

1833 in which case the server default will be used. 

1834 """ 

1835 return self.connection_kwargs.get("protocol", None) 

1836 

1837 def maint_notifications_pool_handler_enabled(self): 

1838 """ 

1839 Returns: 

1840 True if the maintenance notifications pool handler is enabled, False otherwise. 

1841 """ 

1842 maint_notifications_config = self.connection_kwargs.get( 

1843 "maint_notifications_config", None 

1844 ) 

1845 

1846 return maint_notifications_config and maint_notifications_config.enabled 

1847 

1848 def set_maint_notifications_pool_handler( 

1849 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

1850 ): 

1851 self.connection_kwargs.update( 

1852 { 

1853 "maint_notifications_pool_handler": maint_notifications_pool_handler, 

1854 "maint_notifications_config": maint_notifications_pool_handler.config, 

1855 } 

1856 ) 

1857 self._update_connection_kwargs_for_maint_notifications() 

1858 

1859 self._update_maint_notifications_configs_for_connections( 

1860 maint_notifications_pool_handler 

1861 ) 

1862 

1863 def _update_maint_notifications_configs_for_connections( 

1864 self, maint_notifications_pool_handler 

1865 ): 

1866 """Update the maintenance notifications config for all connections in the pool.""" 

1867 with self._lock: 

1868 for conn in self._available_connections: 

1869 conn.set_maint_notifications_pool_handler( 

1870 maint_notifications_pool_handler 

1871 ) 

1872 conn.maint_notifications_config = ( 

1873 maint_notifications_pool_handler.config 

1874 ) 

1875 for conn in self._in_use_connections: 

1876 conn.set_maint_notifications_pool_handler( 

1877 maint_notifications_pool_handler 

1878 ) 

1879 conn.maint_notifications_config = ( 

1880 maint_notifications_pool_handler.config 

1881 ) 

1882 

1883 def _update_connection_kwargs_for_maint_notifications(self): 

1884 """Store original connection parameters for maintenance notifications.""" 

1885 if self.connection_kwargs.get("orig_host_address", None) is None: 

1886 # If orig_host_address is None it means we haven't 

1887 # configured the original values yet 

1888 self.connection_kwargs.update( 

1889 { 

1890 "orig_host_address": self.connection_kwargs.get("host"), 

1891 "orig_socket_timeout": self.connection_kwargs.get( 

1892 "socket_timeout", None 

1893 ), 

1894 "orig_socket_connect_timeout": self.connection_kwargs.get( 

1895 "socket_connect_timeout", None 

1896 ), 

1897 } 

1898 ) 

1899 

1900 def reset(self) -> None: 

1901 self._created_connections = 0 

1902 self._available_connections = [] 

1903 self._in_use_connections = set() 

1904 

1905 # this must be the last operation in this method. while reset() is 

1906 # called when holding _fork_lock, other threads in this process 

1907 # can call _checkpid() which compares self.pid and os.getpid() without 

1908 # holding any lock (for performance reasons). keeping this assignment 

1909 # as the last operation ensures that those other threads will also 

1910 # notice a pid difference and block waiting for the first thread to 

1911 # release _fork_lock. when each of these threads eventually acquire 

1912 # _fork_lock, they will notice that another thread already called 

1913 # reset() and they will immediately release _fork_lock and continue on. 

1914 self.pid = os.getpid() 

1915 

1916 def _checkpid(self) -> None: 

1917 # _checkpid() attempts to keep ConnectionPool fork-safe on modern 

1918 # systems. this is called by all ConnectionPool methods that 

1919 # manipulate the pool's state such as get_connection() and release(). 

1920 # 

1921 # _checkpid() determines whether the process has forked by comparing 

1922 # the current process id to the process id saved on the ConnectionPool 

1923 # instance. if these values are the same, _checkpid() simply returns. 

1924 # 

1925 # when the process ids differ, _checkpid() assumes that the process 

1926 # has forked and that we're now running in the child process. the child 

1927 # process cannot use the parent's file descriptors (e.g., sockets). 

1928 # therefore, when _checkpid() sees the process id change, it calls 

1929 # reset() in order to reinitialize the child's ConnectionPool. this 

1930 # will cause the child to make all new connection objects. 

1931 # 

1932 # _checkpid() is protected by self._fork_lock to ensure that multiple 

1933 # threads in the child process do not call reset() multiple times. 

1934 # 

1935 # there is an extremely small chance this could fail in the following 

1936 # scenario: 

1937 # 1. process A calls _checkpid() for the first time and acquires 

1938 # self._fork_lock. 

1939 # 2. while holding self._fork_lock, process A forks (the fork() 

1940 # could happen in a different thread owned by process A) 

1941 # 3. process B (the forked child process) inherits the 

1942 # ConnectionPool's state from the parent. that state includes 

1943 # a locked _fork_lock. process B will not be notified when 

1944 # process A releases the _fork_lock and will thus never be 

1945 # able to acquire the _fork_lock. 

1946 # 

1947 # to mitigate this possible deadlock, _checkpid() will only wait 5 

1948 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in 

1949 # that time it is assumed that the child is deadlocked and a 

1950 # redis.ChildDeadlockedError error is raised. 

1951 if self.pid != os.getpid(): 

1952 acquired = self._fork_lock.acquire(timeout=5) 

1953 if not acquired: 

1954 raise ChildDeadlockedError 

1955 # reset() the instance for the new process if another thread 

1956 # hasn't already done so 

1957 try: 

1958 if self.pid != os.getpid(): 

1959 self.reset() 

1960 finally: 

1961 self._fork_lock.release() 

1962 

1963 @deprecated_args( 

1964 args_to_warn=["*"], 

1965 reason="Use get_connection() without args instead", 

1966 version="5.3.0", 

1967 ) 

1968 def get_connection(self, command_name=None, *keys, **options) -> "Connection": 

1969 "Get a connection from the pool" 

1970 

1971 self._checkpid() 

1972 with self._lock: 

1973 try: 

1974 connection = self._available_connections.pop() 

1975 except IndexError: 

1976 connection = self.make_connection() 

1977 self._in_use_connections.add(connection) 

1978 

1979 try: 

1980 # ensure this connection is connected to Redis 

1981 connection.connect() 

1982 # connections that the pool provides should be ready to send 

1983 # a command. if not, the connection was either returned to the 

1984 # pool before all data has been read or the socket has been 

1985 # closed. either way, reconnect and verify everything is good. 

1986 try: 

1987 if ( 

1988 connection.can_read() 

1989 and self.cache is None 

1990 and not self.maint_notifications_pool_handler_enabled() 

1991 ): 

1992 raise ConnectionError("Connection has data") 

1993 except (ConnectionError, TimeoutError, OSError): 

1994 connection.disconnect() 

1995 connection.connect() 

1996 if connection.can_read(): 

1997 raise ConnectionError("Connection not ready") 

1998 except BaseException: 

1999 # release the connection back to the pool so that we don't 

2000 # leak it 

2001 self.release(connection) 

2002 raise 

2003 return connection 

2004 

2005 def get_encoder(self) -> Encoder: 

2006 "Return an encoder based on encoding settings" 

2007 kwargs = self.connection_kwargs 

2008 return Encoder( 

2009 encoding=kwargs.get("encoding", "utf-8"), 

2010 encoding_errors=kwargs.get("encoding_errors", "strict"), 

2011 decode_responses=kwargs.get("decode_responses", False), 

2012 ) 

2013 

2014 def make_connection(self) -> "ConnectionInterface": 

2015 "Create a new connection" 

2016 if self._created_connections >= self.max_connections: 

2017 raise MaxConnectionsError("Too many connections") 

2018 self._created_connections += 1 

2019 

2020 kwargs = dict(self.connection_kwargs) 

2021 

2022 if self.cache is not None: 

2023 return CacheProxyConnection( 

2024 self.connection_class(**kwargs), self.cache, self._lock 

2025 ) 

2026 return self.connection_class(**kwargs) 

2027 

2028 def release(self, connection: "Connection") -> None: 

2029 "Releases the connection back to the pool" 

2030 self._checkpid() 

2031 with self._lock: 

2032 try: 

2033 self._in_use_connections.remove(connection) 

2034 except KeyError: 

2035 # Gracefully fail when a connection is returned to this pool 

2036 # that the pool doesn't actually own 

2037 return 

2038 

2039 if self.owns_connection(connection): 

2040 if connection.should_reconnect(): 

2041 connection.disconnect() 

2042 self._available_connections.append(connection) 

2043 self._event_dispatcher.dispatch( 

2044 AfterConnectionReleasedEvent(connection) 

2045 ) 

2046 else: 

2047 # Pool doesn't own this connection, do not add it back 

2048 # to the pool. 

2049 # The created connections count should not be changed, 

2050 # because the connection was not created by the pool. 

2051 connection.disconnect() 

2052 return 

2053 

2054 def owns_connection(self, connection: "Connection") -> int: 

2055 return connection.pid == self.pid 

2056 

2057 def disconnect(self, inuse_connections: bool = True) -> None: 

2058 """ 

2059 Disconnects connections in the pool 

2060 

2061 If ``inuse_connections`` is True, disconnect connections that are 

2062 current in use, potentially by other threads. Otherwise only disconnect 

2063 connections that are idle in the pool. 

2064 """ 

2065 self._checkpid() 

2066 with self._lock: 

2067 if inuse_connections: 

2068 connections = chain( 

2069 self._available_connections, self._in_use_connections 

2070 ) 

2071 else: 

2072 connections = self._available_connections 

2073 

2074 for connection in connections: 

2075 connection.disconnect() 

2076 

2077 def close(self) -> None: 

2078 """Close the pool, disconnecting all connections""" 

2079 self.disconnect() 

2080 

2081 def set_retry(self, retry: Retry) -> None: 

2082 self.connection_kwargs.update({"retry": retry}) 

2083 for conn in self._available_connections: 

2084 conn.retry = retry 

2085 for conn in self._in_use_connections: 

2086 conn.retry = retry 

2087 

2088 def re_auth_callback(self, token: TokenInterface): 

2089 with self._lock: 

2090 for conn in self._available_connections: 

2091 conn.retry.call_with_retry( 

2092 lambda: conn.send_command( 

2093 "AUTH", token.try_get("oid"), token.get_value() 

2094 ), 

2095 lambda error: self._mock(error), 

2096 ) 

2097 conn.retry.call_with_retry( 

2098 lambda: conn.read_response(), lambda error: self._mock(error) 

2099 ) 

2100 for conn in self._in_use_connections: 

2101 conn.set_re_auth_token(token) 

2102 

2103 def _should_update_connection( 

2104 self, 

2105 conn: "Connection", 

2106 matching_pattern: Literal[ 

2107 "connected_address", "configured_address", "notification_hash" 

2108 ] = "connected_address", 

2109 matching_address: Optional[str] = None, 

2110 matching_notification_hash: Optional[int] = None, 

2111 ) -> bool: 

2112 """ 

2113 Check if the connection should be updated based on the matching criteria. 

2114 """ 

2115 if matching_pattern == "connected_address": 

2116 if matching_address and conn.getpeername() != matching_address: 

2117 return False 

2118 elif matching_pattern == "configured_address": 

2119 if matching_address and conn.host != matching_address: 

2120 return False 

2121 elif matching_pattern == "notification_hash": 

2122 if ( 

2123 matching_notification_hash 

2124 and conn.maintenance_notification_hash != matching_notification_hash 

2125 ): 

2126 return False 

2127 return True 

2128 

2129 def update_connection_settings( 

2130 self, 

2131 conn: "Connection", 

2132 state: Optional["MaintenanceState"] = None, 

2133 maintenance_notification_hash: Optional[int] = None, 

2134 host_address: Optional[str] = None, 

2135 relaxed_timeout: Optional[float] = None, 

2136 update_notification_hash: bool = False, 

2137 reset_host_address: bool = False, 

2138 reset_relaxed_timeout: bool = False, 

2139 ): 

2140 """ 

2141 Update the settings for a single connection. 

2142 """ 

2143 if state: 

2144 conn.maintenance_state = state 

2145 

2146 if update_notification_hash: 

2147 # update the notification hash only if requested 

2148 conn.maintenance_notification_hash = maintenance_notification_hash 

2149 

2150 if host_address is not None: 

2151 conn.set_tmp_settings(tmp_host_address=host_address) 

2152 

2153 if relaxed_timeout is not None: 

2154 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout) 

2155 

2156 if reset_relaxed_timeout or reset_host_address: 

2157 conn.reset_tmp_settings( 

2158 reset_host_address=reset_host_address, 

2159 reset_relaxed_timeout=reset_relaxed_timeout, 

2160 ) 

2161 

2162 conn.update_current_socket_timeout(relaxed_timeout) 

2163 

2164 def update_connections_settings( 

2165 self, 

2166 state: Optional["MaintenanceState"] = None, 

2167 maintenance_notification_hash: Optional[int] = None, 

2168 host_address: Optional[str] = None, 

2169 relaxed_timeout: Optional[float] = None, 

2170 matching_address: Optional[str] = None, 

2171 matching_notification_hash: Optional[int] = None, 

2172 matching_pattern: Literal[ 

2173 "connected_address", "configured_address", "notification_hash" 

2174 ] = "connected_address", 

2175 update_notification_hash: bool = False, 

2176 reset_host_address: bool = False, 

2177 reset_relaxed_timeout: bool = False, 

2178 include_free_connections: bool = True, 

2179 ): 

2180 """ 

2181 Update the settings for all matching connections in the pool. 

2182 

2183 This method does not create new connections. 

2184 This method does not affect the connection kwargs. 

2185 

2186 :param state: The maintenance state to set for the connection. 

2187 :param maintenance_notification_hash: The hash of the maintenance notification 

2188 to set for the connection. 

2189 :param host_address: The host address to set for the connection. 

2190 :param relaxed_timeout: The relaxed timeout to set for the connection. 

2191 :param matching_address: The address to match for the connection. 

2192 :param matching_notification_hash: The notification hash to match for the connection. 

2193 :param matching_pattern: The pattern to match for the connection. 

2194 :param update_notification_hash: Whether to update the notification hash for the connection. 

2195 :param reset_host_address: Whether to reset the host address to the original address. 

2196 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout. 

2197 :param include_free_connections: Whether to include free/available connections. 

2198 """ 

2199 with self._lock: 

2200 for conn in self._in_use_connections: 

2201 if self._should_update_connection( 

2202 conn, 

2203 matching_pattern, 

2204 matching_address, 

2205 matching_notification_hash, 

2206 ): 

2207 self.update_connection_settings( 

2208 conn, 

2209 state=state, 

2210 maintenance_notification_hash=maintenance_notification_hash, 

2211 host_address=host_address, 

2212 relaxed_timeout=relaxed_timeout, 

2213 update_notification_hash=update_notification_hash, 

2214 reset_host_address=reset_host_address, 

2215 reset_relaxed_timeout=reset_relaxed_timeout, 

2216 ) 

2217 

2218 if include_free_connections: 

2219 for conn in self._available_connections: 

2220 if self._should_update_connection( 

2221 conn, 

2222 matching_pattern, 

2223 matching_address, 

2224 matching_notification_hash, 

2225 ): 

2226 self.update_connection_settings( 

2227 conn, 

2228 state=state, 

2229 maintenance_notification_hash=maintenance_notification_hash, 

2230 host_address=host_address, 

2231 relaxed_timeout=relaxed_timeout, 

2232 update_notification_hash=update_notification_hash, 

2233 reset_host_address=reset_host_address, 

2234 reset_relaxed_timeout=reset_relaxed_timeout, 

2235 ) 

2236 

2237 def update_connection_kwargs( 

2238 self, 

2239 **kwargs, 

2240 ): 

2241 """ 

2242 Update the connection kwargs for all future connections. 

2243 

2244 This method updates the connection kwargs for all future connections created by the pool. 

2245 Existing connections are not affected. 

2246 """ 

2247 self.connection_kwargs.update(kwargs) 

2248 

2249 def update_active_connections_for_reconnect( 

2250 self, 

2251 moving_address_src: Optional[str] = None, 

2252 ): 

2253 """ 

2254 Mark all active connections for reconnect. 

2255 This is used when a cluster node is migrated to a different address. 

2256 

2257 :param moving_address_src: The address of the node that is being moved. 

2258 """ 

2259 with self._lock: 

2260 for conn in self._in_use_connections: 

2261 if self._should_update_connection( 

2262 conn, "connected_address", moving_address_src 

2263 ): 

2264 conn.mark_for_reconnect() 

2265 

2266 def disconnect_free_connections( 

2267 self, 

2268 moving_address_src: Optional[str] = None, 

2269 ): 

2270 """ 

2271 Disconnect all free/available connections. 

2272 This is used when a cluster node is migrated to a different address. 

2273 

2274 :param moving_address_src: The address of the node that is being moved. 

2275 """ 

2276 with self._lock: 

2277 for conn in self._available_connections: 

2278 if self._should_update_connection( 

2279 conn, "connected_address", moving_address_src 

2280 ): 

2281 conn.disconnect() 

2282 

2283 async def _mock(self, error: RedisError): 

2284 """ 

2285 Dummy functions, needs to be passed as error callback to retry object. 

2286 :param error: 

2287 :return: 

2288 """ 

2289 pass 

2290 

2291 

2292class BlockingConnectionPool(ConnectionPool): 

2293 """ 

2294 Thread-safe blocking connection pool:: 

2295 

2296 >>> from redis.client import Redis 

2297 >>> client = Redis(connection_pool=BlockingConnectionPool()) 

2298 

2299 It performs the same function as the default 

2300 :py:class:`~redis.ConnectionPool` implementation, in that, 

2301 it maintains a pool of reusable connections that can be shared by 

2302 multiple redis clients (safely across threads if required). 

2303 

2304 The difference is that, in the event that a client tries to get a 

2305 connection from the pool when all of connections are in use, rather than 

2306 raising a :py:class:`~redis.ConnectionError` (as the default 

2307 :py:class:`~redis.ConnectionPool` implementation does), it 

2308 makes the client wait ("blocks") for a specified number of seconds until 

2309 a connection becomes available. 

2310 

2311 Use ``max_connections`` to increase / decrease the pool size:: 

2312 

2313 >>> pool = BlockingConnectionPool(max_connections=10) 

2314 

2315 Use ``timeout`` to tell it either how many seconds to wait for a connection 

2316 to become available, or to block forever: 

2317 

2318 >>> # Block forever. 

2319 >>> pool = BlockingConnectionPool(timeout=None) 

2320 

2321 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

2322 >>> # not available. 

2323 >>> pool = BlockingConnectionPool(timeout=5) 

2324 """ 

2325 

2326 def __init__( 

2327 self, 

2328 max_connections=50, 

2329 timeout=20, 

2330 connection_class=Connection, 

2331 queue_class=LifoQueue, 

2332 **connection_kwargs, 

2333 ): 

2334 self.queue_class = queue_class 

2335 self.timeout = timeout 

2336 self._in_maintenance = False 

2337 self._locked = False 

2338 super().__init__( 

2339 connection_class=connection_class, 

2340 max_connections=max_connections, 

2341 **connection_kwargs, 

2342 ) 

2343 

2344 def reset(self): 

2345 # Create and fill up a thread safe queue with ``None`` values. 

2346 try: 

2347 if self._in_maintenance: 

2348 self._lock.acquire() 

2349 self._locked = True 

2350 self.pool = self.queue_class(self.max_connections) 

2351 while True: 

2352 try: 

2353 self.pool.put_nowait(None) 

2354 except Full: 

2355 break 

2356 

2357 # Keep a list of actual connection instances so that we can 

2358 # disconnect them later. 

2359 self._connections = [] 

2360 finally: 

2361 if self._locked: 

2362 try: 

2363 self._lock.release() 

2364 except Exception: 

2365 pass 

2366 self._locked = False 

2367 

2368 # this must be the last operation in this method. while reset() is 

2369 # called when holding _fork_lock, other threads in this process 

2370 # can call _checkpid() which compares self.pid and os.getpid() without 

2371 # holding any lock (for performance reasons). keeping this assignment 

2372 # as the last operation ensures that those other threads will also 

2373 # notice a pid difference and block waiting for the first thread to 

2374 # release _fork_lock. when each of these threads eventually acquire 

2375 # _fork_lock, they will notice that another thread already called 

2376 # reset() and they will immediately release _fork_lock and continue on. 

2377 self.pid = os.getpid() 

2378 

2379 def make_connection(self): 

2380 "Make a fresh connection." 

2381 try: 

2382 if self._in_maintenance: 

2383 self._lock.acquire() 

2384 self._locked = True 

2385 

2386 if self.cache is not None: 

2387 connection = CacheProxyConnection( 

2388 self.connection_class(**self.connection_kwargs), 

2389 self.cache, 

2390 self._lock, 

2391 ) 

2392 else: 

2393 connection = self.connection_class(**self.connection_kwargs) 

2394 self._connections.append(connection) 

2395 return connection 

2396 finally: 

2397 if self._locked: 

2398 try: 

2399 self._lock.release() 

2400 except Exception: 

2401 pass 

2402 self._locked = False 

2403 

2404 @deprecated_args( 

2405 args_to_warn=["*"], 

2406 reason="Use get_connection() without args instead", 

2407 version="5.3.0", 

2408 ) 

2409 def get_connection(self, command_name=None, *keys, **options): 

2410 """ 

2411 Get a connection, blocking for ``self.timeout`` until a connection 

2412 is available from the pool. 

2413 

2414 If the connection returned is ``None`` then creates a new connection. 

2415 Because we use a last-in first-out queue, the existing connections 

2416 (having been returned to the pool after the initial ``None`` values 

2417 were added) will be returned before ``None`` values. This means we only 

2418 create new connections when we need to, i.e.: the actual number of 

2419 connections will only increase in response to demand. 

2420 """ 

2421 # Make sure we haven't changed process. 

2422 self._checkpid() 

2423 

2424 # Try and get a connection from the pool. If one isn't available within 

2425 # self.timeout then raise a ``ConnectionError``. 

2426 connection = None 

2427 try: 

2428 if self._in_maintenance: 

2429 self._lock.acquire() 

2430 self._locked = True 

2431 try: 

2432 connection = self.pool.get(block=True, timeout=self.timeout) 

2433 except Empty: 

2434 # Note that this is not caught by the redis client and will be 

2435 # raised unless handled by application code. If you want never to 

2436 raise ConnectionError("No connection available.") 

2437 

2438 # If the ``connection`` is actually ``None`` then that's a cue to make 

2439 # a new connection to add to the pool. 

2440 if connection is None: 

2441 connection = self.make_connection() 

2442 finally: 

2443 if self._locked: 

2444 try: 

2445 self._lock.release() 

2446 except Exception: 

2447 pass 

2448 self._locked = False 

2449 

2450 try: 

2451 # ensure this connection is connected to Redis 

2452 connection.connect() 

2453 # connections that the pool provides should be ready to send 

2454 # a command. if not, the connection was either returned to the 

2455 # pool before all data has been read or the socket has been 

2456 # closed. either way, reconnect and verify everything is good. 

2457 try: 

2458 if connection.can_read(): 

2459 raise ConnectionError("Connection has data") 

2460 except (ConnectionError, TimeoutError, OSError): 

2461 connection.disconnect() 

2462 connection.connect() 

2463 if connection.can_read(): 

2464 raise ConnectionError("Connection not ready") 

2465 except BaseException: 

2466 # release the connection back to the pool so that we don't leak it 

2467 self.release(connection) 

2468 raise 

2469 

2470 return connection 

2471 

2472 def release(self, connection): 

2473 "Releases the connection back to the pool." 

2474 # Make sure we haven't changed process. 

2475 self._checkpid() 

2476 

2477 try: 

2478 if self._in_maintenance: 

2479 self._lock.acquire() 

2480 self._locked = True 

2481 if not self.owns_connection(connection): 

2482 # pool doesn't own this connection. do not add it back 

2483 # to the pool. instead add a None value which is a placeholder 

2484 # that will cause the pool to recreate the connection if 

2485 # its needed. 

2486 connection.disconnect() 

2487 self.pool.put_nowait(None) 

2488 return 

2489 if connection.should_reconnect(): 

2490 connection.disconnect() 

2491 # Put the connection back into the pool. 

2492 try: 

2493 self.pool.put_nowait(connection) 

2494 except Full: 

2495 # perhaps the pool has been reset() after a fork? regardless, 

2496 # we don't want this connection 

2497 pass 

2498 finally: 

2499 if self._locked: 

2500 try: 

2501 self._lock.release() 

2502 except Exception: 

2503 pass 

2504 self._locked = False 

2505 

2506 def disconnect(self): 

2507 "Disconnects all connections in the pool." 

2508 self._checkpid() 

2509 try: 

2510 if self._in_maintenance: 

2511 self._lock.acquire() 

2512 self._locked = True 

2513 for connection in self._connections: 

2514 connection.disconnect() 

2515 finally: 

2516 if self._locked: 

2517 try: 

2518 self._lock.release() 

2519 except Exception: 

2520 pass 

2521 self._locked = False 

2522 

2523 def update_connections_settings( 

2524 self, 

2525 state: Optional["MaintenanceState"] = None, 

2526 maintenance_notification_hash: Optional[int] = None, 

2527 relaxed_timeout: Optional[float] = None, 

2528 host_address: Optional[str] = None, 

2529 matching_address: Optional[str] = None, 

2530 matching_notification_hash: Optional[int] = None, 

2531 matching_pattern: Literal[ 

2532 "connected_address", "configured_address", "notification_hash" 

2533 ] = "connected_address", 

2534 update_notification_hash: bool = False, 

2535 reset_host_address: bool = False, 

2536 reset_relaxed_timeout: bool = False, 

2537 include_free_connections: bool = True, 

2538 ): 

2539 """ 

2540 Override base class method to work with BlockingConnectionPool's structure. 

2541 """ 

2542 with self._lock: 

2543 if include_free_connections: 

2544 for conn in tuple(self._connections): 

2545 if self._should_update_connection( 

2546 conn, 

2547 matching_pattern, 

2548 matching_address, 

2549 matching_notification_hash, 

2550 ): 

2551 self.update_connection_settings( 

2552 conn, 

2553 state=state, 

2554 maintenance_notification_hash=maintenance_notification_hash, 

2555 host_address=host_address, 

2556 relaxed_timeout=relaxed_timeout, 

2557 update_notification_hash=update_notification_hash, 

2558 reset_host_address=reset_host_address, 

2559 reset_relaxed_timeout=reset_relaxed_timeout, 

2560 ) 

2561 else: 

2562 connections_in_queue = {conn for conn in self.pool.queue if conn} 

2563 for conn in self._connections: 

2564 if conn not in connections_in_queue: 

2565 if self._should_update_connection( 

2566 conn, 

2567 matching_pattern, 

2568 matching_address, 

2569 matching_notification_hash, 

2570 ): 

2571 self.update_connection_settings( 

2572 conn, 

2573 state=state, 

2574 maintenance_notification_hash=maintenance_notification_hash, 

2575 host_address=host_address, 

2576 relaxed_timeout=relaxed_timeout, 

2577 update_notification_hash=update_notification_hash, 

2578 reset_host_address=reset_host_address, 

2579 reset_relaxed_timeout=reset_relaxed_timeout, 

2580 ) 

2581 

2582 def update_active_connections_for_reconnect( 

2583 self, 

2584 moving_address_src: Optional[str] = None, 

2585 ): 

2586 """ 

2587 Mark all active connections for reconnect. 

2588 This is used when a cluster node is migrated to a different address. 

2589 

2590 :param moving_address_src: The address of the node that is being moved. 

2591 """ 

2592 with self._lock: 

2593 connections_in_queue = {conn for conn in self.pool.queue if conn} 

2594 for conn in self._connections: 

2595 if conn not in connections_in_queue: 

2596 if self._should_update_connection( 

2597 conn, 

2598 matching_pattern="connected_address", 

2599 matching_address=moving_address_src, 

2600 ): 

2601 conn.mark_for_reconnect() 

2602 

2603 def disconnect_free_connections( 

2604 self, 

2605 moving_address_src: Optional[str] = None, 

2606 ): 

2607 """ 

2608 Disconnect all free/available connections. 

2609 This is used when a cluster node is migrated to a different address. 

2610 

2611 :param moving_address_src: The address of the node that is being moved. 

2612 """ 

2613 with self._lock: 

2614 existing_connections = self.pool.queue 

2615 

2616 for conn in existing_connections: 

2617 if conn: 

2618 if self._should_update_connection( 

2619 conn, "connected_address", moving_address_src 

2620 ): 

2621 conn.disconnect() 

2622 

2623 def _update_maint_notifications_config_for_connections( 

2624 self, maint_notifications_config 

2625 ): 

2626 for conn in tuple(self._connections): 

2627 conn.maint_notifications_config = maint_notifications_config 

2628 

2629 def _update_maint_notifications_configs_for_connections( 

2630 self, maint_notifications_pool_handler 

2631 ): 

2632 """Update the maintenance notifications config for all connections in the pool.""" 

2633 with self._lock: 

2634 for conn in tuple(self._connections): 

2635 conn.set_maint_notifications_pool_handler( 

2636 maint_notifications_pool_handler 

2637 ) 

2638 conn.maint_notifications_config = ( 

2639 maint_notifications_pool_handler.config 

2640 ) 

2641 

2642 def set_in_maintenance(self, in_maintenance: bool): 

2643 """ 

2644 Sets a flag that this Blocking ConnectionPool is in maintenance mode. 

2645 

2646 This is used to prevent new connections from being created while we are in maintenance mode. 

2647 The pool will be in maintenance mode only when we are processing a MOVING notification. 

2648 """ 

2649 self._in_maintenance = in_maintenance