Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1637 statements  

1import copy 

2import os 

3import socket 

4import sys 

5import threading 

6import time 

7import weakref 

8from abc import ABC, abstractmethod 

9from itertools import chain 

10from queue import Empty, Full, LifoQueue 

11from typing import ( 

12 Any, 

13 Callable, 

14 Dict, 

15 Iterable, 

16 List, 

17 Literal, 

18 Optional, 

19 Type, 

20 TypeVar, 

21 Union, 

22) 

23from urllib.parse import parse_qs, unquote, urlparse 

24 

25from redis.cache import ( 

26 CacheEntry, 

27 CacheEntryStatus, 

28 CacheFactory, 

29 CacheFactoryInterface, 

30 CacheInterface, 

31 CacheKey, 

32 CacheProxy, 

33) 

34 

35from ._defaults import ( 

36 DEFAULT_SOCKET_CONNECT_TIMEOUT, 

37 DEFAULT_SOCKET_READ_SIZE, 

38 DEFAULT_SOCKET_TIMEOUT, 

39 get_default_socket_keepalive_options, 

40) 

41from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser 

42from .auth.token import TokenInterface 

43from .backoff import NoBackoff 

44from .credentials import CredentialProvider, UsernamePasswordCredentialProvider 

45from .driver_info import DriverInfo, resolve_driver_info 

46from .event import AfterConnectionReleasedEvent, EventDispatcher 

47from .exceptions import ( 

48 AuthenticationError, 

49 AuthenticationWrongNumberOfArgsError, 

50 ChildDeadlockedError, 

51 ConnectionError, 

52 DataError, 

53 MaxConnectionsError, 

54 RedisError, 

55 ResponseError, 

56 TimeoutError, 

57) 

58from .maint_notifications import ( 

59 MaintenanceState, 

60 MaintNotificationsConfig, 

61 MaintNotificationsConnectionHandler, 

62 MaintNotificationsPoolHandler, 

63 OSSMaintNotificationsHandler, 

64) 

65from .observability.attributes import ( 

66 DB_CLIENT_CONNECTION_POOL_NAME, 

67 DB_CLIENT_CONNECTION_STATE, 

68 AttributeBuilder, 

69 ConnectionState, 

70 CSCReason, 

71 CSCResult, 

72 get_pool_name, 

73) 

74from .observability.metrics import CloseReason 

75from .observability.recorder import ( 

76 init_csc_items, 

77 record_connection_closed, 

78 record_connection_count, 

79 record_connection_create_time, 

80 record_connection_wait_time, 

81 record_csc_eviction, 

82 record_csc_network_saved, 

83 record_csc_request, 

84 record_error_count, 

85 register_csc_items_callback, 

86) 

87from .retry import Retry 

88from .utils import ( 

89 CRYPTOGRAPHY_AVAILABLE, 

90 DEFAULT_RESP_VERSION, 

91 HIREDIS_AVAILABLE, 

92 SENTINEL, 

93 SSL_AVAILABLE, 

94 check_protocol_version, 

95 compare_versions, 

96 deprecated_args, 

97 ensure_string, 

98 format_error_message, 

99 str_if_bytes, 

100) 

101 

102if SSL_AVAILABLE: 

103 import ssl 

104 from ssl import VerifyFlags 

105else: 

106 ssl = None 

107 VerifyFlags = None 

108 

109if HIREDIS_AVAILABLE: 

110 import hiredis 

111 

112SYM_STAR = b"*" 

113SYM_DOLLAR = b"$" 

114SYM_CRLF = b"\r\n" 

115SYM_EMPTY = b"" 

116 

117DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] 

118if HIREDIS_AVAILABLE: 

119 DefaultParser = _HiredisParser 

120else: 

121 DefaultParser = _RESP2Parser 

122 

123 

124class HiredisRespSerializer: 

125 def pack(self, *args: List): 

126 """Pack a series of arguments into the Redis protocol""" 

127 output = [] 

128 

129 if isinstance(args[0], str): 

130 args = tuple(args[0].encode().split()) + args[1:] 

131 elif b" " in args[0]: 

132 args = tuple(args[0].split()) + args[1:] 

133 args = tuple( 

134 bytes(arg) if isinstance(arg, (bytearray, memoryview)) else arg 

135 for arg in args 

136 ) 

137 try: 

138 output.append(hiredis.pack_command(args)) 

139 except TypeError: 

140 _, value, traceback = sys.exc_info() 

141 raise DataError(value).with_traceback(traceback) 

142 

143 return output 

144 

145 

146class PythonRespSerializer: 

147 def __init__(self, buffer_cutoff, encode) -> None: 

148 self._buffer_cutoff = buffer_cutoff 

149 self.encode = encode 

150 

151 def pack(self, *args): 

152 """Pack a series of arguments into the Redis protocol""" 

153 output = [] 

154 # the client might have included 1 or more literal arguments in 

155 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

156 # arguments to be sent separately, so split the first argument 

157 # manually. These arguments should be bytestrings so that they are 

158 # not encoded. 

159 if isinstance(args[0], str): 

160 args = tuple(args[0].encode().split()) + args[1:] 

161 elif b" " in args[0]: 

162 args = tuple(args[0].split()) + args[1:] 

163 

164 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

165 

166 buffer_cutoff = self._buffer_cutoff 

167 for arg in map(self.encode, args): 

168 # to avoid large string mallocs, chunk the command into the 

169 # output list if we're sending large values or memoryviews 

170 arg_length = len(arg) 

171 if ( 

172 len(buff) > buffer_cutoff 

173 or arg_length > buffer_cutoff 

174 or isinstance(arg, memoryview) 

175 ): 

176 buff = SYM_EMPTY.join( 

177 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

178 ) 

179 output.append(buff) 

180 output.append(arg) 

181 buff = SYM_CRLF 

182 else: 

183 buff = SYM_EMPTY.join( 

184 ( 

185 buff, 

186 SYM_DOLLAR, 

187 str(arg_length).encode(), 

188 SYM_CRLF, 

189 arg, 

190 SYM_CRLF, 

191 ) 

192 ) 

193 output.append(buff) 

194 return output 

195 

196 

197class ConnectionInterface: 

198 @abstractmethod 

199 def repr_pieces(self): 

200 pass 

201 

202 @abstractmethod 

203 def register_connect_callback(self, callback): 

204 pass 

205 

206 @abstractmethod 

207 def deregister_connect_callback(self, callback): 

208 pass 

209 

210 @abstractmethod 

211 def set_parser(self, parser_class): 

212 pass 

213 

214 @abstractmethod 

215 def get_protocol(self): 

216 pass 

217 

218 @abstractmethod 

219 def connect(self): 

220 pass 

221 

222 @abstractmethod 

223 def on_connect(self): 

224 pass 

225 

226 @abstractmethod 

227 def disconnect(self, *args, **kwargs): 

228 pass 

229 

230 @abstractmethod 

231 def check_health(self): 

232 pass 

233 

234 @abstractmethod 

235 def send_packed_command(self, command, check_health=True): 

236 pass 

237 

238 @abstractmethod 

239 def send_command(self, *args, **kwargs): 

240 pass 

241 

242 @abstractmethod 

243 def can_read(self, timeout: float = 0) -> bool: 

244 # TODO: Rename this API; it detects pending data or dirty/closed 

245 # connection state, not only whether application data can be read. 

246 pass 

247 

248 @abstractmethod 

249 def read_response( 

250 self, 

251 disable_decoding=False, 

252 *, 

253 timeout: Union[float, object] = SENTINEL, 

254 disconnect_on_error=True, 

255 push_request=False, 

256 ): 

257 pass 

258 

259 @abstractmethod 

260 def pack_command(self, *args): 

261 pass 

262 

263 @abstractmethod 

264 def pack_commands(self, commands): 

265 pass 

266 

267 @property 

268 @abstractmethod 

269 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

270 pass 

271 

272 @abstractmethod 

273 def set_re_auth_token(self, token: TokenInterface): 

274 pass 

275 

276 @abstractmethod 

277 def re_auth(self): 

278 pass 

279 

280 @abstractmethod 

281 def mark_for_reconnect(self): 

282 """ 

283 Mark the connection to be reconnected on the next command. 

284 This is useful when a connection is moved to a different node. 

285 """ 

286 pass 

287 

288 @abstractmethod 

289 def should_reconnect(self): 

290 """ 

291 Returns True if the connection should be reconnected. 

292 """ 

293 pass 

294 

295 @abstractmethod 

296 def reset_should_reconnect(self): 

297 """ 

298 Reset the internal flag to False. 

299 """ 

300 pass 

301 

302 @abstractmethod 

303 def extract_connection_details(self) -> str: 

304 pass 

305 

306 @property 

307 @abstractmethod 

308 def is_connected(self) -> bool: 

309 """ 

310 Return ``True`` if the connection to the server is active. 

311 """ 

312 pass 

313 

314 

315class MaintNotificationsAbstractConnection: 

316 """ 

317 Abstract class for handling maintenance notifications logic. 

318 This class is expected to be used as base class together with ConnectionInterface. 

319 

320 This class is intended to be used with multiple inheritance! 

321 

322 All logic related to maintenance notifications is encapsulated in this class. 

323 """ 

324 

325 def __init__( 

326 self, 

327 maint_notifications_config: Optional[MaintNotificationsConfig], 

328 maint_notifications_pool_handler: Optional[ 

329 MaintNotificationsPoolHandler 

330 ] = None, 

331 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

332 maintenance_notification_hash: Optional[int] = None, 

333 orig_host_address: Optional[str] = None, 

334 orig_socket_timeout: Optional[float] = None, 

335 orig_socket_connect_timeout: Optional[float] = None, 

336 oss_cluster_maint_notifications_handler: Optional[ 

337 OSSMaintNotificationsHandler 

338 ] = None, 

339 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, 

340 event_dispatcher: Optional[EventDispatcher] = None, 

341 ): 

342 """ 

343 Initialize the maintenance notifications for the connection. 

344 

345 Args: 

346 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications. 

347 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications. 

348 maintenance_state (MaintenanceState): The current maintenance state of the connection. 

349 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection. 

350 orig_host_address (Optional[str]): The original host address of the connection. 

351 orig_socket_timeout (Optional[float]): The original socket timeout of the connection. 

352 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection. 

353 oss_cluster_maint_notifications_handler (Optional[OSSMaintNotificationsHandler]): The OSS cluster handler for maintenance notifications. 

354 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications. 

355 If not provided, the parser from the connection is used. 

356 This is useful when the parser is created after this object. 

357 """ 

358 self.maint_notifications_config = maint_notifications_config 

359 self.maintenance_state = maintenance_state 

360 self.maintenance_notification_hash = maintenance_notification_hash 

361 

362 if event_dispatcher is not None: 

363 self.event_dispatcher = event_dispatcher 

364 else: 

365 self.event_dispatcher = EventDispatcher() 

366 

367 self._configure_maintenance_notifications( 

368 maint_notifications_pool_handler, 

369 orig_host_address, 

370 orig_socket_timeout, 

371 orig_socket_connect_timeout, 

372 oss_cluster_maint_notifications_handler, 

373 parser, 

374 ) 

375 self._processed_start_maint_notifications = set() 

376 self._skipped_end_maint_notifications = set() 

377 

378 @abstractmethod 

379 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]: 

380 pass 

381 

382 @abstractmethod 

383 def _get_socket(self) -> Optional[socket.socket]: 

384 pass 

385 

386 @abstractmethod 

387 def get_protocol(self) -> Union[int, str]: 

388 """ 

389 Returns: 

390 The RESP protocol version, or ``None`` if the protocol is not specified, 

391 in which case the server default will be used. 

392 """ 

393 pass 

394 

395 @property 

396 @abstractmethod 

397 def host(self) -> str: 

398 pass 

399 

400 @host.setter 

401 @abstractmethod 

402 def host(self, value: str): 

403 pass 

404 

405 @property 

406 @abstractmethod 

407 def socket_timeout(self) -> Optional[Union[float, int]]: 

408 pass 

409 

410 @socket_timeout.setter 

411 @abstractmethod 

412 def socket_timeout(self, value: Optional[Union[float, int]]): 

413 pass 

414 

415 @property 

416 @abstractmethod 

417 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

418 pass 

419 

420 @socket_connect_timeout.setter 

421 @abstractmethod 

422 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

423 pass 

424 

425 @abstractmethod 

426 def send_command(self, *args, **kwargs): 

427 pass 

428 

429 @abstractmethod 

430 def read_response( 

431 self, 

432 disable_decoding=False, 

433 *, 

434 timeout: Union[float, object] = SENTINEL, 

435 disconnect_on_error=True, 

436 push_request=False, 

437 ): 

438 pass 

439 

440 @abstractmethod 

441 def disconnect(self, *args, **kwargs): 

442 pass 

443 

444 def _configure_maintenance_notifications( 

445 self, 

446 maint_notifications_pool_handler: Optional[ 

447 MaintNotificationsPoolHandler 

448 ] = None, 

449 orig_host_address=None, 

450 orig_socket_timeout=None, 

451 orig_socket_connect_timeout=None, 

452 oss_cluster_maint_notifications_handler: Optional[ 

453 OSSMaintNotificationsHandler 

454 ] = None, 

455 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, 

456 ): 

457 """ 

458 Enable maintenance notifications by setting up 

459 handlers and storing original connection parameters. 

460 

461 Should be used ONLY with parsers that support push notifications. 

462 """ 

463 if ( 

464 not self.maint_notifications_config 

465 or not self.maint_notifications_config.enabled 

466 ): 

467 self._maint_notifications_pool_handler = None 

468 self._maint_notifications_connection_handler = None 

469 self._oss_cluster_maint_notifications_handler = None 

470 return 

471 

472 if not parser: 

473 raise RedisError( 

474 "To configure maintenance notifications, a parser must be provided!" 

475 ) 

476 

477 if not isinstance(parser, _HiredisParser) and not isinstance( 

478 parser, _RESP3Parser 

479 ): 

480 raise RedisError( 

481 "Maintenance notifications are only supported with hiredis and RESP3 parsers!" 

482 ) 

483 

484 if maint_notifications_pool_handler: 

485 # Extract a reference to a new pool handler that copies all properties 

486 # of the original one and has a different connection reference 

487 # This is needed because when we attach the handler to the parser 

488 # we need to make sure that the handler has a reference to the 

489 # connection that the parser is attached to. 

490 self._maint_notifications_pool_handler = ( 

491 maint_notifications_pool_handler.get_handler_for_connection() 

492 ) 

493 self._maint_notifications_pool_handler.set_connection(self) 

494 else: 

495 self._maint_notifications_pool_handler = None 

496 

497 self._maint_notifications_connection_handler = ( 

498 MaintNotificationsConnectionHandler(self, self.maint_notifications_config) 

499 ) 

500 

501 if oss_cluster_maint_notifications_handler: 

502 self._oss_cluster_maint_notifications_handler = ( 

503 oss_cluster_maint_notifications_handler 

504 ) 

505 else: 

506 self._oss_cluster_maint_notifications_handler = None 

507 

508 # Set up OSS cluster handler to parser if available 

509 if self._oss_cluster_maint_notifications_handler: 

510 parser.set_oss_cluster_maint_push_handler( 

511 self._oss_cluster_maint_notifications_handler.handle_notification 

512 ) 

513 

514 # Set up pool handler to parser if available 

515 if self._maint_notifications_pool_handler: 

516 parser.set_node_moving_push_handler( 

517 self._maint_notifications_pool_handler.handle_notification 

518 ) 

519 

520 # Set up connection handler 

521 parser.set_maintenance_push_handler( 

522 self._maint_notifications_connection_handler.handle_notification 

523 ) 

524 

525 # Store original connection parameters 

526 self.orig_host_address = orig_host_address if orig_host_address else self.host 

527 self.orig_socket_timeout = ( 

528 orig_socket_timeout if orig_socket_timeout else self.socket_timeout 

529 ) 

530 self.orig_socket_connect_timeout = ( 

531 orig_socket_connect_timeout 

532 if orig_socket_connect_timeout 

533 else self.socket_connect_timeout 

534 ) 

535 

536 def set_maint_notifications_pool_handler_for_connection( 

537 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

538 ): 

