Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1540 statements  

1import copy 

2import os 

3import socket 

4import sys 

5import threading 

6import time 

7import weakref 

8from abc import ABC, abstractmethod 

9from itertools import chain 

10from queue import Empty, Full, LifoQueue 

11from typing import ( 

12 Any, 

13 Callable, 

14 Dict, 

15 Iterable, 

16 List, 

17 Literal, 

18 Optional, 

19 Type, 

20 TypeVar, 

21 Union, 

22) 

23from urllib.parse import parse_qs, unquote, urlparse 

24 

25from redis.cache import ( 

26 CacheEntry, 

27 CacheEntryStatus, 

28 CacheFactory, 

29 CacheFactoryInterface, 

30 CacheInterface, 

31 CacheKey, 

32 CacheProxy, 

33) 

34 

35from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser 

36from .auth.token import TokenInterface 

37from .backoff import NoBackoff 

38from .credentials import CredentialProvider, UsernamePasswordCredentialProvider 

39from .driver_info import DriverInfo, resolve_driver_info 

40from .event import AfterConnectionReleasedEvent, EventDispatcher 

41from .exceptions import ( 

42 AuthenticationError, 

43 AuthenticationWrongNumberOfArgsError, 

44 ChildDeadlockedError, 

45 ConnectionError, 

46 DataError, 

47 MaxConnectionsError, 

48 RedisError, 

49 ResponseError, 

50 TimeoutError, 

51) 

52from .maint_notifications import ( 

53 MaintenanceState, 

54 MaintNotificationsConfig, 

55 MaintNotificationsConnectionHandler, 

56 MaintNotificationsPoolHandler, 

57 OSSMaintNotificationsHandler, 

58) 

59from .observability.attributes import ( 

60 DB_CLIENT_CONNECTION_POOL_NAME, 

61 DB_CLIENT_CONNECTION_STATE, 

62 AttributeBuilder, 

63 ConnectionState, 

64 CSCReason, 

65 CSCResult, 

66 get_pool_name, 

67) 

68from .observability.metrics import CloseReason 

69from .observability.recorder import ( 

70 init_csc_items, 

71 record_connection_closed, 

72 record_connection_create_time, 

73 record_connection_wait_time, 

74 record_csc_eviction, 

75 record_csc_network_saved, 

76 record_csc_request, 

77 record_error_count, 

78 register_csc_items_callback, 

79) 

80from .retry import Retry 

81from .utils import ( 

82 CRYPTOGRAPHY_AVAILABLE, 

83 HIREDIS_AVAILABLE, 

84 SSL_AVAILABLE, 

85 check_protocol_version, 

86 compare_versions, 

87 deprecated_args, 

88 ensure_string, 

89 format_error_message, 

90 str_if_bytes, 

91) 

92 

93if SSL_AVAILABLE: 

94 import ssl 

95 from ssl import VerifyFlags 

96else: 

97 ssl = None 

98 VerifyFlags = None 

99 

100if HIREDIS_AVAILABLE: 

101 import hiredis 

102 

103SYM_STAR = b"*" 

104SYM_DOLLAR = b"$" 

105SYM_CRLF = b"\r\n" 

106SYM_EMPTY = b"" 

107 

108DEFAULT_RESP_VERSION = 2 

109 

110SENTINEL = object() 

111 

112DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] 

113if HIREDIS_AVAILABLE: 

114 DefaultParser = _HiredisParser 

115else: 

116 DefaultParser = _RESP2Parser 

117 

118 

119class HiredisRespSerializer: 

120 def pack(self, *args: List): 

121 """Pack a series of arguments into the Redis protocol""" 

122 output = [] 

123 

124 if isinstance(args[0], str): 

125 args = tuple(args[0].encode().split()) + args[1:] 

126 elif b" " in args[0]: 

127 args = tuple(args[0].split()) + args[1:] 

128 try: 

129 output.append(hiredis.pack_command(args)) 

130 except TypeError: 

131 _, value, traceback = sys.exc_info() 

132 raise DataError(value).with_traceback(traceback) 

133 

134 return output 

135 

136 

137class PythonRespSerializer: 

138 def __init__(self, buffer_cutoff, encode) -> None: 

139 self._buffer_cutoff = buffer_cutoff 

140 self.encode = encode 

141 

142 def pack(self, *args): 

143 """Pack a series of arguments into the Redis protocol""" 

144 output = [] 

145 # the client might have included 1 or more literal arguments in 

146 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

147 # arguments to be sent separately, so split the first argument 

148 # manually. These arguments should be bytestrings so that they are 

149 # not encoded. 

150 if isinstance(args[0], str): 

151 args = tuple(args[0].encode().split()) + args[1:] 

152 elif b" " in args[0]: 

153 args = tuple(args[0].split()) + args[1:] 

154 

155 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

156 

157 buffer_cutoff = self._buffer_cutoff 

158 for arg in map(self.encode, args): 

159 # to avoid large string mallocs, chunk the command into the 

160 # output list if we're sending large values or memoryviews 

161 arg_length = len(arg) 

162 if ( 

163 len(buff) > buffer_cutoff 

164 or arg_length > buffer_cutoff 

165 or isinstance(arg, memoryview) 

166 ): 

167 buff = SYM_EMPTY.join( 

168 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

169 ) 

170 output.append(buff) 

171 output.append(arg) 

172 buff = SYM_CRLF 

173 else: 

174 buff = SYM_EMPTY.join( 

175 ( 

176 buff, 

177 SYM_DOLLAR, 

178 str(arg_length).encode(), 

179 SYM_CRLF, 

180 arg, 

181 SYM_CRLF, 

182 ) 

183 ) 

184 output.append(buff) 

185 return output 

186 

187 

188class ConnectionInterface: 

189 @abstractmethod 

190 def repr_pieces(self): 

191 pass 

192 

193 @abstractmethod 

194 def register_connect_callback(self, callback): 

195 pass 

196 

197 @abstractmethod 

198 def deregister_connect_callback(self, callback): 

199 pass 

200 

201 @abstractmethod 

202 def set_parser(self, parser_class): 

203 pass 

204 

205 @abstractmethod 

206 def get_protocol(self): 

207 pass 

208 

209 @abstractmethod 

210 def connect(self): 

211 pass 

212 

213 @abstractmethod 

214 def on_connect(self): 

215 pass 

216 

217 @abstractmethod 

218 def disconnect(self, *args, **kwargs): 

219 pass 

220 

221 @abstractmethod 

222 def check_health(self): 

223 pass 

224 

225 @abstractmethod 

226 def send_packed_command(self, command, check_health=True): 

227 pass 

228 

229 @abstractmethod 

230 def send_command(self, *args, **kwargs): 

231 pass 

232 

233 @abstractmethod 

234 def can_read(self, timeout=0): 

235 pass 

236 

237 @abstractmethod 

238 def read_response( 

239 self, 

240 disable_decoding=False, 

241 *, 

242 disconnect_on_error=True, 

243 push_request=False, 

244 ): 

245 pass 

246 

247 @abstractmethod 

248 def pack_command(self, *args): 

249 pass 

250 

251 @abstractmethod 

252 def pack_commands(self, commands): 

253 pass 

254 

255 @property 

256 @abstractmethod 

257 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

258 pass 

259 

260 @abstractmethod 

261 def set_re_auth_token(self, token: TokenInterface): 

262 pass 

263 

264 @abstractmethod 

265 def re_auth(self): 

266 pass 

267 

268 @abstractmethod 

269 def mark_for_reconnect(self): 

270 """ 

271 Mark the connection to be reconnected on the next command. 

272 This is useful when a connection is moved to a different node. 

273 """ 

274 pass 

275 

276 @abstractmethod 

277 def should_reconnect(self): 

278 """ 

279 Returns True if the connection should be reconnected. 

280 """ 

281 pass 

282 

283 @abstractmethod 

284 def reset_should_reconnect(self): 

285 """ 

286 Reset the internal flag to False. 

287 """ 

288 pass 

289 

290 @abstractmethod 

291 def extract_connection_details(self) -> str: 

292 pass 

293 

294 

295class MaintNotificationsAbstractConnection: 

296 """ 

297 Abstract class for handling maintenance notifications logic. 

298 This class is expected to be used as base class together with ConnectionInterface. 

299 

300 This class is intended to be used with multiple inheritance! 

301 

302 All logic related to maintenance notifications is encapsulated in this class. 

303 """ 

304 

305 def __init__( 

306 self, 

307 maint_notifications_config: Optional[MaintNotificationsConfig], 

308 maint_notifications_pool_handler: Optional[ 

309 MaintNotificationsPoolHandler 

310 ] = None, 

311 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

312 maintenance_notification_hash: Optional[int] = None, 

313 orig_host_address: Optional[str] = None, 

314 orig_socket_timeout: Optional[float] = None, 

315 orig_socket_connect_timeout: Optional[float] = None, 

316 oss_cluster_maint_notifications_handler: Optional[ 

317 OSSMaintNotificationsHandler 

318 ] = None, 

319 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, 

320 event_dispatcher: Optional[EventDispatcher] = None, 

321 ): 

322 """ 

323 Initialize the maintenance notifications for the connection. 

324 

325 Args: 

326 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications. 

327 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications. 

328 maintenance_state (MaintenanceState): The current maintenance state of the connection. 

329 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection. 

330 orig_host_address (Optional[str]): The original host address of the connection. 

331 orig_socket_timeout (Optional[float]): The original socket timeout of the connection. 

332 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection. 

333 oss_cluster_maint_notifications_handler (Optional[OSSMaintNotificationsHandler]): The OSS cluster handler for maintenance notifications. 

334 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications. 

335 If not provided, the parser from the connection is used. 

336 This is useful when the parser is created after this object. 

337 """ 

338 self.maint_notifications_config = maint_notifications_config 

339 self.maintenance_state = maintenance_state 

340 self.maintenance_notification_hash = maintenance_notification_hash 

341 

342 if event_dispatcher is not None: 

343 self.event_dispatcher = event_dispatcher 

344 else: 

345 self.event_dispatcher = EventDispatcher() 

346 

347 self._configure_maintenance_notifications( 

348 maint_notifications_pool_handler, 

349 orig_host_address, 

350 orig_socket_timeout, 

351 orig_socket_connect_timeout, 

352 oss_cluster_maint_notifications_handler, 

353 parser, 

354 ) 

355 self._processed_start_maint_notifications = set() 

356 self._skipped_end_maint_notifications = set() 

357 

358 @abstractmethod 

359 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]: 

360 pass 

361 

362 @abstractmethod 

363 def _get_socket(self) -> Optional[socket.socket]: 

364 pass 

365 

366 @abstractmethod 

367 def get_protocol(self) -> Union[int, str]: 

368 """ 

369 Returns: 

370 The RESP protocol version, or ``None`` if the protocol is not specified, 

371 in which case the server default will be used. 

372 """ 

373 pass 

374 

375 @property 

376 @abstractmethod 

377 def host(self) -> str: 

378 pass 

379 

380 @host.setter 

381 @abstractmethod 

382 def host(self, value: str): 

383 pass 

384 

385 @property 

386 @abstractmethod 

387 def socket_timeout(self) -> Optional[Union[float, int]]: 

388 pass 

389 

390 @socket_timeout.setter 

391 @abstractmethod 

392 def socket_timeout(self, value: Optional[Union[float, int]]): 

393 pass 

394 

395 @property 

396 @abstractmethod 

397 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

398 pass 

399 

400 @socket_connect_timeout.setter 

401 @abstractmethod 

402 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

403 pass 

404 

405 @abstractmethod 

406 def send_command(self, *args, **kwargs): 

407 pass 

408 

409 @abstractmethod 

410 def read_response( 

411 self, 

412 disable_decoding=False, 

413 *, 

414 disconnect_on_error=True, 

415 push_request=False, 

416 ): 

417 pass 

418 

419 @abstractmethod 

420 def disconnect(self, *args, **kwargs): 

421 pass 

422 

423 def _configure_maintenance_notifications( 

424 self, 

425 maint_notifications_pool_handler: Optional[ 

426 MaintNotificationsPoolHandler 

427 ] = None, 

428 orig_host_address=None, 

429 orig_socket_timeout=None, 

430 orig_socket_connect_timeout=None, 

431 oss_cluster_maint_notifications_handler: Optional[ 

432 OSSMaintNotificationsHandler 

433 ] = None, 

434 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, 

435 ): 

436 """ 

437 Enable maintenance notifications by setting up 

438 handlers and storing original connection parameters. 

439 

440 Should be used ONLY with parsers that support push notifications. 

441 """ 

442 if ( 

443 not self.maint_notifications_config 

444 or not self.maint_notifications_config.enabled 

445 ): 

446 self._maint_notifications_pool_handler = None 

447 self._maint_notifications_connection_handler = None 

448 self._oss_cluster_maint_notifications_handler = None 

449 return 

450 

451 if not parser: 

452 raise RedisError( 

453 "To configure maintenance notifications, a parser must be provided!" 

454 ) 

455 

456 if not isinstance(parser, _HiredisParser) and not isinstance( 

457 parser, _RESP3Parser 

458 ): 

459 raise RedisError( 

460 "Maintenance notifications are only supported with hiredis and RESP3 parsers!" 

461 ) 

462 

463 if maint_notifications_pool_handler: 

464 # Extract a reference to a new pool handler that copies all properties 

465 # of the original one and has a different connection reference 

466 # This is needed because when we attach the handler to the parser 

467 # we need to make sure that the handler has a reference to the 

468 # connection that the parser is attached to. 

469 self._maint_notifications_pool_handler = ( 

470 maint_notifications_pool_handler.get_handler_for_connection() 

471 ) 

472 self._maint_notifications_pool_handler.set_connection(self) 

473 else: 

474 self._maint_notifications_pool_handler = None 

475 

476 self._maint_notifications_connection_handler = ( 

477 MaintNotificationsConnectionHandler(self, self.maint_notifications_config) 

478 ) 

479 

480 if oss_cluster_maint_notifications_handler: 

481 self._oss_cluster_maint_notifications_handler = ( 

482 oss_cluster_maint_notifications_handler 

483 ) 

484 else: 

485 self._oss_cluster_maint_notifications_handler = None 

486 

487 # Set up OSS cluster handler to parser if available 

488 if self._oss_cluster_maint_notifications_handler: 

489 parser.set_oss_cluster_maint_push_handler( 

490 self._oss_cluster_maint_notifications_handler.handle_notification 

491 ) 

492 

493 # Set up pool handler to parser if available 

494 if self._maint_notifications_pool_handler: 

495 parser.set_node_moving_push_handler( 

496 self._maint_notifications_pool_handler.handle_notification 

497 ) 

498 

499 # Set up connection handler 