539 # Deep copy the pool handler to avoid sharing the same pool handler 

540 # between multiple connections, because otherwise each connection will override 

541 # the connection reference and the pool handler will only hold a reference 

542 # to the last connection that was set. 

543 maint_notifications_pool_handler_copy = ( 

544 maint_notifications_pool_handler.get_handler_for_connection() 

545 ) 

546 

547 maint_notifications_pool_handler_copy.set_connection(self) 

548 self._get_parser().set_node_moving_push_handler( 

549 maint_notifications_pool_handler_copy.handle_notification 

550 ) 

551 

552 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy 

553 

554 # Update maintenance notification connection handler if it doesn't exist 

555 if not self._maint_notifications_connection_handler: 

556 self._maint_notifications_connection_handler = ( 

557 MaintNotificationsConnectionHandler( 

558 self, maint_notifications_pool_handler.config 

559 ) 

560 ) 

561 self._get_parser().set_maintenance_push_handler( 

562 self._maint_notifications_connection_handler.handle_notification 

563 ) 

564 else: 

565 self._maint_notifications_connection_handler.config = ( 

566 maint_notifications_pool_handler.config 

567 ) 

568 

569 def set_maint_notifications_cluster_handler_for_connection( 

570 self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler 

571 ): 

572 self._get_parser().set_oss_cluster_maint_push_handler( 

573 oss_cluster_maint_notifications_handler.handle_notification 

574 ) 

575 

576 self._oss_cluster_maint_notifications_handler = ( 

577 oss_cluster_maint_notifications_handler 

578 ) 

579 

580 # Update maintenance notification connection handler if it doesn't exist 

581 if not self._maint_notifications_connection_handler: 

582 self._maint_notifications_connection_handler = ( 

583 MaintNotificationsConnectionHandler( 

584 self, oss_cluster_maint_notifications_handler.config 

585 ) 

586 ) 

587 self._get_parser().set_maintenance_push_handler( 

588 self._maint_notifications_connection_handler.handle_notification 

589 ) 

590 else: 

591 self._maint_notifications_connection_handler.config = ( 

592 oss_cluster_maint_notifications_handler.config 

593 ) 

594 

595 def activate_maint_notifications_handling_if_enabled(self, check_health=True): 

596 # Send maintenance notifications handshake if RESP3 is active 

597 # and maintenance notifications are enabled 

598 # and we have a host to determine the endpoint type from 

599 # When the maint_notifications_config enabled mode is "auto", 

600 # we just log a warning if the handshake fails 

601 # When the mode is enabled=True, we raise an exception in case of failure 

602 host = getattr(self, "host", None) 

603 if ( 

604 self.get_protocol() not in [2, "2"] 

605 and self.maint_notifications_config 

606 and self.maint_notifications_config.enabled 

607 and self._maint_notifications_connection_handler 

608 and host is not None 

609 ): 

610 self._enable_maintenance_notifications( 

611 maint_notifications_config=self.maint_notifications_config, 

612 check_health=check_health, 

613 ) 

614 

615 def _enable_maintenance_notifications( 

616 self, maint_notifications_config: MaintNotificationsConfig, check_health=True 

617 ): 

618 try: 

619 host = getattr(self, "host", None) 

620 if host is None: 

621 raise ValueError( 

622 "Cannot enable maintenance notifications for connection" 

623 " object that doesn't have a host attribute." 

624 ) 

625 else: 

626 endpoint_type = maint_notifications_config.get_endpoint_type(host, self) 

627 self.send_command( 

628 "CLIENT", 

629 "MAINT_NOTIFICATIONS", 

630 "ON", 

631 "moving-endpoint-type", 

632 endpoint_type.value, 

633 check_health=check_health, 

634 ) 

635 response = self.read_response() 

636 if not response or str_if_bytes(response) != "OK": 

637 raise ResponseError( 

638 "The server doesn't support maintenance notifications" 

639 ) 

640 except Exception as e: 

641 if ( 

642 isinstance(e, ResponseError) 

643 and maint_notifications_config.enabled == "auto" 

644 ): 

645 # Log warning but don't fail the connection 

646 import logging 

647 

648 logger = logging.getLogger(__name__) 

649 logger.debug(f"Failed to enable maintenance notifications: {e}") 

650 else: 

651 raise 

652 

653 def get_resolved_ip(self) -> Optional[str]: 

654 """ 

655 Extract the resolved IP address from an 

656 established connection or resolve it from the host. 

657 

658 First tries to get the actual IP from the socket (most accurate), 

659 then falls back to DNS resolution if needed. 

660 

661 Args: 

662 connection: The connection object to extract the IP from 

663 

664 Returns: 

665 str: The resolved IP address, or None if it cannot be determined 

666 """ 

667 

668 # Method 1: Try to get the actual IP from the established socket connection 

669 # This is most accurate as it shows the exact IP being used 

670 try: 

671 conn_socket = self._get_socket() 

672 if conn_socket is not None: 

673 peer_addr = conn_socket.getpeername() 

674 if peer_addr and len(peer_addr) >= 1: 

675 # For TCP sockets, peer_addr is typically (host, port) tuple 

676 # Return just the host part 

677 return peer_addr[0] 

678 except (AttributeError, OSError): 

679 # Socket might not be connected or getpeername() might fail 

680 pass 

681 

682 # Method 2: Fallback to DNS resolution of the host 

683 # This is less accurate but works when socket is not available 

684 try: 

685 host = getattr(self, "host", "localhost") 

686 port = getattr(self, "port", 6379) 

687 if host: 

688 # Use getaddrinfo to resolve the hostname to IP 

689 # This mimics what the connection would do during _connect() 

690 addr_info = socket.getaddrinfo( 

691 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM 

692 ) 

693 if addr_info: 

694 # Return the IP from the first result 

695 # addr_info[0] is (family, socktype, proto, canonname, sockaddr) 

696 # sockaddr[0] is the IP address 

697 return str(addr_info[0][4][0]) 

698 except (AttributeError, OSError, socket.gaierror): 

699 # DNS resolution might fail 

700 pass 

701 

702 return None 

703 

704 @property 

705 def maintenance_state(self) -> MaintenanceState: 

706 return self._maintenance_state 

707 

708 @maintenance_state.setter 

709 def maintenance_state(self, state: "MaintenanceState"): 

710 self._maintenance_state = state 

711 

712 def add_maint_start_notification(self, id: int): 

713 self._processed_start_maint_notifications.add(id) 

714 

715 def get_processed_start_notifications(self) -> set: 

716 return self._processed_start_maint_notifications 

717 

718 def add_skipped_end_notification(self, id: int): 

719 self._skipped_end_maint_notifications.add(id) 

720 

721 def get_skipped_end_notifications(self) -> set: 

722 return self._skipped_end_maint_notifications 

723 

724 def reset_received_notifications(self): 

725 self._processed_start_maint_notifications.clear() 

726 self._skipped_end_maint_notifications.clear() 

727 

728 def getpeername(self): 

729 """ 

730 Returns the peer name of the connection. 

731 """ 

732 conn_socket = self._get_socket() 

733 if conn_socket: 

734 return conn_socket.getpeername()[0] 

735 return None 

736 

737 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

738 conn_socket = self._get_socket() 

739 if conn_socket: 

740 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout 

741 # if the current timeout is 0 it means we are in the middle of a can_read call 

742 # in this case we don't want to change the timeout because the operation 

743 # is non-blocking and should return immediately 

744 # Changing the state from non-blocking to blocking in the middle of a read operation 

745 # will lead to a deadlock 

746 if conn_socket.gettimeout() != 0: 

747 conn_socket.settimeout(timeout) 

748 self.update_parser_timeout(timeout) 

749 

750 def update_parser_timeout(self, timeout: Optional[float] = None): 

751 parser = self._get_parser() 

752 if parser and parser._buffer: 

753 if isinstance(parser, _RESP3Parser) and timeout: 

754 parser._buffer.socket_timeout = timeout 

755 elif isinstance(parser, _HiredisParser): 

756 parser._socket_timeout = timeout 

757 

758 def set_tmp_settings( 

759 self, 

760 tmp_host_address: Optional[Union[str, object]] = SENTINEL, 

761 tmp_relaxed_timeout: Optional[float] = None, 

762 ): 

763 """ 

764 The value of SENTINEL is used to indicate that the property should not be updated. 

765 """ 

766 if tmp_host_address and tmp_host_address != SENTINEL: 

767 self.host = str(tmp_host_address) 

768 if tmp_relaxed_timeout != -1: 

769 self.socket_timeout = tmp_relaxed_timeout 

770 self.socket_connect_timeout = tmp_relaxed_timeout 

771 

772 def reset_tmp_settings( 

773 self, 

774 reset_host_address: bool = False, 

775 reset_relaxed_timeout: bool = False, 

776 ): 

777 if reset_host_address: 

778 self.host = self.orig_host_address 

779 if reset_relaxed_timeout: 

780 self.socket_timeout = self.orig_socket_timeout 

781 self.socket_connect_timeout = self.orig_socket_connect_timeout 

782 

783 

784class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface): 

785 "Manages communication to and from a Redis server" 

786 

787 @deprecated_args( 

788 args_to_warn=["lib_name", "lib_version"], 

789 reason="Use 'driver_info' parameter instead. " 

790 "lib_name and lib_version will be removed in a future version.", 

791 ) 

792 def __init__( 

793 self, 

794 db: int = 0, 

795 password: Optional[str] = None, 

796 socket_timeout: Optional[float] = DEFAULT_SOCKET_TIMEOUT, 

797 socket_connect_timeout: Optional[float] = DEFAULT_SOCKET_CONNECT_TIMEOUT, 

798 retry_on_timeout: bool = False, 

799 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL, 

800 encoding: str = "utf-8", 

801 encoding_errors: str = "strict", 

802 decode_responses: bool = False, 

803 parser_class=DefaultParser, 

804 socket_read_size: int = DEFAULT_SOCKET_READ_SIZE, 

805 health_check_interval: int = 0, 

806 client_name: Optional[str] = None, 

807 lib_name: Union[Optional[str], object] = SENTINEL, 

808 lib_version: Union[Optional[str], object] = SENTINEL, 

809 driver_info: Union[Optional[DriverInfo], object] = SENTINEL, 

810 username: Optional[str] = None, 

811 retry: Union[Any, None] = None, 

812 redis_connect_func: Optional[Callable[[], None]] = None, 

813 credential_provider: Optional[CredentialProvider] = None, 

814 protocol: Optional[int] = None, 

815 legacy_responses: bool = True, 

816 command_packer: Optional[Callable[[], None]] = None, 

817 event_dispatcher: Optional[EventDispatcher] = None, 

818 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

819 maint_notifications_pool_handler: Optional[ 

820 MaintNotificationsPoolHandler 

821 ] = None, 

822 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

823 maintenance_notification_hash: Optional[int] = None, 

824 orig_host_address: Optional[str] = None, 

825 orig_socket_timeout: Optional[float] = None, 

826 orig_socket_connect_timeout: Optional[float] = None, 

827 oss_cluster_maint_notifications_handler: Optional[ 

828 OSSMaintNotificationsHandler 

829 ] = None, 

830 ): 

831 """ 

832 Initialize a new Connection. 

833 

834 To specify a retry policy for specific errors, first set 

835 `retry_on_error` to a list of the error/s to retry on, then set 

836 `retry` to a valid `Retry` object. 

837 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

838 

839 Parameters 

840 ---------- 

841 driver_info : DriverInfo, optional 

842 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version 

843 are ignored. If not provided, a DriverInfo will be created from lib_name 

844 and lib_version. Explicit None disables CLIENT SETINFO. 

845 lib_name : str, optional 

846 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO. 

847 lib_version : str, optional 

848 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO. 

849 """ 

850 if (username or password) and credential_provider is not None: 

851 raise DataError( 

852 "'username' and 'password' cannot be passed along with 'credential_" 

853 "provider'. Please provide only one of the following arguments: \n" 

854 "1. 'password' and (optional) 'username'\n" 

855 "2. 'credential_provider'" 

856 ) 

857 if event_dispatcher is None: 

858 self._event_dispatcher = EventDispatcher() 

859 else: 

860 self._event_dispatcher = event_dispatcher 

861 self.pid = os.getpid() 

862 self.db = db 

863 self.client_name = client_name 

864 

865 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version. 

866 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version) 

867 

868 self.credential_provider = credential_provider 

869 self.password = password 

870 self.username = username 

871 self._socket_timeout = socket_timeout 

872 if socket_connect_timeout is None: 

873 socket_connect_timeout = socket_timeout 

874 self._socket_connect_timeout = socket_connect_timeout 

875 self.retry_on_timeout = retry_on_timeout 

876 if retry_on_error is SENTINEL: 

877 retry_on_errors_list = [] 

878 else: 

879 retry_on_errors_list = list(retry_on_error) 

880 if retry_on_timeout: 

881 # Add TimeoutError to the errors list to retry on 

882 retry_on_errors_list.append(TimeoutError) 

883 self.retry_on_error = retry_on_errors_list 

884 if retry or self.retry_on_error: 

885 if retry is None: 

886 self.retry = Retry(NoBackoff(), 1) 

887 else: 

888 # deep-copy the Retry object as it is mutable 

889 self.retry = copy.deepcopy(retry) 

890 if self.retry_on_error: 

891 # Update the retry's supported errors with the specified errors 

892 self.retry.update_supported_errors(self.retry_on_error) 

893 else: 

894 self.retry = Retry(NoBackoff(), 0) 

895 self.health_check_interval = health_check_interval 

896 self.next_health_check = 0 

897 self.redis_connect_func = redis_connect_func 

898 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

899 self.handshake_metadata = None 

900 self._sock = None 

901 self._socket_read_size = socket_read_size 

902 self._connect_callbacks = [] 

903 self._buffer_cutoff = 6000 

904 self._re_auth_token: Optional[TokenInterface] = None 

905 try: 

906 p = int(protocol) 

907 except TypeError: 

908 p = DEFAULT_RESP_VERSION 

909 except ValueError: 

910 raise ConnectionError("protocol must be an integer") 

911 else: 

912 if p < 2 or p > 3: 

913 raise ConnectionError("protocol must be either 2 or 3") 

914 self.protocol = p 

915 self.legacy_responses = legacy_responses 

916 if self.protocol == 3 and parser_class == _RESP2Parser: 

917 # If the protocol is 3 but the parser is RESP2, change it to RESP3 

918 # This is needed because the parser might be set before the protocol 

919 # or might be provided as a kwarg to the constructor 

920 # We need to react on discrepancy only for RESP2 and RESP3 

921 # as hiredis supports both 

922 parser_class = _RESP3Parser 

923 self.set_parser(parser_class) 

924 

925 self._command_packer = self._construct_command_packer(command_packer) 

926 self._should_reconnect = False 

927 

928 # Set up maintenance notifications 

929 MaintNotificationsAbstractConnection.__init__( 

930 self, 

931 maint_notifications_config, 

932 maint_notifications_pool_handler, 

933 maintenance_state, 

934 maintenance_notification_hash, 

935 orig_host_address, 

936 orig_socket_timeout, 

937 orig_socket_connect_timeout, 

938 oss_cluster_maint_notifications_handler, 

939 self._parser, 

940 event_dispatcher=self._event_dispatcher, 

941 ) 

942 

943 def __repr__(self): 

944 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) 

945 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

946 

947 @abstractmethod 

948 def repr_pieces(self): 

949 pass 

950 

951 def __del__(self): 

952 try: 

953 self.disconnect() 

954 except Exception: 

955 pass 

956 

957 @property 

958 def is_connected(self) -> bool: 

959 return self._sock is not None 

960 

961 def _construct_command_packer(self, packer): 

962 if packer is not None: 

963 return packer 

964 elif HIREDIS_AVAILABLE: 

965 return HiredisRespSerializer() 

966 else: 

967 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) 

968 

969 def register_connect_callback(self, callback): 

970 """ 

971 Register a callback to be called when the connection is established either 

972 initially or reconnected. This allows listeners to issue commands that 

973 are ephemeral to the connection, for example pub/sub subscription or 

974 key tracking. The callback must be a _method_ and will be kept as 

975 a weak reference. 

976 """ 