500 parser.set_maintenance_push_handler( 

501 self._maint_notifications_connection_handler.handle_notification 

502 ) 

503 

504 # Store original connection parameters 

505 self.orig_host_address = orig_host_address if orig_host_address else self.host 

506 self.orig_socket_timeout = ( 

507 orig_socket_timeout if orig_socket_timeout else self.socket_timeout 

508 ) 

509 self.orig_socket_connect_timeout = ( 

510 orig_socket_connect_timeout 

511 if orig_socket_connect_timeout 

512 else self.socket_connect_timeout 

513 ) 

514 

515 def set_maint_notifications_pool_handler_for_connection( 

516 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

517 ): 

518 # Deep copy the pool handler to avoid sharing the same pool handler 

519 # between multiple connections, because otherwise each connection will override 

520 # the connection reference and the pool handler will only hold a reference 

521 # to the last connection that was set. 

522 maint_notifications_pool_handler_copy = ( 

523 maint_notifications_pool_handler.get_handler_for_connection() 

524 ) 

525 

526 maint_notifications_pool_handler_copy.set_connection(self) 

527 self._get_parser().set_node_moving_push_handler( 

528 maint_notifications_pool_handler_copy.handle_notification 

529 ) 

530 

531 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy 

532 

533 # Update maintenance notification connection handler if it doesn't exist 

534 if not self._maint_notifications_connection_handler: 

535 self._maint_notifications_connection_handler = ( 

536 MaintNotificationsConnectionHandler( 

537 self, maint_notifications_pool_handler.config 

538 ) 

539 ) 

540 self._get_parser().set_maintenance_push_handler( 

541 self._maint_notifications_connection_handler.handle_notification 

542 ) 

543 else: 

544 self._maint_notifications_connection_handler.config = ( 

545 maint_notifications_pool_handler.config 

546 ) 

547 

548 def set_maint_notifications_cluster_handler_for_connection( 

549 self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler 

550 ): 

551 self._get_parser().set_oss_cluster_maint_push_handler( 

552 oss_cluster_maint_notifications_handler.handle_notification 

553 ) 

554 

555 self._oss_cluster_maint_notifications_handler = ( 

556 oss_cluster_maint_notifications_handler 

557 ) 

558 

559 # Update maintenance notification connection handler if it doesn't exist 

560 if not self._maint_notifications_connection_handler: 

561 self._maint_notifications_connection_handler = ( 

562 MaintNotificationsConnectionHandler( 

563 self, oss_cluster_maint_notifications_handler.config 

564 ) 

565 ) 

566 self._get_parser().set_maintenance_push_handler( 

567 self._maint_notifications_connection_handler.handle_notification 

568 ) 

569 else: 

570 self._maint_notifications_connection_handler.config = ( 

571 oss_cluster_maint_notifications_handler.config 

572 ) 

573 

574 def activate_maint_notifications_handling_if_enabled(self, check_health=True): 

575 # Send maintenance notifications handshake if RESP3 is active 

576 # and maintenance notifications are enabled 

577 # and we have a host to determine the endpoint type from 

578 # When the maint_notifications_config enabled mode is "auto", 

579 # we just log a warning if the handshake fails 

580 # When the mode is enabled=True, we raise an exception in case of failure 

581 if ( 

582 self.get_protocol() not in [2, "2"] 

583 and self.maint_notifications_config 

584 and self.maint_notifications_config.enabled 

585 and self._maint_notifications_connection_handler 

586 and hasattr(self, "host") 

587 ): 

588 self._enable_maintenance_notifications( 

589 maint_notifications_config=self.maint_notifications_config, 

590 check_health=check_health, 

591 ) 

592 

593 def _enable_maintenance_notifications( 

594 self, maint_notifications_config: MaintNotificationsConfig, check_health=True 

595 ): 

596 try: 

597 host = getattr(self, "host", None) 

598 if host is None: 

599 raise ValueError( 

600 "Cannot enable maintenance notifications for connection" 

601 " object that doesn't have a host attribute." 

602 ) 

603 else: 

604 endpoint_type = maint_notifications_config.get_endpoint_type(host, self) 

605 self.send_command( 

606 "CLIENT", 

607 "MAINT_NOTIFICATIONS", 

608 "ON", 

609 "moving-endpoint-type", 

610 endpoint_type.value, 

611 check_health=check_health, 

612 ) 

613 response = self.read_response() 

614 if not response or str_if_bytes(response) != "OK": 

615 raise ResponseError( 

616 "The server doesn't support maintenance notifications" 

617 ) 

618 except Exception as e: 

619 if ( 

620 isinstance(e, ResponseError) 

621 and maint_notifications_config.enabled == "auto" 

622 ): 

623 # Log warning but don't fail the connection 

624 import logging 

625 

626 logger = logging.getLogger(__name__) 

627 logger.debug(f"Failed to enable maintenance notifications: {e}") 

628 else: 

629 raise 

630 

631 def get_resolved_ip(self) -> Optional[str]: 

632 """ 

633 Extract the resolved IP address from an 

634 established connection or resolve it from the host. 

635 

636 First tries to get the actual IP from the socket (most accurate), 

637 then falls back to DNS resolution if needed. 

638 

639 Args: 

640 connection: The connection object to extract the IP from 

641 

642 Returns: 

643 str: The resolved IP address, or None if it cannot be determined 

644 """ 

645 

646 # Method 1: Try to get the actual IP from the established socket connection 

647 # This is most accurate as it shows the exact IP being used 

648 try: 

649 conn_socket = self._get_socket() 

650 if conn_socket is not None: 

651 peer_addr = conn_socket.getpeername() 

652 if peer_addr and len(peer_addr) >= 1: 

653 # For TCP sockets, peer_addr is typically (host, port) tuple 

654 # Return just the host part 

655 return peer_addr[0] 

656 except (AttributeError, OSError): 

657 # Socket might not be connected or getpeername() might fail 

658 pass 

659 

660 # Method 2: Fallback to DNS resolution of the host 

661 # This is less accurate but works when socket is not available 

662 try: 

663 host = getattr(self, "host", "localhost") 

664 port = getattr(self, "port", 6379) 

665 if host: 

666 # Use getaddrinfo to resolve the hostname to IP 

667 # This mimics what the connection would do during _connect() 

668 addr_info = socket.getaddrinfo( 

669 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM 

670 ) 

671 if addr_info: 

672 # Return the IP from the first result 

673 # addr_info[0] is (family, socktype, proto, canonname, sockaddr) 

674 # sockaddr[0] is the IP address 

675 return str(addr_info[0][4][0]) 

676 except (AttributeError, OSError, socket.gaierror): 

677 # DNS resolution might fail 

678 pass 

679 

680 return None 

681 

682 @property 

683 def maintenance_state(self) -> MaintenanceState: 

684 return self._maintenance_state 

685 

686 @maintenance_state.setter 

687 def maintenance_state(self, state: "MaintenanceState"): 

688 self._maintenance_state = state 

689 

690 def add_maint_start_notification(self, id: int): 

691 self._processed_start_maint_notifications.add(id) 

692 

693 def get_processed_start_notifications(self) -> set: 

694 return self._processed_start_maint_notifications 

695 

696 def add_skipped_end_notification(self, id: int): 

697 self._skipped_end_maint_notifications.add(id) 

698 

699 def get_skipped_end_notifications(self) -> set: 

700 return self._skipped_end_maint_notifications 

701 

702 def reset_received_notifications(self): 

703 self._processed_start_maint_notifications.clear() 

704 self._skipped_end_maint_notifications.clear() 

705 

706 def getpeername(self): 

707 """ 

708 Returns the peer name of the connection. 

709 """ 

710 conn_socket = self._get_socket() 

711 if conn_socket: 

712 return conn_socket.getpeername()[0] 

713 return None 

714 

715 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

716 conn_socket = self._get_socket() 

717 if conn_socket: 

718 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout 

719 # if the current timeout is 0 it means we are in the middle of a can_read call 

720 # in this case we don't want to change the timeout because the operation 

721 # is non-blocking and should return immediately 

722 # Changing the state from non-blocking to blocking in the middle of a read operation 

723 # will lead to a deadlock 

724 if conn_socket.gettimeout() != 0: 

725 conn_socket.settimeout(timeout) 

726 self.update_parser_timeout(timeout) 

727 

728 def update_parser_timeout(self, timeout: Optional[float] = None): 

729 parser = self._get_parser() 

730 if parser and parser._buffer: 

731 if isinstance(parser, _RESP3Parser) and timeout: 

732 parser._buffer.socket_timeout = timeout 

733 elif isinstance(parser, _HiredisParser): 

734 parser._socket_timeout = timeout 

735 

736 def set_tmp_settings( 

737 self, 

738 tmp_host_address: Optional[Union[str, object]] = SENTINEL, 

739 tmp_relaxed_timeout: Optional[float] = None, 

740 ): 

741 """ 

742 The value of SENTINEL is used to indicate that the property should not be updated. 

743 """ 

744 if tmp_host_address and tmp_host_address != SENTINEL: 

745 self.host = str(tmp_host_address) 

746 if tmp_relaxed_timeout != -1: 

747 self.socket_timeout = tmp_relaxed_timeout 

748 self.socket_connect_timeout = tmp_relaxed_timeout 

749 

750 def reset_tmp_settings( 

751 self, 

752 reset_host_address: bool = False, 

753 reset_relaxed_timeout: bool = False, 

754 ): 

755 if reset_host_address: 

756 self.host = self.orig_host_address 

757 if reset_relaxed_timeout: 

758 self.socket_timeout = self.orig_socket_timeout 

759 self.socket_connect_timeout = self.orig_socket_connect_timeout 

760 

761 

762class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface): 

763 "Manages communication to and from a Redis server" 

764 

765 @deprecated_args( 

766 args_to_warn=["lib_name", "lib_version"], 

767 reason="Use 'driver_info' parameter instead. " 

768 "lib_name and lib_version will be removed in a future version.", 

769 ) 

770 def __init__( 

771 self, 

772 db: int = 0, 

773 password: Optional[str] = None, 

774 socket_timeout: Optional[float] = None, 

775 socket_connect_timeout: Optional[float] = None, 

776 retry_on_timeout: bool = False, 

777 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL, 

778 encoding: str = "utf-8", 

779 encoding_errors: str = "strict", 

780 decode_responses: bool = False, 

781 parser_class=DefaultParser, 

782 socket_read_size: int = 65536, 

783 health_check_interval: int = 0, 

784 client_name: Optional[str] = None, 

785 lib_name: Optional[str] = None, 

786 lib_version: Optional[str] = None, 

787 driver_info: Optional[DriverInfo] = None, 

788 username: Optional[str] = None, 

789 retry: Union[Any, None] = None, 

790 redis_connect_func: Optional[Callable[[], None]] = None, 

791 credential_provider: Optional[CredentialProvider] = None, 

792 protocol: Optional[int] = 2, 

793 command_packer: Optional[Callable[[], None]] = None, 

794 event_dispatcher: Optional[EventDispatcher] = None, 

795 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

796 maint_notifications_pool_handler: Optional[ 

797 MaintNotificationsPoolHandler 

798 ] = None, 

799 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

800 maintenance_notification_hash: Optional[int] = None, 

801 orig_host_address: Optional[str] = None, 

802 orig_socket_timeout: Optional[float] = None, 

803 orig_socket_connect_timeout: Optional[float] = None, 

804 oss_cluster_maint_notifications_handler: Optional[ 

805 OSSMaintNotificationsHandler 

806 ] = None, 

807 ): 

808 """ 

809 Initialize a new Connection. 

810 

811 To specify a retry policy for specific errors, first set 

812 `retry_on_error` to a list of the error/s to retry on, then set 

813 `retry` to a valid `Retry` object. 

814 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

815 

816 Parameters 

817 ---------- 

818 driver_info : DriverInfo, optional 

819 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version 

820 are ignored. If not provided, a DriverInfo will be created from lib_name 

821 and lib_version (or defaults if those are also None). 

822 lib_name : str, optional 

823 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO. 

824 lib_version : str, optional 

825 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO. 

826 """ 

827 if (username or password) and credential_provider is not None: 

828 raise DataError( 

829 "'username' and 'password' cannot be passed along with 'credential_" 

830 "provider'. Please provide only one of the following arguments: \n" 

831 "1. 'password' and (optional) 'username'\n" 

832 "2. 'credential_provider'" 

833 ) 

834 if event_dispatcher is None: 

835 self._event_dispatcher = EventDispatcher() 

836 else: 

837 self._event_dispatcher = event_dispatcher 

838 self.pid = os.getpid() 

839 self.db = db 

840 self.client_name = client_name 

841 

842 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version 

843 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version) 

844 

845 self.credential_provider = credential_provider 

846 self.password = password 

847 self.username = username 

848 self._socket_timeout = socket_timeout 

849 if socket_connect_timeout is None: 

850 socket_connect_timeout = socket_timeout 

851 self._socket_connect_timeout = socket_connect_timeout 

852 self.retry_on_timeout = retry_on_timeout 

853 if retry_on_error is SENTINEL: 

854 retry_on_errors_list = [] 

855 else: 

856 retry_on_errors_list = list(retry_on_error) 

857 if retry_on_timeout: 

858 # Add TimeoutError to the errors list to retry on 

859 retry_on_errors_list.append(TimeoutError) 

860 self.retry_on_error = retry_on_errors_list 

861 if retry or self.retry_on_error: 

862 if retry is None: 

863 self.retry = Retry(NoBackoff(), 1) 

864 else: 

865 # deep-copy the Retry object as it is mutable 

866 self.retry = copy.deepcopy(retry) 

867 if self.retry_on_error: 

868 # Update the retry's supported errors with the specified errors 

869 self.retry.update_supported_errors(self.retry_on_error) 

870 else: 

871 self.retry = Retry(NoBackoff(), 0) 

872 self.health_check_interval = health_check_interval 

873 self.next_health_check = 0 

874 self.redis_connect_func = redis_connect_func 

875 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

876 self.handshake_metadata = None 

877 self._sock = None 

878 self._socket_read_size = socket_read_size 

879 self._connect_callbacks = [] 

880 self._buffer_cutoff = 6000 

881 self._re_auth_token: Optional[TokenInterface] = None 

882 try: 

883 p = int(protocol) 

884 except TypeError: 

885 p = DEFAULT_RESP_VERSION 

886 except ValueError: 

887 raise ConnectionError("protocol must be an integer") 

888 else: 

889 if p < 2 or p > 3: 

890 raise ConnectionError("protocol must be either 2 or 3") 

891 self.protocol = p 

892 if self.protocol == 3 and parser_class == _RESP2Parser: 

893 # If the protocol is 3 but the parser is RESP2, change it to RESP3 

894 # This is needed because the parser might be set before the protocol 

895 # or might be provided as a kwarg to the constructor 

896 # We need to react on discrepancy only for RESP2 and RESP3 

897 # as hiredis supports both 

898 parser_class = _RESP3Parser 

899 self.set_parser(parser_class) 

900 

901 self._command_packer = self._construct_command_packer(command_packer) 

902 self._should_reconnect = False 

903 

904 # Set up maintenance notifications 

905 MaintNotificationsAbstractConnection.__init__( 

906 self, 

907 maint_notifications_config, 

908 maint_notifications_pool_handler, 

909 maintenance_state, 

910 maintenance_notification_hash, 

911 orig_host_address, 

912 orig_socket_timeout, 

913 orig_socket_connect_timeout, 

914 oss_cluster_maint_notifications_handler, 

915 self._parser, 

916 event_dispatcher=self._event_dispatcher, 

917 ) 

918 

919 def __repr__(self): 

920 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) 

921 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

922 

923 @abstractmethod 

924 def repr_pieces(self): 

925 pass 

926 

927 def __del__(self): 

928 try: 

929 self.disconnect() 

930 except Exception: 

931 pass 

932 

933 def _construct_command_packer(self, packer): 

934 if packer is not None: 

935 return packer 

936 elif HIREDIS_AVAILABLE: 

937 return HiredisRespSerializer() 

938 else: 

939 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) 

940 

941 def register_connect_callback(self, callback): 

942 """ 

943 Register a callback to be called when the connection is established either 

944 initially or reconnected. This allows listeners to issue commands that 

945 are ephemeral to the connection, for example pub/sub subscription or 

946 key tracking. The callback must be a _method_ and will be kept as 

947 a weak reference. 

948 """ 

949 wm = weakref.WeakMethod(callback) 

950 if wm not in self._connect_callbacks: 

951 self._connect_callbacks.append(wm) 

952 

953 def deregister_connect_callback(self, callback): 

954 """ 

955 De-register a previously registered callback. It will no-longer receive 

956 notifications on connection events. Calling this is not required when the 

957 listener goes away, since the callbacks are kept as weak methods. 

958 """ 

959 try: 

960 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

961 except ValueError: 

962 pass 

963 

964 def set_parser(self, parser_class): 

965 """ 

966 Creates a new instance of parser_class with socket size: 

967 _socket_read_size and assigns it to the parser for the connection 

968 :param parser_class: The required parser class 

969 """ 

970 self._parser = parser_class(socket_read_size=self._socket_read_size) 

971 

972 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]: 

973 return self._parser 

974 

975 def connect(self): 

976 "Connects to the Redis server if not already connected" 

977 # try once the socket connect with the handshake, retry the whole 

978 # connect/handshake flow based on retry policy 

979 self.retry.call_with_retry( 

980 lambda: self.connect_check_health( 

981 check_health=True, retry_socket_connect=False 

982 ), 

983 lambda error: self.disconnect(error), 

984 ) 

985 

986 def connect_check_health( 

987 self, check_health: bool = True, retry_socket_connect: bool = True 

988 ): 

989 if self._sock: 

990 return 

991 # Track actual retry attempts for error reporting 

992 actual_retry_attempts = [0] 

993 

994 def failure_callback(error, failure_count): 

995 actual_retry_attempts[0] = failure_count 

996 self.disconnect(error=error, failure_count=failure_count) 

997 

998 try: 

999 if retry_socket_connect: 

1000 sock = self.retry.call_with_retry( 

1001 self._connect, 

1002 failure_callback, 

1003 with_failure_count=True, 

1004 ) 

1005 else: 

1006 sock = self._connect() 

1007 except socket.timeout: 

1008 e = TimeoutError("Timeout connecting to server") 

1009 record_error_count( 

1010 server_address=self.host, 

1011 server_port=self.port, 

1012 network_peer_address=self.host, 

1013 network_peer_port=self.port, 

1014 error_type=e, 

1015 retry_attempts=actual_retry_attempts[0], 

1016 ) 

1017 raise e 

1018 except OSError as e: 

1019 e = ConnectionError(self._error_message(e)) 

1020 record_error_count( 

1021 server_address=getattr(self, "host", None), 

1022 server_port=getattr(self, "port", None), 

1023 network_peer_address=getattr(self, "host", None), 

1024 network_peer_port=getattr(self, "port", None), 

1025 error_type=e, 

1026 retry_attempts=actual_retry_attempts[0], 

1027 ) 

1028 raise e 

1029 

1030 self._sock = sock 

1031 try: 

1032 if self.redis_connect_func is None: 

1033 # Use the default on_connect function 

1034 self.on_connect_check_health(check_health=check_health) 

1035 else: 

1036 # Use the passed function redis_connect_func 

1037 self.redis_connect_func(self) 

1038 except RedisError: 

1039 # clean up after any error in on_connect 

1040 self.disconnect() 

1041 raise 

1042 

1043 # run any user callbacks. right now the only internal callback 

1044 # is for pubsub channel/pattern resubscription 

1045 # first, remove any dead weakrefs 

1046 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

1047 for ref in self._connect_callbacks: 

1048 callback = ref() 

1049 if callback: 

1050 callback(self) 

1051 

1052 @abstractmethod 

1053 def _connect(self): 

1054 pass 

1055 

1056 @abstractmethod 

1057 def _host_error(self): 

1058 pass 

1059 

1060 def _error_message(self, exception): 

1061 return format_error_message(self._host_error(), exception) 

1062 

1063 def on_connect(self): 

1064 self.on_connect_check_health(check_health=True) 

1065 

1066 def on_connect_check_health(self, check_health: bool = True): 

1067 "Initialize the connection, authenticate and select a database" 

1068 self._parser.on_connect(self) 

1069 parser = self._parser 

1070 

1071 auth_args = None 

1072 # if credential provider or username and/or password are set, authenticate 

1073 if self.credential_provider or (self.username or self.password): 

1074 cred_provider = ( 

1075 self.credential_provider 

1076 or UsernamePasswordCredentialProvider(self.username, self.password) 

1077 ) 

1078 auth_args = cred_provider.get_credentials() 

1079 

1080 # if resp version is specified and we have auth args, 

1081 # we need to send them via HELLO 

1082 if auth_args and self.protocol not in [2, "2"]: 

1083 if isinstance(self._parser, _RESP2Parser): 

1084 self.set_parser(_RESP3Parser) 

1085 # update cluster exception classes 

1086 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

1087 self._parser.on_connect(self) 

1088 if len(auth_args) == 1: 

1089 auth_args = ["default", auth_args[0]] 

1090 # avoid checking health here -- PING will fail if we try 

1091 # to check the health prior to the AUTH 

1092 self.send_command( 

1093 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

1094 ) 

1095 self.handshake_metadata = self.read_response() 

1096 # if response.get(b"proto") != self.protocol and response.get( 

1097 # "proto" 

1098 # ) != self.protocol: 

1099 # raise ConnectionError("Invalid RESP version") 

1100 elif auth_args: 

1101 # avoid checking health here -- PING will fail if we try 

1102 # to check the health prior to the AUTH 

1103 self.send_command("AUTH", *auth_args, check_health=False) 

1104 

1105 try: 

1106 auth_response = self.read_response() 

1107 except AuthenticationWrongNumberOfArgsError: 

1108 # a username and password were specified but the Redis 

1109 # server seems to be < 6.0.0 which expects a single password 

1110 # arg. retry auth with just the password. 

1111 # https://github.com/andymccurdy/redis-py/issues/1274 

1112 self.send_command("AUTH", auth_args[-1], check_health=False) 

1113 auth_response = self.read_response() 

1114 

1115 if str_if_bytes(auth_response) != "OK": 

1116 raise AuthenticationError("Invalid Username or Password") 

1117 

1118 # if resp version is specified, switch to it 

1119 elif self.protocol not in [2, "2"]: 

1120 if isinstance(self._parser, _RESP2Parser): 

1121 self.set_parser(_RESP3Parser) 

1122 # update cluster exception classes 

1123 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

1124 self._parser.on_connect(self) 

1125 self.send_command("HELLO", self.protocol, check_health=check_health) 

1126 self.handshake_metadata = self.read_response() 

1127 if ( 

1128 self.handshake_metadata.get(b"proto") != self.protocol 

1129 and self.handshake_metadata.get("proto") != self.protocol 

1130 ): 

1131 raise ConnectionError("Invalid RESP version") 

1132 

1133 # Activate maintenance notifications for this connection 

1134 # if enabled in the configuration 

1135 # This is a no-op if maintenance notifications are not enabled 

1136 self.activate_maint_notifications_handling_if_enabled(check_health=check_health) 

1137 

1138 # if a client_name is given, set it 

1139 if self.client_name: 

1140 self.send_command( 

1141 "CLIENT", 

1142 "SETNAME", 

1143 self.client_name, 

1144 check_health=check_health, 

1145 ) 

1146 if str_if_bytes(self.read_response()) != "OK": 

1147 raise ConnectionError("Error setting client name") 

1148 

1149 # Set the library name and version from driver_info 

1150 try: 

1151 if self.driver_info and self.driver_info.formatted_name: 

1152 self.send_command( 

1153 "CLIENT", 

1154 "SETINFO", 

1155 "LIB-NAME", 

1156 self.driver_info.formatted_name, 

1157 check_health=check_health, 

1158 ) 

1159 self.read_response() 

1160 except ResponseError: 

1161 pass 

1162 

1163 try: 

1164 if self.driver_info and self.driver_info.lib_version: 

1165 self.send_command( 

1166 "CLIENT", 

1167 "SETINFO", 

1168 "LIB-VER", 

1169 self.driver_info.lib_version, 

1170 check_health=check_health, 

1171 ) 

1172 self.read_response() 

1173 except ResponseError: 

1174 pass 

1175 

1176 # if a database is specified, switch to it 

1177 if self.db: 

1178 self.send_command("SELECT", self.db, check_health=check_health) 

1179 if str_if_bytes(self.read_response()) != "OK": 

1180 raise ConnectionError("Invalid Database") 

1181 

1182 def disconnect(self, *args, **kwargs): 

1183 "Disconnects from the Redis server" 

1184 self._parser.on_disconnect() 

1185 

1186 conn_sock = self._sock 

1187 self._sock = None 

1188 # reset the reconnect flag 

1189 self.reset_should_reconnect() 

1190 

1191 if conn_sock is None: 

1192 return 

1193 

1194 if os.getpid() == self.pid: 

1195 try: 

1196 conn_sock.shutdown(socket.SHUT_RDWR) 

1197 except (OSError, TypeError): 

1198 pass 

1199 

1200 try: 

1201 conn_sock.close() 

1202 except OSError: 

1203 pass 

1204 

1205 error = kwargs.get("error") 

1206 failure_count = kwargs.get("failure_count") 

1207 health_check_failed = kwargs.get("health_check_failed") 

1208 

1209 if error: 

1210 if health_check_failed: 

1211 close_reason = CloseReason.HEALTHCHECK_FAILED 

1212 else: 

1213 close_reason = CloseReason.ERROR 

1214 

1215 if failure_count is not None and failure_count > self.retry.get_retries(): 

1216 record_error_count( 

1217 server_address=self.host, 

1218 server_port=self.port, 

1219 network_peer_address=self.host, 

1220 network_peer_port=self.port, 

1221 error_type=error, 

1222 retry_attempts=failure_count, 

1223 ) 

1224 

1225 record_connection_closed( 

1226 close_reason=close_reason, 

1227 error_type=error, 

1228 ) 

1229 else: 

1230 record_connection_closed( 

1231 close_reason=CloseReason.APPLICATION_CLOSE, 

1232 ) 

1233 

1234 if self.maintenance_state == MaintenanceState.MAINTENANCE: 

1235 # this block will be executed only if the connection was in maintenance state 

1236 # and the connection was closed. 

1237 # The state change won't be applied on connections that are in Moving state 

1238 # because their state and configurations will be handled when the moving ttl expires. 

1239 self.reset_tmp_settings(reset_relaxed_timeout=True) 

1240 self.maintenance_state = MaintenanceState.NONE 

1241 # reset the sets that keep track of received start maint 

1242 # notifications and skipped end maint notifications 

1243 self.reset_received_notifications() 

1244 

1245 def mark_for_reconnect(self): 

1246 self._should_reconnect = True 

1247 

1248 def should_reconnect(self): 

1249 return self._should_reconnect 

1250 

1251 def reset_should_reconnect(self): 

1252 self._should_reconnect = False 

1253 

1254 def _send_ping(self): 

1255 """Send PING, expect PONG in return""" 

1256 self.send_command("PING", check_health=False) 

1257 if str_if_bytes(self.read_response()) != "PONG": 

1258 raise ConnectionError("Bad response from PING health check") 

1259 

1260 def _ping_failed(self, error, failure_count): 

1261 """Function to call when PING fails""" 

1262 self.disconnect( 

1263 error=error, failure_count=failure_count, health_check_failed=True 

1264 ) 

1265 

1266 def check_health(self): 

1267 """Check the health of the connection with a PING/PONG""" 

1268 if self.health_check_interval and time.monotonic() > self.next_health_check: 

1269 self.retry.call_with_retry( 

1270 self._send_ping, 

1271 self._ping_failed, 

1272 with_failure_count=True, 

1273 ) 

1274 

1275 def send_packed_command(self, command, check_health=True): 

1276 """Send an already packed command to the Redis server""" 

1277 if not self._sock: 

1278 self.connect_check_health(check_health=False) 

1279 # guard against health check recursion 

1280 if check_health: 

1281 self.check_health() 

1282 try: 

1283 if isinstance(command, str): 

1284 command = [command] 

1285 for item in command: 

1286 self._sock.sendall(item) 

1287 except socket.timeout: 

1288 self.disconnect() 

1289 raise TimeoutError("Timeout writing to socket") 

1290 except OSError as e: 

1291 self.disconnect() 

1292 if len(e.args) == 1: 

1293 errno, errmsg = "UNKNOWN", e.args[0] 

1294 else: 

1295 errno = e.args[0] 

1296 errmsg = e.args[1] 

1297 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") 

1298 except BaseException: 

1299 # BaseExceptions can be raised when a socket send operation is not 

1300 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

1301 # to send un-sent data. However, the send_packed_command() API 

1302 # does not support it so there is no point in keeping the connection open. 

1303 self.disconnect() 

1304 raise 

1305 

1306 def send_command(self, *args, **kwargs): 

1307 """Pack and send a command to the Redis server""" 

1308 self.send_packed_command( 

1309 self._command_packer.pack(*args), 

1310 check_health=kwargs.get("check_health", True), 

1311 ) 

1312 

1313 def can_read(self, timeout=0): 

1314 """Poll the socket to see if there's data that can be read.""" 