977 wm = weakref.WeakMethod(callback) 

978 if wm not in self._connect_callbacks: 

979 self._connect_callbacks.append(wm) 

980 

981 def deregister_connect_callback(self, callback): 

982 """ 

983 De-register a previously registered callback. It will no-longer receive 

984 notifications on connection events. Calling this is not required when the 

985 listener goes away, since the callbacks are kept as weak methods. 

986 """ 

987 try: 

988 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

989 except ValueError: 

990 pass 

991 

992 def set_parser(self, parser_class): 

993 """ 

994 Creates a new instance of parser_class with socket size: 

995 _socket_read_size and assigns it to the parser for the connection 

996 :param parser_class: The required parser class 

997 """ 

998 self._parser = parser_class(socket_read_size=self._socket_read_size) 

999 

1000 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]: 

1001 return self._parser 

1002 

1003 def connect(self): 

1004 "Connects to the Redis server if not already connected" 

1005 # try once the socket connect with the handshake, retry the whole 

1006 # connect/handshake flow based on retry policy 

1007 self.retry.call_with_retry( 

1008 lambda: self.connect_check_health( 

1009 check_health=True, retry_socket_connect=False 

1010 ), 

1011 lambda error: self.disconnect(error), 

1012 ) 

1013 

1014 def connect_check_health( 

1015 self, check_health: bool = True, retry_socket_connect: bool = True 

1016 ): 

1017 if self._sock: 

1018 return 

1019 # Track actual retry attempts for error reporting 

1020 actual_retry_attempts = [0] 

1021 

1022 def failure_callback(error, failure_count): 

1023 actual_retry_attempts[0] = failure_count 

1024 self.disconnect(error=error, failure_count=failure_count) 

1025 

1026 try: 

1027 if retry_socket_connect: 

1028 sock = self.retry.call_with_retry( 

1029 self._connect, 

1030 failure_callback, 

1031 with_failure_count=True, 

1032 ) 

1033 else: 

1034 sock = self._connect() 

1035 except socket.timeout: 

1036 e = TimeoutError("Timeout connecting to server") 

1037 record_error_count( 

1038 server_address=self.host, 

1039 server_port=self.port, 

1040 network_peer_address=self.host, 

1041 network_peer_port=self.port, 

1042 error_type=e, 

1043 retry_attempts=actual_retry_attempts[0], 

1044 ) 

1045 raise e 

1046 except OSError as e: 

1047 e = ConnectionError(self._error_message(e)) 

1048 record_error_count( 

1049 server_address=getattr(self, "host", None), 

1050 server_port=getattr(self, "port", None), 

1051 network_peer_address=getattr(self, "host", None), 

1052 network_peer_port=getattr(self, "port", None), 

1053 error_type=e, 

1054 retry_attempts=actual_retry_attempts[0], 

1055 ) 

1056 raise e 

1057 

1058 self._sock = sock 

1059 try: 

1060 if self.redis_connect_func is None: 

1061 # Use the default on_connect function 

1062 self.on_connect_check_health(check_health=check_health) 

1063 else: 

1064 # Use the passed function redis_connect_func 

1065 self.redis_connect_func(self) 

1066 except RedisError: 

1067 # clean up after any error in on_connect 

1068 self.disconnect() 

1069 raise 

1070 

1071 # run any user callbacks. right now the only internal callback 

1072 # is for pubsub channel/pattern resubscription 

1073 # first, remove any dead weakrefs 

1074 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

1075 for ref in self._connect_callbacks: 

1076 callback = ref() 

1077 if callback: 

1078 callback(self) 

1079 

1080 @abstractmethod 

1081 def _connect(self): 

1082 pass 

1083 

1084 @abstractmethod 

1085 def _host_error(self): 

1086 pass 

1087 

1088 def _error_message(self, exception): 

1089 return format_error_message(self._host_error(), exception) 

1090 

1091 def on_connect(self): 

1092 self.on_connect_check_health(check_health=True) 

1093 

1094 def on_connect_check_health(self, check_health: bool = True): 

1095 "Initialize the connection, authenticate and select a database" 

1096 self._parser.on_connect(self) 

1097 parser = self._parser 

1098 

1099 auth_args = None 

1100 # if credential provider or username and/or password are set, authenticate 

1101 if self.credential_provider or (self.username or self.password): 

1102 cred_provider = ( 

1103 self.credential_provider 

1104 or UsernamePasswordCredentialProvider(self.username, self.password) 

1105 ) 

1106 auth_args = cred_provider.get_credentials() 

1107 

1108 # if resp version is specified and we have auth args, 

1109 # we need to send them via HELLO 

1110 if auth_args and self.protocol not in [2, "2"]: 

1111 if isinstance(self._parser, _RESP2Parser): 

1112 self.set_parser(_RESP3Parser) 

1113 # update cluster exception classes 

1114 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

1115 self._parser.on_connect(self) 

1116 if len(auth_args) == 1: 

1117 auth_args = ["default", auth_args[0]] 

1118 # avoid checking health here -- PING will fail if we try 

1119 # to check the health prior to the AUTH 

1120 self.send_command( 

1121 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

1122 ) 

1123 self.handshake_metadata = self.read_response() 

1124 # if response.get(b"proto") != self.protocol and response.get( 

1125 # "proto" 

1126 # ) != self.protocol: 

1127 # raise ConnectionError("Invalid RESP version") 

1128 elif auth_args: 

1129 # avoid checking health here -- PING will fail if we try 

1130 # to check the health prior to the AUTH 

1131 self.send_command("AUTH", *auth_args, check_health=False) 

1132 

1133 try: 

1134 auth_response = self.read_response() 

1135 except AuthenticationWrongNumberOfArgsError: 

1136 # a username and password were specified but the Redis 

1137 # server seems to be < 6.0.0 which expects a single password 

1138 # arg. retry auth with just the password. 

1139 # https://github.com/andymccurdy/redis-py/issues/1274 

1140 self.send_command("AUTH", auth_args[-1], check_health=False) 

1141 auth_response = self.read_response() 

1142 

1143 if str_if_bytes(auth_response) != "OK": 

1144 raise AuthenticationError("Invalid Username or Password") 

1145 

1146 # if resp version is specified, switch to it 

1147 elif self.protocol not in [2, "2"]: 

1148 if isinstance(self._parser, _RESP2Parser): 

1149 self.set_parser(_RESP3Parser) 

1150 # update cluster exception classes 

1151 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

1152 self._parser.on_connect(self) 

1153 self.send_command("HELLO", self.protocol, check_health=check_health) 

1154 self.handshake_metadata = self.read_response() 

1155 if ( 

1156 self.handshake_metadata.get(b"proto") != self.protocol 

1157 and self.handshake_metadata.get("proto") != self.protocol 

1158 ): 

1159 raise ConnectionError("Invalid RESP version") 

1160 

1161 # Activate maintenance notifications for this connection 

1162 # if enabled in the configuration 

1163 # This is a no-op if maintenance notifications are not enabled 

1164 self.activate_maint_notifications_handling_if_enabled(check_health=check_health) 

1165 

1166 # if a client_name is given, set it 

1167 if self.client_name: 

1168 self.send_command( 

1169 "CLIENT", 

1170 "SETNAME", 

1171 self.client_name, 

1172 check_health=check_health, 

1173 ) 

1174 if str_if_bytes(self.read_response()) != "OK": 

1175 raise ConnectionError("Error setting client name") 

1176 

1177 # Set the library name and version from driver_info 

1178 try: 

1179 if self.driver_info and self.driver_info.formatted_name: 

1180 self.send_command( 

1181 "CLIENT", 

1182 "SETINFO", 

1183 "LIB-NAME", 

1184 self.driver_info.formatted_name, 

1185 check_health=check_health, 

1186 ) 

1187 self.read_response() 

1188 except ResponseError: 

1189 pass 

1190 

1191 try: 

1192 if self.driver_info and self.driver_info.lib_version: 

1193 self.send_command( 

1194 "CLIENT", 

1195 "SETINFO", 

1196 "LIB-VER", 

1197 self.driver_info.lib_version, 

1198 check_health=check_health, 

1199 ) 

1200 self.read_response() 

1201 except ResponseError: 

1202 pass 

1203 

1204 # if a database is specified, switch to it 

1205 if self.db: 

1206 self.send_command("SELECT", self.db, check_health=check_health) 

1207 if str_if_bytes(self.read_response()) != "OK": 

1208 raise ConnectionError("Invalid Database") 

1209 

1210 def disconnect(self, *args, **kwargs): 

1211 "Disconnects from the Redis server" 

1212 self._parser.on_disconnect() 

1213 

1214 conn_sock = self._sock 

1215 self._sock = None 

1216 # reset the reconnect flag 

1217 self.reset_should_reconnect() 

1218 

1219 if conn_sock is None: 

1220 return 

1221 

1222 if os.getpid() == self.pid: 

1223 try: 

1224 conn_sock.shutdown(socket.SHUT_RDWR) 

1225 except (OSError, TypeError): 

1226 pass 

1227 

1228 try: 

1229 conn_sock.close() 

1230 except OSError: 

1231 pass 

1232 

1233 error = kwargs.get("error") 

1234 failure_count = kwargs.get("failure_count") 

1235 health_check_failed = kwargs.get("health_check_failed") 

1236 

1237 if error: 

1238 if health_check_failed: 

1239 close_reason = CloseReason.HEALTHCHECK_FAILED 

1240 else: 

1241 close_reason = CloseReason.ERROR 

1242 

1243 if failure_count is not None and failure_count > self.retry.get_retries(): 

1244 record_error_count( 

1245 server_address=self.host, 

1246 server_port=self.port, 

1247 network_peer_address=self.host, 

1248 network_peer_port=self.port, 

1249 error_type=error, 

1250 retry_attempts=failure_count, 

1251 ) 

1252 

1253 record_connection_closed( 

1254 close_reason=close_reason, 

1255 error_type=error, 

1256 ) 

1257 else: 

1258 record_connection_closed( 

1259 close_reason=CloseReason.APPLICATION_CLOSE, 

1260 ) 

1261 

1262 if self.maintenance_state == MaintenanceState.MAINTENANCE: 

1263 # this block will be executed only if the connection was in maintenance state 

1264 # and the connection was closed. 

1265 # The state change won't be applied on connections that are in Moving state 

1266 # because their state and configurations will be handled when the moving ttl expires. 

1267 self.reset_tmp_settings(reset_relaxed_timeout=True) 

1268 self.maintenance_state = MaintenanceState.NONE 

1269 # reset the sets that keep track of received start maint 

1270 # notifications and skipped end maint notifications 

1271 self.reset_received_notifications() 

1272 

1273 def mark_for_reconnect(self): 

1274 self._should_reconnect = True 

1275 

1276 def should_reconnect(self): 

1277 return self._should_reconnect 

1278 

1279 def reset_should_reconnect(self): 

1280 self._should_reconnect = False 

1281 

1282 def _send_ping(self): 

1283 """Send PING, expect PONG in return""" 

1284 self.send_command("PING", check_health=False) 

1285 if str_if_bytes(self.read_response()) != "PONG": 

1286 raise ConnectionError("Bad response from PING health check") 

1287 

1288 def _ping_failed(self, error, failure_count): 

1289 """Function to call when PING fails""" 

1290 self.disconnect( 

1291 error=error, failure_count=failure_count, health_check_failed=True 

1292 ) 

1293 

1294 def check_health(self): 

1295 """Check the health of the connection with a PING/PONG""" 

1296 if self.health_check_interval and time.monotonic() > self.next_health_check: 

1297 self.retry.call_with_retry( 

1298 self._send_ping, 

1299 self._ping_failed, 

1300 with_failure_count=True, 

1301 ) 

1302 

1303 def send_packed_command(self, command, check_health=True): 

1304 """Send an already packed command to the Redis server""" 

1305 if not self._sock: 

1306 self.connect_check_health(check_health=False) 

1307 # guard against health check recursion 

1308 if check_health: 

1309 self.check_health() 

1310 try: 

1311 if isinstance(command, str): 

1312 command = [command] 

1313 for item in command: 

1314 self._sock.sendall(item) 

1315 except socket.timeout: 

1316 self.disconnect() 

1317 raise TimeoutError("Timeout writing to socket") 

1318 except OSError as e: 

1319 self.disconnect() 

1320 if len(e.args) == 1: 

1321 errno, errmsg = "UNKNOWN", e.args[0] 

1322 else: 

1323 errno = e.args[0] 

1324 errmsg = e.args[1] 

1325 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") 

1326 except BaseException: 

1327 # BaseExceptions can be raised when a socket send operation is not 

1328 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

1329 # to send un-sent data. However, the send_packed_command() API 

1330 # does not support it so there is no point in keeping the connection open. 

1331 self.disconnect() 

1332 raise 

1333 

1334 def send_command(self, *args, **kwargs): 

1335 """Pack and send a command to the Redis server""" 

1336 self.send_packed_command( 

1337 self._command_packer.pack(*args), 

1338 check_health=kwargs.get("check_health", True), 

1339 ) 

1340 

1341 def can_read(self, timeout: float = 0) -> bool: 

1342 """Poll the socket to see if there's data that can be read.""" 

1343 # TODO: Rename this API; it detects pending data or dirty/closed 

1344 # connection state, not only whether application data can be read. 

1345 sock = self._sock 

1346 if not sock: 

1347 self.connect() 

1348 

1349 host_error = self._host_error() 

1350 

1351 try: 

1352 return self._parser.can_read(timeout) 

1353 

1354 except OSError as e: 

1355 self.disconnect() 

1356 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

1357 

1358 def read_response( 

1359 self, 

1360 disable_decoding=False, 

1361 *, 

1362 timeout: Union[float, object] = SENTINEL, 

1363 disconnect_on_error=True, 

1364 push_request=False, 

1365 ): 

1366 """Read the response from a previously sent command""" 

1367 

1368 host_error = self._host_error() 

1369 

1370 try: 

1371 if self.protocol in ["3", 3]: 

1372 response = self._parser.read_response( 

1373 disable_decoding=disable_decoding, 

1374 push_request=push_request, 

1375 timeout=timeout, 

1376 ) 

1377 else: 

1378 response = self._parser.read_response( 

1379 disable_decoding=disable_decoding, timeout=timeout 

1380 ) 

1381 except socket.timeout: 

1382 if disconnect_on_error: 

1383 self.disconnect() 

1384 raise TimeoutError(f"Timeout reading from {host_error}") 

1385 except OSError as e: 

1386 if disconnect_on_error: 

1387 self.disconnect() 

1388 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

1389 except BaseException: 

1390 # Also by default close in case of BaseException. A lot of code 

1391 # relies on this behaviour when doing Command/Response pairs. 

1392 # See #1128. 

1393 if disconnect_on_error: 

1394 self.disconnect() 

1395 raise 

1396 

1397 if self.health_check_interval: 

1398 self.next_health_check = time.monotonic() + self.health_check_interval 

1399 

1400 if isinstance(response, ResponseError): 

1401 try: 

1402 raise response 

1403 finally: 

1404 del response # avoid creating ref cycles 

1405 return response 

1406 

1407 def pack_command(self, *args): 

1408 """Pack a series of arguments into the Redis protocol""" 

1409 return self._command_packer.pack(*args) 

1410 

1411 def pack_commands(self, commands): 

1412 """Pack multiple commands into the Redis protocol""" 

1413 output = [] 

1414 pieces = [] 

1415 buffer_length = 0 

1416 buffer_cutoff = self._buffer_cutoff 

1417 

1418 for cmd in commands: 

1419 for chunk in self._command_packer.pack(*cmd): 

1420 chunklen = len(chunk) 

1421 if ( 

1422 buffer_length > buffer_cutoff 

1423 or chunklen > buffer_cutoff 

1424 or isinstance(chunk, memoryview) 

1425 ): 

1426 if pieces: 

1427 output.append(SYM_EMPTY.join(pieces)) 

1428 buffer_length = 0 

1429 pieces = [] 

1430 

1431 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

1432 output.append(chunk) 

1433 else: 

1434 pieces.append(chunk) 

1435 buffer_length += chunklen 

1436 

1437 if pieces: 

1438 output.append(SYM_EMPTY.join(pieces)) 

1439 return output 

1440 

1441 def get_protocol(self) -> Union[int, str]: 

1442 return self.protocol 

1443 

1444 @property 

1445 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1446 return self._handshake_metadata 

1447 

1448 @handshake_metadata.setter 

1449 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): 

1450 self._handshake_metadata = value 

1451 

1452 def set_re_auth_token(self, token: TokenInterface): 

1453 self._re_auth_token = token 

1454 

1455 def re_auth(self): 

1456 if self._re_auth_token is not None: 

1457 self.send_command( 

1458 "AUTH", 

1459 self._re_auth_token.try_get("oid"), 

1460 self._re_auth_token.get_value(), 

1461 ) 

1462 self.read_response() 

1463 self._re_auth_token = None 

1464 

1465 def _get_socket(self) -> Optional[socket.socket]: 

1466 return self._sock 

1467 

1468 @property 

1469 def socket_timeout(self) -> Optional[Union[float, int]]: 

1470 return self._socket_timeout 

1471 

1472 @socket_timeout.setter 

1473 def socket_timeout(self, value: Optional[Union[float, int]]): 

1474 self._socket_timeout = value 

1475 

1476 @property 

1477 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

1478 return self._socket_connect_timeout 

1479 

1480 @socket_connect_timeout.setter 

1481 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

1482 self._socket_connect_timeout = value 

1483 

1484 def extract_connection_details(self) -> str: 

1485 socket_address = None 

1486 if self._sock is None: 

1487 return "not connected" 

1488 try: 

1489 socket_address = self._sock.getsockname() if self._sock else None 

1490 socket_address = socket_address[1] if socket_address else None 

1491 except (AttributeError, OSError): 

1492 pass 

1493 

1494 return f"connected to ip {self.get_resolved_ip()}, local socket port: {socket_address}" 

1495 

1496 

1497class Connection(AbstractConnection): 

1498 "Manages TCP communication to and from a Redis server" 

1499 

1500 def __init__( 

1501 self, 

1502 host="localhost", 

1503 port=6379, 

1504 socket_keepalive=True, 

1505 socket_keepalive_options=SENTINEL, 

1506 socket_type=0, 

1507 **kwargs, 

1508 ): 

1509 """ 

1510 Initialize a TCP connection. 

1511 

1512 Parameters 

1513 ---------- 

1514 socket_keepalive : bool 

1515 If `True`, TCP keepalive is enabled for TCP socket connections. 

1516 socket_keepalive_options : Mapping[int, int | bytes] | object | None 

1517 Mapping of TCP keepalive socket option constants to values, for 

1518 example `{socket.TCP_KEEPIDLE: 30}`. If left unspecified, redis-py 

1519 uses TCP keepalive defaults when `socket_keepalive` is enabled: 

1520 idle 30 seconds, interval 5 seconds, and 3 probes. Platform-specific 

1521 options that are not available are skipped. Pass `None` or `{}` to 

1522 avoid setting additional TCP keepalive options. 

1523 """ 

1524 self._host = host 

1525 self.port = int(port) 

1526 self.socket_keepalive = socket_keepalive 

1527 if socket_keepalive_options is SENTINEL: 

1528 socket_keepalive_options = get_default_socket_keepalive_options() 

1529 self.socket_keepalive_options = socket_keepalive_options or {} 

1530 self.socket_type = socket_type 

1531 super().__init__(**kwargs) 

1532 

1533 def repr_pieces(self): 

1534 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

1535 if self.client_name: 

1536 pieces.append(("client_name", self.client_name)) 

1537 return pieces 

1538 

1539 def _connect(self): 

1540 "Create a TCP socket connection" 

1541 # we want to mimic what socket.create_connection does to support 

1542 # ipv4/ipv6, but we want to set options prior to calling 

1543 # socket.connect() 

1544 err = None 

1545 

1546 for res in socket.getaddrinfo( 

1547 self.host, self.port, self.socket_type, socket.SOCK_STREAM 

1548 ): 

1549 family, socktype, proto, canonname, socket_address = res 

1550 sock = None 

1551 try: 

1552 sock = socket.socket(family, socktype, proto) 

1553 # TCP_NODELAY 

1554 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

1555 

1556 # TCP_KEEPALIVE 

1557 if self.socket_keepalive: 

1558 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

1559 for k, v in self.socket_keepalive_options.items(): 

1560 sock.setsockopt(socket.IPPROTO_TCP, k, v) 

1561 

1562 # set the socket_connect_timeout before we connect 

1563 sock.settimeout(self.socket_connect_timeout) 

1564 

1565 # connect 

1566 sock.connect(socket_address) 

1567 

1568 # set the socket_timeout now that we're connected 

1569 sock.settimeout(self.socket_timeout) 

1570 return sock 

1571 

1572 except OSError as _: 

1573 err = _ 

1574 if sock is not None: 

1575 try: 

1576 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1577 except OSError: 

1578 pass 

1579 sock.close() 

1580 

1581 if err is not None: 

1582 raise err 

1583 raise OSError("socket.getaddrinfo returned an empty list") 

1584 

1585 def _host_error(self): 

1586 return f"{self.host}:{self.port}" 

1587 

1588 @property 

1589 def host(self) -> str: 

1590 return self._host 

1591 

1592 @host.setter 

1593 def host(self, value: str): 

1594 self._host = value 

1595 

1596 

1597class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface): 

1598 DUMMY_CACHE_VALUE = b"foo" 

1599 MIN_ALLOWED_VERSION = "7.4.0" 

1600 DEFAULT_SERVER_NAME = "redis" 

1601 

1602 def __init__( 

1603 self, 

1604 conn: ConnectionInterface, 

1605 cache: CacheInterface, 

1606 pool_lock: threading.RLock, 

1607 ): 

1608 self.pid = os.getpid() 

1609 self._conn = conn 

1610 self.retry = self._conn.retry 

1611 self.host = self._conn.host 

1612 self.port = self._conn.port 

1613 self.db = self._conn.db 

1614 self._event_dispatcher = self._conn._event_dispatcher 

1615 self.credential_provider = conn.credential_provider 

1616 self._pool_lock = pool_lock 

1617 self._cache = cache 

1618 self._cache_lock = threading.RLock() 

1619 self._current_command_cache_key = None 

1620 self._current_options = None 

1621 self.register_connect_callback(self._enable_tracking_callback) 

1622 

1623 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1624 MaintNotificationsAbstractConnection.__init__( 

1625 self, 

1626 self._conn.maint_notifications_config, 

1627 self._conn._maint_notifications_pool_handler, 

1628 self._conn.maintenance_state, 

1629 self._conn.maintenance_notification_hash, 

1630 self._conn.host, 

1631 self._conn.socket_timeout, 

1632 self._conn.socket_connect_timeout, 

1633 self._conn._oss_cluster_maint_notifications_handler, 

1634 self._conn._get_parser(), 

1635 event_dispatcher=self._conn.event_dispatcher, 

1636 ) 

1637 

1638 def repr_pieces(self): 

1639 return self._conn.repr_pieces() 

1640 

1641 @property 

1642 def is_connected(self) -> bool: 

1643 return self._conn.is_connected 

1644 

1645 def register_connect_callback(self, callback): 

1646 self._conn.register_connect_callback(callback) 

1647 

1648 def deregister_connect_callback(self, callback): 

1649 self._conn.deregister_connect_callback(callback) 

1650 

1651 def set_parser(self, parser_class): 

1652 self._conn.set_parser(parser_class) 

1653 

1654 def set_maint_notifications_pool_handler_for_connection( 

1655 self, maint_notifications_pool_handler 

1656 ): 

1657 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1658 self._conn.set_maint_notifications_pool_handler_for_connection( 

1659 maint_notifications_pool_handler 

1660 ) 

1661 

1662 def set_maint_notifications_cluster_handler_for_connection( 

1663 self, oss_cluster_maint_notifications_handler 

1664 ): 

1665 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1666 self._conn.set_maint_notifications_cluster_handler_for_connection( 

1667 oss_cluster_maint_notifications_handler 

1668 ) 

1669 

1670 def get_protocol(self): 

1671 return self._conn.get_protocol() 

1672 

1673 def connect(self): 

1674 self._conn.connect() 

1675 

1676 server_name = self._conn.handshake_metadata.get(b"server", None) 

1677 if server_name is None: 

1678 server_name = self._conn.handshake_metadata.get("server", None) 

1679 server_ver = self._conn.handshake_metadata.get(b"version", None) 

1680 if server_ver is None: 

1681 server_ver = self._conn.handshake_metadata.get("version", None) 

1682 if server_ver is None or server_name is None: 

1683 raise ConnectionError("Cannot retrieve information about server version") 

1684 

1685 server_ver = ensure_string(server_ver) 

1686 server_name = ensure_string(server_name) 

1687 

1688 if ( 

1689 server_name != self.DEFAULT_SERVER_NAME 

1690 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 

1691 ): 

1692 raise ConnectionError( 

1693 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 

1694 ) 

1695 

1696 def on_connect(self): 

1697 self._conn.on_connect() 

1698 

1699 def disconnect(self, *args, **kwargs): 

1700 with self._cache_lock: 

1701 self._cache.flush() 

1702 self._conn.disconnect(*args, **kwargs) 

1703 

1704 def check_health(self): 

1705 self._conn.check_health() 

1706 

1707 def send_packed_command(self, command, check_health=True): 

1708 # TODO: Investigate if it's possible to unpack command 

1709 # or extract keys from packed command 

1710 self._conn.send_packed_command(command) 

1711 

1712 def send_command(self, *args, **kwargs): 

1713 self._process_pending_invalidations() 

1714 

1715 with self._cache_lock: 

1716 # Command is write command or not allowed 

1717 # to be cached. 

1718 if not self._cache.is_cachable( 

1719 CacheKey(command=args[0], redis_keys=(), redis_args=()) 

1720 ): 

1721 self._current_command_cache_key = None 

1722 self._conn.send_command(*args, **kwargs) 

1723 return 

1724 

1725 if kwargs.get("keys") is None: 

1726 raise ValueError("Cannot create cache key.") 

1727 

1728 # Creates cache key. 

1729 self._current_command_cache_key = CacheKey( 

1730 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args 

1731 ) 

1732 

1733 with self._cache_lock: 

1734 # We have to trigger invalidation processing in case if 

1735 # it was cached by another connection to avoid 

1736 # queueing invalidations in stale connections. 

1737 if self._cache.get(self._current_command_cache_key): 

1738 entry = self._cache.get(self._current_command_cache_key) 

1739 

1740 with self._pool_lock: 

1741 while entry.connection_ref.can_read(): 

1742 try: 

1743 entry.connection_ref.read_response( 

1744 push_request=True, 

1745 timeout=0, 

1746 disconnect_on_error=False, 

1747 ) 

1748 except TimeoutError: 

1749 break 

1750 

1751 # Re-check: if the entry was invalidated during the drain, 

1752 # fall through to send the command over the network. 

1753 if self._cache.get(self._current_command_cache_key): 

1754 return 

1755 

1756 # Set temporary entry value to prevent 

1757 # race condition from another connection. 

1758 self._cache.set( 

1759 CacheEntry( 

1760 cache_key=self._current_command_cache_key, 

1761 cache_value=self.DUMMY_CACHE_VALUE, 

1762 status=CacheEntryStatus.IN_PROGRESS, 

1763 connection_ref=self._conn, 

1764 ) 

1765 ) 

1766 

1767 # Send command over socket only if it's allowed 

1768 # read-only command that not yet cached. 

1769 self._conn.send_command(*args, **kwargs) 

1770 

1771 def can_read(self, timeout: float = 0) -> bool: 

1772 # TODO: Rename this API; it detects pending data or dirty/closed 

1773 # connection state, not only whether application data can be read. 

1774 return self._conn.can_read(timeout) 

1775 

1776 def read_response( 

1777 self, 

1778 disable_decoding=False, 

1779 *, 

1780 timeout: Union[float, object] = SENTINEL, 

1781 disconnect_on_error=True, 

1782 push_request=False, 

1783 ): 

1784 with self._cache_lock: 

1785 # Check if command response exists in a cache and it's not in progress. 

1786 if self._current_command_cache_key is not None: 

1787 if ( 

1788 self._cache.get(self._current_command_cache_key) is not None 

1789 and self._cache.get(self._current_command_cache_key).status 

1790 != CacheEntryStatus.IN_PROGRESS 

1791 ): 

1792 res = copy.deepcopy( 

1793 self._cache.get(self._current_command_cache_key).cache_value 

1794 ) 

1795 self._current_command_cache_key = None 

1796 record_csc_request( 

1797 result=CSCResult.HIT, 

1798 ) 

1799 record_csc_network_saved( 

1800 bytes_saved=len(res) if hasattr(res, "__len__") else 0, 

1801 ) 

1802 return res 

1803 record_csc_request( 

1804 result=CSCResult.MISS, 

1805 ) 

1806 

1807 response = self._conn.read_response( 

1808 disable_decoding=disable_decoding, 

1809 timeout=timeout, 

1810 disconnect_on_error=disconnect_on_error, 

1811 push_request=push_request, 

1812 ) 

1813 

1814 with self._cache_lock: 

1815 # Prevent not-allowed command from caching. 

1816 if self._current_command_cache_key is None: 

1817 return response 

1818 # If response is None prevent from caching. 

1819 if response is None: 

1820 self._cache.delete_by_cache_keys([self._current_command_cache_key]) 

1821 return response 

1822 

1823 cache_entry = self._cache.get(self._current_command_cache_key) 

1824 

1825 # Cache only responses that still valid 

1826 # and wasn't invalidated by another connection in meantime. 

1827 if cache_entry is not None: 

1828 cache_entry.status = CacheEntryStatus.VALID 

1829 cache_entry.cache_value = response 

1830 self._cache.set(cache_entry) 

1831 

1832 self._current_command_cache_key = None 

1833 

1834 return response 

1835 

1836 def pack_command(self, *args): 

1837 return self._conn.pack_command(*args) 

1838 

1839 def pack_commands(self, commands): 

1840 return self._conn.pack_commands(commands) 

1841 

1842 @property 

1843 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1844 return self._conn.handshake_metadata 

1845 

1846 def set_re_auth_token(self, token: TokenInterface): 

1847 self._conn.set_re_auth_token(token) 

1848 

1849 def re_auth(self): 

1850 self._conn.re_auth() 

1851 

1852 def mark_for_reconnect(self): 

1853 self._conn.mark_for_reconnect() 

1854 

1855 def should_reconnect(self): 

1856 return self._conn.should_reconnect() 

1857 

1858 def reset_should_reconnect(self): 

1859 self._conn.reset_should_reconnect() 

1860 

1861 @property 

1862 def host(self) -> str: 

1863 return self._conn.host 

1864 

1865 @host.setter 

1866 def host(self, value: str): 

1867 self._conn.host = value 

1868 

1869 @property 

1870 def socket_timeout(self) -> Optional[Union[float, int]]: 

1871 return self._conn.socket_timeout 

1872 

1873 @socket_timeout.setter 

1874 def socket_timeout(self, value: Optional[Union[float, int]]): 

1875 self._conn.socket_timeout = value 

1876 

1877 @property 

1878 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

1879 return self._conn.socket_connect_timeout 

1880 

1881 @socket_connect_timeout.setter 

1882 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

1883 self._conn.socket_connect_timeout = value 

1884 

1885 @property 

1886 def _maint_notifications_connection_handler( 

1887 self, 

1888 ) -> Optional[MaintNotificationsConnectionHandler]: 

1889 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1890 return self._conn._maint_notifications_connection_handler 

1891 

1892 @_maint_notifications_connection_handler.setter 

1893 def _maint_notifications_connection_handler( 

1894 self, value: Optional[MaintNotificationsConnectionHandler] 

1895 ): 

1896 self._conn._maint_notifications_connection_handler = value 

1897 

1898 def _get_socket(self) -> Optional[socket.socket]: 

1899 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1900 return self._conn._get_socket() 

1901 else: 

1902 raise NotImplementedError( 

1903 "Maintenance notifications are not supported by this connection type" 

1904 ) 

1905 

1906 def _get_maint_notifications_connection_instance( 

1907 self, connection 

1908 ) -> MaintNotificationsAbstractConnection: 