1315 sock = self._sock 

1316 if not sock: 

1317 self.connect() 

1318 

1319 host_error = self._host_error() 

1320 

1321 try: 

1322 return self._parser.can_read(timeout) 

1323 

1324 except OSError as e: 

1325 self.disconnect() 

1326 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

1327 

1328 def read_response( 

1329 self, 

1330 disable_decoding=False, 

1331 *, 

1332 disconnect_on_error=True, 

1333 push_request=False, 

1334 ): 

1335 """Read the response from a previously sent command""" 

1336 

1337 host_error = self._host_error() 

1338 

1339 try: 

1340 if self.protocol in ["3", 3]: 

1341 response = self._parser.read_response( 

1342 disable_decoding=disable_decoding, push_request=push_request 

1343 ) 

1344 else: 

1345 response = self._parser.read_response(disable_decoding=disable_decoding) 

1346 except socket.timeout: 

1347 if disconnect_on_error: 

1348 self.disconnect() 

1349 raise TimeoutError(f"Timeout reading from {host_error}") 

1350 except OSError as e: 

1351 if disconnect_on_error: 

1352 self.disconnect() 

1353 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

1354 except BaseException: 

1355 # Also by default close in case of BaseException. A lot of code 

1356 # relies on this behaviour when doing Command/Response pairs. 

1357 # See #1128. 

1358 if disconnect_on_error: 

1359 self.disconnect() 

1360 raise 

1361 

1362 if self.health_check_interval: 

1363 self.next_health_check = time.monotonic() + self.health_check_interval 

1364 

1365 if isinstance(response, ResponseError): 

1366 try: 

1367 raise response 

1368 finally: 

1369 del response # avoid creating ref cycles 

1370 return response 

1371 

1372 def pack_command(self, *args): 

1373 """Pack a series of arguments into the Redis protocol""" 

1374 return self._command_packer.pack(*args) 

1375 

1376 def pack_commands(self, commands): 

1377 """Pack multiple commands into the Redis protocol""" 

1378 output = [] 

1379 pieces = [] 

1380 buffer_length = 0 

1381 buffer_cutoff = self._buffer_cutoff 

1382 

1383 for cmd in commands: 

1384 for chunk in self._command_packer.pack(*cmd): 

1385 chunklen = len(chunk) 

1386 if ( 

1387 buffer_length > buffer_cutoff 

1388 or chunklen > buffer_cutoff 

1389 or isinstance(chunk, memoryview) 

1390 ): 

1391 if pieces: 

1392 output.append(SYM_EMPTY.join(pieces)) 

1393 buffer_length = 0 

1394 pieces = [] 

1395 

1396 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

1397 output.append(chunk) 

1398 else: 

1399 pieces.append(chunk) 

1400 buffer_length += chunklen 

1401 

1402 if pieces: 

1403 output.append(SYM_EMPTY.join(pieces)) 

1404 return output 

1405 

1406 def get_protocol(self) -> Union[int, str]: 

1407 return self.protocol 

1408 

1409 @property 

1410 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1411 return self._handshake_metadata 

1412 

1413 @handshake_metadata.setter 

1414 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): 

1415 self._handshake_metadata = value 

1416 

1417 def set_re_auth_token(self, token: TokenInterface): 

1418 self._re_auth_token = token 

1419 

1420 def re_auth(self): 

1421 if self._re_auth_token is not None: 

1422 self.send_command( 

1423 "AUTH", 

1424 self._re_auth_token.try_get("oid"), 

1425 self._re_auth_token.get_value(), 

1426 ) 

1427 self.read_response() 

1428 self._re_auth_token = None 

1429 

1430 def _get_socket(self) -> Optional[socket.socket]: 

1431 return self._sock 

1432 

1433 @property 

1434 def socket_timeout(self) -> Optional[Union[float, int]]: 

1435 return self._socket_timeout 

1436 

1437 @socket_timeout.setter 

1438 def socket_timeout(self, value: Optional[Union[float, int]]): 

1439 self._socket_timeout = value 

1440 

1441 @property 

1442 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

1443 return self._socket_connect_timeout 

1444 

1445 @socket_connect_timeout.setter 

1446 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

1447 self._socket_connect_timeout = value 

1448 

1449 def extract_connection_details(self) -> str: 

1450 socket_address = None 

1451 if self._sock is None: 

1452 return "not connected" 

1453 try: 

1454 socket_address = self._sock.getsockname() if self._sock else None 

1455 socket_address = socket_address[1] if socket_address else None 

1456 except (AttributeError, OSError): 

1457 pass 

1458 

1459 return f"connected to ip {self.get_resolved_ip()}, local socket port: {socket_address}" 

1460 

1461 

1462class Connection(AbstractConnection): 

1463 "Manages TCP communication to and from a Redis server" 

1464 

1465 def __init__( 

1466 self, 

1467 host="localhost", 

1468 port=6379, 

1469 socket_keepalive=False, 

1470 socket_keepalive_options=None, 

1471 socket_type=0, 

1472 **kwargs, 

1473 ): 

1474 self._host = host 

1475 self.port = int(port) 

1476 self.socket_keepalive = socket_keepalive 

1477 self.socket_keepalive_options = socket_keepalive_options or {} 

1478 self.socket_type = socket_type 

1479 super().__init__(**kwargs) 

1480 

1481 def repr_pieces(self): 

1482 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

1483 if self.client_name: 

1484 pieces.append(("client_name", self.client_name)) 

1485 return pieces 

1486 

1487 def _connect(self): 

1488 "Create a TCP socket connection" 

1489 # we want to mimic what socket.create_connection does to support 

1490 # ipv4/ipv6, but we want to set options prior to calling 

1491 # socket.connect() 

1492 err = None 

1493 

1494 for res in socket.getaddrinfo( 

1495 self.host, self.port, self.socket_type, socket.SOCK_STREAM 

1496 ): 

1497 family, socktype, proto, canonname, socket_address = res 

1498 sock = None 

1499 try: 

1500 sock = socket.socket(family, socktype, proto) 

1501 # TCP_NODELAY 

1502 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

1503 

1504 # TCP_KEEPALIVE 

1505 if self.socket_keepalive: 

1506 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

1507 for k, v in self.socket_keepalive_options.items(): 

1508 sock.setsockopt(socket.IPPROTO_TCP, k, v) 

1509 

1510 # set the socket_connect_timeout before we connect 

1511 sock.settimeout(self.socket_connect_timeout) 

1512 

1513 # connect 

1514 sock.connect(socket_address) 

1515 

1516 # set the socket_timeout now that we're connected 

1517 sock.settimeout(self.socket_timeout) 

1518 return sock 

1519 

1520 except OSError as _: 

1521 err = _ 

1522 if sock is not None: 

1523 try: 

1524 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1525 except OSError: 

1526 pass 

1527 sock.close() 

1528 

1529 if err is not None: 

1530 raise err 

1531 raise OSError("socket.getaddrinfo returned an empty list") 

1532 

1533 def _host_error(self): 

1534 return f"{self.host}:{self.port}" 

1535 

1536 @property 

1537 def host(self) -> str: 

1538 return self._host 

1539 

1540 @host.setter 

1541 def host(self, value: str): 

1542 self._host = value 

1543 

1544 

1545class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface): 

1546 DUMMY_CACHE_VALUE = b"foo" 

1547 MIN_ALLOWED_VERSION = "7.4.0" 

1548 DEFAULT_SERVER_NAME = "redis" 

1549 

1550 def __init__( 

1551 self, 

1552 conn: ConnectionInterface, 

1553 cache: CacheInterface, 

1554 pool_lock: threading.RLock, 

1555 ): 

1556 self.pid = os.getpid() 

1557 self._conn = conn 

1558 self.retry = self._conn.retry 

1559 self.host = self._conn.host 

1560 self.port = self._conn.port 

1561 self.db = self._conn.db 

1562 self._event_dispatcher = self._conn._event_dispatcher 

1563 self.credential_provider = conn.credential_provider 

1564 self._pool_lock = pool_lock 

1565 self._cache = cache 

1566 self._cache_lock = threading.RLock() 

1567 self._current_command_cache_key = None 

1568 self._current_options = None 

1569 self.register_connect_callback(self._enable_tracking_callback) 

1570 

1571 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1572 MaintNotificationsAbstractConnection.__init__( 

1573 self, 

1574 self._conn.maint_notifications_config, 

1575 self._conn._maint_notifications_pool_handler, 

1576 self._conn.maintenance_state, 

1577 self._conn.maintenance_notification_hash, 

1578 self._conn.host, 

1579 self._conn.socket_timeout, 

1580 self._conn.socket_connect_timeout, 

1581 self._conn._oss_cluster_maint_notifications_handler, 

1582 self._conn._get_parser(), 

1583 event_dispatcher=self._conn.event_dispatcher, 

1584 ) 

1585 

1586 def repr_pieces(self): 

1587 return self._conn.repr_pieces() 

1588 

1589 def register_connect_callback(self, callback): 

1590 self._conn.register_connect_callback(callback) 

1591 

1592 def deregister_connect_callback(self, callback): 

1593 self._conn.deregister_connect_callback(callback) 

1594 

1595 def set_parser(self, parser_class): 

1596 self._conn.set_parser(parser_class) 

1597 

1598 def set_maint_notifications_pool_handler_for_connection( 

1599 self, maint_notifications_pool_handler 

1600 ): 

1601 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1602 self._conn.set_maint_notifications_pool_handler_for_connection( 

1603 maint_notifications_pool_handler 

1604 ) 

1605 

1606 def set_maint_notifications_cluster_handler_for_connection( 

1607 self, oss_cluster_maint_notifications_handler 

1608 ): 

1609 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1610 self._conn.set_maint_notifications_cluster_handler_for_connection( 

1611 oss_cluster_maint_notifications_handler 

1612 ) 

1613 

1614 def get_protocol(self): 

1615 return self._conn.get_protocol() 

1616 

1617 def connect(self): 

1618 self._conn.connect() 

1619 

1620 server_name = self._conn.handshake_metadata.get(b"server", None) 

1621 if server_name is None: 

1622 server_name = self._conn.handshake_metadata.get("server", None) 

1623 server_ver = self._conn.handshake_metadata.get(b"version", None) 

1624 if server_ver is None: 

1625 server_ver = self._conn.handshake_metadata.get("version", None) 

1626 if server_ver is None or server_name is None: 

1627 raise ConnectionError("Cannot retrieve information about server version") 

1628 

1629 server_ver = ensure_string(server_ver) 

1630 server_name = ensure_string(server_name) 

1631 

1632 if ( 

1633 server_name != self.DEFAULT_SERVER_NAME 

1634 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 

1635 ): 

1636 raise ConnectionError( 

1637 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 

1638 ) 

1639 

1640 def on_connect(self): 

1641 self._conn.on_connect() 

1642 

1643 def disconnect(self, *args, **kwargs): 

1644 with self._cache_lock: 

1645 self._cache.flush() 

1646 self._conn.disconnect(*args, **kwargs) 

1647 

1648 def check_health(self): 

1649 self._conn.check_health() 

1650 

1651 def send_packed_command(self, command, check_health=True): 

1652 # TODO: Investigate if it's possible to unpack command 

1653 # or extract keys from packed command 

1654 self._conn.send_packed_command(command) 

1655 

1656 def send_command(self, *args, **kwargs): 

1657 self._process_pending_invalidations() 

1658 

1659 with self._cache_lock: 

1660 # Command is write command or not allowed 

1661 # to be cached. 

1662 if not self._cache.is_cachable( 

1663 CacheKey(command=args[0], redis_keys=(), redis_args=()) 

1664 ): 

1665 self._current_command_cache_key = None 

1666 self._conn.send_command(*args, **kwargs) 

1667 return 

1668 

1669 if kwargs.get("keys") is None: 

1670 raise ValueError("Cannot create cache key.") 

1671 

1672 # Creates cache key. 

1673 self._current_command_cache_key = CacheKey( 

1674 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args 

1675 ) 

1676 

1677 with self._cache_lock: 

1678 # We have to trigger invalidation processing in case if 

1679 # it was cached by another connection to avoid 

1680 # queueing invalidations in stale connections. 

1681 if self._cache.get(self._current_command_cache_key): 

1682 entry = self._cache.get(self._current_command_cache_key) 

1683 

1684 if entry.connection_ref != self._conn: 

1685 with self._pool_lock: 

1686 while entry.connection_ref.can_read(): 

1687 entry.connection_ref.read_response(push_request=True) 

1688 

1689 return 

1690 

1691 # Set temporary entry value to prevent 

1692 # race condition from another connection. 

1693 self._cache.set( 

1694 CacheEntry( 

1695 cache_key=self._current_command_cache_key, 

1696 cache_value=self.DUMMY_CACHE_VALUE, 

1697 status=CacheEntryStatus.IN_PROGRESS, 

1698 connection_ref=self._conn, 

1699 ) 

1700 ) 

1701 

1702 # Send command over socket only if it's allowed 

1703 # read-only command that not yet cached. 

1704 self._conn.send_command(*args, **kwargs) 

1705 

1706 def can_read(self, timeout=0): 

1707 return self._conn.can_read(timeout) 

1708 

1709 def read_response( 

1710 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False 

1711 ): 

1712 with self._cache_lock: 

1713 # Check if command response exists in a cache and it's not in progress. 

1714 if self._current_command_cache_key is not None: 

1715 if ( 

1716 self._cache.get(self._current_command_cache_key) is not None 

1717 and self._cache.get(self._current_command_cache_key).status 

1718 != CacheEntryStatus.IN_PROGRESS 

1719 ): 

1720 res = copy.deepcopy( 

1721 self._cache.get(self._current_command_cache_key).cache_value 

1722 ) 

1723 self._current_command_cache_key = None 

1724 record_csc_request( 

1725 result=CSCResult.HIT, 

1726 ) 

1727 record_csc_network_saved( 

1728 bytes_saved=len(res), 

1729 ) 

1730 return res 

1731 record_csc_request( 

1732 result=CSCResult.MISS, 

1733 ) 

1734 

1735 response = self._conn.read_response( 

1736 disable_decoding=disable_decoding, 

1737 disconnect_on_error=disconnect_on_error, 

1738 push_request=push_request, 

1739 ) 

1740 

1741 with self._cache_lock: 

1742 # Prevent not-allowed command from caching. 

1743 if self._current_command_cache_key is None: 

1744 return response 

1745 # If response is None prevent from caching. 

1746 if response is None: 

1747 self._cache.delete_by_cache_keys([self._current_command_cache_key]) 

1748 return response 

1749 

1750 cache_entry = self._cache.get(self._current_command_cache_key) 

1751 

1752 # Cache only responses that still valid 

1753 # and wasn't invalidated by another connection in meantime. 

1754 if cache_entry is not None: 

1755 cache_entry.status = CacheEntryStatus.VALID 

1756 cache_entry.cache_value = response 

1757 self._cache.set(cache_entry) 

1758 

1759 self._current_command_cache_key = None 

1760 

1761 return response 

1762 

1763 def pack_command(self, *args): 

1764 return self._conn.pack_command(*args) 

1765 

1766 def pack_commands(self, commands): 

1767 return self._conn.pack_commands(commands) 

1768 

1769 @property 

1770 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1771 return self._conn.handshake_metadata 

1772 

1773 def set_re_auth_token(self, token: TokenInterface): 

1774 self._conn.set_re_auth_token(token) 

1775 

1776 def re_auth(self): 

1777 self._conn.re_auth() 

1778 

1779 def mark_for_reconnect(self): 

1780 self._conn.mark_for_reconnect() 

1781 

1782 def should_reconnect(self): 

1783 return self._conn.should_reconnect() 

1784 

1785 def reset_should_reconnect(self): 

1786 self._conn.reset_should_reconnect() 

1787 

1788 @property 

1789 def host(self) -> str: 

1790 return self._conn.host 

1791 

1792 @host.setter 

1793 def host(self, value: str): 

1794 self._conn.host = value 

1795 

1796 @property 

1797 def socket_timeout(self) -> Optional[Union[float, int]]: 

1798 return self._conn.socket_timeout 

1799 

1800 @socket_timeout.setter 

1801 def socket_timeout(self, value: Optional[Union[float, int]]): 

1802 self._conn.socket_timeout = value 

1803 

1804 @property 

1805 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

1806 return self._conn.socket_connect_timeout 

1807 

1808 @socket_connect_timeout.setter 

1809 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

1810 self._conn.socket_connect_timeout = value 

1811 

1812 @property 

1813 def _maint_notifications_connection_handler( 

1814 self, 

1815 ) -> Optional[MaintNotificationsConnectionHandler]: 

1816 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1817 return self._conn._maint_notifications_connection_handler 

1818 

1819 @_maint_notifications_connection_handler.setter 

1820 def _maint_notifications_connection_handler( 

1821 self, value: Optional[MaintNotificationsConnectionHandler] 

1822 ): 

1823 self._conn._maint_notifications_connection_handler = value 

1824 

1825 def _get_socket(self) -> Optional[socket.socket]: 

1826 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1827 return self._conn._get_socket() 

1828 else: 

1829 raise NotImplementedError( 

1830 "Maintenance notifications are not supported by this connection type" 

1831 ) 

1832 

1833 def _get_maint_notifications_connection_instance( 

1834 self, connection 

1835 ) -> MaintNotificationsAbstractConnection: 

1836 """ 

1837 Validate that connection instance supports maintenance notifications. 

1838 With this helper method we ensure that we are working 

1839 with the correct connection type. 

1840 After twe validate that connection instance supports maintenance notifications 

1841 we can safely return the connection instance 

1842 as MaintNotificationsAbstractConnection. 

1843 """ 

1844 if not isinstance(connection, MaintNotificationsAbstractConnection): 

1845 raise NotImplementedError( 

1846 "Maintenance notifications are not supported by this connection type" 

1847 ) 

1848 else: 

1849 return connection 

1850 

1851 @property 

1852 def maintenance_state(self) -> MaintenanceState: 

1853 con = self._get_maint_notifications_connection_instance(self._conn) 

1854 return con.maintenance_state 

1855 

1856 @maintenance_state.setter 

1857 def maintenance_state(self, state: MaintenanceState): 

1858 con = self._get_maint_notifications_connection_instance(self._conn) 

1859 con.maintenance_state = state 

1860 

1861 def getpeername(self): 

1862 con = self._get_maint_notifications_connection_instance(self._conn) 

1863 return con.getpeername() 

1864 

1865 def get_resolved_ip(self): 

1866 con = self._get_maint_notifications_connection_instance(self._conn) 

1867 return con.get_resolved_ip() 

1868 

1869 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

1870 con = self._get_maint_notifications_connection_instance(self._conn) 

1871 con.update_current_socket_timeout(relaxed_timeout) 

1872 

1873 def set_tmp_settings( 

1874 self, 

1875 tmp_host_address: Optional[str] = None, 

1876 tmp_relaxed_timeout: Optional[float] = None, 

1877 ): 

1878 con = self._get_maint_notifications_connection_instance(self._conn) 

1879 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout) 

1880 

1881 def reset_tmp_settings( 

1882 self, 

1883 reset_host_address: bool = False, 

1884 reset_relaxed_timeout: bool = False, 

1885 ): 

1886 con = self._get_maint_notifications_connection_instance(self._conn) 

1887 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout) 

1888 

1889 def _connect(self): 

1890 self._conn._connect() 

1891 

1892 def _host_error(self): 

1893 self._conn._host_error() 

1894 

1895 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: 

1896 conn.send_command("CLIENT", "TRACKING", "ON") 

1897 conn.read_response() 

1898 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) 

1899 

1900 def _process_pending_invalidations(self): 

1901 while self.can_read(): 

1902 self._conn.read_response(push_request=True) 

1903 

1904 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): 

1905 with self._cache_lock: 

1906 # Flush cache when DB flushed on server-side 

1907 if data[1] is None: 

1908 self._cache.flush() 

1909 else: 

1910 keys_deleted = self._cache.delete_by_redis_keys(data[1]) 

1911 

1912 if len(keys_deleted) > 0: 

1913 record_csc_eviction( 

1914 count=len(keys_deleted), 

1915 reason=CSCReason.INVALIDATION, 

1916 ) 

1917 

1918 def extract_connection_details(self) -> str: 

1919 return self._conn.extract_connection_details() 

1920 

1921 

1922class SSLConnection(Connection): 

1923 """Manages SSL connections to and from the Redis server(s). 

1924 This class extends the Connection class, adding SSL functionality, and making 

1925 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

1926 """ # noqa 

1927 

1928 def __init__( 

1929 self, 

1930 ssl_keyfile=None, 

1931 ssl_certfile=None, 

1932 ssl_cert_reqs="required", 

1933 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None, 

1934 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None, 

1935 ssl_ca_certs=None, 

1936 ssl_ca_data=None, 

1937 ssl_check_hostname=True, 

1938 ssl_ca_path=None, 

1939 ssl_password=None, 

1940 ssl_validate_ocsp=False, 

1941 ssl_validate_ocsp_stapled=False, 

1942 ssl_ocsp_context=None, 

1943 ssl_ocsp_expected_cert=None, 

1944 ssl_min_version=None, 

1945 ssl_ciphers=None, 

1946 **kwargs, 

1947 ): 

1948 """Constructor 

1949 

1950 Args: 

1951 ssl_keyfile: Path to an ssl private key. Defaults to None. 

1952 ssl_certfile: Path to an ssl certificate. Defaults to None. 

1953 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), 

1954 or an ssl.VerifyMode. Defaults to "required". 

1955 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None. 

1956 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None. 

1957 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. 

1958 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. 

1959 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True. 

1960 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. 

1961 ssl_password: Password for unlocking an encrypted private key. Defaults to None. 

1962 

1963 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) 

1964 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response 

1965 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert 

1966 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. 

1967 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. 

1968 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information. 

1969 

1970 Raises: 

1971 RedisError 

1972 """ # noqa 

1973 if not SSL_AVAILABLE: 

1974 raise RedisError("Python wasn't built with SSL support") 

1975 

1976 self.keyfile = ssl_keyfile 

1977 self.certfile = ssl_certfile 

1978 if ssl_cert_reqs is None: 

1979 ssl_cert_reqs = ssl.CERT_NONE 

1980 elif isinstance(ssl_cert_reqs, str): 

1981 CERT_REQS = { # noqa: N806 

1982 "none": ssl.CERT_NONE, 

1983 "optional": ssl.CERT_OPTIONAL, 

1984 "required": ssl.CERT_REQUIRED, 

1985 } 

1986 if ssl_cert_reqs not in CERT_REQS: 

1987 raise RedisError( 

1988 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" 

1989 ) 

1990 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 

1991 self.cert_reqs = ssl_cert_reqs 

1992 self.ssl_include_verify_flags = ssl_include_verify_flags 

1993 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags 

1994 self.ca_certs = ssl_ca_certs 

1995 self.ca_data = ssl_ca_data 

1996 self.ca_path = ssl_ca_path 

1997 self.check_hostname = ( 

1998 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1999 ) 

2000 self.certificate_password = ssl_password 

2001 self.ssl_validate_ocsp = ssl_validate_ocsp 

2002 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled 

2003 self.ssl_ocsp_context = ssl_ocsp_context 

2004 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert 

2005 self.ssl_min_version = ssl_min_version 

2006 self.ssl_ciphers = ssl_ciphers 

2007 super().__init__(**kwargs) 

2008 

2009 def _connect(self): 

2010 """ 

2011 Wrap the socket with SSL support, handling potential errors. 

2012 """ 

2013 sock = super()._connect() 

2014 try: 

2015 return self._wrap_socket_with_ssl(sock) 

2016 except (OSError, RedisError): 

2017 sock.close() 

2018 raise 

2019 

2020 def _wrap_socket_with_ssl(self, sock): 

2021 """ 

2022 Wraps the socket with SSL support. 

2023 

2024 Args: 

2025 sock: The plain socket to wrap with SSL. 

2026 

2027 Returns: 

2028 An SSL wrapped socket. 

2029 """ 

2030 context = ssl.create_default_context() 

2031 context.check_hostname = self.check_hostname 

2032 context.verify_mode = self.cert_reqs 

2033 if self.ssl_include_verify_flags: 

2034 for flag in self.ssl_include_verify_flags: 

2035 context.verify_flags |= flag 

2036 if self.ssl_exclude_verify_flags: 

2037 for flag in self.ssl_exclude_verify_flags: 

2038 context.verify_flags &= ~flag 

2039 if self.certfile or self.keyfile: 

2040 context.load_cert_chain( 

2041 certfile=self.certfile, 

2042 keyfile=self.keyfile, 

2043 password=self.certificate_password, 

2044 ) 

2045 if ( 

2046 self.ca_certs is not None 

2047 or self.ca_path is not None 

2048 or self.ca_data is not None 

2049 ): 

2050 context.load_verify_locations( 

2051 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

2052 ) 

2053 if self.ssl_min_version is not None: 

2054 context.minimum_version = self.ssl_min_version 

2055 if self.ssl_ciphers: 

2056 context.set_ciphers(self.ssl_ciphers) 

2057 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: 

2058 raise RedisError("cryptography is not installed.") 

2059 

2060 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: 

2061 raise RedisError( 

2062 "Either an OCSP staple or pure OCSP connection must be validated " 

2063 "- not both." 

2064 ) 

2065 

2066 sslsock = context.wrap_socket(sock, server_hostname=self.host) 

2067 

2068 # validation for the stapled case 

2069 if self.ssl_validate_ocsp_stapled: 

2070 import OpenSSL 

2071 

2072 from .ocsp import ocsp_staple_verifier 

2073 

2074 # if a context is provided use it - otherwise, a basic context 

2075 if self.ssl_ocsp_context is None: 

2076 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) 

2077 staple_ctx.use_certificate_file(self.certfile) 

2078 staple_ctx.use_privatekey_file(self.keyfile) 

2079 else: 

2080 staple_ctx = self.ssl_ocsp_context 

2081 

2082 staple_ctx.set_ocsp_client_callback( 

2083 ocsp_staple_verifier, self.ssl_ocsp_expected_cert 

2084 ) 

2085 

2086 # need another socket 

2087 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) 

2088 con.request_ocsp() 

2089 con.connect((self.host, self.port)) 

2090 con.do_handshake() 

2091 con.shutdown() 

2092 return sslsock 

2093 

2094 # pure ocsp validation 

2095 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: 

2096 from .ocsp import OCSPVerifier 

2097 

2098 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) 

2099 if o.is_valid(): 

2100 return sslsock 

2101 else: 

2102 raise ConnectionError("ocsp validation error") 

2103 return sslsock 

2104 

2105 

2106class UnixDomainSocketConnection(AbstractConnection): 

2107 "Manages UDS communication to and from a Redis server" 

2108 

2109 def __init__(self, path="", socket_timeout=None, **kwargs): 

2110 super().__init__(**kwargs) 

2111 self.path = path 

2112 self.socket_timeout = socket_timeout 

2113 

2114 def repr_pieces(self): 

2115 pieces = [("path", self.path), ("db", self.db)] 

2116 if self.client_name: 

2117 pieces.append(("client_name", self.client_name)) 

2118 return pieces 

2119 

2120 def _connect(self): 

2121 "Create a Unix domain socket connection" 

2122 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

2123 sock.settimeout(self.socket_connect_timeout) 

2124 try: 

2125 sock.connect(self.path) 

2126 except OSError: 

2127 # Prevent ResourceWarnings for unclosed sockets. 

2128 try: 

2129 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

2130 except OSError: 

2131 pass 

2132 sock.close() 

2133 raise 

2134 sock.settimeout(self.socket_timeout) 

2135 return sock 

2136 

2137 def _host_error(self): 

2138 return self.path 

2139 

2140 

2141FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

2142 

2143 

2144def to_bool(value): 

2145 if value is None or value == "": 

2146 return None 

2147 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

2148 return False 

2149 return bool(value) 

2150 

2151 

2152def parse_ssl_verify_flags(value): 

2153 # flags are passed in as a string representation of a list, 

2154 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

2155 verify_flags_str = value.replace("[", "").replace("]", "") 

2156 

2157 verify_flags = [] 

2158 for flag in verify_flags_str.split(","): 

2159 flag = flag.strip() 

2160 if not hasattr(VerifyFlags, flag): 

2161 raise ValueError(f"Invalid ssl verify flag: {flag}") 

2162 verify_flags.append(getattr(VerifyFlags, flag)) 

2163 return verify_flags 

2164 

2165 

2166URL_QUERY_ARGUMENT_PARSERS = { 

2167 "db": int, 

2168 "socket_timeout": float, 

2169 "socket_connect_timeout": float, 

2170 "socket_keepalive": to_bool, 

2171 "retry_on_timeout": to_bool, 

2172 "retry_on_error": list, 

2173 "max_connections": int, 

2174 "health_check_interval": int, 

2175 "ssl_check_hostname": to_bool, 

2176 "ssl_include_verify_flags": parse_ssl_verify_flags, 

2177 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

2178 "timeout": float, 

2179} 

2180 

2181 

2182def parse_url(url): 

2183 if not ( 

2184 url.startswith("redis://") 

2185 or url.startswith("rediss://") 

2186 or url.startswith("unix://") 

2187 ): 