1909 """ 

1910 Validate that connection instance supports maintenance notifications. 

1911 With this helper method we ensure that we are working 

1912 with the correct connection type. 

1913 After twe validate that connection instance supports maintenance notifications 

1914 we can safely return the connection instance 

1915 as MaintNotificationsAbstractConnection. 

1916 """ 

1917 if not isinstance(connection, MaintNotificationsAbstractConnection): 

1918 raise NotImplementedError( 

1919 "Maintenance notifications are not supported by this connection type" 

1920 ) 

1921 else: 

1922 return connection 

1923 

1924 @property 

1925 def maintenance_state(self) -> MaintenanceState: 

1926 con = self._get_maint_notifications_connection_instance(self._conn) 

1927 return con.maintenance_state 

1928 

1929 @maintenance_state.setter 

1930 def maintenance_state(self, state: MaintenanceState): 

1931 con = self._get_maint_notifications_connection_instance(self._conn) 

1932 con.maintenance_state = state 

1933 

1934 def getpeername(self): 

1935 con = self._get_maint_notifications_connection_instance(self._conn) 

1936 return con.getpeername() 

1937 

1938 def get_resolved_ip(self): 

1939 con = self._get_maint_notifications_connection_instance(self._conn) 

1940 return con.get_resolved_ip() 

1941 

1942 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

1943 con = self._get_maint_notifications_connection_instance(self._conn) 

1944 con.update_current_socket_timeout(relaxed_timeout) 

1945 

1946 def set_tmp_settings( 

1947 self, 

1948 tmp_host_address: Optional[str] = None, 

1949 tmp_relaxed_timeout: Optional[float] = None, 

1950 ): 

1951 con = self._get_maint_notifications_connection_instance(self._conn) 

1952 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout) 

1953 

1954 def reset_tmp_settings( 

1955 self, 

1956 reset_host_address: bool = False, 

1957 reset_relaxed_timeout: bool = False, 

1958 ): 

1959 con = self._get_maint_notifications_connection_instance(self._conn) 

1960 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout) 

1961 

1962 def _connect(self): 

1963 self._conn._connect() 

1964 

1965 def _host_error(self): 

1966 self._conn._host_error() 

1967 

1968 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: 

1969 conn.send_command("CLIENT", "TRACKING", "ON") 

1970 conn.read_response() 

1971 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) 

1972 

1973 def _process_pending_invalidations(self): 

1974 while self.can_read(): 

1975 try: 

1976 self._conn.read_response( 

1977 push_request=True, timeout=0, disconnect_on_error=False 

1978 ) 

1979 except TimeoutError: 

1980 break 

1981 

1982 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): 

1983 with self._cache_lock: 

1984 # Flush cache when DB flushed on server-side 

1985 if data[1] is None: 

1986 self._cache.flush() 

1987 else: 

1988 keys_deleted = self._cache.delete_by_redis_keys(data[1]) 

1989 

1990 if len(keys_deleted) > 0: 

1991 record_csc_eviction( 

1992 count=len(keys_deleted), 

1993 reason=CSCReason.INVALIDATION, 

1994 ) 

1995 

1996 def extract_connection_details(self) -> str: 

1997 return self._conn.extract_connection_details() 

1998 

1999 

2000class SSLConnection(Connection): 

2001 """Manages SSL connections to and from the Redis server(s). 

2002 This class extends the Connection class, adding SSL functionality, and making 

2003 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

2004 """ # noqa 

2005 

2006 def __init__( 

2007 self, 

2008 ssl_keyfile=None, 

2009 ssl_certfile=None, 

2010 ssl_cert_reqs="required", 

2011 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None, 

2012 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None, 

2013 ssl_ca_certs=None, 

2014 ssl_ca_data=None, 

2015 ssl_check_hostname=True, 

2016 ssl_ca_path=None, 

2017 ssl_password=None, 

2018 ssl_validate_ocsp=False, 

2019 ssl_validate_ocsp_stapled=False, 

2020 ssl_ocsp_context=None, 

2021 ssl_ocsp_expected_cert=None, 

2022 ssl_min_version=None, 

2023 ssl_ciphers=None, 

2024 **kwargs, 

2025 ): 

2026 """Constructor 

2027 

2028 Args: 

2029 ssl_keyfile: Path to an ssl private key. Defaults to None. 

2030 ssl_certfile: Path to an ssl certificate. Defaults to None. 

2031 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), 

2032 or an ssl.VerifyMode. Defaults to "required". 

2033 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None. 

2034 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None. 

2035 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. 

2036 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. 

2037 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True. 

2038 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. 

2039 ssl_password: Password for unlocking an encrypted private key. Defaults to None. 

2040 

2041 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) 

2042 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response 

2043 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert 

2044 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. 

2045 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. 

2046 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information. 

2047 

2048 Raises: 

2049 RedisError 

2050 """ # noqa 

2051 if not SSL_AVAILABLE: 

2052 raise RedisError("Python wasn't built with SSL support") 

2053 

2054 self.keyfile = ssl_keyfile 

2055 self.certfile = ssl_certfile 

2056 if ssl_cert_reqs is None: 

2057 ssl_cert_reqs = ssl.CERT_NONE 

2058 elif isinstance(ssl_cert_reqs, str): 

2059 CERT_REQS = { # noqa: N806 

2060 "none": ssl.CERT_NONE, 

2061 "optional": ssl.CERT_OPTIONAL, 

2062 "required": ssl.CERT_REQUIRED, 

2063 } 

2064 if ssl_cert_reqs not in CERT_REQS: 

2065 raise RedisError( 

2066 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" 

2067 ) 

2068 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 

2069 self.cert_reqs = ssl_cert_reqs 

2070 self.ssl_include_verify_flags = ssl_include_verify_flags 

2071 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags 

2072 self.ca_certs = ssl_ca_certs 

2073 self.ca_data = ssl_ca_data 

2074 self.ca_path = ssl_ca_path 

2075 self.check_hostname = ( 

2076 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

2077 ) 

2078 self.certificate_password = ssl_password 

2079 self.ssl_validate_ocsp = ssl_validate_ocsp 

2080 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled 

2081 self.ssl_ocsp_context = ssl_ocsp_context 

2082 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert 

2083 self.ssl_min_version = ssl_min_version 

2084 self.ssl_ciphers = ssl_ciphers 

2085 super().__init__(**kwargs) 

2086 

2087 def _connect(self): 

2088 """ 

2089 Wrap the socket with SSL support, handling potential errors. 

2090 """ 

2091 sock = super()._connect() 

2092 try: 

2093 return self._wrap_socket_with_ssl(sock) 

2094 except (OSError, RedisError): 

2095 sock.close() 

2096 raise 

2097 

2098 def _wrap_socket_with_ssl(self, sock): 

2099 """ 

2100 Wraps the socket with SSL support. 

2101 

2102 Args: 

2103 sock: The plain socket to wrap with SSL. 

2104 

2105 Returns: 

2106 An SSL wrapped socket. 

2107 """ 

2108 context = ssl.create_default_context() 

2109 context.check_hostname = self.check_hostname 

2110 context.verify_mode = self.cert_reqs 

2111 if self.ssl_include_verify_flags: 

2112 for flag in self.ssl_include_verify_flags: 

2113 context.verify_flags |= flag 

2114 if self.ssl_exclude_verify_flags: 

2115 for flag in self.ssl_exclude_verify_flags: 

2116 context.verify_flags &= ~flag 

2117 if self.certfile or self.keyfile: 

2118 context.load_cert_chain( 

2119 certfile=self.certfile, 

2120 keyfile=self.keyfile, 

2121 password=self.certificate_password, 

2122 ) 

2123 if ( 

2124 self.ca_certs is not None 

2125 or self.ca_path is not None 

2126 or self.ca_data is not None 

2127 ): 

2128 context.load_verify_locations( 

2129 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

2130 ) 

2131 if self.ssl_min_version is not None: 

2132 context.minimum_version = self.ssl_min_version 

2133 if self.ssl_ciphers: 

2134 context.set_ciphers(self.ssl_ciphers) 

2135 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: 

2136 raise RedisError("cryptography is not installed.") 

2137 

2138 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: 

2139 raise RedisError( 

2140 "Either an OCSP staple or pure OCSP connection must be validated " 

2141 "- not both." 

2142 ) 

2143 

2144 sslsock = context.wrap_socket(sock, server_hostname=self.host) 

2145 

2146 # validation for the stapled case 

2147 if self.ssl_validate_ocsp_stapled: 

2148 import OpenSSL 

2149 

2150 from .ocsp import ocsp_staple_verifier 

2151 

2152 # if a context is provided use it - otherwise, a basic context 

2153 if self.ssl_ocsp_context is None: 

2154 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) 

2155 staple_ctx.use_certificate_file(self.certfile) 

2156 staple_ctx.use_privatekey_file(self.keyfile) 

2157 else: 

2158 staple_ctx = self.ssl_ocsp_context 

2159 

2160 staple_ctx.set_ocsp_client_callback( 

2161 ocsp_staple_verifier, self.ssl_ocsp_expected_cert 

2162 ) 

2163 

2164 # need another socket 

2165 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) 

2166 con.request_ocsp() 

2167 con.connect((self.host, self.port)) 

2168 con.do_handshake() 

2169 con.shutdown() 

2170 return sslsock 

2171 

2172 # pure ocsp validation 

2173 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: 

2174 from .ocsp import OCSPVerifier 

2175 

2176 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) 

2177 if o.is_valid(): 

2178 return sslsock 

2179 else: 

2180 raise ConnectionError("ocsp validation error") 

2181 return sslsock 

2182 

2183 

2184class UnixDomainSocketConnection(AbstractConnection): 

2185 "Manages UDS communication to and from a Redis server" 

2186 

2187 def __init__(self, path="", socket_timeout=DEFAULT_SOCKET_TIMEOUT, **kwargs): 

2188 super().__init__(**kwargs) 

2189 self.path = path 

2190 self.socket_timeout = socket_timeout 

2191 

2192 def repr_pieces(self): 

2193 pieces = [("path", self.path), ("db", self.db)] 

2194 if self.client_name: 

2195 pieces.append(("client_name", self.client_name)) 

2196 return pieces 

2197 

2198 def _connect(self): 

2199 "Create a Unix domain socket connection" 

2200 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

2201 sock.settimeout(self.socket_connect_timeout) 

2202 try: 

2203 sock.connect(self.path) 

2204 except OSError: 

2205 # Prevent ResourceWarnings for unclosed sockets. 

2206 try: 

2207 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

2208 except OSError: 

2209 pass 

2210 sock.close() 

2211 raise 

2212 sock.settimeout(self.socket_timeout) 

2213 return sock 

2214 

2215 def _host_error(self): 

2216 return self.path 

2217 

2218 

2219FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

2220 

2221 

2222def to_bool(value): 

2223 if value is None or value == "": 

2224 return None 

2225 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

2226 return False 

2227 return bool(value) 

2228 

2229 

2230def parse_ssl_verify_flags(value): 

2231 # flags are passed in as a string representation of a list, 

2232 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

2233 verify_flags_str = value.replace("[", "").replace("]", "") 

2234 

2235 verify_flags = [] 

2236 for flag in verify_flags_str.split(","): 

2237 flag = flag.strip() 

2238 if not hasattr(VerifyFlags, flag): 

2239 raise ValueError(f"Invalid ssl verify flag: {flag}") 

2240 verify_flags.append(getattr(VerifyFlags, flag)) 

2241 return verify_flags 

2242 

2243 

2244URL_QUERY_ARGUMENT_PARSERS = { 

2245 "db": int, 

2246 "socket_timeout": float, 

2247 "socket_connect_timeout": float, 

2248 "socket_read_size": int, 

2249 "socket_keepalive": to_bool, 

2250 "retry_on_timeout": to_bool, 

2251 "retry_on_error": list, 

2252 "max_connections": int, 

2253 "health_check_interval": int, 

2254 "ssl_check_hostname": to_bool, 

2255 "ssl_include_verify_flags": parse_ssl_verify_flags, 

2256 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

2257 "ssl_min_version": int, 

2258 "timeout": float, 

2259 "protocol": int, 

2260 "legacy_responses": to_bool, 

2261} 

2262 

2263 

2264def parse_url(url): 

2265 if not ( 

2266 url.startswith("redis://") 

2267 or url.startswith("rediss://") 

2268 or url.startswith("unix://") 

2269 ): 

2270 raise ValueError( 

2271 "Redis URL must specify one of the following " 

2272 "schemes (redis://, rediss://, unix://)" 

2273 ) 

2274 

2275 url = urlparse(url) 

2276 kwargs = {} 

2277 

2278 for name, value in parse_qs(url.query).items(): 

2279 if value and len(value) > 0: 

2280 value = unquote(value[0]) 

2281 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

2282 if parser: 

2283 try: 

2284 kwargs[name] = parser(value) 

2285 except (TypeError, ValueError): 

2286 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

2287 else: 

2288 kwargs[name] = value 

2289 

2290 if url.username: 

2291 kwargs["username"] = unquote(url.username) 

2292 if url.password: 

2293 kwargs["password"] = unquote(url.password) 

2294 

2295 # We only support redis://, rediss:// and unix:// schemes. 

2296 if url.scheme == "unix": 

2297 if url.path: 

2298 kwargs["path"] = unquote(url.path) 

2299 kwargs["connection_class"] = UnixDomainSocketConnection 

2300 

2301 else: # implied: url.scheme in ("redis", "rediss"): 

2302 if url.hostname: 

2303 kwargs["host"] = unquote(url.hostname) 

2304 if url.port: 

2305 kwargs["port"] = int(url.port) 

2306 

2307 # If there's a path argument, use it as the db argument if a 

2308 # querystring value wasn't specified 

2309 if url.path and "db" not in kwargs: 

2310 try: 

2311 kwargs["db"] = int(unquote(url.path).replace("/", "")) 

2312 except (AttributeError, ValueError): 

2313 pass 

2314 

2315 if url.scheme == "rediss": 

2316 kwargs["connection_class"] = SSLConnection 

2317 

2318 return kwargs 

2319 

2320 

2321_CP = TypeVar("_CP", bound="ConnectionPool") 

2322 

2323 

2324class ConnectionPoolInterface(ABC): 

2325 @abstractmethod 

2326 def get_protocol(self): 

2327 pass 

2328 

2329 @abstractmethod 

2330 def reset(self): 

2331 pass 

2332 

2333 @abstractmethod 

2334 @deprecated_args( 

2335 args_to_warn=["*"], 

2336 reason="Use get_connection() without args instead", 

2337 version="5.3.0", 

2338 ) 

2339 def get_connection( 

2340 self, command_name: Optional[str], *keys, **options 

2341 ) -> ConnectionInterface: 

2342 pass 

2343 

2344 @abstractmethod 

2345 def get_encoder(self): 

2346 pass 

2347 

2348 @abstractmethod 

2349 def release(self, connection: ConnectionInterface): 

2350 pass 

2351 

2352 @abstractmethod 

2353 def disconnect(self, inuse_connections: bool = True): 

2354 pass 

2355 

2356 @abstractmethod 

2357 def close(self): 

2358 pass 

2359 

2360 @abstractmethod 

2361 def set_retry(self, retry: Retry): 

2362 pass 

2363 

2364 @abstractmethod 

2365 def re_auth_callback(self, token: TokenInterface): 

2366 pass 

2367 

2368 @abstractmethod 

2369 def get_connection_count(self) -> list[tuple[int, dict]]: 

2370 """ 

2371 Returns a connection count (both idle and in use). 

2372 """ 

2373 pass 

2374 

2375 

2376class MaintNotificationsAbstractConnectionPool: 

2377 """ 

2378 Abstract class for handling maintenance notifications logic. 

2379 This class is mixed into the ConnectionPool classes. 

2380 

2381 This class is not intended to be used directly! 

2382 

2383 All logic related to maintenance notifications and 

2384 connection pool handling is encapsulated in this class. 

2385 """ 

2386 

2387 def __init__( 

2388 self, 

2389 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

2390 oss_cluster_maint_notifications_handler: Optional[ 

2391 OSSMaintNotificationsHandler 

2392 ] = None, 

2393 **kwargs, 

2394 ): 

2395 # Initialize maintenance notifications 

2396 is_protocol_supported = check_protocol_version(kwargs.get("protocol"), 3) 

2397 

2398 if maint_notifications_config is None and is_protocol_supported: 

2399 maint_notifications_config = MaintNotificationsConfig() 

2400 

2401 if maint_notifications_config and maint_notifications_config.enabled: 

2402 if not is_protocol_supported: 

2403 raise RedisError( 

2404 "Maintenance notifications handlers on connection are only supported with RESP version 3" 

2405 ) 

2406 

2407 self._event_dispatcher = kwargs.get("event_dispatcher", None) 

2408 if self._event_dispatcher is None: 

2409 self._event_dispatcher = EventDispatcher() 

2410 

2411 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2412 self, maint_notifications_config 

2413 ) 

2414 if oss_cluster_maint_notifications_handler: 

2415 self._oss_cluster_maint_notifications_handler = ( 

2416 oss_cluster_maint_notifications_handler 

2417 ) 

2418 self._update_connection_kwargs_for_maint_notifications( 

2419 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler 

2420 ) 

2421 self._maint_notifications_pool_handler = None 

2422 else: 

2423 self._oss_cluster_maint_notifications_handler = None 

2424 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2425 self, maint_notifications_config 

2426 ) 

2427 

2428 self._update_connection_kwargs_for_maint_notifications( 

2429 maint_notifications_pool_handler=self._maint_notifications_pool_handler 

2430 ) 

2431 else: 

2432 self._maint_notifications_pool_handler = None 

2433 self._oss_cluster_maint_notifications_handler = None 

2434 

2435 @property 

2436 @abstractmethod 

2437 def connection_kwargs(self) -> Dict[str, Any]: 

2438 pass 

2439 

2440 @connection_kwargs.setter 

2441 @abstractmethod 

2442 def connection_kwargs(self, value: Dict[str, Any]): 

2443 pass 

2444 

2445 @abstractmethod 

2446 def _get_pool_lock(self) -> threading.RLock: 

2447 pass 

2448 

2449 @abstractmethod 

2450 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]: 

2451 pass 

2452 

2453 @abstractmethod 

2454 def _get_in_use_connections( 

2455 self, 

2456 ) -> Iterable["MaintNotificationsAbstractConnection"]: 

2457 pass 

2458 

2459 def maint_notifications_enabled(self): 

2460 """ 

2461 Returns: 

2462 True if the maintenance notifications are enabled, False otherwise. 

2463 The maintenance notifications config is stored in the pool handler. 

2464 If the pool handler is not set, the maintenance notifications are not enabled. 

2465 """ 

2466 if self._oss_cluster_maint_notifications_handler: 

2467 maint_notifications_config = ( 

2468 self._oss_cluster_maint_notifications_handler.config 

2469 ) 

2470 else: 

2471 maint_notifications_config = ( 

2472 self._maint_notifications_pool_handler.config 

2473 if self._maint_notifications_pool_handler 

2474 else None 

2475 ) 

2476 

2477 return maint_notifications_config and maint_notifications_config.enabled 

2478 

2479 def update_maint_notifications_config( 

2480 self, 

2481 maint_notifications_config: MaintNotificationsConfig, 

2482 oss_cluster_maint_notifications_handler: Optional[ 

2483 OSSMaintNotificationsHandler 

2484 ] = None, 

2485 ): 

2486 """ 

2487 Updates the maintenance notifications configuration. 

2488 This method should be called only if the pool was created 

2489 without enabling the maintenance notifications and 

2490 in a later point in time maintenance notifications 

2491 are requested to be enabled. 

2492 """ 

2493 if ( 

2494 self.maint_notifications_enabled() 

2495 and not maint_notifications_config.enabled 

2496 ): 

2497 raise ValueError( 

2498 "Cannot disable maintenance notifications after enabling them" 

2499 ) 

2500 if oss_cluster_maint_notifications_handler: 

2501 self._oss_cluster_maint_notifications_handler = ( 

2502 oss_cluster_maint_notifications_handler 

2503 ) 

2504 else: 

2505 # first update pool settings 

2506 if not self._maint_notifications_pool_handler: 

2507 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2508 self, maint_notifications_config 

2509 ) 

2510 else: 

2511 self._maint_notifications_pool_handler.config = ( 

2512 maint_notifications_config 

2513 ) 

2514 

2515 # then update connection kwargs and existing connections 

2516 self._update_connection_kwargs_for_maint_notifications( 

2517 maint_notifications_pool_handler=self._maint_notifications_pool_handler, 

2518 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler, 

2519 ) 

2520 self._update_maint_notifications_configs_for_connections( 

2521 maint_notifications_pool_handler=self._maint_notifications_pool_handler, 

2522 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler, 

2523 ) 

2524 

2525 def _update_connection_kwargs_for_maint_notifications( 

2526 self, 

2527 maint_notifications_pool_handler: Optional[ 

2528 MaintNotificationsPoolHandler 

2529 ] = None, 

2530 oss_cluster_maint_notifications_handler: Optional[ 

2531 OSSMaintNotificationsHandler 

2532 ] = None, 

2533 ): 

2534 """ 

2535 Update the connection kwargs for all future connections. 

2536 """ 

2537 if not self.maint_notifications_enabled(): 

2538 return 

2539 if maint_notifications_pool_handler: 

2540 self.connection_kwargs.update( 

2541 { 

2542 "maint_notifications_pool_handler": maint_notifications_pool_handler, 

2543 "maint_notifications_config": maint_notifications_pool_handler.config, 

2544 } 

2545 ) 

2546 if oss_cluster_maint_notifications_handler: 

2547 self.connection_kwargs.update( 

2548 { 

2549 "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler, 

2550 "maint_notifications_config": oss_cluster_maint_notifications_handler.config, 

2551 } 

2552 ) 

2553 

2554 # Store original connection parameters for maintenance notifications. 

2555 if self.connection_kwargs.get("orig_host_address", None) is None: 

2556 # If orig_host_address is None it means we haven't 

2557 # configured the original values yet 

2558 self.connection_kwargs.update( 

2559 { 

2560 "orig_host_address": self.connection_kwargs.get("host"), 

2561 "orig_socket_timeout": self.connection_kwargs.get( 

2562 "socket_timeout", DEFAULT_SOCKET_TIMEOUT 

2563 ), 

2564 "orig_socket_connect_timeout": self.connection_kwargs.get( 

2565 "socket_connect_timeout", DEFAULT_SOCKET_CONNECT_TIMEOUT 

2566 ), 

2567 } 

2568 ) 

2569 

2570 def _update_maint_notifications_configs_for_connections( 

2571 self, 

2572 maint_notifications_pool_handler: Optional[ 

2573 MaintNotificationsPoolHandler 

2574 ] = None, 

2575 oss_cluster_maint_notifications_handler: Optional[ 

2576 OSSMaintNotificationsHandler 

2577 ] = None, 

2578 ): 

2579 """Update the maintenance notifications config for all connections in the pool.""" 

2580 with self._get_pool_lock(): 

2581 for conn in self._get_free_connections(): 

2582 if oss_cluster_maint_notifications_handler: 

2583 # set cluster handler for conn 

2584 conn.set_maint_notifications_cluster_handler_for_connection( 

2585 oss_cluster_maint_notifications_handler 

2586 ) 

2587 conn.maint_notifications_config = ( 

2588 oss_cluster_maint_notifications_handler.config 

2589 ) 

2590 elif maint_notifications_pool_handler: 

2591 conn.set_maint_notifications_pool_handler_for_connection( 

2592 maint_notifications_pool_handler 

2593 ) 

2594 conn.maint_notifications_config = ( 

2595 maint_notifications_pool_handler.config 

2596 ) 

2597 else: 

2598 raise ValueError( 

2599 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set" 

2600 ) 

2601 conn.disconnect() 

2602 for conn in self._get_in_use_connections(): 

2603 if oss_cluster_maint_notifications_handler: 

2604 conn.maint_notifications_config = ( 

2605 oss_cluster_maint_notifications_handler.config 

2606 ) 

2607 conn._configure_maintenance_notifications( 

2608 oss_cluster_maint_notifications_handler=oss_cluster_maint_notifications_handler 

2609 ) 

2610 elif maint_notifications_pool_handler: 

2611 conn.set_maint_notifications_pool_handler_for_connection( 

2612 maint_notifications_pool_handler 

2613 ) 

2614 conn.maint_notifications_config = ( 

2615 maint_notifications_pool_handler.config 

2616 ) 

2617 else: 

2618 raise ValueError( 

2619 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set" 

2620 ) 

2621 conn.mark_for_reconnect() 

2622 

2623 def _should_update_connection( 

2624 self, 

2625 conn: "MaintNotificationsAbstractConnection", 

2626 matching_pattern: Literal[ 

2627 "connected_address", "configured_address", "notification_hash" 

2628 ] = "connected_address", 

2629 matching_address: Optional[str] = None, 

2630 matching_notification_hash: Optional[int] = None, 

2631 ) -> bool: 

2632 """ 

2633 Check if the connection should be updated based on the matching criteria. 

2634 """ 

2635 if matching_pattern == "connected_address": 

2636 if matching_address and conn.getpeername() != matching_address: 

2637 return False 

2638 elif matching_pattern == "configured_address": 

2639 if matching_address and conn.host != matching_address: 

2640 return False 

2641 elif matching_pattern == "notification_hash": 

2642 if ( 

2643 matching_notification_hash 

2644 and conn.maintenance_notification_hash != matching_notification_hash 

2645 ): 

2646 return False 

2647 return True 

2648 

2649 def update_connection_settings( 

2650 self, 

2651 conn: "MaintNotificationsAbstractConnection", 

2652 state: Optional["MaintenanceState"] = None, 

2653 maintenance_notification_hash: Optional[int] = None, 

2654 host_address: Optional[str] = None, 

2655 relaxed_timeout: Optional[float] = None, 

2656 update_notification_hash: bool = False, 

2657 reset_host_address: bool = False, 

2658 reset_relaxed_timeout: bool = False, 

2659 ): 

2660 """ 

2661 Update the settings for a single connection. 

2662 """ 

2663 if state: 

2664 conn.maintenance_state = state 

2665 

2666 if update_notification_hash: 

2667 # update the notification hash only if requested 

2668 conn.maintenance_notification_hash = maintenance_notification_hash 

2669 

2670 if host_address is not None: 

2671 conn.set_tmp_settings(tmp_host_address=host_address) 

2672 

2673 if relaxed_timeout is not None: 

2674 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout) 

2675 

2676 if reset_relaxed_timeout or reset_host_address: 

2677 conn.reset_tmp_settings( 

2678 reset_host_address=reset_host_address, 

2679 reset_relaxed_timeout=reset_relaxed_timeout, 

2680 ) 

2681 

2682 conn.update_current_socket_timeout(relaxed_timeout) 

2683 

2684 def update_connections_settings( 

2685 self, 

2686 state: Optional["MaintenanceState"] = None, 

2687 maintenance_notification_hash: Optional[int] = None, 

2688 host_address: Optional[str] = None, 

2689 relaxed_timeout: Optional[float] = None, 

2690 matching_address: Optional[str] = None, 

2691 matching_notification_hash: Optional[int] = None, 

2692 matching_pattern: Literal[ 

2693 "connected_address", "configured_address", "notification_hash" 

2694 ] = "connected_address", 

2695 update_notification_hash: bool = False, 

2696 reset_host_address: bool = False, 

2697 reset_relaxed_timeout: bool = False, 

2698 include_free_connections: bool = True, 

2699 ): 

2700 """ 

2701 Update the settings for all matching connections in the pool. 

2702 

2703 This method does not create new connections. 

2704 This method does not affect the connection kwargs. 

2705 

2706 :param state: The maintenance state to set for the connection. 

2707 :param maintenance_notification_hash: The hash of the maintenance notification 

2708 to set for the connection. 

2709 :param host_address: The host address to set for the connection. 

2710 :param relaxed_timeout: The relaxed timeout to set for the connection. 

2711 :param matching_address: The address to match for the connection. 

2712 :param matching_notification_hash: The notification hash to match for the connection. 

2713 :param matching_pattern: The pattern to match for the connection. 

2714 :param update_notification_hash: Whether to update the notification hash for the connection. 

2715 :param reset_host_address: Whether to reset the host address to the original address. 

2716 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout. 

2717 :param include_free_connections: Whether to include free/available connections. 

2718 """ 

2719 with self._get_pool_lock(): 

2720 for conn in self._get_in_use_connections(): 

2721 if self._should_update_connection( 

2722 conn, 

2723 matching_pattern, 

2724 matching_address, 

2725 matching_notification_hash, 

2726 ): 

2727 self.update_connection_settings( 

2728 conn, 

2729 state=state, 

2730 maintenance_notification_hash=maintenance_notification_hash, 

2731 host_address=host_address, 

2732 relaxed_timeout=relaxed_timeout, 

2733 update_notification_hash=update_notification_hash, 

2734 reset_host_address=reset_host_address, 

2735 reset_relaxed_timeout=reset_relaxed_timeout, 

2736 ) 

2737 

2738 if include_free_connections: 

2739 for conn in self._get_free_connections(): 

2740 if self._should_update_connection( 

2741 conn, 

2742 matching_pattern, 

2743 matching_address, 

2744 matching_notification_hash, 

2745 ): 

2746 self.update_connection_settings( 

2747 conn, 

2748 state=state, 

2749 maintenance_notification_hash=maintenance_notification_hash, 

2750 host_address=host_address, 

2751 relaxed_timeout=relaxed_timeout, 

2752 update_notification_hash=update_notification_hash, 

2753 reset_host_address=reset_host_address, 

2754 reset_relaxed_timeout=reset_relaxed_timeout, 

2755 ) 

2756 

2757 def update_connection_kwargs( 

2758 self, 

2759 **kwargs, 

2760 ): 

2761 """ 

2762 Update the connection kwargs for all future connections. 

2763 

2764 This method updates the connection kwargs for all future connections created by the pool. 

2765 Existing connections are not affected. 

2766 """ 

2767 self.connection_kwargs.update(kwargs) 

2768 

2769 def update_active_connections_for_reconnect( 

2770 self, 

2771 moving_address_src: Optional[str] = None, 

2772 ): 

2773 """ 

2774 Mark all active connections for reconnect. 

2775 This is used when a cluster node is migrated to a different address. 

2776 

2777 :param moving_address_src: The address of the node that is being moved. 

2778 """ 

2779 with self._get_pool_lock(): 

2780 for conn in self._get_in_use_connections(): 

2781 if self._should_update_connection( 

2782 conn, "connected_address", moving_address_src 

2783 ): 

2784 conn.mark_for_reconnect() 

2785 

2786 def disconnect_free_connections( 

2787 self, 

2788 moving_address_src: Optional[str] = None, 

2789 ): 

2790 """ 

2791 Disconnect all free/available connections. 

2792 This is used when a cluster node is migrated to a different address. 

2793 

2794 :param moving_address_src: The address of the node that is being moved. 

2795 """ 

2796 with self._get_pool_lock(): 

2797 for conn in self._get_free_connections(): 

2798 if self._should_update_connection( 

2799 conn, "connected_address", moving_address_src 

2800 ): 

2801 conn.disconnect() 

2802 

2803 

2804class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface): 

2805 """ 

2806 Create a connection pool. ``If max_connections`` is set, then this 

2807 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's 

2808 limit is reached. 

2809 

2810 By default, TCP connections are created unless ``connection_class`` 

2811 is specified. Use class:`.UnixDomainSocketConnection` for 

2812 unix sockets. 

2813 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

2814 

2815 If ``maint_notifications_config`` is provided, the connection pool will support 

2816 maintenance notifications. 

2817 Maintenance notifications are supported only with RESP3. 

2818 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3, 

2819 the maintenance notifications will be enabled by default. 

2820 

2821 Any additional keyword arguments are passed to the constructor of 

2822 ``connection_class``. 

2823 """ 

2824 

2825 @classmethod 

2826 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