2188 raise ValueError( 

2189 "Redis URL must specify one of the following " 

2190 "schemes (redis://, rediss://, unix://)" 

2191 ) 

2192 

2193 url = urlparse(url) 

2194 kwargs = {} 

2195 

2196 for name, value in parse_qs(url.query).items(): 

2197 if value and len(value) > 0: 

2198 value = unquote(value[0]) 

2199 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

2200 if parser: 

2201 try: 

2202 kwargs[name] = parser(value) 

2203 except (TypeError, ValueError): 

2204 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

2205 else: 

2206 kwargs[name] = value 

2207 

2208 if url.username: 

2209 kwargs["username"] = unquote(url.username) 

2210 if url.password: 

2211 kwargs["password"] = unquote(url.password) 

2212 

2213 # We only support redis://, rediss:// and unix:// schemes. 

2214 if url.scheme == "unix": 

2215 if url.path: 

2216 kwargs["path"] = unquote(url.path) 

2217 kwargs["connection_class"] = UnixDomainSocketConnection 

2218 

2219 else: # implied: url.scheme in ("redis", "rediss"): 

2220 if url.hostname: 

2221 kwargs["host"] = unquote(url.hostname) 

2222 if url.port: 

2223 kwargs["port"] = int(url.port) 

2224 

2225 # If there's a path argument, use it as the db argument if a 

2226 # querystring value wasn't specified 

2227 if url.path and "db" not in kwargs: 

2228 try: 

2229 kwargs["db"] = int(unquote(url.path).replace("/", "")) 

2230 except (AttributeError, ValueError): 

2231 pass 

2232 

2233 if url.scheme == "rediss": 

2234 kwargs["connection_class"] = SSLConnection 

2235 

2236 return kwargs 

2237 

2238 

2239_CP = TypeVar("_CP", bound="ConnectionPool") 

2240 

2241 

2242class ConnectionPoolInterface(ABC): 

2243 @abstractmethod 

2244 def get_protocol(self): 

2245 pass 

2246 

2247 @abstractmethod 

2248 def reset(self): 

2249 pass 

2250 

2251 @abstractmethod 

2252 @deprecated_args( 

2253 args_to_warn=["*"], 

2254 reason="Use get_connection() without args instead", 

2255 version="5.3.0", 

2256 ) 

2257 def get_connection( 

2258 self, command_name: Optional[str], *keys, **options 

2259 ) -> ConnectionInterface: 

2260 pass 

2261 

2262 @abstractmethod 

2263 def get_encoder(self): 

2264 pass 

2265 

2266 @abstractmethod 

2267 def release(self, connection: ConnectionInterface): 

2268 pass 

2269 

2270 @abstractmethod 

2271 def disconnect(self, inuse_connections: bool = True): 

2272 pass 

2273 

2274 @abstractmethod 

2275 def close(self): 

2276 pass 

2277 

2278 @abstractmethod 

2279 def set_retry(self, retry: Retry): 

2280 pass 

2281 

2282 @abstractmethod 

2283 def re_auth_callback(self, token: TokenInterface): 

2284 pass 

2285 

2286 @abstractmethod 

2287 def get_connection_count(self) -> list[tuple[int, dict]]: 

2288 """ 

2289 Returns a connection count (both idle and in use). 

2290 """ 

2291 pass 

2292 

2293 

2294class MaintNotificationsAbstractConnectionPool: 

2295 """ 

2296 Abstract class for handling maintenance notifications logic. 

2297 This class is mixed into the ConnectionPool classes. 

2298 

2299 This class is not intended to be used directly! 

2300 

2301 All logic related to maintenance notifications and 

2302 connection pool handling is encapsulated in this class. 

2303 """ 

2304 

2305 def __init__( 

2306 self, 

2307 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

2308 oss_cluster_maint_notifications_handler: Optional[ 

2309 OSSMaintNotificationsHandler 

2310 ] = None, 

2311 **kwargs, 

2312 ): 

2313 # Initialize maintenance notifications 

2314 is_protocol_supported = check_protocol_version(kwargs.get("protocol"), 3) 

2315 

2316 if maint_notifications_config is None and is_protocol_supported: 

2317 maint_notifications_config = MaintNotificationsConfig() 

2318 

2319 if maint_notifications_config and maint_notifications_config.enabled: 

2320 if not is_protocol_supported: 

2321 raise RedisError( 

2322 "Maintenance notifications handlers on connection are only supported with RESP version 3" 

2323 ) 

2324 

2325 self._event_dispatcher = kwargs.get("event_dispatcher", None) 

2326 if self._event_dispatcher is None: 

2327 self._event_dispatcher = EventDispatcher() 

2328 

2329 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2330 self, maint_notifications_config 

2331 ) 

2332 if oss_cluster_maint_notifications_handler: 

2333 self._oss_cluster_maint_notifications_handler = ( 

2334 oss_cluster_maint_notifications_handler 

2335 ) 

2336 self._update_connection_kwargs_for_maint_notifications( 

2337 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler 

2338 ) 

2339 self._maint_notifications_pool_handler = None 

2340 else: 

2341 self._oss_cluster_maint_notifications_handler = None 

2342 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2343 self, maint_notifications_config 

2344 ) 

2345 

2346 self._update_connection_kwargs_for_maint_notifications( 

2347 maint_notifications_pool_handler=self._maint_notifications_pool_handler 

2348 ) 

2349 else: 

2350 self._maint_notifications_pool_handler = None 

2351 self._oss_cluster_maint_notifications_handler = None 

2352 

2353 @property 

2354 @abstractmethod 

2355 def connection_kwargs(self) -> Dict[str, Any]: 

2356 pass 

2357 

2358 @connection_kwargs.setter 

2359 @abstractmethod 

2360 def connection_kwargs(self, value: Dict[str, Any]): 

2361 pass 

2362 

2363 @abstractmethod 

2364 def _get_pool_lock(self) -> threading.RLock: 

2365 pass 

2366 

2367 @abstractmethod 

2368 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]: 

2369 pass 

2370 

2371 @abstractmethod 

2372 def _get_in_use_connections( 

2373 self, 

2374 ) -> Iterable["MaintNotificationsAbstractConnection"]: 

2375 pass 

2376 

2377 def maint_notifications_enabled(self): 

2378 """ 

2379 Returns: 

2380 True if the maintenance notifications are enabled, False otherwise. 

2381 The maintenance notifications config is stored in the pool handler. 

2382 If the pool handler is not set, the maintenance notifications are not enabled. 

2383 """ 

2384 if self._oss_cluster_maint_notifications_handler: 

2385 maint_notifications_config = ( 

2386 self._oss_cluster_maint_notifications_handler.config 

2387 ) 

2388 else: 

2389 maint_notifications_config = ( 

2390 self._maint_notifications_pool_handler.config 

2391 if self._maint_notifications_pool_handler 

2392 else None 

2393 ) 

2394 

2395 return maint_notifications_config and maint_notifications_config.enabled 

2396 

2397 def update_maint_notifications_config( 

2398 self, 

2399 maint_notifications_config: MaintNotificationsConfig, 

2400 oss_cluster_maint_notifications_handler: Optional[ 

2401 OSSMaintNotificationsHandler 

2402 ] = None, 

2403 ): 

2404 """ 

2405 Updates the maintenance notifications configuration. 

2406 This method should be called only if the pool was created 

2407 without enabling the maintenance notifications and 

2408 in a later point in time maintenance notifications 

2409 are requested to be enabled. 

2410 """ 

2411 if ( 

2412 self.maint_notifications_enabled() 

2413 and not maint_notifications_config.enabled 

2414 ): 

2415 raise ValueError( 

2416 "Cannot disable maintenance notifications after enabling them" 

2417 ) 

2418 if oss_cluster_maint_notifications_handler: 

2419 self._oss_cluster_maint_notifications_handler = ( 

2420 oss_cluster_maint_notifications_handler 

2421 ) 

2422 else: 

2423 # first update pool settings 

2424 if not self._maint_notifications_pool_handler: 

2425 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2426 self, maint_notifications_config 

2427 ) 

2428 else: 

2429 self._maint_notifications_pool_handler.config = ( 

2430 maint_notifications_config 

2431 ) 

2432 

2433 # then update connection kwargs and existing connections 

2434 self._update_connection_kwargs_for_maint_notifications( 

2435 maint_notifications_pool_handler=self._maint_notifications_pool_handler, 

2436 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler, 

2437 ) 

2438 self._update_maint_notifications_configs_for_connections( 

2439 maint_notifications_pool_handler=self._maint_notifications_pool_handler, 

2440 oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler, 

2441 ) 

2442 

2443 def _update_connection_kwargs_for_maint_notifications( 

2444 self, 

2445 maint_notifications_pool_handler: Optional[ 

2446 MaintNotificationsPoolHandler 

2447 ] = None, 

2448 oss_cluster_maint_notifications_handler: Optional[ 

2449 OSSMaintNotificationsHandler 

2450 ] = None, 

2451 ): 

2452 """ 

2453 Update the connection kwargs for all future connections. 

2454 """ 

2455 if not self.maint_notifications_enabled(): 

2456 return 

2457 if maint_notifications_pool_handler: 

2458 self.connection_kwargs.update( 

2459 { 

2460 "maint_notifications_pool_handler": maint_notifications_pool_handler, 

2461 "maint_notifications_config": maint_notifications_pool_handler.config, 

2462 } 

2463 ) 

2464 if oss_cluster_maint_notifications_handler: 

2465 self.connection_kwargs.update( 

2466 { 

2467 "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler, 

2468 "maint_notifications_config": oss_cluster_maint_notifications_handler.config, 

2469 } 

2470 ) 

2471 

2472 # Store original connection parameters for maintenance notifications. 

2473 if self.connection_kwargs.get("orig_host_address", None) is None: 

2474 # If orig_host_address is None it means we haven't 

2475 # configured the original values yet 

2476 self.connection_kwargs.update( 

2477 { 

2478 "orig_host_address": self.connection_kwargs.get("host"), 

2479 "orig_socket_timeout": self.connection_kwargs.get( 

2480 "socket_timeout", None 

2481 ), 

2482 "orig_socket_connect_timeout": self.connection_kwargs.get( 

2483 "socket_connect_timeout", None 

2484 ), 

2485 } 

2486 ) 

2487 

2488 def _update_maint_notifications_configs_for_connections( 

2489 self, 

2490 maint_notifications_pool_handler: Optional[ 

2491 MaintNotificationsPoolHandler 

2492 ] = None, 

2493 oss_cluster_maint_notifications_handler: Optional[ 

2494 OSSMaintNotificationsHandler 

2495 ] = None, 

2496 ): 

2497 """Update the maintenance notifications config for all connections in the pool.""" 

2498 with self._get_pool_lock(): 

2499 for conn in self._get_free_connections(): 

2500 if oss_cluster_maint_notifications_handler: 

2501 # set cluster handler for conn 

2502 conn.set_maint_notifications_cluster_handler_for_connection( 

2503 oss_cluster_maint_notifications_handler 

2504 ) 

2505 conn.maint_notifications_config = ( 

2506 oss_cluster_maint_notifications_handler.config 

2507 ) 

2508 elif maint_notifications_pool_handler: 

2509 conn.set_maint_notifications_pool_handler_for_connection( 

2510 maint_notifications_pool_handler 

2511 ) 

2512 conn.maint_notifications_config = ( 

2513 maint_notifications_pool_handler.config 

2514 ) 

2515 else: 

2516 raise ValueError( 

2517 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set" 

2518 ) 

2519 conn.disconnect() 

2520 for conn in self._get_in_use_connections(): 

2521 if oss_cluster_maint_notifications_handler: 

2522 conn.maint_notifications_config = ( 

2523 oss_cluster_maint_notifications_handler.config 

2524 ) 

2525 conn._configure_maintenance_notifications( 

2526 oss_cluster_maint_notifications_handler=oss_cluster_maint_notifications_handler 

2527 ) 

2528 elif maint_notifications_pool_handler: 

2529 conn.set_maint_notifications_pool_handler_for_connection( 

2530 maint_notifications_pool_handler 

2531 ) 

2532 conn.maint_notifications_config = ( 

2533 maint_notifications_pool_handler.config 

2534 ) 

2535 else: 

2536 raise ValueError( 

2537 "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set" 

2538 ) 

2539 conn.mark_for_reconnect() 

2540 

2541 def _should_update_connection( 

2542 self, 

2543 conn: "MaintNotificationsAbstractConnection", 

2544 matching_pattern: Literal[ 

2545 "connected_address", "configured_address", "notification_hash" 

2546 ] = "connected_address", 

2547 matching_address: Optional[str] = None, 

2548 matching_notification_hash: Optional[int] = None, 

2549 ) -> bool: 

2550 """ 

2551 Check if the connection should be updated based on the matching criteria. 

2552 """ 

2553 if matching_pattern == "connected_address": 

2554 if matching_address and conn.getpeername() != matching_address: 

2555 return False 

2556 elif matching_pattern == "configured_address": 

2557 if matching_address and conn.host != matching_address: 

2558 return False 

2559 elif matching_pattern == "notification_hash": 

2560 if ( 

2561 matching_notification_hash 

2562 and conn.maintenance_notification_hash != matching_notification_hash 

2563 ): 

2564 return False 

2565 return True 

2566 

2567 def update_connection_settings( 

2568 self, 

2569 conn: "MaintNotificationsAbstractConnection", 

2570 state: Optional["MaintenanceState"] = None, 

2571 maintenance_notification_hash: Optional[int] = None, 

2572 host_address: Optional[str] = None, 

2573 relaxed_timeout: Optional[float] = None, 

2574 update_notification_hash: bool = False, 

2575 reset_host_address: bool = False, 

2576 reset_relaxed_timeout: bool = False, 

2577 ): 

2578 """ 

2579 Update the settings for a single connection. 

2580 """ 

2581 if state: 

2582 conn.maintenance_state = state 

2583 

2584 if update_notification_hash: 

2585 # update the notification hash only if requested 

2586 conn.maintenance_notification_hash = maintenance_notification_hash 

2587 

2588 if host_address is not None: 

2589 conn.set_tmp_settings(tmp_host_address=host_address) 

2590 

2591 if relaxed_timeout is not None: 

2592 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout) 

2593 

2594 if reset_relaxed_timeout or reset_host_address: 

2595 conn.reset_tmp_settings( 

2596 reset_host_address=reset_host_address, 

2597 reset_relaxed_timeout=reset_relaxed_timeout, 

2598 ) 

2599 

2600 conn.update_current_socket_timeout(relaxed_timeout) 

2601 