2827 """ 

2828 Return a connection pool configured from the given URL. 

2829 

2830 For example:: 

2831 

2832 redis://[[username]:[password]]@localhost:6379/0 

2833 rediss://[[username]:[password]]@localhost:6379/0 

2834 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

2835 

2836 Three URL schemes are supported: 

2837 

2838 - `redis://` creates a TCP socket connection. See more at: 

2839 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

2840 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

2841 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

2842 - ``unix://``: creates a Unix Domain Socket connection. 

2843 

2844 The username, password, hostname, path and all querystring values 

2845 are passed through urllib.parse.unquote in order to replace any 

2846 percent-encoded values with their corresponding characters. 

2847 

2848 There are several ways to specify a database number. The first value 

2849 found will be used: 

2850 

2851 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

2852 2. If using the redis:// or rediss:// schemes, the path argument 

2853 of the url, e.g. redis://localhost/0 

2854 3. A ``db`` keyword argument to this function. 

2855 

2856 If none of these options are specified, the default db=0 is used. 

2857 

2858 All querystring options are cast to their appropriate Python types. 

2859 Boolean arguments can be specified with string values "True"/"False" 

2860 or "Yes"/"No". Values that cannot be properly cast cause a 

2861 ``ValueError`` to be raised. Once parsed, the querystring arguments 

2862 and keyword arguments are passed to the ``ConnectionPool``'s 

2863 class initializer. In the case of conflicting arguments, querystring 

2864 arguments always win. 

2865 """ 

2866 url_options = parse_url(url) 

2867 

2868 if "connection_class" in kwargs: 

2869 url_options["connection_class"] = kwargs["connection_class"] 

2870 

2871 kwargs.update(url_options) 

2872 return cls(**kwargs) 

2873 

2874 def __init__( 

2875 self, 

2876 connection_class=Connection, 

2877 max_connections: Optional[int] = None, 

2878 cache_factory: Optional[CacheFactoryInterface] = None, 

2879 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

2880 **connection_kwargs, 

2881 ): 

2882 max_connections = max_connections or 100 

2883 if not isinstance(max_connections, int) or max_connections < 0: 

2884 raise ValueError('"max_connections" must be a positive integer') 

2885 

2886 self.connection_class = connection_class 

2887 self._connection_kwargs = connection_kwargs 

2888 self.max_connections = max_connections 

2889 self.cache = None 

2890 self._cache_factory = cache_factory 

2891 

2892 try: 

2893 supports_maint_notifications = issubclass( 

2894 connection_class, MaintNotificationsAbstractConnection 

2895 ) 

2896 is_unix_domain_socket_connection = issubclass( 

2897 connection_class, UnixDomainSocketConnection 

2898 ) 

2899 except TypeError: 

2900 supports_maint_notifications = False 

2901 is_unix_domain_socket_connection = False 

2902 

2903 if is_unix_domain_socket_connection or not supports_maint_notifications: 

2904 if ( 

2905 maint_notifications_config 

2906 and maint_notifications_config.enabled is True 

2907 ): 

2908 raise RedisError( 

2909 "Maintenance notifications are not supported with " 

2910 f"{connection_class}" 

2911 ) 

2912 maint_notifications_config = MaintNotificationsConfig(enabled=False) 

2913 

2914 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None) 

2915 if self._event_dispatcher is None: 

2916 self._event_dispatcher = EventDispatcher() 

2917 

2918 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): 

2919 if not check_protocol_version(self._connection_kwargs.get("protocol"), 3): 

2920 raise RedisError("Client caching is only supported with RESP version 3") 

2921 

2922 cache = self._connection_kwargs.get("cache") 

2923 

2924 if cache is not None: 

2925 if not isinstance(cache, CacheInterface): 

2926 raise ValueError("Cache must implement CacheInterface") 

2927 

2928 self.cache = cache 

2929 else: 

2930 if self._cache_factory is not None: 

2931 self.cache = CacheProxy(self._cache_factory.get_cache()) 

2932 else: 

2933 self.cache = CacheFactory( 

2934 self._connection_kwargs.get("cache_config") 

2935 ).get_cache() 

2936 

2937 init_csc_items() 

2938 register_csc_items_callback( 

2939 callback=lambda: self.cache.size, 

2940 pool_name=get_pool_name(self), 

2941 ) 

2942 

2943 connection_kwargs.pop("cache", None) 

2944 connection_kwargs.pop("cache_config", None) 

2945 

2946 # a lock to protect the critical section in _checkpid(). 

2947 # this lock is acquired when the process id changes, such as 

2948 # after a fork. during this time, multiple threads in the child 

2949 # process could attempt to acquire this lock. the first thread 

2950 # to acquire the lock will reset the data structures and lock 

2951 # object of this pool. subsequent threads acquiring this lock 

2952 # will notice the first thread already did the work and simply 

2953 # release the lock. 

2954 

2955 self._fork_lock = threading.RLock() 

2956 self._lock = threading.RLock() 

2957 

2958 # Generate unique pool ID for observability (matches go-redis behavior) 

2959 import secrets 

2960 

2961 self._pool_id = secrets.token_hex(4) 

2962 

2963 MaintNotificationsAbstractConnectionPool.__init__( 

2964 self, 

2965 maint_notifications_config=maint_notifications_config, 

2966 **connection_kwargs, 

2967 ) 

2968 

2969 self.reset() 

2970 

2971 # Keys that should be redacted in __repr__ to avoid exposing sensitive information 

2972 SENSITIVE_REPR_KEYS = frozenset( 

2973 { 

2974 "password", 

2975 "username", 

2976 "ssl_password", 

2977 "credential_provider", 

2978 } 

2979 ) 

2980 

2981 def __repr__(self) -> str: 

2982 conn_kwargs = ",".join( 

2983 [ 

2984 f"{k}={'<REDACTED>' if k in self.SENSITIVE_REPR_KEYS else v}" 

2985 for k, v in self.connection_kwargs.items() 

2986 ] 

2987 ) 

2988 return ( 

2989 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

2990 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

2991 f"({conn_kwargs})>)>" 

2992 ) 

2993 

2994 @property 

2995 def connection_kwargs(self) -> Dict[str, Any]: 

2996 return self._connection_kwargs 

2997 

2998 @connection_kwargs.setter 

2999 def connection_kwargs(self, value: Dict[str, Any]): 

3000 self._connection_kwargs = value 

3001 

3002 def get_protocol(self): 

3003 """ 

3004 Returns: 

3005 The RESP protocol version, or ``None`` if the protocol is not specified, 

3006 in which case the server default will be used. 

3007 """ 

3008 return self.connection_kwargs.get("protocol", None) 

3009 

3010 def reset(self) -> None: 

3011 # Record metrics for connections being removed before clearing 

3012 # (only if attributes exist - they won't during __init__) 

3013 if hasattr(self, "_available_connections") and hasattr( 

3014 self, "_in_use_connections" 

3015 ): 

3016 with self._lock: 

3017 idle_count = len(self._available_connections) 

3018 in_use_count = len(self._in_use_connections) 

3019 if idle_count > 0 or in_use_count > 0: 

3020 pool_name = get_pool_name(self) 

3021 if idle_count > 0: 

3022 record_connection_count( 

3023 pool_name=pool_name, 

3024 connection_state=ConnectionState.IDLE, 

3025 counter=-idle_count, 

3026 ) 

3027 if in_use_count > 0: 

3028 record_connection_count( 

3029 pool_name=pool_name, 

3030 connection_state=ConnectionState.USED, 

3031 counter=-in_use_count, 

3032 ) 

3033 

3034 self._created_connections = 0 

3035 self._available_connections = [] 

3036 self._in_use_connections = set() 

3037 

3038 # this must be the last operation in this method. while reset() is 

3039 # called when holding _fork_lock, other threads in this process 

3040 # can call _checkpid() which compares self.pid and os.getpid() without 

3041 # holding any lock (for performance reasons). keeping this assignment 

3042 # as the last operation ensures that those other threads will also 

3043 # notice a pid difference and block waiting for the first thread to 

3044 # release _fork_lock. when each of these threads eventually acquire 

3045 # _fork_lock, they will notice that another thread already called 

3046 # reset() and they will immediately release _fork_lock and continue on. 

3047 self.pid = os.getpid() 

3048 

3049 def __del__(self) -> None: 

3050 """Clean up connection pool and record metrics when garbage collected.""" 

3051 try: 

3052 if not hasattr(self, "_available_connections") or not hasattr( 

3053 self, "_in_use_connections" 

3054 ): 

3055 return 

3056 # Record metrics for all connections being removed 

3057 idle_count = len(self._available_connections) 

3058 in_use_count = len(self._in_use_connections) 

3059 if idle_count > 0 or in_use_count > 0: 

3060 pool_name = get_pool_name(self) 

3061 if idle_count > 0: 

3062 record_connection_count( 

3063 pool_name=pool_name, 

3064 connection_state=ConnectionState.IDLE, 

3065 counter=-idle_count, 

3066 ) 

3067 if in_use_count > 0: 

3068 record_connection_count( 

3069 pool_name=pool_name, 

3070 connection_state=ConnectionState.USED, 

3071 counter=-in_use_count, 

3072 ) 

3073 except Exception: 

3074 pass 

3075 

3076 def _checkpid(self) -> None: 

3077 # _checkpid() attempts to keep ConnectionPool fork-safe on modern 

3078 # systems. this is called by all ConnectionPool methods that 

3079 # manipulate the pool's state such as get_connection() and release(). 

3080 # 

3081 # _checkpid() determines whether the process has forked by comparing 

3082 # the current process id to the process id saved on the ConnectionPool 

3083 # instance. if these values are the same, _checkpid() simply returns. 

3084 # 

3085 # when the process ids differ, _checkpid() assumes that the process 

3086 # has forked and that we're now running in the child process. the child 

3087 # process cannot use the parent's file descriptors (e.g., sockets). 

3088 # therefore, when _checkpid() sees the process id change, it calls 

3089 # reset() in order to reinitialize the child's ConnectionPool. this 

3090 # will cause the child to make all new connection objects. 

3091 # 

3092 # _checkpid() is protected by self._fork_lock to ensure that multiple 

3093 # threads in the child process do not call reset() multiple times. 

3094 # 

3095 # there is an extremely small chance this could fail in the following 

3096 # scenario: 

3097 # 1. process A calls _checkpid() for the first time and acquires 

3098 # self._fork_lock. 

3099 # 2. while holding self._fork_lock, process A forks (the fork() 

3100 # could happen in a different thread owned by process A) 

3101 # 3. process B (the forked child process) inherits the 

3102 # ConnectionPool's state from the parent. that state includes 

3103 # a locked _fork_lock. process B will not be notified when 

3104 # process A releases the _fork_lock and will thus never be 

3105 # able to acquire the _fork_lock. 

3106 # 

3107 # to mitigate this possible deadlock, _checkpid() will only wait 5 

3108 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in 

3109 # that time it is assumed that the child is deadlocked and a 

3110 # redis.ChildDeadlockedError error is raised. 

3111 if self.pid != os.getpid(): 

3112 acquired = self._fork_lock.acquire(timeout=5) 

3113 if not acquired: 

3114 raise ChildDeadlockedError 

3115 # reset() the instance for the new process if another thread 

3116 # hasn't already done so 

3117 try: 

3118 if self.pid != os.getpid(): 

3119 self.reset() 

3120 finally: 

3121 self._fork_lock.release() 

3122 

3123 @deprecated_args( 

3124 args_to_warn=["*"], 

3125 reason="Use get_connection() without args instead", 

3126 version="5.3.0", 

3127 ) 

3128 def get_connection(self, command_name=None, *keys, **options) -> "Connection": 

3129 "Get a connection from the pool" 

3130 

3131 # Start timing for observability 

3132 self._checkpid() 

3133 is_created = False 

3134 

3135 with self._lock: 

3136 try: 

3137 connection = self._available_connections.pop() 

3138 except IndexError: 

3139 # Start timing for observability 

3140 start_time_created = time.monotonic() 

3141 

3142 connection = self.make_connection() 

3143 is_created = True 

3144 self._in_use_connections.add(connection) 

3145 

3146 # Record state transition: IDLE -> USED 

3147 # (make_connection already recorded IDLE +1 for new connections) 

3148 # This ensures counters stay balanced if connect() fails and release() is called 

3149 pool_name = get_pool_name(self) 

3150 record_connection_count( 

3151 pool_name=pool_name, 

3152 connection_state=ConnectionState.IDLE, 

3153 counter=-1, 

3154 ) 

3155 record_connection_count( 

3156 pool_name=pool_name, 

3157 connection_state=ConnectionState.USED, 

3158 counter=1, 

3159 ) 

3160 

3161 try: 

3162 # ensure this connection is connected to Redis 

3163 connection.connect() 

3164 # connections that the pool provides should be ready to send 

3165 # a command. if not, the connection was either returned to the 

3166 # pool before all data has been read or the socket has been 

3167 # closed. either way, reconnect and verify everything is good. 

3168 try: 

3169 if ( 

3170 connection.can_read() 

3171 and self.cache is None 

3172 and not self.maint_notifications_enabled() 

3173 ): 

3174 raise ConnectionError("Connection has data") 

3175 except (ConnectionError, TimeoutError, OSError): 

3176 connection.disconnect() 

3177 connection.connect() 

3178 if connection.can_read(): 

3179 raise ConnectionError("Connection not ready") 

3180 except BaseException: 

3181 # release the connection back to the pool so that we don't 

3182 # leak it 

3183 self.release(connection) 

3184 raise 

3185 

3186 if is_created: 

3187 record_connection_create_time( 

3188 connection_pool=self, 

3189 duration_seconds=time.monotonic() - start_time_created, 

3190 ) 

3191 

3192 return connection 

3193 

3194 def get_encoder(self) -> Encoder: 

3195 "Return an encoder based on encoding settings" 

3196 kwargs = self.connection_kwargs 

3197 return Encoder( 

3198 encoding=kwargs.get("encoding", "utf-8"), 

3199 encoding_errors=kwargs.get("encoding_errors", "strict"), 

3200 decode_responses=kwargs.get("decode_responses", False), 

3201 ) 

3202 

3203 def make_connection(self) -> "ConnectionInterface": 

3204 "Create a new connection" 

3205 if self._created_connections >= self.max_connections: 

3206 raise MaxConnectionsError("Too many connections") 

3207 self._created_connections += 1 

3208 

3209 kwargs = dict(self.connection_kwargs) 

3210 

3211 # Create the connection first, then record metrics only on success 

3212 if self.cache is not None: 

3213 connection = CacheProxyConnection( 

3214 self.connection_class(**kwargs), self.cache, self._lock 

3215 ) 

3216 else: 

3217 connection = self.connection_class(**kwargs) 

3218 

3219 # Record new connection created (starts as IDLE) - only after successful construction 

3220 record_connection_count( 

3221 pool_name=get_pool_name(self), 

3222 connection_state=ConnectionState.IDLE, 

3223 counter=1, 

3224 ) 

3225 

3226 return connection 

3227 

3228 def release(self, connection: "Connection") -> None: 

3229 "Releases the connection back to the pool" 

3230 self._checkpid() 

3231 with self._lock: 

3232 try: 

3233 self._in_use_connections.remove(connection) 

3234 except KeyError: 

3235 # Gracefully fail when a connection is returned to this pool 

3236 # that the pool doesn't actually own 

3237 return 

3238 

3239 if self.owns_connection(connection): 

3240 if connection.should_reconnect(): 

3241 connection.disconnect() 

3242 self._available_connections.append(connection) 

3243 self._event_dispatcher.dispatch( 

3244 AfterConnectionReleasedEvent(connection) 

3245 ) 

3246 

3247 # Record state transition: USED -> IDLE 

3248 pool_name = get_pool_name(self) 

3249 record_connection_count( 

3250 pool_name=pool_name, 

3251 connection_state=ConnectionState.USED, 

3252 counter=-1, 

3253 ) 

3254 record_connection_count( 

3255 pool_name=pool_name, 

3256 connection_state=ConnectionState.IDLE, 

3257 counter=1, 

3258 ) 

3259 else: 

3260 # Pool doesn't own this connection, do not add it back 

3261 # to the pool. 

3262 # The created connections count should not be changed, 

3263 # because the connection was not created by the pool. 

3264 # Still need to decrement USED since it was counted in get_connection() 

3265 connection.disconnect() 