2602 def update_connections_settings( 

2603 self, 

2604 state: Optional["MaintenanceState"] = None, 

2605 maintenance_notification_hash: Optional[int] = None, 

2606 host_address: Optional[str] = None, 

2607 relaxed_timeout: Optional[float] = None, 

2608 matching_address: Optional[str] = None, 

2609 matching_notification_hash: Optional[int] = None, 

2610 matching_pattern: Literal[ 

2611 "connected_address", "configured_address", "notification_hash" 

2612 ] = "connected_address", 

2613 update_notification_hash: bool = False, 

2614 reset_host_address: bool = False, 

2615 reset_relaxed_timeout: bool = False, 

2616 include_free_connections: bool = True, 

2617 ): 

2618 """ 

2619 Update the settings for all matching connections in the pool. 

2620 

2621 This method does not create new connections. 

2622 This method does not affect the connection kwargs. 

2623 

2624 :param state: The maintenance state to set for the connection. 

2625 :param maintenance_notification_hash: The hash of the maintenance notification 

2626 to set for the connection. 

2627 :param host_address: The host address to set for the connection. 

2628 :param relaxed_timeout: The relaxed timeout to set for the connection. 

2629 :param matching_address: The address to match for the connection. 

2630 :param matching_notification_hash: The notification hash to match for the connection. 

2631 :param matching_pattern: The pattern to match for the connection. 

2632 :param update_notification_hash: Whether to update the notification hash for the connection. 

2633 :param reset_host_address: Whether to reset the host address to the original address. 

2634 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout. 

2635 :param include_free_connections: Whether to include free/available connections. 

2636 """ 

2637 with self._get_pool_lock(): 

2638 for conn in self._get_in_use_connections(): 

2639 if self._should_update_connection( 

2640 conn, 

2641 matching_pattern, 

2642 matching_address, 

2643 matching_notification_hash, 

2644 ): 

2645 self.update_connection_settings( 

2646 conn, 

2647 state=state, 

2648 maintenance_notification_hash=maintenance_notification_hash, 

2649 host_address=host_address, 

2650 relaxed_timeout=relaxed_timeout, 

2651 update_notification_hash=update_notification_hash, 

2652 reset_host_address=reset_host_address, 

2653 reset_relaxed_timeout=reset_relaxed_timeout, 

2654 ) 

2655 

2656 if include_free_connections: 

2657 for conn in self._get_free_connections(): 

2658 if self._should_update_connection( 

2659 conn, 

2660 matching_pattern, 

2661 matching_address, 

2662 matching_notification_hash, 

2663 ): 

2664 self.update_connection_settings( 

2665 conn, 

2666 state=state, 

2667 maintenance_notification_hash=maintenance_notification_hash, 

2668 host_address=host_address, 

2669 relaxed_timeout=relaxed_timeout, 

2670 update_notification_hash=update_notification_hash, 

2671 reset_host_address=reset_host_address, 

2672 reset_relaxed_timeout=reset_relaxed_timeout, 

2673 ) 

2674 

2675 def update_connection_kwargs( 

2676 self, 

2677 **kwargs, 

2678 ): 

2679 """ 

2680 Update the connection kwargs for all future connections. 

2681 

2682 This method updates the connection kwargs for all future connections created by the pool. 

2683 Existing connections are not affected. 

2684 """ 

2685 self.connection_kwargs.update(kwargs) 

2686 

2687 def update_active_connections_for_reconnect( 

2688 self, 

2689 moving_address_src: Optional[str] = None, 

2690 ): 

2691 """ 

2692 Mark all active connections for reconnect. 

2693 This is used when a cluster node is migrated to a different address. 

2694 

2695 :param moving_address_src: The address of the node that is being moved. 

2696 """ 

2697 with self._get_pool_lock(): 

2698 for conn in self._get_in_use_connections(): 

2699 if self._should_update_connection( 

2700 conn, "connected_address", moving_address_src 

2701 ): 

2702 conn.mark_for_reconnect() 

2703 

2704 def disconnect_free_connections( 

2705 self, 

2706 moving_address_src: Optional[str] = None, 

2707 ): 

2708 """ 

2709 Disconnect all free/available connections. 

2710 This is used when a cluster node is migrated to a different address. 

2711 

2712 :param moving_address_src: The address of the node that is being moved. 

2713 """ 

2714 with self._get_pool_lock(): 

2715 for conn in self._get_free_connections(): 

2716 if self._should_update_connection( 

2717 conn, "connected_address", moving_address_src 

2718 ): 

2719 conn.disconnect() 

2720 

2721 

2722class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface): 

2723 """ 

2724 Create a connection pool. ``If max_connections`` is set, then this 

2725 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's 

2726 limit is reached. 

2727 

2728 By default, TCP connections are created unless ``connection_class`` 

2729 is specified. Use class:`.UnixDomainSocketConnection` for 

2730 unix sockets. 

2731 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

2732 

2733 If ``maint_notifications_config`` is provided, the connection pool will support 

2734 maintenance notifications. 

2735 Maintenance notifications are supported only with RESP3. 

2736 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3, 

2737 the maintenance notifications will be enabled by default. 

2738 

2739 Any additional keyword arguments are passed to the constructor of 

2740 ``connection_class``. 

2741 """ 

2742 

2743 @classmethod 

2744 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

2745 """ 

2746 Return a connection pool configured from the given URL. 

2747 

2748 For example:: 

2749 

2750 redis://[[username]:[password]]@localhost:6379/0 

2751 rediss://[[username]:[password]]@localhost:6379/0 

2752 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

2753 

2754 Three URL schemes are supported: 

2755 

2756 - `redis://` creates a TCP socket connection. See more at: 

2757 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

2758 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

2759 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

2760 - ``unix://``: creates a Unix Domain Socket connection. 

2761 

2762 The username, password, hostname, path and all querystring values 

2763 are passed through urllib.parse.unquote in order to replace any 

2764 percent-encoded values with their corresponding characters. 

2765 

2766 There are several ways to specify a database number. The first value 

2767 found will be used: 

2768 

2769 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

2770 2. If using the redis:// or rediss:// schemes, the path argument 

2771 of the url, e.g. redis://localhost/0 

2772 3. A ``db`` keyword argument to this function. 

2773 

2774 If none of these options are specified, the default db=0 is used. 

2775 

2776 All querystring options are cast to their appropriate Python types. 

2777 Boolean arguments can be specified with string values "True"/"False" 

2778 or "Yes"/"No". Values that cannot be properly cast cause a 

2779 ``ValueError`` to be raised. Once parsed, the querystring arguments 

2780 and keyword arguments are passed to the ``ConnectionPool``'s 

2781 class initializer. In the case of conflicting arguments, querystring 

2782 arguments always win. 

2783 """ 

2784 url_options = parse_url(url) 

2785 

2786 if "connection_class" in kwargs: 

2787 url_options["connection_class"] = kwargs["connection_class"] 

2788 

2789 kwargs.update(url_options) 

2790 return cls(**kwargs) 

2791 

2792 def __init__( 

2793 self, 

2794 connection_class=Connection, 

2795 max_connections: Optional[int] = None, 

2796 cache_factory: Optional[CacheFactoryInterface] = None, 

2797 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

2798 **connection_kwargs, 

2799 ): 

2800 max_connections = max_connections or 2**31 

2801 if not isinstance(max_connections, int) or max_connections < 0: 

2802 raise ValueError('"max_connections" must be a positive integer') 

2803 

2804 self.connection_class = connection_class 

2805 self._connection_kwargs = connection_kwargs 

2806 self.max_connections = max_connections 

2807 self.cache = None 

2808 self._cache_factory = cache_factory 

2809 

2810 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None) 

2811 if self._event_dispatcher is None: 

2812 self._event_dispatcher = EventDispatcher() 

2813 

2814 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): 

2815 if not check_protocol_version(self._connection_kwargs.get("protocol"), 3): 

2816 raise RedisError("Client caching is only supported with RESP version 3") 

2817 

2818 cache = self._connection_kwargs.get("cache") 

2819 

2820 if cache is not None: 

2821 if not isinstance(cache, CacheInterface): 

2822 raise ValueError("Cache must implement CacheInterface") 

2823 

2824 self.cache = cache 

2825 else: 

2826 if self._cache_factory is not None: 

2827 self.cache = CacheProxy(self._cache_factory.get_cache()) 

2828 else: 

2829 self.cache = CacheFactory( 

2830 self._connection_kwargs.get("cache_config") 

2831 ).get_cache() 

2832 

2833 init_csc_items() 

2834 register_csc_items_callback( 

2835 callback=lambda: self.cache.size, 

2836 pool_name=get_pool_name(self), 

2837 ) 

2838 

2839 connection_kwargs.pop("cache", None) 

2840 connection_kwargs.pop("cache_config", None) 

2841 

2842 # a lock to protect the critical section in _checkpid(). 

2843 # this lock is acquired when the process id changes, such as 

2844 # after a fork. during this time, multiple threads in the child 

2845 # process could attempt to acquire this lock. the first thread 

2846 # to acquire the lock will reset the data structures and lock 

2847 # object of this pool. subsequent threads acquiring this lock 

2848 # will notice the first thread already did the work and simply 

2849 # release the lock. 

2850 

2851 self._fork_lock = threading.RLock() 

2852 self._lock = threading.RLock() 

2853 

2854 # Generate unique pool ID for observability (matches go-redis behavior) 

2855 import secrets 

2856 

2857 self._pool_id = secrets.token_hex(4) 

2858 

2859 MaintNotificationsAbstractConnectionPool.__init__( 

2860 self, 

2861 maint_notifications_config=maint_notifications_config, 

2862 **connection_kwargs, 

2863 ) 

2864 

2865 self.reset() 

2866 

2867 def __repr__(self) -> str: 

2868 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

2869 return ( 

2870 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

2871 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

2872 f"({conn_kwargs})>)>" 

2873 ) 

2874 

2875 @property 

2876 def connection_kwargs(self) -> Dict[str, Any]: 

2877 return self._connection_kwargs 

2878 

2879 @connection_kwargs.setter 

2880 def connection_kwargs(self, value: Dict[str, Any]): 

2881 self._connection_kwargs = value 

2882 

2883 def get_protocol(self): 

2884 """ 

2885 Returns: 

2886 The RESP protocol version, or ``None`` if the protocol is not specified, 

2887 in which case the server default will be used. 

2888 """ 

2889 return self.connection_kwargs.get("protocol", None) 

2890 

2891 def reset(self) -> None: 

2892 self._created_connections = 0 

2893 self._available_connections = [] 

2894 self._in_use_connections = set() 

2895 

2896 # this must be the last operation in this method. while reset() is 

2897 # called when holding _fork_lock, other threads in this process 

2898 # can call _checkpid() which compares self.pid and os.getpid() without 

2899 # holding any lock (for performance reasons). keeping this assignment 

2900 # as the last operation ensures that those other threads will also 

2901 # notice a pid difference and block waiting for the first thread to 

2902 # release _fork_lock. when each of these threads eventually acquire 

2903 # _fork_lock, they will notice that another thread already called 

2904 # reset() and they will immediately release _fork_lock and continue on. 

2905 self.pid = os.getpid() 

2906 

2907 def _checkpid(self) -> None: 

2908 # _checkpid() attempts to keep ConnectionPool fork-safe on modern 

2909 # systems. this is called by all ConnectionPool methods that 

2910 # manipulate the pool's state such as get_connection() and release(). 

2911 # 

2912 # _checkpid() determines whether the process has forked by comparing 

2913 # the current process id to the process id saved on the ConnectionPool 

2914 # instance. if these values are the same, _checkpid() simply returns. 

2915 # 

2916 # when the process ids differ, _checkpid() assumes that the process 

2917 # has forked and that we're now running in the child process. the child 

2918 # process cannot use the parent's file descriptors (e.g., sockets). 

2919 # therefore, when _checkpid() sees the process id change, it calls 

2920 # reset() in order to reinitialize the child's ConnectionPool. this 

2921 # will cause the child to make all new connection objects. 

2922 # 

2923 # _checkpid() is protected by self._fork_lock to ensure that multiple 

2924 # threads in the child process do not call reset() multiple times. 

2925 # 

2926 # there is an extremely small chance this could fail in the following 

2927 # scenario: 

2928 # 1. process A calls _checkpid() for the first time and acquires 

2929 # self._fork_lock. 

2930 # 2. while holding self._fork_lock, process A forks (the fork() 

2931 # could happen in a different thread owned by process A) 

2932 # 3. process B (the forked child process) inherits the 

2933 # ConnectionPool's state from the parent. that state includes 

2934 # a locked _fork_lock. process B will not be notified when 

2935 # process A releases the _fork_lock and will thus never be 

2936 # able to acquire the _fork_lock. 

2937 # 

2938 # to mitigate this possible deadlock, _checkpid() will only wait 5 

2939 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in 

2940 # that time it is assumed that the child is deadlocked and a 

2941 # redis.ChildDeadlockedError error is raised. 

2942 if self.pid != os.getpid(): 

2943 acquired = self._fork_lock.acquire(timeout=5) 

2944 if not acquired: 

2945 raise ChildDeadlockedError 

2946 # reset() the instance for the new process if another thread 

2947 # hasn't already done so 

2948 try: 

2949 if self.pid != os.getpid(): 

2950 self.reset() 

2951 finally: 

2952 self._fork_lock.release() 

2953 

2954 @deprecated_args( 

2955 args_to_warn=["*"], 

2956 reason="Use get_connection() without args instead", 

2957 version="5.3.0", 

2958 ) 

2959 def get_connection(self, command_name=None, *keys, **options) -> "Connection": 

2960 "Get a connection from the pool" 

2961 

2962 # Start timing for observability 

2963 self._checkpid() 

2964 is_created = False 

2965 

2966 with self._lock: 

2967 try: 

2968 connection = self._available_connections.pop() 

2969 except IndexError: 

2970 # Start timing for observability 

2971 start_time_created = time.monotonic() 

2972 

2973 connection = self.make_connection() 

2974 is_created = True 

2975 self._in_use_connections.add(connection) 

2976 

2977 try: 

2978 # ensure this connection is connected to Redis 

2979 connection.connect() 

2980 # connections that the pool provides should be ready to send 

2981 # a command. if not, the connection was either returned to the 

2982 # pool before all data has been read or the socket has been 

2983 # closed. either way, reconnect and verify everything is good. 

2984 try: 

2985 if ( 

2986 connection.can_read() 

2987 and self.cache is None 

2988 and not self.maint_notifications_enabled() 

2989 ): 

2990 raise ConnectionError("Connection has data") 

2991 except (ConnectionError, TimeoutError, OSError): 

2992 connection.disconnect() 

2993 connection.connect() 

2994 if connection.can_read(): 

2995 raise ConnectionError("Connection not ready") 

2996 except BaseException: 

2997 # release the connection back to the pool so that we don't 

2998 # leak it 

2999 self.release(connection) 

3000 raise 

3001 

3002 if is_created: 

3003 record_connection_create_time( 

3004 connection_pool=self, 

3005 duration_seconds=time.monotonic() - start_time_created, 

3006 ) 

3007 return connection 

3008 

3009 def get_encoder(self) -> Encoder: 

3010 "Return an encoder based on encoding settings" 

3011 kwargs = self.connection_kwargs 

3012 return Encoder( 

3013 encoding=kwargs.get("encoding", "utf-8"), 

3014 encoding_errors=kwargs.get("encoding_errors", "strict"), 

3015 decode_responses=kwargs.get("decode_responses", False), 

3016 ) 

3017 

3018 def make_connection(self) -> "ConnectionInterface": 

3019 "Create a new connection" 

3020 if self._created_connections >= self.max_connections: 

3021 raise MaxConnectionsError("Too many connections") 

3022 self._created_connections += 1 

3023 

3024 kwargs = dict(self.connection_kwargs) 

3025 

3026 if self.cache is not None: 

3027 return CacheProxyConnection( 

3028 self.connection_class(**kwargs), self.cache, self._lock 

3029 ) 

3030 return self.connection_class(**kwargs) 

3031 

3032 def release(self, connection: "Connection") -> None: 

3033 "Releases the connection back to the pool" 

3034 self._checkpid() 

3035 with self._lock: 

3036 try: 

3037 self._in_use_connections.remove(connection) 

3038 except KeyError: 

3039 # Gracefully fail when a connection is returned to this pool 

3040 # that the pool doesn't actually own 

3041 return 

3042 

3043 if self.owns_connection(connection): 

3044 if connection.should_reconnect(): 

3045 connection.disconnect() 

3046 self._available_connections.append(connection) 

3047 self._event_dispatcher.dispatch( 

3048 AfterConnectionReleasedEvent(connection) 

3049 ) 

3050 else: 

3051 # Pool doesn't own this connection, do not add it back 

3052 # to the pool. 

3053 # The created connections count should not be changed, 

3054 # because the connection was not created by the pool. 

3055 connection.disconnect() 

3056 return 

3057 

3058 def owns_connection(self, connection: "Connection") -> int: 

3059 return connection.pid == self.pid 

3060 

3061 def disconnect(self, inuse_connections: bool = True) -> None: 

3062 """ 

3063 Disconnects connections in the pool 

3064 

3065 If ``inuse_connections`` is True, disconnect connections that are 

3066 currently in use, potentially by other threads. Otherwise only disconnect 

3067 connections that are idle in the pool. 

3068 """ 

3069 self._checkpid() 

3070 with self._lock: 

3071 if inuse_connections: 

3072 connections = chain( 

3073 self._available_connections, self._in_use_connections 

3074 ) 

3075 else: 

3076 connections = self._available_connections 

3077 

3078 for connection in connections: 

3079 connection.disconnect() 

3080 

3081 def close(self) -> None: 

3082 """Close the pool, disconnecting all connections""" 

3083 self.disconnect() 

3084 

3085 def set_retry(self, retry: Retry) -> None: 

3086 self.connection_kwargs.update({"retry": retry}) 

3087 for conn in self._available_connections: 

3088 conn.retry = retry 

3089 for conn in self._in_use_connections: 

3090 conn.retry = retry 

3091 

3092 def re_auth_callback(self, token: TokenInterface): 

3093 with self._lock: 

3094 for conn in self._available_connections: 

3095 conn.retry.call_with_retry( 

3096 lambda: conn.send_command( 

3097 "AUTH", token.try_get("oid"), token.get_value() 

3098 ), 

3099 lambda error: self._mock(error), 

3100 ) 

3101 conn.retry.call_with_retry( 

3102 lambda: conn.read_response(), lambda error: self._mock(error) 

3103 ) 

3104 for conn in self._in_use_connections: 

3105 conn.set_re_auth_token(token) 

3106 

3107 def _get_pool_lock(self): 

3108 return self._lock 

3109 

3110 def _get_free_connections(self): 

3111 with self._lock: 

3112 return list(self._available_connections) 

3113 

3114 def _get_in_use_connections(self): 

3115 with self._lock: 

3116 return set(self._in_use_connections) 

3117 

3118 def _mock(self, error: RedisError): 

3119 """ 

3120 Dummy functions, needs to be passed as error callback to retry object. 

3121 :param error: 

3122 :return: 

3123 """ 

3124 pass 

3125 

3126 def get_connection_count(self) -> List[tuple[int, dict]]: 

3127 from redis.observability.attributes import get_pool_name 

3128 

3129 attributes = AttributeBuilder.build_base_attributes() 

3130 attributes[DB_CLIENT_CONNECTION_POOL_NAME] = get_pool_name(self) 

3131 free_connections_attributes = attributes.copy() 

3132 in_use_connections_attributes = attributes.copy() 

3133 

3134 free_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ( 

3135 ConnectionState.IDLE.value 

3136 ) 

3137 in_use_connections_attributes[DB_CLIENT_CONNECTION_STATE] = ( 

3138 ConnectionState.USED.value 

3139 ) 

3140 

3141 return [ 

3142 (len(self._get_free_connections()), free_connections_attributes), 

3143 (len(self._get_in_use_connections()), in_use_connections_attributes), 

3144 ] 

3145 

3146 

3147class BlockingConnectionPool(ConnectionPool): 

3148 """ 

3149 Thread-safe blocking connection pool:: 

3150 

3151 >>> from redis.client import Redis 

3152 >>> client = Redis(connection_pool=BlockingConnectionPool()) 

3153 

3154 It performs the same function as the default 

3155 :py:class:`~redis.ConnectionPool` implementation, in that, 

3156 it maintains a pool of reusable connections that can be shared by 

3157 multiple redis clients (safely across threads if required). 

3158 

3159 The difference is that, in the event that a client tries to get a 

3160 connection from the pool when all of connections are in use, rather than 

3161 raising a :py:class:`~redis.ConnectionError` (as the default 

3162 :py:class:`~redis.ConnectionPool` implementation does), it 

3163 makes the client wait ("blocks") for a specified number of seconds until 

3164 a connection becomes available. 

3165 

3166 Use ``max_connections`` to increase / decrease the pool size:: 

3167 

3168 >>> pool = BlockingConnectionPool(max_connections=10) 

3169 

3170 Use ``timeout`` to tell it either how many seconds to wait for a connection 

3171 to become available, or to block forever: 

3172 

3173 >>> # Block forever. 

3174 >>> pool = BlockingConnectionPool(timeout=None) 

3175 

3176 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

3177 >>> # not available. 

3178 >>> pool = BlockingConnectionPool(timeout=5) 

3179 """ 

3180 

3181 def __init__( 

3182 self, 

3183 max_connections=50, 

3184 timeout=20, 

3185 connection_class=Connection, 

3186 queue_class=LifoQueue, 

3187 **connection_kwargs, 

3188 ): 

3189 self.queue_class = queue_class 

3190 self.timeout = timeout 

3191 self._in_maintenance = False 

3192 self._locked = False 

3193 super().__init__( 

3194 connection_class=connection_class, 

3195 max_connections=max_connections, 

3196 **connection_kwargs, 

3197 ) 

3198 

3199 def reset(self): 

3200 # Create and fill up a thread safe queue with ``None`` values. 

3201 try: 

3202 if self._in_maintenance: 

3203 self._lock.acquire() 

3204 self._locked = True 

3205 self.pool = self.queue_class(self.max_connections) 

3206 while True: 

3207 try: 

3208 self.pool.put_nowait(None) 

3209 except Full: 

3210 break 

3211 

3212 # Keep a list of actual connection instances so that we can 

3213 # disconnect them later. 

3214 self._connections = [] 

3215 finally: 

3216 if self._locked: 

3217 try: 

3218 self._lock.release() 

3219 except Exception: 

3220 pass 

3221 self._locked = False 

3222 

3223 # this must be the last operation in this method. while reset() is 

3224 # called when holding _fork_lock, other threads in this process 

3225 # can call _checkpid() which compares self.pid and os.getpid() without 

3226 # holding any lock (for performance reasons). keeping this assignment 

3227 # as the last operation ensures that those other threads will also 

3228 # notice a pid difference and block waiting for the first thread to 

3229 # release _fork_lock. when each of these threads eventually acquire 

3230 # _fork_lock, they will notice that another thread already called 

3231 # reset() and they will immediately release _fork_lock and continue on. 

3232 self.pid = os.getpid() 

3233 

3234 def make_connection(self): 

3235 "Make a fresh connection." 

3236 try: 

3237 if self._in_maintenance: 

3238 self._lock.acquire() 

3239 self._locked = True 

3240 

3241 if self.cache is not None: 

3242 connection = CacheProxyConnection( 

3243 self.connection_class(**self.connection_kwargs), 

3244 self.cache, 

3245 self._lock, 

3246 ) 

3247 else: 

3248 connection = self.connection_class(**self.connection_kwargs) 

3249 self._connections.append(connection) 

3250 return connection 

3251 finally: 

3252 if self._locked: 

3253 try: 

3254 self._lock.release() 

3255 except Exception: 

3256 pass 

3257 self._locked = False 

3258 

3259 @deprecated_args( 

3260 args_to_warn=["*"], 

3261 reason="Use get_connection() without args instead", 

3262 version="5.3.0", 

3263 ) 

3264 def get_connection(self, command_name=None, *keys, **options): 

3265 """ 

3266 Get a connection, blocking for ``self.timeout`` until a connection 

3267 is available from the pool. 

3268 

3269 If the connection returned is ``None`` then creates a new connection. 

3270 Because we use a last-in first-out queue, the existing connections 

3271 (having been returned to the pool after the initial ``None`` values 

3272 were added) will be returned before ``None`` values. This means we only 

3273 create new connections when we need to, i.e.: the actual number of 

3274 connections will only increase in response to demand. 

3275 """ 

3276 start_time_acquired = time.monotonic() 

3277 # Make sure we haven't changed process. 

3278 self._checkpid() 

3279 is_created = False 

3280 

3281 # Try and get a connection from the pool. If one isn't available within 

3282 # self.timeout then raise a ``ConnectionError``. 

3283 connection = None 

3284 try: 

3285 if self._in_maintenance: 

3286 self._lock.acquire() 

3287 self._locked = True 

3288 try: 

3289 connection = self.pool.get(block=True, timeout=self.timeout) 

3290 except Empty: 

3291 # Note that this is not caught by the redis client and will be 

3292 # raised unless handled by application code. If you want never to 

3293 raise ConnectionError("No connection available.") 

3294 

3295 # If the ``connection`` is actually ``None`` then that's a cue to make 

3296 # a new connection to add to the pool. 

3297 if connection is None: 

3298 # Start timing for observability 

3299 start_time_created = time.monotonic() 

3300 connection = self.make_connection() 

3301 is_created = True 

3302 finally: 

3303 if self._locked: 

3304 try: 

3305 self._lock.release() 

3306 except Exception: 

3307 pass 

3308 self._locked = False 

3309 

3310 try: 

3311 # ensure this connection is connected to Redis 

3312 connection.connect() 

3313 # connections that the pool provides should be ready to send 

3314 # a command. if not, the connection was either returned to the 

3315 # pool before all data has been read or the socket has been 

3316 # closed. either way, reconnect and verify everything is good. 

3317 try: 

3318 if connection.can_read(): 

3319 raise ConnectionError("Connection has data") 

3320 except (ConnectionError, TimeoutError, OSError): 

3321 connection.disconnect() 

3322 connection.connect() 

3323 if connection.can_read(): 

3324 raise ConnectionError("Connection not ready") 

3325 except BaseException: 

3326 # release the connection back to the pool so that we don't leak it 

3327 self.release(connection) 

3328 raise 

3329 

3330 if is_created: 

3331 record_connection_create_time( 

3332 connection_pool=self, 

3333 duration_seconds=time.monotonic() - start_time_created, 

3334 ) 

3335 

3336 record_connection_wait_time( 

3337 pool_name=get_pool_name(self), 

3338 duration_seconds=time.monotonic() - start_time_acquired, 

3339 ) 

3340 

3341 return connection 

3342 

3343 def release(self, connection): 

3344 "Releases the connection back to the pool." 

3345 # Make sure we haven't changed process. 

3346 self._checkpid() 

3347 

3348 try: 

3349 if self._in_maintenance: 

3350 self._lock.acquire() 

3351 self._locked = True 

3352 if not self.owns_connection(connection): 

3353 # pool doesn't own this connection. do not add it back 

3354 # to the pool. instead add a None value which is a placeholder 

3355 # that will cause the pool to recreate the connection if 

3356 # its needed. 

3357 connection.disconnect() 

3358 self.pool.put_nowait(None) 

3359 return 

3360 if connection.should_reconnect(): 

3361 connection.disconnect() 

3362 # Put the connection back into the pool. 

3363 try: 

3364 self.pool.put_nowait(connection) 

3365 except Full: 

3366 # perhaps the pool has been reset() after a fork? regardless, 

3367 # we don't want this connection 

3368 pass 

3369 finally: 

3370 if self._locked: 

3371 try: 

3372 self._lock.release() 

3373 except Exception: 

3374 pass 

3375 self._locked = False 

3376 

3377 def disconnect(self, inuse_connections: bool = True): 

3378 "Disconnects either all connections in the pool or just the free connections." 

3379 self._checkpid() 

3380 try: 

3381 if self._in_maintenance: 

3382 self._lock.acquire() 

3383 self._locked = True 

3384 if inuse_connections: 

3385 connections = self._connections 

3386 else: 

3387 connections = self._get_free_connections() 

3388 for connection in connections: 

3389 connection.disconnect() 

3390 finally: 

3391 if self._locked: 

3392 try: 

3393 self._lock.release() 

3394 except Exception: 

3395 pass 

3396 self._locked = False 

3397 

3398 def _get_free_connections(self): 

3399 with self._lock: 

3400 return {conn for conn in self.pool.queue if conn} 

3401 

3402 def _get_in_use_connections(self): 

3403 with self._lock: 

3404 # free connections 

3405 connections_in_queue = {conn for conn in self.pool.queue if conn} 

3406 # in self._connections we keep all created connections 

3407 # so the ones that are not in the queue are the in use ones 

3408 return { 

3409 conn for conn in self._connections if conn not in connections_in_queue 

3410 } 

3411 

3412 def set_in_maintenance(self, in_maintenance: bool): 

3413 """ 

3414 Sets a flag that this Blocking ConnectionPool is in maintenance mode. 

3415 

3416 This is used to prevent new connections from being created while we are in maintenance mode. 

3417 The pool will be in maintenance mode only when we are processing a MOVING notification. 

3418 """ 

3419 self._in_maintenance = in_maintenance