3266 record_connection_count( 

3267 pool_name="unknown_pool", 

3268 connection_state=ConnectionState.USED, 

3269 counter=-1, 

3270 ) 

3271 return 

3272 

3273 def owns_connection(self, connection: "Connection") -> int: 

3274 return connection.pid == self.pid 

3275 

3276 def disconnect(self, inuse_connections: bool = True) -> None: 

3277 """ 

3278 Disconnects connections in the pool 

3279 

3280 If ``inuse_connections`` is True, disconnect connections that are 

3281 currently in use, potentially by other threads. Otherwise only disconnect 

3282 connections that are idle in the pool. 

3283 """ 

3284 self._checkpid() 

3285 with self._lock: 

3286 if inuse_connections: 

3287 connections = chain( 

3288 self._available_connections, self._in_use_connections 

3289 ) 

3290 else: 

3291 connections = self._available_connections 

3292 

3293 for connection in connections: 

3294 connection.disconnect() 

3295 

3296 def close(self) -> None: 

3297 """Close the pool, disconnecting all connections""" 

3298 self.disconnect() 

3299 

3300 def set_retry(self, retry: Retry) -> None: 

3301 self.connection_kwargs.update({"retry": retry}) 

3302 for conn in self._available_connections: 

3303 conn.retry = retry 

3304 for conn in self._in_use_connections: 

3305 conn.retry = retry 

3306 

3307 def re_auth_callback(self, token: TokenInterface): 

3308 with self._lock: 

3309 for conn in self._available_connections: 

3310 conn.retry.call_with_retry( 

3311 lambda: conn.send_command( 

3312 "AUTH", token.try_get("oid"), token.get_value() 

3313 ), 

3314 lambda error: self._mock(error), 

3315 ) 

3316 conn.retry.call_with_retry( 

3317 lambda: conn.read_response(), lambda error: self._mock(error) 

3318 ) 

3319 for conn in self._in_use_connections: 

3320 conn.set_re_auth_token(token) 

3321 

3322 def _get_pool_lock(self): 

3323 return self._lock 

3324 

3325 def _get_free_connections(self): 

3326 with self._lock: 

3327 return list(self._available_connections) 

3328 

3329 def _get_in_use_connections(self): 

3330 with self._lock: 

3331 return set(self._in_use_connections) 

3332 

3333 def _mock(self, error: RedisError): 

3334 """ 

3335 Dummy functions, needs to be passed as error callback to retry object. 

3336 :param error: 

3337 :return: 

3338 """ 

3339 pass 

3340 

3341 def get_connection_count(self) -> List[tuple[int, dict]]: 

3342 from redis.observability.attributes import get_pool_name 

3343 

3344 attributes = AttributeBuilder.build_base_attributes() 

3345 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self) 

3346 free_connections_attributes = attributes.copy() 

3347 in_use_connections_attributes = attributes.copy() 

3348 

3349 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ( 

3350 ConnectionState.IDLE.value 

3351 ) 

3352 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ( 

3353 ConnectionState.USED.value 

3354 ) 

3355 

3356 return [ 

3357 (len(self._get_free_connections()), free_connections_attributes), 

3358 (len(self._get_in_use_connections()), in_use_connections_attributes), 

3359 ] 

3360 

3361 

3362class BlockingConnectionPool(ConnectionPool): 

3363 """ 

3364 Thread-safe blocking connection pool:: 

3365 

3366 >>> from redis.client import Redis 

3367 >>> client = Redis(connection_pool=BlockingConnectionPool()) 

3368 

3369 It performs the same function as the default 

3370 :py:class:`~redis.ConnectionPool` implementation, in that, 

3371 it maintains a pool of reusable connections that can be shared by 

3372 multiple redis clients (safely across threads if required). 

3373 

3374 The difference is that, in the event that a client tries to get a 

3375 connection from the pool when all of connections are in use, rather than 

3376 raising a :py:class:`~redis.ConnectionError` (as the default 

3377 :py:class:`~redis.ConnectionPool` implementation does), it 

3378 makes the client wait ("blocks") for a specified number of seconds until 

3379 a connection becomes available. 

3380 

3381 Use ``max_connections`` to increase / decrease the pool size:: 

3382 

3383 >>> pool = BlockingConnectionPool(max_connections=10) 

3384 

3385 Use ``timeout`` to tell it either how many seconds to wait for a connection 

3386 to become available, or to block forever: 

3387 

3388 >>> # Block forever. 

3389 >>> pool = BlockingConnectionPool(timeout=None) 

3390 

3391 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

3392 >>> # not available. 

3393 >>> pool = BlockingConnectionPool(timeout=5) 

3394 """ 

3395 

3396 def __init__( 

3397 self, 

3398 max_connections=50, 

3399 timeout=20, 

3400 connection_class=Connection, 

3401 queue_class=LifoQueue, 

3402 **connection_kwargs, 

3403 ): 

3404 self.queue_class = queue_class 

3405 self.timeout = timeout 

3406 self._in_maintenance = False 

3407 self._locked = False 

3408 super().__init__( 

3409 connection_class=connection_class, 

3410 max_connections=max_connections, 

3411 **connection_kwargs, 

3412 ) 

3413 

3414 def reset(self): 

3415 # Create and fill up a thread safe queue with ``None`` values. 

3416 try: 

3417 if self._in_maintenance: 

3418 self._lock.acquire() 

3419 self._locked = True 

3420 

3421 # Record metrics for connections being removed before clearing 

3422 # Note: Access pool.queue directly to avoid deadlock since we may 

3423 # already hold self._lock (which is non-reentrant) 

3424 if ( 

3425 hasattr(self, "_connections") 

3426 and self._connections 

3427 and hasattr(self, "pool") 

3428 ): 

3429 with self._lock: 

3430 connections_in_queue = {conn for conn in self.pool.queue if conn} 

3431 idle_count = len(connections_in_queue) 

3432 in_use_count = len(self._connections) - idle_count 

3433 if idle_count > 0 or in_use_count > 0: 

3434 pool_name = get_pool_name(self) 

3435 if idle_count > 0: 

3436 record_connection_count( 

3437 pool_name=pool_name, 

3438 connection_state=ConnectionState.IDLE, 

3439 counter=-idle_count, 

3440 ) 

3441 if in_use_count > 0: 

3442 record_connection_count( 

3443 pool_name=pool_name, 

3444 connection_state=ConnectionState.USED, 

3445 counter=-in_use_count, 

3446 ) 

3447 

3448 self.pool = self.queue_class(self.max_connections) 

3449 while True: 

3450 try: 

3451 self.pool.put_nowait(None) 

3452 except Full: 

3453 break 

3454 

3455 # Keep a list of actual connection instances so that we can 

3456 # disconnect them later. 

3457 self._connections = [] 

3458 finally: 

3459 if self._locked: 

3460 try: 

3461 self._lock.release() 

3462 except Exception: 

3463 pass 

3464 self._locked = False 

3465 

3466 # this must be the last operation in this method. while reset() is 

3467 # called when holding _fork_lock, other threads in this process 

3468 # can call _checkpid() which compares self.pid and os.getpid() without 

3469 # holding any lock (for performance reasons). keeping this assignment 

3470 # as the last operation ensures that those other threads will also 

3471 # notice a pid difference and block waiting for the first thread to 

3472 # release _fork_lock. when each of these threads eventually acquire 

3473 # _fork_lock, they will notice that another thread already called 

3474 # reset() and they will immediately release _fork_lock and continue on. 

3475 self.pid = os.getpid() 

3476 

3477 def __del__(self) -> None: 

3478 """Clean up connection pool and record metrics when garbage collected.""" 

3479 try: 

3480 # Note: Access pool.queue directly to avoid potential deadlock 

3481 # if GC runs while the lock is held by the same thread 

3482 if ( 

3483 hasattr(self, "_connections") 

3484 and self._connections 

3485 and hasattr(self, "pool") 

3486 ): 

3487 connections_in_queue = {conn for conn in self.pool.queue if conn} 

3488 idle_count = len(connections_in_queue) 

3489 in_use_count = len(self._connections) - idle_count 

3490 if idle_count > 0 or in_use_count > 0: 

3491 pool_name = get_pool_name(self) 

3492 if idle_count > 0: 

3493 record_connection_count( 

3494 pool_name=pool_name, 

3495 connection_state=ConnectionState.IDLE, 

3496 counter=-idle_count, 

3497 ) 

3498 if in_use_count > 0: 

3499 record_connection_count( 

3500 pool_name=pool_name, 

3501 connection_state=ConnectionState.USED, 

3502 counter=-in_use_count, 

3503 ) 

3504 except Exception: 

3505 pass 

3506 

3507 def make_connection(self): 

3508 "Make a fresh connection." 

3509 try: 

3510 if self._in_maintenance: 

3511 self._lock.acquire() 

3512 self._locked = True 

3513 

3514 if self.cache is not None: 

3515 connection = CacheProxyConnection( 

3516 self.connection_class(**self.connection_kwargs), 

3517 self.cache, 

3518 self._lock, 

3519 ) 

3520 else: 

3521 connection = self.connection_class(**self.connection_kwargs) 

3522 self._connections.append(connection) 

3523 

3524 # Record new connection created (starts as IDLE) 

3525 record_connection_count( 

3526 pool_name=get_pool_name(self), 

3527 connection_state=ConnectionState.IDLE, 

3528 counter=1, 

3529 ) 

3530 

3531 return connection 

3532 finally: 

3533 if self._locked: 

3534 try: 

3535 self._lock.release() 

3536 except Exception: 

3537 pass 

3538 self._locked = False 

3539 

3540 @deprecated_args( 

3541 args_to_warn=["*"], 

3542 reason="Use get_connection() without args instead", 

3543 version="5.3.0", 

3544 ) 

3545 def get_connection(self, command_name=None, *keys, **options): 

3546 """ 

3547 Get a connection, blocking for ``self.timeout`` until a connection 

3548 is available from the pool. 

3549 

3550 If the connection returned is ``None`` then creates a new connection. 

3551 Because we use a last-in first-out queue, the existing connections 

3552 (having been returned to the pool after the initial ``None`` values 

3553 were added) will be returned before ``None`` values. This means we only 

3554 create new connections when we need to, i.e.: the actual number of 

3555 connections will only increase in response to demand. 

3556 """ 

3557 start_time_acquired = time.monotonic() 

3558 # Make sure we haven't changed process. 

3559 self._checkpid() 

3560 is_created = False 

3561 

3562 # Try and get a connection from the pool. If one isn't available within 

3563 # self.timeout then raise a ``ConnectionError``. 

3564 connection = None 

3565 try: 

3566 if self._in_maintenance: 

3567 self._lock.acquire() 

3568 self._locked = True 

3569 try: 

3570 connection = self.pool.get(block=True, timeout=self.timeout) 

3571 except Empty: 

3572 # Note that this is not caught by the redis client and will be 

3573 # raised unless handled by application code. If you want never to 

3574 raise ConnectionError("No connection available.") 

3575 

3576 # If the ``connection`` is actually ``None`` then that's a cue to make 

3577 # a new connection to add to the pool. 

3578 if connection is None: 

3579 # Start timing for observability 

3580 start_time_created = time.monotonic() 

3581 connection = self.make_connection() 

3582 is_created = True 

3583 finally: 

3584 if self._locked: 

3585 try: 

3586 self._lock.release() 

3587 except Exception: 

3588 pass 

3589 self._locked = False 

3590 

3591 # Record state transition: IDLE -> USED 

3592 # (make_connection already recorded IDLE +1 for new connections) 

3593 # This ensures counters stay balanced if connect() fails and release() is called 

3594 pool_name = get_pool_name(self) 

3595 record_connection_count( 

3596 pool_name=pool_name, 

3597 connection_state=ConnectionState.IDLE, 

3598 counter=-1, 

3599 ) 

3600 record_connection_count( 

3601 pool_name=pool_name, 

3602 connection_state=ConnectionState.USED, 

3603 counter=1, 

3604 ) 

3605 

3606 try: 

3607 # ensure this connection is connected to Redis 

3608 connection.connect() 

3609 # connections that the pool provides should be ready to send 

3610 # a command. if not, the connection was either returned to the 

3611 # pool before all data has been read or the socket has been 

3612 # closed. either way, reconnect and verify everything is good. 

3613 try: 

3614 if connection.can_read(): 

3615 raise ConnectionError("Connection has data") 

3616 except (ConnectionError, TimeoutError, OSError): 

3617 connection.disconnect() 

3618 connection.connect() 

3619 if connection.can_read(): 

3620 raise ConnectionError("Connection not ready") 

3621 except BaseException: 

3622 # release the connection back to the pool so that we don't leak it 

3623 self.release(connection) 

3624 raise 

3625 

3626 if is_created: 

3627 record_connection_create_time( 

3628 connection_pool=self, 

3629 duration_seconds=time.monotonic() - start_time_created, 

3630 ) 

3631 

3632 record_connection_wait_time( 

3633 pool_name=pool_name, 

3634 duration_seconds=time.monotonic() - start_time_acquired, 

3635 ) 

3636 

3637 return connection 

3638 

3639 def release(self, connection): 

3640 "Releases the connection back to the pool." 

3641 # Make sure we haven't changed process. 

3642 self._checkpid() 

3643 

3644 try: 

3645 if self._in_maintenance: 

3646 self._lock.acquire() 

3647 self._locked = True 

3648 if not self.owns_connection(connection): 

3649 # pool doesn't own this connection. do not add it back 

3650 # to the pool. instead add a None value which is a placeholder 

3651 # that will cause the pool to recreate the connection if 

3652 # its needed. 

3653 connection.disconnect() 

3654 self.pool.put_nowait(None) 

3655 # Still need to decrement USED since it was counted in get_connection() 

3656 record_connection_count( 

3657 pool_name="unknown_pool", 

3658 connection_state=ConnectionState.USED, 

3659 counter=-1, 

3660 ) 

3661 return 

3662 if connection.should_reconnect(): 

3663 connection.disconnect() 

3664 # Put the connection back into the pool. 

3665 pool_name = get_pool_name(self) 

3666 try: 

3667 self.pool.put_nowait(connection) 

3668 

3669 # Record state transition: USED -> IDLE 

3670 record_connection_count( 

3671 pool_name=pool_name, 

3672 connection_state=ConnectionState.USED, 

3673 counter=-1, 

3674 ) 

3675 record_connection_count( 

3676 pool_name=pool_name, 

3677 connection_state=ConnectionState.IDLE, 

3678 counter=1, 

3679 ) 

3680 except Full: 

3681 pass 

3682 finally: 

3683 if self._locked: 

3684 try: 

3685 self._lock.release() 

3686 except Exception: 

3687 pass 

3688 self._locked = False 

3689 

3690 def disconnect(self, inuse_connections: bool = True): 

3691 """ 

3692 Disconnects either all connections in the pool or just the free connections. 

3693 """ 

3694 self._checkpid() 

3695 try: 

3696 if self._in_maintenance: 

3697 self._lock.acquire() 

3698 self._locked = True 

3699 

3700 if inuse_connections: 

3701 connections = self._connections 

3702 else: 

3703 connections = self._get_free_connections() 

3704 

3705 for connection in connections: 

3706 connection.disconnect() 

3707 finally: 

3708 if self._locked: 

3709 try: 

3710 self._lock.release() 

3711 except Exception: 

3712 pass 

3713 self._locked = False 

3714 

3715 def _get_free_connections(self): 

3716 with self._lock: 

3717 return {conn for conn in self.pool.queue if conn} 

3718 

3719 def _get_in_use_connections(self): 

3720 with self._lock: 

3721 # free connections 

3722 connections_in_queue = {conn for conn in self.pool.queue if conn} 

3723 # in self._connections we keep all created connections 

3724 # so the ones that are not in the queue are the in use ones 

3725 return { 

3726 conn for conn in self._connections if conn not in connections_in_queue 

3727 } 

3728 

3729 def set_in_maintenance(self, in_maintenance: bool): 

3730 """ 

3731 Sets a flag that this Blocking ConnectionPool is in maintenance mode. 

3732 

3733 This is used to prevent new connections from being created while we are in maintenance mode. 

3734 The pool will be in maintenance mode only when we are processing a MOVING notification. 

3735 """ 

3736 self._in_maintenance = in_maintenance