Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1395 statements  

1import copy 

2import os 

3import socket 

4import sys 

5import threading 

6import time 

7import weakref 

8from abc import ABC, abstractmethod 

9from itertools import chain 

10from queue import Empty, Full, LifoQueue 

11from typing import ( 

12 Any, 

13 Callable, 

14 Dict, 

15 Iterable, 

16 List, 

17 Literal, 

18 Optional, 

19 Type, 

20 TypeVar, 

21 Union, 

22) 

23from urllib.parse import parse_qs, unquote, urlparse 

24 

25from redis.cache import ( 

26 CacheEntry, 

27 CacheEntryStatus, 

28 CacheFactory, 

29 CacheFactoryInterface, 

30 CacheInterface, 

31 CacheKey, 

32) 

33 

34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser 

35from .auth.token import TokenInterface 

36from .backoff import NoBackoff 

37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider 

38from .event import AfterConnectionReleasedEvent, EventDispatcher 

39from .exceptions import ( 

40 AuthenticationError, 

41 AuthenticationWrongNumberOfArgsError, 

42 ChildDeadlockedError, 

43 ConnectionError, 

44 DataError, 

45 MaxConnectionsError, 

46 RedisError, 

47 ResponseError, 

48 TimeoutError, 

49) 

50from .maint_notifications import ( 

51 MaintenanceState, 

52 MaintNotificationsConfig, 

53 MaintNotificationsConnectionHandler, 

54 MaintNotificationsPoolHandler, 

55) 

56from .retry import Retry 

57from .utils import ( 

58 CRYPTOGRAPHY_AVAILABLE, 

59 HIREDIS_AVAILABLE, 

60 SSL_AVAILABLE, 

61 compare_versions, 

62 deprecated_args, 

63 ensure_string, 

64 format_error_message, 

65 get_lib_version, 

66 str_if_bytes, 

67) 

68 

69if SSL_AVAILABLE: 

70 import ssl 

71 from ssl import VerifyFlags 

72else: 

73 ssl = None 

74 VerifyFlags = None 

75 

76if HIREDIS_AVAILABLE: 

77 import hiredis 

78 

79SYM_STAR = b"*" 

80SYM_DOLLAR = b"$" 

81SYM_CRLF = b"\r\n" 

82SYM_EMPTY = b"" 

83 

84DEFAULT_RESP_VERSION = 2 

85 

86SENTINEL = object() 

87 

88DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] 

89if HIREDIS_AVAILABLE: 

90 DefaultParser = _HiredisParser 

91else: 

92 DefaultParser = _RESP2Parser 

93 

94 

95class HiredisRespSerializer: 

96 def pack(self, *args: List): 

97 """Pack a series of arguments into the Redis protocol""" 

98 output = [] 

99 

100 if isinstance(args[0], str): 

101 args = tuple(args[0].encode().split()) + args[1:] 

102 elif b" " in args[0]: 

103 args = tuple(args[0].split()) + args[1:] 

104 try: 

105 output.append(hiredis.pack_command(args)) 

106 except TypeError: 

107 _, value, traceback = sys.exc_info() 

108 raise DataError(value).with_traceback(traceback) 

109 

110 return output 

111 

112 

113class PythonRespSerializer: 

114 def __init__(self, buffer_cutoff, encode) -> None: 

115 self._buffer_cutoff = buffer_cutoff 

116 self.encode = encode 

117 

118 def pack(self, *args): 

119 """Pack a series of arguments into the Redis protocol""" 

120 output = [] 

121 # the client might have included 1 or more literal arguments in 

122 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

123 # arguments to be sent separately, so split the first argument 

124 # manually. These arguments should be bytestrings so that they are 

125 # not encoded. 

126 if isinstance(args[0], str): 

127 args = tuple(args[0].encode().split()) + args[1:] 

128 elif b" " in args[0]: 

129 args = tuple(args[0].split()) + args[1:] 

130 

131 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

132 

133 buffer_cutoff = self._buffer_cutoff 

134 for arg in map(self.encode, args): 

135 # to avoid large string mallocs, chunk the command into the 

136 # output list if we're sending large values or memoryviews 

137 arg_length = len(arg) 

138 if ( 

139 len(buff) > buffer_cutoff 

140 or arg_length > buffer_cutoff 

141 or isinstance(arg, memoryview) 

142 ): 

143 buff = SYM_EMPTY.join( 

144 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

145 ) 

146 output.append(buff) 

147 output.append(arg) 

148 buff = SYM_CRLF 

149 else: 

150 buff = SYM_EMPTY.join( 

151 ( 

152 buff, 

153 SYM_DOLLAR, 

154 str(arg_length).encode(), 

155 SYM_CRLF, 

156 arg, 

157 SYM_CRLF, 

158 ) 

159 ) 

160 output.append(buff) 

161 return output 

162 

163 

164class ConnectionInterface: 

165 @abstractmethod 

166 def repr_pieces(self): 

167 pass 

168 

169 @abstractmethod 

170 def register_connect_callback(self, callback): 

171 pass 

172 

173 @abstractmethod 

174 def deregister_connect_callback(self, callback): 

175 pass 

176 

177 @abstractmethod 

178 def set_parser(self, parser_class): 

179 pass 

180 

181 @abstractmethod 

182 def get_protocol(self): 

183 pass 

184 

185 @abstractmethod 

186 def connect(self): 

187 pass 

188 

189 @abstractmethod 

190 def on_connect(self): 

191 pass 

192 

193 @abstractmethod 

194 def disconnect(self, *args): 

195 pass 

196 

197 @abstractmethod 

198 def check_health(self): 

199 pass 

200 

201 @abstractmethod 

202 def send_packed_command(self, command, check_health=True): 

203 pass 

204 

205 @abstractmethod 

206 def send_command(self, *args, **kwargs): 

207 pass 

208 

209 @abstractmethod 

210 def can_read(self, timeout=0): 

211 pass 

212 

213 @abstractmethod 

214 def read_response( 

215 self, 

216 disable_decoding=False, 

217 *, 

218 disconnect_on_error=True, 

219 push_request=False, 

220 ): 

221 pass 

222 

223 @abstractmethod 

224 def pack_command(self, *args): 

225 pass 

226 

227 @abstractmethod 

228 def pack_commands(self, commands): 

229 pass 

230 

231 @property 

232 @abstractmethod 

233 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

234 pass 

235 

236 @abstractmethod 

237 def set_re_auth_token(self, token: TokenInterface): 

238 pass 

239 

240 @abstractmethod 

241 def re_auth(self): 

242 pass 

243 

244 @abstractmethod 

245 def mark_for_reconnect(self): 

246 """ 

247 Mark the connection to be reconnected on the next command. 

248 This is useful when a connection is moved to a different node. 

249 """ 

250 pass 

251 

252 @abstractmethod 

253 def should_reconnect(self): 

254 """ 

255 Returns True if the connection should be reconnected. 

256 """ 

257 pass 

258 

259 @abstractmethod 

260 def reset_should_reconnect(self): 

261 """ 

262 Reset the internal flag to False. 

263 """ 

264 pass 

265 

266 

267class MaintNotificationsAbstractConnection: 

268 """ 

269 Abstract class for handling maintenance notifications logic. 

270 This class is expected to be used as base class together with ConnectionInterface. 

271 

272 This class is intended to be used with multiple inheritance! 

273 

274 All logic related to maintenance notifications is encapsulated in this class. 

275 """ 

276 

277 def __init__( 

278 self, 

279 maint_notifications_config: Optional[MaintNotificationsConfig], 

280 maint_notifications_pool_handler: Optional[ 

281 MaintNotificationsPoolHandler 

282 ] = None, 

283 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

284 maintenance_notification_hash: Optional[int] = None, 

285 orig_host_address: Optional[str] = None, 

286 orig_socket_timeout: Optional[float] = None, 

287 orig_socket_connect_timeout: Optional[float] = None, 

288 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, 

289 ): 

290 """ 

291 Initialize the maintenance notifications for the connection. 

292 

293 Args: 

294 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications. 

295 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications. 

296 maintenance_state (MaintenanceState): The current maintenance state of the connection. 

297 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection. 

298 orig_host_address (Optional[str]): The original host address of the connection. 

299 orig_socket_timeout (Optional[float]): The original socket timeout of the connection. 

300 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection. 

301 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications. 

302 If not provided, the parser from the connection is used. 

303 This is useful when the parser is created after this object. 

304 """ 

305 self.maint_notifications_config = maint_notifications_config 

306 self.maintenance_state = maintenance_state 

307 self.maintenance_notification_hash = maintenance_notification_hash 

308 self._configure_maintenance_notifications( 

309 maint_notifications_pool_handler, 

310 orig_host_address, 

311 orig_socket_timeout, 

312 orig_socket_connect_timeout, 

313 parser, 

314 ) 

315 

316 @abstractmethod 

317 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]: 

318 pass 

319 

320 @abstractmethod 

321 def _get_socket(self) -> Optional[socket.socket]: 

322 pass 

323 

324 @abstractmethod 

325 def get_protocol(self) -> Union[int, str]: 

326 """ 

327 Returns: 

328 The RESP protocol version, or ``None`` if the protocol is not specified, 

329 in which case the server default will be used. 

330 """ 

331 pass 

332 

333 @property 

334 @abstractmethod 

335 def host(self) -> str: 

336 pass 

337 

338 @host.setter 

339 @abstractmethod 

340 def host(self, value: str): 

341 pass 

342 

343 @property 

344 @abstractmethod 

345 def socket_timeout(self) -> Optional[Union[float, int]]: 

346 pass 

347 

348 @socket_timeout.setter 

349 @abstractmethod 

350 def socket_timeout(self, value: Optional[Union[float, int]]): 

351 pass 

352 

353 @property 

354 @abstractmethod 

355 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

356 pass 

357 

358 @socket_connect_timeout.setter 

359 @abstractmethod 

360 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

361 pass 

362 

363 @abstractmethod 

364 def send_command(self, *args, **kwargs): 

365 pass 

366 

367 @abstractmethod 

368 def read_response( 

369 self, 

370 disable_decoding=False, 

371 *, 

372 disconnect_on_error=True, 

373 push_request=False, 

374 ): 

375 pass 

376 

377 @abstractmethod 

378 def disconnect(self, *args): 

379 pass 

380 

381 def _configure_maintenance_notifications( 

382 self, 

383 maint_notifications_pool_handler: Optional[ 

384 MaintNotificationsPoolHandler 

385 ] = None, 

386 orig_host_address=None, 

387 orig_socket_timeout=None, 

388 orig_socket_connect_timeout=None, 

389 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, 

390 ): 

391 """ 

392 Enable maintenance notifications by setting up 

393 handlers and storing original connection parameters. 

394 

395 Should be used ONLY with parsers that support push notifications. 

396 """ 

397 if ( 

398 not self.maint_notifications_config 

399 or not self.maint_notifications_config.enabled 

400 ): 

401 self._maint_notifications_pool_handler = None 

402 self._maint_notifications_connection_handler = None 

403 return 

404 

405 if not parser: 

406 raise RedisError( 

407 "To configure maintenance notifications, a parser must be provided!" 

408 ) 

409 

410 if not isinstance(parser, _HiredisParser) and not isinstance( 

411 parser, _RESP3Parser 

412 ): 

413 raise RedisError( 

414 "Maintenance notifications are only supported with hiredis and RESP3 parsers!" 

415 ) 

416 

417 if maint_notifications_pool_handler: 

418 # Extract a reference to a new pool handler that copies all properties 

419 # of the original one and has a different connection reference 

420 # This is needed because when we attach the handler to the parser 

421 # we need to make sure that the handler has a reference to the 

422 # connection that the parser is attached to. 

423 self._maint_notifications_pool_handler = ( 

424 maint_notifications_pool_handler.get_handler_for_connection() 

425 ) 

426 self._maint_notifications_pool_handler.set_connection(self) 

427 else: 

428 self._maint_notifications_pool_handler = None 

429 

430 self._maint_notifications_connection_handler = ( 

431 MaintNotificationsConnectionHandler(self, self.maint_notifications_config) 

432 ) 

433 

434 # Set up pool handler if available 

435 if self._maint_notifications_pool_handler: 

436 parser.set_node_moving_push_handler( 

437 self._maint_notifications_pool_handler.handle_notification 

438 ) 

439 

440 # Set up connection handler 

441 parser.set_maintenance_push_handler( 

442 self._maint_notifications_connection_handler.handle_notification 

443 ) 

444 

445 # Store original connection parameters 

446 self.orig_host_address = orig_host_address if orig_host_address else self.host 

447 self.orig_socket_timeout = ( 

448 orig_socket_timeout if orig_socket_timeout else self.socket_timeout 

449 ) 

450 self.orig_socket_connect_timeout = ( 

451 orig_socket_connect_timeout 

452 if orig_socket_connect_timeout 

453 else self.socket_connect_timeout 

454 ) 

455 

456 def set_maint_notifications_pool_handler_for_connection( 

457 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

458 ): 

459 # Deep copy the pool handler to avoid sharing the same pool handler 

460 # between multiple connections, because otherwise each connection will override 

461 # the connection reference and the pool handler will only hold a reference 

462 # to the last connection that was set. 

463 maint_notifications_pool_handler_copy = ( 

464 maint_notifications_pool_handler.get_handler_for_connection() 

465 ) 

466 

467 maint_notifications_pool_handler_copy.set_connection(self) 

468 self._get_parser().set_node_moving_push_handler( 

469 maint_notifications_pool_handler_copy.handle_notification 

470 ) 

471 

472 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy 

473 

474 # Update maintenance notification connection handler if it doesn't exist 

475 if not self._maint_notifications_connection_handler: 

476 self._maint_notifications_connection_handler = ( 

477 MaintNotificationsConnectionHandler( 

478 self, maint_notifications_pool_handler.config 

479 ) 

480 ) 

481 self._get_parser().set_maintenance_push_handler( 

482 self._maint_notifications_connection_handler.handle_notification 

483 ) 

484 else: 

485 self._maint_notifications_connection_handler.config = ( 

486 maint_notifications_pool_handler.config 

487 ) 

488 

489 def activate_maint_notifications_handling_if_enabled(self, check_health=True): 

490 # Send maintenance notifications handshake if RESP3 is active 

491 # and maintenance notifications are enabled 

492 # and we have a host to determine the endpoint type from 

493 # When the maint_notifications_config enabled mode is "auto", 

494 # we just log a warning if the handshake fails 

495 # When the mode is enabled=True, we raise an exception in case of failure 

496 if ( 

497 self.get_protocol() not in [2, "2"] 

498 and self.maint_notifications_config 

499 and self.maint_notifications_config.enabled 

500 and self._maint_notifications_connection_handler 

501 and hasattr(self, "host") 

502 ): 

503 self._enable_maintenance_notifications( 

504 maint_notifications_config=self.maint_notifications_config, 

505 check_health=check_health, 

506 ) 

507 

508 def _enable_maintenance_notifications( 

509 self, maint_notifications_config: MaintNotificationsConfig, check_health=True 

510 ): 

511 try: 

512 host = getattr(self, "host", None) 

513 if host is None: 

514 raise ValueError( 

515 "Cannot enable maintenance notifications for connection" 

516 " object that doesn't have a host attribute." 

517 ) 

518 else: 

519 endpoint_type = maint_notifications_config.get_endpoint_type(host, self) 

520 self.send_command( 

521 "CLIENT", 

522 "MAINT_NOTIFICATIONS", 

523 "ON", 

524 "moving-endpoint-type", 

525 endpoint_type.value, 

526 check_health=check_health, 

527 ) 

528 response = self.read_response() 

529 if not response or str_if_bytes(response) != "OK": 

530 raise ResponseError( 

531 "The server doesn't support maintenance notifications" 

532 ) 

533 except Exception as e: 

534 if ( 

535 isinstance(e, ResponseError) 

536 and maint_notifications_config.enabled == "auto" 

537 ): 

538 # Log warning but don't fail the connection 

539 import logging 

540 

541 logger = logging.getLogger(__name__) 

542 logger.debug(f"Failed to enable maintenance notifications: {e}") 

543 else: 

544 raise 

545 

546 def get_resolved_ip(self) -> Optional[str]: 

547 """ 

548 Extract the resolved IP address from an 

549 established connection or resolve it from the host. 

550 

551 First tries to get the actual IP from the socket (most accurate), 

552 then falls back to DNS resolution if needed. 

553 

554 Args: 

555 connection: The connection object to extract the IP from 

556 

557 Returns: 

558 str: The resolved IP address, or None if it cannot be determined 

559 """ 

560 

561 # Method 1: Try to get the actual IP from the established socket connection 

562 # This is most accurate as it shows the exact IP being used 

563 try: 

564 conn_socket = self._get_socket() 

565 if conn_socket is not None: 

566 peer_addr = conn_socket.getpeername() 

567 if peer_addr and len(peer_addr) >= 1: 

568 # For TCP sockets, peer_addr is typically (host, port) tuple 

569 # Return just the host part 

570 return peer_addr[0] 

571 except (AttributeError, OSError): 

572 # Socket might not be connected or getpeername() might fail 

573 pass 

574 

575 # Method 2: Fallback to DNS resolution of the host 

576 # This is less accurate but works when socket is not available 

577 try: 

578 host = getattr(self, "host", "localhost") 

579 port = getattr(self, "port", 6379) 

580 if host: 

581 # Use getaddrinfo to resolve the hostname to IP 

582 # This mimics what the connection would do during _connect() 

583 addr_info = socket.getaddrinfo( 

584 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM 

585 ) 

586 if addr_info: 

587 # Return the IP from the first result 

588 # addr_info[0] is (family, socktype, proto, canonname, sockaddr) 

589 # sockaddr[0] is the IP address 

590 return str(addr_info[0][4][0]) 

591 except (AttributeError, OSError, socket.gaierror): 

592 # DNS resolution might fail 

593 pass 

594 

595 return None 

596 

597 @property 

598 def maintenance_state(self) -> MaintenanceState: 

599 return self._maintenance_state 

600 

601 @maintenance_state.setter 

602 def maintenance_state(self, state: "MaintenanceState"): 

603 self._maintenance_state = state 

604 

605 def getpeername(self): 

606 """ 

607 Returns the peer name of the connection. 

608 """ 

609 conn_socket = self._get_socket() 

610 if conn_socket: 

611 return conn_socket.getpeername()[0] 

612 return None 

613 

614 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

615 conn_socket = self._get_socket() 

616 if conn_socket: 

617 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout 

618 conn_socket.settimeout(timeout) 

619 self.update_parser_timeout(timeout) 

620 

621 def update_parser_timeout(self, timeout: Optional[float] = None): 

622 parser = self._get_parser() 

623 if parser and parser._buffer: 

624 if isinstance(parser, _RESP3Parser) and timeout: 

625 parser._buffer.socket_timeout = timeout 

626 elif isinstance(parser, _HiredisParser): 

627 parser._socket_timeout = timeout 

628 

629 def set_tmp_settings( 

630 self, 

631 tmp_host_address: Optional[Union[str, object]] = SENTINEL, 

632 tmp_relaxed_timeout: Optional[float] = None, 

633 ): 

634 """ 

635 The value of SENTINEL is used to indicate that the property should not be updated. 

636 """ 

637 if tmp_host_address and tmp_host_address != SENTINEL: 

638 self.host = str(tmp_host_address) 

639 if tmp_relaxed_timeout != -1: 

640 self.socket_timeout = tmp_relaxed_timeout 

641 self.socket_connect_timeout = tmp_relaxed_timeout 

642 

643 def reset_tmp_settings( 

644 self, 

645 reset_host_address: bool = False, 

646 reset_relaxed_timeout: bool = False, 

647 ): 

648 if reset_host_address: 

649 self.host = self.orig_host_address 

650 if reset_relaxed_timeout: 

651 self.socket_timeout = self.orig_socket_timeout 

652 self.socket_connect_timeout = self.orig_socket_connect_timeout 

653 

654 

655class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface): 

656 "Manages communication to and from a Redis server" 

657 

658 def __init__( 

659 self, 

660 db: int = 0, 

661 password: Optional[str] = None, 

662 socket_timeout: Optional[float] = None, 

663 socket_connect_timeout: Optional[float] = None, 

664 retry_on_timeout: bool = False, 

665 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL, 

666 encoding: str = "utf-8", 

667 encoding_errors: str = "strict", 

668 decode_responses: bool = False, 

669 parser_class=DefaultParser, 

670 socket_read_size: int = 65536, 

671 health_check_interval: int = 0, 

672 client_name: Optional[str] = None, 

673 lib_name: Optional[str] = "redis-py", 

674 lib_version: Optional[str] = get_lib_version(), 

675 username: Optional[str] = None, 

676 retry: Union[Any, None] = None, 

677 redis_connect_func: Optional[Callable[[], None]] = None, 

678 credential_provider: Optional[CredentialProvider] = None, 

679 protocol: Optional[int] = 2, 

680 command_packer: Optional[Callable[[], None]] = None, 

681 event_dispatcher: Optional[EventDispatcher] = None, 

682 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

683 maint_notifications_pool_handler: Optional[ 

684 MaintNotificationsPoolHandler 

685 ] = None, 

686 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

687 maintenance_notification_hash: Optional[int] = None, 

688 orig_host_address: Optional[str] = None, 

689 orig_socket_timeout: Optional[float] = None, 

690 orig_socket_connect_timeout: Optional[float] = None, 

691 ): 

692 """ 

693 Initialize a new Connection. 

694 To specify a retry policy for specific errors, first set 

695 `retry_on_error` to a list of the error/s to retry on, then set 

696 `retry` to a valid `Retry` object. 

697 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

698 """ 

699 if (username or password) and credential_provider is not None: 

700 raise DataError( 

701 "'username' and 'password' cannot be passed along with 'credential_" 

702 "provider'. Please provide only one of the following arguments: \n" 

703 "1. 'password' and (optional) 'username'\n" 

704 "2. 'credential_provider'" 

705 ) 

706 if event_dispatcher is None: 

707 self._event_dispatcher = EventDispatcher() 

708 else: 

709 self._event_dispatcher = event_dispatcher 

710 self.pid = os.getpid() 

711 self.db = db 

712 self.client_name = client_name 

713 self.lib_name = lib_name 

714 self.lib_version = lib_version 

715 self.credential_provider = credential_provider 

716 self.password = password 

717 self.username = username 

718 self._socket_timeout = socket_timeout 

719 if socket_connect_timeout is None: 

720 socket_connect_timeout = socket_timeout 

721 self._socket_connect_timeout = socket_connect_timeout 

722 self.retry_on_timeout = retry_on_timeout 

723 if retry_on_error is SENTINEL: 

724 retry_on_errors_list = [] 

725 else: 

726 retry_on_errors_list = list(retry_on_error) 

727 if retry_on_timeout: 

728 # Add TimeoutError to the errors list to retry on 

729 retry_on_errors_list.append(TimeoutError) 

730 self.retry_on_error = retry_on_errors_list 

731 if retry or self.retry_on_error: 

732 if retry is None: 

733 self.retry = Retry(NoBackoff(), 1) 

734 else: 

735 # deep-copy the Retry object as it is mutable 

736 self.retry = copy.deepcopy(retry) 

737 if self.retry_on_error: 

738 # Update the retry's supported errors with the specified errors 

739 self.retry.update_supported_errors(self.retry_on_error) 

740 else: 

741 self.retry = Retry(NoBackoff(), 0) 

742 self.health_check_interval = health_check_interval 

743 self.next_health_check = 0 

744 self.redis_connect_func = redis_connect_func 

745 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

746 self.handshake_metadata = None 

747 self._sock = None 

748 self._socket_read_size = socket_read_size 

749 self._connect_callbacks = [] 

750 self._buffer_cutoff = 6000 

751 self._re_auth_token: Optional[TokenInterface] = None 

752 try: 

753 p = int(protocol) 

754 except TypeError: 

755 p = DEFAULT_RESP_VERSION 

756 except ValueError: 

757 raise ConnectionError("protocol must be an integer") 

758 finally: 

759 if p < 2 or p > 3: 

760 raise ConnectionError("protocol must be either 2 or 3") 

761 # p = DEFAULT_RESP_VERSION 

762 self.protocol = p 

763 if self.protocol == 3 and parser_class == _RESP2Parser: 

764 # If the protocol is 3 but the parser is RESP2, change it to RESP3 

765 # This is needed because the parser might be set before the protocol 

766 # or might be provided as a kwarg to the constructor 

767 # We need to react on discrepancy only for RESP2 and RESP3 

768 # as hiredis supports both 

769 parser_class = _RESP3Parser 

770 self.set_parser(parser_class) 

771 

772 self._command_packer = self._construct_command_packer(command_packer) 

773 self._should_reconnect = False 

774 

775 # Set up maintenance notifications 

776 MaintNotificationsAbstractConnection.__init__( 

777 self, 

778 maint_notifications_config, 

779 maint_notifications_pool_handler, 

780 maintenance_state, 

781 maintenance_notification_hash, 

782 orig_host_address, 

783 orig_socket_timeout, 

784 orig_socket_connect_timeout, 

785 self._parser, 

786 ) 

787 

788 def __repr__(self): 

789 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) 

790 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

791 

792 @abstractmethod 

793 def repr_pieces(self): 

794 pass 

795 

796 def __del__(self): 

797 try: 

798 self.disconnect() 

799 except Exception: 

800 pass 

801 

802 def _construct_command_packer(self, packer): 

803 if packer is not None: 

804 return packer 

805 elif HIREDIS_AVAILABLE: 

806 return HiredisRespSerializer() 

807 else: 

808 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) 

809 

810 def register_connect_callback(self, callback): 

811 """ 

812 Register a callback to be called when the connection is established either 

813 initially or reconnected. This allows listeners to issue commands that 

814 are ephemeral to the connection, for example pub/sub subscription or 

815 key tracking. The callback must be a _method_ and will be kept as 

816 a weak reference. 

817 """ 

818 wm = weakref.WeakMethod(callback) 

819 if wm not in self._connect_callbacks: 

820 self._connect_callbacks.append(wm) 

821 

822 def deregister_connect_callback(self, callback): 

823 """ 

824 De-register a previously registered callback. It will no-longer receive 

825 notifications on connection events. Calling this is not required when the 

826 listener goes away, since the callbacks are kept as weak methods. 

827 """ 

828 try: 

829 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

830 except ValueError: 

831 pass 

832 

833 def set_parser(self, parser_class): 

834 """ 

835 Creates a new instance of parser_class with socket size: 

836 _socket_read_size and assigns it to the parser for the connection 

837 :param parser_class: The required parser class 

838 """ 

839 self._parser = parser_class(socket_read_size=self._socket_read_size) 

840 

841 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]: 

842 return self._parser 

843 

844 def connect(self): 

845 "Connects to the Redis server if not already connected" 

846 self.connect_check_health(check_health=True) 

847 

848 def connect_check_health( 

849 self, check_health: bool = True, retry_socket_connect: bool = True 

850 ): 

851 if self._sock: 

852 return 

853 try: 

854 if retry_socket_connect: 

855 sock = self.retry.call_with_retry( 

856 lambda: self._connect(), lambda error: self.disconnect(error) 

857 ) 

858 else: 

859 sock = self._connect() 

860 except socket.timeout: 

861 raise TimeoutError("Timeout connecting to server") 

862 except OSError as e: 

863 raise ConnectionError(self._error_message(e)) 

864 

865 self._sock = sock 

866 try: 

867 if self.redis_connect_func is None: 

868 # Use the default on_connect function 

869 self.on_connect_check_health(check_health=check_health) 

870 else: 

871 # Use the passed function redis_connect_func 

872 self.redis_connect_func(self) 

873 except RedisError: 

874 # clean up after any error in on_connect 

875 self.disconnect() 

876 raise 

877 

878 # run any user callbacks. right now the only internal callback 

879 # is for pubsub channel/pattern resubscription 

880 # first, remove any dead weakrefs 

881 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

882 for ref in self._connect_callbacks: 

883 callback = ref() 

884 if callback: 

885 callback(self) 

886 

887 @abstractmethod 

888 def _connect(self): 

889 pass 

890 

891 @abstractmethod 

892 def _host_error(self): 

893 pass 

894 

895 def _error_message(self, exception): 

896 return format_error_message(self._host_error(), exception) 

897 

898 def on_connect(self): 

899 self.on_connect_check_health(check_health=True) 

900 

901 def on_connect_check_health(self, check_health: bool = True): 

902 "Initialize the connection, authenticate and select a database" 

903 self._parser.on_connect(self) 

904 parser = self._parser 

905 

906 auth_args = None 

907 # if credential provider or username and/or password are set, authenticate 

908 if self.credential_provider or (self.username or self.password): 

909 cred_provider = ( 

910 self.credential_provider 

911 or UsernamePasswordCredentialProvider(self.username, self.password) 

912 ) 

913 auth_args = cred_provider.get_credentials() 

914 

915 # if resp version is specified and we have auth args, 

916 # we need to send them via HELLO 

917 if auth_args and self.protocol not in [2, "2"]: 

918 if isinstance(self._parser, _RESP2Parser): 

919 self.set_parser(_RESP3Parser) 

920 # update cluster exception classes 

921 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

922 self._parser.on_connect(self) 

923 if len(auth_args) == 1: 

924 auth_args = ["default", auth_args[0]] 

925 # avoid checking health here -- PING will fail if we try 

926 # to check the health prior to the AUTH 

927 self.send_command( 

928 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

929 ) 

930 self.handshake_metadata = self.read_response() 

931 # if response.get(b"proto") != self.protocol and response.get( 

932 # "proto" 

933 # ) != self.protocol: 

934 # raise ConnectionError("Invalid RESP version") 

935 elif auth_args: 

936 # avoid checking health here -- PING will fail if we try 

937 # to check the health prior to the AUTH 

938 self.send_command("AUTH", *auth_args, check_health=False) 

939 

940 try: 

941 auth_response = self.read_response() 

942 except AuthenticationWrongNumberOfArgsError: 

943 # a username and password were specified but the Redis 

944 # server seems to be < 6.0.0 which expects a single password 

945 # arg. retry auth with just the password. 

946 # https://github.com/andymccurdy/redis-py/issues/1274 

947 self.send_command("AUTH", auth_args[-1], check_health=False) 

948 auth_response = self.read_response() 

949 

950 if str_if_bytes(auth_response) != "OK": 

951 raise AuthenticationError("Invalid Username or Password") 

952 

953 # if resp version is specified, switch to it 

954 elif self.protocol not in [2, "2"]: 

955 if isinstance(self._parser, _RESP2Parser): 

956 self.set_parser(_RESP3Parser) 

957 # update cluster exception classes 

958 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

959 self._parser.on_connect(self) 

960 self.send_command("HELLO", self.protocol, check_health=check_health) 

961 self.handshake_metadata = self.read_response() 

962 if ( 

963 self.handshake_metadata.get(b"proto") != self.protocol 

964 and self.handshake_metadata.get("proto") != self.protocol 

965 ): 

966 raise ConnectionError("Invalid RESP version") 

967 

968 # Activate maintenance notifications for this connection 

969 # if enabled in the configuration 

970 # This is a no-op if maintenance notifications are not enabled 

971 self.activate_maint_notifications_handling_if_enabled(check_health=check_health) 

972 

973 # if a client_name is given, set it 

974 if self.client_name: 

975 self.send_command( 

976 "CLIENT", 

977 "SETNAME", 

978 self.client_name, 

979 check_health=check_health, 

980 ) 

981 if str_if_bytes(self.read_response()) != "OK": 

982 raise ConnectionError("Error setting client name") 

983 

984 try: 

985 # set the library name and version 

986 if self.lib_name: 

987 self.send_command( 

988 "CLIENT", 

989 "SETINFO", 

990 "LIB-NAME", 

991 self.lib_name, 

992 check_health=check_health, 

993 ) 

994 self.read_response() 

995 except ResponseError: 

996 pass 

997 

998 try: 

999 if self.lib_version: 

1000 self.send_command( 

1001 "CLIENT", 

1002 "SETINFO", 

1003 "LIB-VER", 

1004 self.lib_version, 

1005 check_health=check_health, 

1006 ) 

1007 self.read_response() 

1008 except ResponseError: 

1009 pass 

1010 

1011 # if a database is specified, switch to it 

1012 if self.db: 

1013 self.send_command("SELECT", self.db, check_health=check_health) 

1014 if str_if_bytes(self.read_response()) != "OK": 

1015 raise ConnectionError("Invalid Database") 

1016 

1017 def disconnect(self, *args): 

1018 "Disconnects from the Redis server" 

1019 self._parser.on_disconnect() 

1020 

1021 conn_sock = self._sock 

1022 self._sock = None 

1023 # reset the reconnect flag 

1024 self.reset_should_reconnect() 

1025 if conn_sock is None: 

1026 return 

1027 

1028 if os.getpid() == self.pid: 

1029 try: 

1030 conn_sock.shutdown(socket.SHUT_RDWR) 

1031 except (OSError, TypeError): 

1032 pass 

1033 

1034 try: 

1035 conn_sock.close() 

1036 except OSError: 

1037 pass 

1038 

1039 def mark_for_reconnect(self): 

1040 self._should_reconnect = True 

1041 

1042 def should_reconnect(self): 

1043 return self._should_reconnect 

1044 

1045 def reset_should_reconnect(self): 

1046 self._should_reconnect = False 

1047 

1048 def _send_ping(self): 

1049 """Send PING, expect PONG in return""" 

1050 self.send_command("PING", check_health=False) 

1051 if str_if_bytes(self.read_response()) != "PONG": 

1052 raise ConnectionError("Bad response from PING health check") 

1053 

1054 def _ping_failed(self, error): 

1055 """Function to call when PING fails""" 

1056 self.disconnect() 

1057 

1058 def check_health(self): 

1059 """Check the health of the connection with a PING/PONG""" 

1060 if self.health_check_interval and time.monotonic() > self.next_health_check: 

1061 self.retry.call_with_retry(self._send_ping, self._ping_failed) 

1062 

1063 def send_packed_command(self, command, check_health=True): 

1064 """Send an already packed command to the Redis server""" 

1065 if not self._sock: 

1066 self.connect_check_health(check_health=False) 

1067 # guard against health check recursion 

1068 if check_health: 

1069 self.check_health() 

1070 try: 

1071 if isinstance(command, str): 

1072 command = [command] 

1073 for item in command: 

1074 self._sock.sendall(item) 

1075 except socket.timeout: 

1076 self.disconnect() 

1077 raise TimeoutError("Timeout writing to socket") 

1078 except OSError as e: 

1079 self.disconnect() 

1080 if len(e.args) == 1: 

1081 errno, errmsg = "UNKNOWN", e.args[0] 

1082 else: 

1083 errno = e.args[0] 

1084 errmsg = e.args[1] 

1085 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") 

1086 except BaseException: 

1087 # BaseExceptions can be raised when a socket send operation is not 

1088 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

1089 # to send un-sent data. However, the send_packed_command() API 

1090 # does not support it so there is no point in keeping the connection open. 

1091 self.disconnect() 

1092 raise 

1093 

1094 def send_command(self, *args, **kwargs): 

1095 """Pack and send a command to the Redis server""" 

1096 self.send_packed_command( 

1097 self._command_packer.pack(*args), 

1098 check_health=kwargs.get("check_health", True), 

1099 ) 

1100 

1101 def can_read(self, timeout=0): 

1102 """Poll the socket to see if there's data that can be read.""" 

1103 sock = self._sock 

1104 if not sock: 

1105 self.connect() 

1106 

1107 host_error = self._host_error() 

1108 

1109 try: 

1110 return self._parser.can_read(timeout) 

1111 

1112 except OSError as e: 

1113 self.disconnect() 

1114 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

1115 

1116 def read_response( 

1117 self, 

1118 disable_decoding=False, 

1119 *, 

1120 disconnect_on_error=True, 

1121 push_request=False, 

1122 ): 

1123 """Read the response from a previously sent command""" 

1124 

1125 host_error = self._host_error() 

1126 

1127 try: 

1128 if self.protocol in ["3", 3]: 

1129 response = self._parser.read_response( 

1130 disable_decoding=disable_decoding, push_request=push_request 

1131 ) 

1132 else: 

1133 response = self._parser.read_response(disable_decoding=disable_decoding) 

1134 except socket.timeout: 

1135 if disconnect_on_error: 

1136 self.disconnect() 

1137 raise TimeoutError(f"Timeout reading from {host_error}") 

1138 except OSError as e: 

1139 if disconnect_on_error: 

1140 self.disconnect() 

1141 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

1142 except BaseException: 

1143 # Also by default close in case of BaseException. A lot of code 

1144 # relies on this behaviour when doing Command/Response pairs. 

1145 # See #1128. 

1146 if disconnect_on_error: 

1147 self.disconnect() 

1148 raise 

1149 

1150 if self.health_check_interval: 

1151 self.next_health_check = time.monotonic() + self.health_check_interval 

1152 

1153 if isinstance(response, ResponseError): 

1154 try: 

1155 raise response 

1156 finally: 

1157 del response # avoid creating ref cycles 

1158 return response 

1159 

1160 def pack_command(self, *args): 

1161 """Pack a series of arguments into the Redis protocol""" 

1162 return self._command_packer.pack(*args) 

1163 

1164 def pack_commands(self, commands): 

1165 """Pack multiple commands into the Redis protocol""" 

1166 output = [] 

1167 pieces = [] 

1168 buffer_length = 0 

1169 buffer_cutoff = self._buffer_cutoff 

1170 

1171 for cmd in commands: 

1172 for chunk in self._command_packer.pack(*cmd): 

1173 chunklen = len(chunk) 

1174 if ( 

1175 buffer_length > buffer_cutoff 

1176 or chunklen > buffer_cutoff 

1177 or isinstance(chunk, memoryview) 

1178 ): 

1179 if pieces: 

1180 output.append(SYM_EMPTY.join(pieces)) 

1181 buffer_length = 0 

1182 pieces = [] 

1183 

1184 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

1185 output.append(chunk) 

1186 else: 

1187 pieces.append(chunk) 

1188 buffer_length += chunklen 

1189 

1190 if pieces: 

1191 output.append(SYM_EMPTY.join(pieces)) 

1192 return output 

1193 

1194 def get_protocol(self) -> Union[int, str]: 

1195 return self.protocol 

1196 

1197 @property 

1198 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1199 return self._handshake_metadata 

1200 

1201 @handshake_metadata.setter 

1202 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): 

1203 self._handshake_metadata = value 

1204 

1205 def set_re_auth_token(self, token: TokenInterface): 

1206 self._re_auth_token = token 

1207 

1208 def re_auth(self): 

1209 if self._re_auth_token is not None: 

1210 self.send_command( 

1211 "AUTH", 

1212 self._re_auth_token.try_get("oid"), 

1213 self._re_auth_token.get_value(), 

1214 ) 

1215 self.read_response() 

1216 self._re_auth_token = None 

1217 

1218 def _get_socket(self) -> Optional[socket.socket]: 

1219 return self._sock 

1220 

1221 @property 

1222 def socket_timeout(self) -> Optional[Union[float, int]]: 

1223 return self._socket_timeout 

1224 

1225 @socket_timeout.setter 

1226 def socket_timeout(self, value: Optional[Union[float, int]]): 

1227 self._socket_timeout = value 

1228 

1229 @property 

1230 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

1231 return self._socket_connect_timeout 

1232 

1233 @socket_connect_timeout.setter 

1234 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

1235 self._socket_connect_timeout = value 

1236 

1237 

1238class Connection(AbstractConnection): 

1239 "Manages TCP communication to and from a Redis server" 

1240 

1241 def __init__( 

1242 self, 

1243 host="localhost", 

1244 port=6379, 

1245 socket_keepalive=False, 

1246 socket_keepalive_options=None, 

1247 socket_type=0, 

1248 **kwargs, 

1249 ): 

1250 self._host = host 

1251 self.port = int(port) 

1252 self.socket_keepalive = socket_keepalive 

1253 self.socket_keepalive_options = socket_keepalive_options or {} 

1254 self.socket_type = socket_type 

1255 super().__init__(**kwargs) 

1256 

1257 def repr_pieces(self): 

1258 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

1259 if self.client_name: 

1260 pieces.append(("client_name", self.client_name)) 

1261 return pieces 

1262 

1263 def _connect(self): 

1264 "Create a TCP socket connection" 

1265 # we want to mimic what socket.create_connection does to support 

1266 # ipv4/ipv6, but we want to set options prior to calling 

1267 # socket.connect() 

1268 err = None 

1269 

1270 for res in socket.getaddrinfo( 

1271 self.host, self.port, self.socket_type, socket.SOCK_STREAM 

1272 ): 

1273 family, socktype, proto, canonname, socket_address = res 

1274 sock = None 

1275 try: 

1276 sock = socket.socket(family, socktype, proto) 

1277 # TCP_NODELAY 

1278 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

1279 

1280 # TCP_KEEPALIVE 

1281 if self.socket_keepalive: 

1282 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

1283 for k, v in self.socket_keepalive_options.items(): 

1284 sock.setsockopt(socket.IPPROTO_TCP, k, v) 

1285 

1286 # set the socket_connect_timeout before we connect 

1287 sock.settimeout(self.socket_connect_timeout) 

1288 

1289 # connect 

1290 sock.connect(socket_address) 

1291 

1292 # set the socket_timeout now that we're connected 

1293 sock.settimeout(self.socket_timeout) 

1294 return sock 

1295 

1296 except OSError as _: 

1297 err = _ 

1298 if sock is not None: 

1299 try: 

1300 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1301 except OSError: 

1302 pass 

1303 sock.close() 

1304 

1305 if err is not None: 

1306 raise err 

1307 raise OSError("socket.getaddrinfo returned an empty list") 

1308 

1309 def _host_error(self): 

1310 return f"{self.host}:{self.port}" 

1311 

1312 @property 

1313 def host(self) -> str: 

1314 return self._host 

1315 

1316 @host.setter 

1317 def host(self, value: str): 

1318 self._host = value 

1319 

1320 

1321class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface): 

1322 DUMMY_CACHE_VALUE = b"foo" 

1323 MIN_ALLOWED_VERSION = "7.4.0" 

1324 DEFAULT_SERVER_NAME = "redis" 

1325 

1326 def __init__( 

1327 self, 

1328 conn: ConnectionInterface, 

1329 cache: CacheInterface, 

1330 pool_lock: threading.RLock, 

1331 ): 

1332 self.pid = os.getpid() 

1333 self._conn = conn 

1334 self.retry = self._conn.retry 

1335 self.host = self._conn.host 

1336 self.port = self._conn.port 

1337 self.credential_provider = conn.credential_provider 

1338 self._pool_lock = pool_lock 

1339 self._cache = cache 

1340 self._cache_lock = threading.RLock() 

1341 self._current_command_cache_key = None 

1342 self._current_options = None 

1343 self.register_connect_callback(self._enable_tracking_callback) 

1344 

1345 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1346 MaintNotificationsAbstractConnection.__init__( 

1347 self, 

1348 self._conn.maint_notifications_config, 

1349 self._conn._maint_notifications_pool_handler, 

1350 self._conn.maintenance_state, 

1351 self._conn.maintenance_notification_hash, 

1352 self._conn.host, 

1353 self._conn.socket_timeout, 

1354 self._conn.socket_connect_timeout, 

1355 self._conn._get_parser(), 

1356 ) 

1357 

1358 def repr_pieces(self): 

1359 return self._conn.repr_pieces() 

1360 

1361 def register_connect_callback(self, callback): 

1362 self._conn.register_connect_callback(callback) 

1363 

1364 def deregister_connect_callback(self, callback): 

1365 self._conn.deregister_connect_callback(callback) 

1366 

1367 def set_parser(self, parser_class): 

1368 self._conn.set_parser(parser_class) 

1369 

1370 def set_maint_notifications_pool_handler_for_connection( 

1371 self, maint_notifications_pool_handler 

1372 ): 

1373 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1374 self._conn.set_maint_notifications_pool_handler_for_connection( 

1375 maint_notifications_pool_handler 

1376 ) 

1377 

1378 def get_protocol(self): 

1379 return self._conn.get_protocol() 

1380 

1381 def connect(self): 

1382 self._conn.connect() 

1383 

1384 server_name = self._conn.handshake_metadata.get(b"server", None) 

1385 if server_name is None: 

1386 server_name = self._conn.handshake_metadata.get("server", None) 

1387 server_ver = self._conn.handshake_metadata.get(b"version", None) 

1388 if server_ver is None: 

1389 server_ver = self._conn.handshake_metadata.get("version", None) 

1390 if server_ver is None or server_ver is None: 

1391 raise ConnectionError("Cannot retrieve information about server version") 

1392 

1393 server_ver = ensure_string(server_ver) 

1394 server_name = ensure_string(server_name) 

1395 

1396 if ( 

1397 server_name != self.DEFAULT_SERVER_NAME 

1398 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 

1399 ): 

1400 raise ConnectionError( 

1401 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 

1402 ) 

1403 

1404 def on_connect(self): 

1405 self._conn.on_connect() 

1406 

1407 def disconnect(self, *args): 

1408 with self._cache_lock: 

1409 self._cache.flush() 

1410 self._conn.disconnect(*args) 

1411 

1412 def check_health(self): 

1413 self._conn.check_health() 

1414 

1415 def send_packed_command(self, command, check_health=True): 

1416 # TODO: Investigate if it's possible to unpack command 

1417 # or extract keys from packed command 

1418 self._conn.send_packed_command(command) 

1419 

1420 def send_command(self, *args, **kwargs): 

1421 self._process_pending_invalidations() 

1422 

1423 with self._cache_lock: 

1424 # Command is write command or not allowed 

1425 # to be cached. 

1426 if not self._cache.is_cachable( 

1427 CacheKey(command=args[0], redis_keys=(), redis_args=()) 

1428 ): 

1429 self._current_command_cache_key = None 

1430 self._conn.send_command(*args, **kwargs) 

1431 return 

1432 

1433 if kwargs.get("keys") is None: 

1434 raise ValueError("Cannot create cache key.") 

1435 

1436 # Creates cache key. 

1437 self._current_command_cache_key = CacheKey( 

1438 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args 

1439 ) 

1440 

1441 with self._cache_lock: 

1442 # We have to trigger invalidation processing in case if 

1443 # it was cached by another connection to avoid 

1444 # queueing invalidations in stale connections. 

1445 if self._cache.get(self._current_command_cache_key): 

1446 entry = self._cache.get(self._current_command_cache_key) 

1447 

1448 if entry.connection_ref != self._conn: 

1449 with self._pool_lock: 

1450 while entry.connection_ref.can_read(): 

1451 entry.connection_ref.read_response(push_request=True) 

1452 

1453 return 

1454 

1455 # Set temporary entry value to prevent 

1456 # race condition from another connection. 

1457 self._cache.set( 

1458 CacheEntry( 

1459 cache_key=self._current_command_cache_key, 

1460 cache_value=self.DUMMY_CACHE_VALUE, 

1461 status=CacheEntryStatus.IN_PROGRESS, 

1462 connection_ref=self._conn, 

1463 ) 

1464 ) 

1465 

1466 # Send command over socket only if it's allowed 

1467 # read-only command that not yet cached. 

1468 self._conn.send_command(*args, **kwargs) 

1469 

1470 def can_read(self, timeout=0): 

1471 return self._conn.can_read(timeout) 

1472 

1473 def read_response( 

1474 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False 

1475 ): 

1476 with self._cache_lock: 

1477 # Check if command response exists in a cache and it's not in progress. 

1478 if ( 

1479 self._current_command_cache_key is not None 

1480 and self._cache.get(self._current_command_cache_key) is not None 

1481 and self._cache.get(self._current_command_cache_key).status 

1482 != CacheEntryStatus.IN_PROGRESS 

1483 ): 

1484 res = copy.deepcopy( 

1485 self._cache.get(self._current_command_cache_key).cache_value 

1486 ) 

1487 self._current_command_cache_key = None 

1488 return res 

1489 

1490 response = self._conn.read_response( 

1491 disable_decoding=disable_decoding, 

1492 disconnect_on_error=disconnect_on_error, 

1493 push_request=push_request, 

1494 ) 

1495 

1496 with self._cache_lock: 

1497 # Prevent not-allowed command from caching. 

1498 if self._current_command_cache_key is None: 

1499 return response 

1500 # If response is None prevent from caching. 

1501 if response is None: 

1502 self._cache.delete_by_cache_keys([self._current_command_cache_key]) 

1503 return response 

1504 

1505 cache_entry = self._cache.get(self._current_command_cache_key) 

1506 

1507 # Cache only responses that still valid 

1508 # and wasn't invalidated by another connection in meantime. 

1509 if cache_entry is not None: 

1510 cache_entry.status = CacheEntryStatus.VALID 

1511 cache_entry.cache_value = response 

1512 self._cache.set(cache_entry) 

1513 

1514 self._current_command_cache_key = None 

1515 

1516 return response 

1517 

1518 def pack_command(self, *args): 

1519 return self._conn.pack_command(*args) 

1520 

1521 def pack_commands(self, commands): 

1522 return self._conn.pack_commands(commands) 

1523 

1524 @property 

1525 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1526 return self._conn.handshake_metadata 

1527 

1528 def set_re_auth_token(self, token: TokenInterface): 

1529 self._conn.set_re_auth_token(token) 

1530 

1531 def re_auth(self): 

1532 self._conn.re_auth() 

1533 

1534 def mark_for_reconnect(self): 

1535 self._conn.mark_for_reconnect() 

1536 

1537 def should_reconnect(self): 

1538 return self._conn.should_reconnect() 

1539 

1540 def reset_should_reconnect(self): 

1541 self._conn.reset_should_reconnect() 

1542 

1543 @property 

1544 def host(self) -> str: 

1545 return self._conn.host 

1546 

1547 @host.setter 

1548 def host(self, value: str): 

1549 self._conn.host = value 

1550 

1551 @property 

1552 def socket_timeout(self) -> Optional[Union[float, int]]: 

1553 return self._conn.socket_timeout 

1554 

1555 @socket_timeout.setter 

1556 def socket_timeout(self, value: Optional[Union[float, int]]): 

1557 self._conn.socket_timeout = value 

1558 

1559 @property 

1560 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

1561 return self._conn.socket_connect_timeout 

1562 

1563 @socket_connect_timeout.setter 

1564 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

1565 self._conn.socket_connect_timeout = value 

1566 

1567 def _get_socket(self) -> Optional[socket.socket]: 

1568 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1569 return self._conn._get_socket() 

1570 else: 

1571 raise NotImplementedError( 

1572 "Maintenance notifications are not supported by this connection type" 

1573 ) 

1574 

1575 def _get_maint_notifications_connection_instance( 

1576 self, connection 

1577 ) -> MaintNotificationsAbstractConnection: 

1578 """ 

1579 Validate that connection instance supports maintenance notifications. 

1580 With this helper method we ensure that we are working 

1581 with the correct connection type. 

1582 After twe validate that connection instance supports maintenance notifications 

1583 we can safely return the connection instance 

1584 as MaintNotificationsAbstractConnection. 

1585 """ 

1586 if not isinstance(connection, MaintNotificationsAbstractConnection): 

1587 raise NotImplementedError( 

1588 "Maintenance notifications are not supported by this connection type" 

1589 ) 

1590 else: 

1591 return connection 

1592 

1593 @property 

1594 def maintenance_state(self) -> MaintenanceState: 

1595 con = self._get_maint_notifications_connection_instance(self._conn) 

1596 return con.maintenance_state 

1597 

1598 @maintenance_state.setter 

1599 def maintenance_state(self, state: MaintenanceState): 

1600 con = self._get_maint_notifications_connection_instance(self._conn) 

1601 con.maintenance_state = state 

1602 

1603 def getpeername(self): 

1604 con = self._get_maint_notifications_connection_instance(self._conn) 

1605 return con.getpeername() 

1606 

1607 def get_resolved_ip(self): 

1608 con = self._get_maint_notifications_connection_instance(self._conn) 

1609 return con.get_resolved_ip() 

1610 

1611 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

1612 con = self._get_maint_notifications_connection_instance(self._conn) 

1613 con.update_current_socket_timeout(relaxed_timeout) 

1614 

1615 def set_tmp_settings( 

1616 self, 

1617 tmp_host_address: Optional[str] = None, 

1618 tmp_relaxed_timeout: Optional[float] = None, 

1619 ): 

1620 con = self._get_maint_notifications_connection_instance(self._conn) 

1621 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout) 

1622 

1623 def reset_tmp_settings( 

1624 self, 

1625 reset_host_address: bool = False, 

1626 reset_relaxed_timeout: bool = False, 

1627 ): 

1628 con = self._get_maint_notifications_connection_instance(self._conn) 

1629 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout) 

1630 

1631 def _connect(self): 

1632 self._conn._connect() 

1633 

1634 def _host_error(self): 

1635 self._conn._host_error() 

1636 

1637 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: 

1638 conn.send_command("CLIENT", "TRACKING", "ON") 

1639 conn.read_response() 

1640 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) 

1641 

1642 def _process_pending_invalidations(self): 

1643 while self.can_read(): 

1644 self._conn.read_response(push_request=True) 

1645 

1646 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): 

1647 with self._cache_lock: 

1648 # Flush cache when DB flushed on server-side 

1649 if data[1] is None: 

1650 self._cache.flush() 

1651 else: 

1652 self._cache.delete_by_redis_keys(data[1]) 

1653 

1654 

1655class SSLConnection(Connection): 

1656 """Manages SSL connections to and from the Redis server(s). 

1657 This class extends the Connection class, adding SSL functionality, and making 

1658 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

1659 """ # noqa 

1660 

1661 def __init__( 

1662 self, 

1663 ssl_keyfile=None, 

1664 ssl_certfile=None, 

1665 ssl_cert_reqs="required", 

1666 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None, 

1667 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None, 

1668 ssl_ca_certs=None, 

1669 ssl_ca_data=None, 

1670 ssl_check_hostname=True, 

1671 ssl_ca_path=None, 

1672 ssl_password=None, 

1673 ssl_validate_ocsp=False, 

1674 ssl_validate_ocsp_stapled=False, 

1675 ssl_ocsp_context=None, 

1676 ssl_ocsp_expected_cert=None, 

1677 ssl_min_version=None, 

1678 ssl_ciphers=None, 

1679 **kwargs, 

1680 ): 

1681 """Constructor 

1682 

1683 Args: 

1684 ssl_keyfile: Path to an ssl private key. Defaults to None. 

1685 ssl_certfile: Path to an ssl certificate. Defaults to None. 

1686 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), 

1687 or an ssl.VerifyMode. Defaults to "required". 

1688 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None. 

1689 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None. 

1690 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. 

1691 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. 

1692 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True. 

1693 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. 

1694 ssl_password: Password for unlocking an encrypted private key. Defaults to None. 

1695 

1696 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) 

1697 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response 

1698 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert 

1699 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. 

1700 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. 

1701 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information. 

1702 

1703 Raises: 

1704 RedisError 

1705 """ # noqa 

1706 if not SSL_AVAILABLE: 

1707 raise RedisError("Python wasn't built with SSL support") 

1708 

1709 self.keyfile = ssl_keyfile 

1710 self.certfile = ssl_certfile 

1711 if ssl_cert_reqs is None: 

1712 ssl_cert_reqs = ssl.CERT_NONE 

1713 elif isinstance(ssl_cert_reqs, str): 

1714 CERT_REQS = { # noqa: N806 

1715 "none": ssl.CERT_NONE, 

1716 "optional": ssl.CERT_OPTIONAL, 

1717 "required": ssl.CERT_REQUIRED, 

1718 } 

1719 if ssl_cert_reqs not in CERT_REQS: 

1720 raise RedisError( 

1721 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" 

1722 ) 

1723 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 

1724 self.cert_reqs = ssl_cert_reqs 

1725 self.ssl_include_verify_flags = ssl_include_verify_flags 

1726 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags 

1727 self.ca_certs = ssl_ca_certs 

1728 self.ca_data = ssl_ca_data 

1729 self.ca_path = ssl_ca_path 

1730 self.check_hostname = ( 

1731 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1732 ) 

1733 self.certificate_password = ssl_password 

1734 self.ssl_validate_ocsp = ssl_validate_ocsp 

1735 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled 

1736 self.ssl_ocsp_context = ssl_ocsp_context 

1737 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert 

1738 self.ssl_min_version = ssl_min_version 

1739 self.ssl_ciphers = ssl_ciphers 

1740 super().__init__(**kwargs) 

1741 

1742 def _connect(self): 

1743 """ 

1744 Wrap the socket with SSL support, handling potential errors. 

1745 """ 

1746 sock = super()._connect() 

1747 try: 

1748 return self._wrap_socket_with_ssl(sock) 

1749 except (OSError, RedisError): 

1750 sock.close() 

1751 raise 

1752 

1753 def _wrap_socket_with_ssl(self, sock): 

1754 """ 

1755 Wraps the socket with SSL support. 

1756 

1757 Args: 

1758 sock: The plain socket to wrap with SSL. 

1759 

1760 Returns: 

1761 An SSL wrapped socket. 

1762 """ 

1763 context = ssl.create_default_context() 

1764 context.check_hostname = self.check_hostname 

1765 context.verify_mode = self.cert_reqs 

1766 if self.ssl_include_verify_flags: 

1767 for flag in self.ssl_include_verify_flags: 

1768 context.verify_flags |= flag 

1769 if self.ssl_exclude_verify_flags: 

1770 for flag in self.ssl_exclude_verify_flags: 

1771 context.verify_flags &= ~flag 

1772 if self.certfile or self.keyfile: 

1773 context.load_cert_chain( 

1774 certfile=self.certfile, 

1775 keyfile=self.keyfile, 

1776 password=self.certificate_password, 

1777 ) 

1778 if ( 

1779 self.ca_certs is not None 

1780 or self.ca_path is not None 

1781 or self.ca_data is not None 

1782 ): 

1783 context.load_verify_locations( 

1784 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1785 ) 

1786 if self.ssl_min_version is not None: 

1787 context.minimum_version = self.ssl_min_version 

1788 if self.ssl_ciphers: 

1789 context.set_ciphers(self.ssl_ciphers) 

1790 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: 

1791 raise RedisError("cryptography is not installed.") 

1792 

1793 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: 

1794 raise RedisError( 

1795 "Either an OCSP staple or pure OCSP connection must be validated " 

1796 "- not both." 

1797 ) 

1798 

1799 sslsock = context.wrap_socket(sock, server_hostname=self.host) 

1800 

1801 # validation for the stapled case 

1802 if self.ssl_validate_ocsp_stapled: 

1803 import OpenSSL 

1804 

1805 from .ocsp import ocsp_staple_verifier 

1806 

1807 # if a context is provided use it - otherwise, a basic context 

1808 if self.ssl_ocsp_context is None: 

1809 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) 

1810 staple_ctx.use_certificate_file(self.certfile) 

1811 staple_ctx.use_privatekey_file(self.keyfile) 

1812 else: 

1813 staple_ctx = self.ssl_ocsp_context 

1814 

1815 staple_ctx.set_ocsp_client_callback( 

1816 ocsp_staple_verifier, self.ssl_ocsp_expected_cert 

1817 ) 

1818 

1819 # need another socket 

1820 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) 

1821 con.request_ocsp() 

1822 con.connect((self.host, self.port)) 

1823 con.do_handshake() 

1824 con.shutdown() 

1825 return sslsock 

1826 

1827 # pure ocsp validation 

1828 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: 

1829 from .ocsp import OCSPVerifier 

1830 

1831 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) 

1832 if o.is_valid(): 

1833 return sslsock 

1834 else: 

1835 raise ConnectionError("ocsp validation error") 

1836 return sslsock 

1837 

1838 

1839class UnixDomainSocketConnection(AbstractConnection): 

1840 "Manages UDS communication to and from a Redis server" 

1841 

1842 def __init__(self, path="", socket_timeout=None, **kwargs): 

1843 super().__init__(**kwargs) 

1844 self.path = path 

1845 self.socket_timeout = socket_timeout 

1846 

1847 def repr_pieces(self): 

1848 pieces = [("path", self.path), ("db", self.db)] 

1849 if self.client_name: 

1850 pieces.append(("client_name", self.client_name)) 

1851 return pieces 

1852 

1853 def _connect(self): 

1854 "Create a Unix domain socket connection" 

1855 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

1856 sock.settimeout(self.socket_connect_timeout) 

1857 try: 

1858 sock.connect(self.path) 

1859 except OSError: 

1860 # Prevent ResourceWarnings for unclosed sockets. 

1861 try: 

1862 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1863 except OSError: 

1864 pass 

1865 sock.close() 

1866 raise 

1867 sock.settimeout(self.socket_timeout) 

1868 return sock 

1869 

1870 def _host_error(self): 

1871 return self.path 

1872 

1873 

1874FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1875 

1876 

1877def to_bool(value): 

1878 if value is None or value == "": 

1879 return None 

1880 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1881 return False 

1882 return bool(value) 

1883 

1884 

1885def parse_ssl_verify_flags(value): 

1886 # flags are passed in as a string representation of a list, 

1887 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

1888 verify_flags_str = value.replace("[", "").replace("]", "") 

1889 

1890 verify_flags = [] 

1891 for flag in verify_flags_str.split(","): 

1892 flag = flag.strip() 

1893 if not hasattr(VerifyFlags, flag): 

1894 raise ValueError(f"Invalid ssl verify flag: {flag}") 

1895 verify_flags.append(getattr(VerifyFlags, flag)) 

1896 return verify_flags 

1897 

1898 

1899URL_QUERY_ARGUMENT_PARSERS = { 

1900 "db": int, 

1901 "socket_timeout": float, 

1902 "socket_connect_timeout": float, 

1903 "socket_keepalive": to_bool, 

1904 "retry_on_timeout": to_bool, 

1905 "retry_on_error": list, 

1906 "max_connections": int, 

1907 "health_check_interval": int, 

1908 "ssl_check_hostname": to_bool, 

1909 "ssl_include_verify_flags": parse_ssl_verify_flags, 

1910 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

1911 "timeout": float, 

1912} 

1913 

1914 

1915def parse_url(url): 

1916 if not ( 

1917 url.startswith("redis://") 

1918 or url.startswith("rediss://") 

1919 or url.startswith("unix://") 

1920 ): 

1921 raise ValueError( 

1922 "Redis URL must specify one of the following " 

1923 "schemes (redis://, rediss://, unix://)" 

1924 ) 

1925 

1926 url = urlparse(url) 

1927 kwargs = {} 

1928 

1929 for name, value in parse_qs(url.query).items(): 

1930 if value and len(value) > 0: 

1931 value = unquote(value[0]) 

1932 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1933 if parser: 

1934 try: 

1935 kwargs[name] = parser(value) 

1936 except (TypeError, ValueError): 

1937 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1938 else: 

1939 kwargs[name] = value 

1940 

1941 if url.username: 

1942 kwargs["username"] = unquote(url.username) 

1943 if url.password: 

1944 kwargs["password"] = unquote(url.password) 

1945 

1946 # We only support redis://, rediss:// and unix:// schemes. 

1947 if url.scheme == "unix": 

1948 if url.path: 

1949 kwargs["path"] = unquote(url.path) 

1950 kwargs["connection_class"] = UnixDomainSocketConnection 

1951 

1952 else: # implied: url.scheme in ("redis", "rediss"): 

1953 if url.hostname: 

1954 kwargs["host"] = unquote(url.hostname) 

1955 if url.port: 

1956 kwargs["port"] = int(url.port) 

1957 

1958 # If there's a path argument, use it as the db argument if a 

1959 # querystring value wasn't specified 

1960 if url.path and "db" not in kwargs: 

1961 try: 

1962 kwargs["db"] = int(unquote(url.path).replace("/", "")) 

1963 except (AttributeError, ValueError): 

1964 pass 

1965 

1966 if url.scheme == "rediss": 

1967 kwargs["connection_class"] = SSLConnection 

1968 

1969 return kwargs 

1970 

1971 

1972_CP = TypeVar("_CP", bound="ConnectionPool") 

1973 

1974 

1975class ConnectionPoolInterface(ABC): 

1976 @abstractmethod 

1977 def get_protocol(self): 

1978 pass 

1979 

1980 @abstractmethod 

1981 def reset(self): 

1982 pass 

1983 

1984 @abstractmethod 

1985 @deprecated_args( 

1986 args_to_warn=["*"], 

1987 reason="Use get_connection() without args instead", 

1988 version="5.3.0", 

1989 ) 

1990 def get_connection( 

1991 self, command_name: Optional[str], *keys, **options 

1992 ) -> ConnectionInterface: 

1993 pass 

1994 

1995 @abstractmethod 

1996 def get_encoder(self): 

1997 pass 

1998 

1999 @abstractmethod 

2000 def release(self, connection: ConnectionInterface): 

2001 pass 

2002 

2003 @abstractmethod 

2004 def disconnect(self, inuse_connections: bool = True): 

2005 pass 

2006 

2007 @abstractmethod 

2008 def close(self): 

2009 pass 

2010 

2011 @abstractmethod 

2012 def set_retry(self, retry: Retry): 

2013 pass 

2014 

2015 @abstractmethod 

2016 def re_auth_callback(self, token: TokenInterface): 

2017 pass 

2018 

2019 

2020class MaintNotificationsAbstractConnectionPool: 

2021 """ 

2022 Abstract class for handling maintenance notifications logic. 

2023 This class is mixed into the ConnectionPool classes. 

2024 

2025 This class is not intended to be used directly! 

2026 

2027 All logic related to maintenance notifications and 

2028 connection pool handling is encapsulated in this class. 

2029 """ 

2030 

2031 def __init__( 

2032 self, 

2033 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

2034 **kwargs, 

2035 ): 

2036 # Initialize maintenance notifications 

2037 is_protocol_supported = kwargs.get("protocol") in [3, "3"] 

2038 if maint_notifications_config is None and is_protocol_supported: 

2039 maint_notifications_config = MaintNotificationsConfig() 

2040 

2041 if maint_notifications_config and maint_notifications_config.enabled: 

2042 if not is_protocol_supported: 

2043 raise RedisError( 

2044 "Maintenance notifications handlers on connection are only supported with RESP version 3" 

2045 ) 

2046 

2047 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2048 self, maint_notifications_config 

2049 ) 

2050 

2051 self._update_connection_kwargs_for_maint_notifications( 

2052 self._maint_notifications_pool_handler 

2053 ) 

2054 else: 

2055 self._maint_notifications_pool_handler = None 

2056 

2057 @property 

2058 @abstractmethod 

2059 def connection_kwargs(self) -> Dict[str, Any]: 

2060 pass 

2061 

2062 @connection_kwargs.setter 

2063 @abstractmethod 

2064 def connection_kwargs(self, value: Dict[str, Any]): 

2065 pass 

2066 

2067 @abstractmethod 

2068 def _get_pool_lock(self) -> threading.RLock: 

2069 pass 

2070 

2071 @abstractmethod 

2072 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]: 

2073 pass 

2074 

2075 @abstractmethod 

2076 def _get_in_use_connections( 

2077 self, 

2078 ) -> Iterable["MaintNotificationsAbstractConnection"]: 

2079 pass 

2080 

2081 def maint_notifications_enabled(self): 

2082 """ 

2083 Returns: 

2084 True if the maintenance notifications are enabled, False otherwise. 

2085 The maintenance notifications config is stored in the pool handler. 

2086 If the pool handler is not set, the maintenance notifications are not enabled. 

2087 """ 

2088 maint_notifications_config = ( 

2089 self._maint_notifications_pool_handler.config 

2090 if self._maint_notifications_pool_handler 

2091 else None 

2092 ) 

2093 

2094 return maint_notifications_config and maint_notifications_config.enabled 

2095 

2096 def update_maint_notifications_config( 

2097 self, maint_notifications_config: MaintNotificationsConfig 

2098 ): 

2099 """ 

2100 Updates the maintenance notifications configuration. 

2101 This method should be called only if the pool was created 

2102 without enabling the maintenance notifications and 

2103 in a later point in time maintenance notifications 

2104 are requested to be enabled. 

2105 """ 

2106 if ( 

2107 self.maint_notifications_enabled() 

2108 and not maint_notifications_config.enabled 

2109 ): 

2110 raise ValueError( 

2111 "Cannot disable maintenance notifications after enabling them" 

2112 ) 

2113 # first update pool settings 

2114 if not self._maint_notifications_pool_handler: 

2115 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2116 self, maint_notifications_config 

2117 ) 

2118 else: 

2119 self._maint_notifications_pool_handler.config = maint_notifications_config 

2120 

2121 # then update connection kwargs and existing connections 

2122 self._update_connection_kwargs_for_maint_notifications( 

2123 self._maint_notifications_pool_handler 

2124 ) 

2125 self._update_maint_notifications_configs_for_connections( 

2126 self._maint_notifications_pool_handler 

2127 ) 

2128 

2129 def _update_connection_kwargs_for_maint_notifications( 

2130 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

2131 ): 

2132 """ 

2133 Update the connection kwargs for all future connections. 

2134 """ 

2135 if not self.maint_notifications_enabled(): 

2136 return 

2137 

2138 self.connection_kwargs.update( 

2139 { 

2140 "maint_notifications_pool_handler": maint_notifications_pool_handler, 

2141 "maint_notifications_config": maint_notifications_pool_handler.config, 

2142 } 

2143 ) 

2144 

2145 # Store original connection parameters for maintenance notifications. 

2146 if self.connection_kwargs.get("orig_host_address", None) is None: 

2147 # If orig_host_address is None it means we haven't 

2148 # configured the original values yet 

2149 self.connection_kwargs.update( 

2150 { 

2151 "orig_host_address": self.connection_kwargs.get("host"), 

2152 "orig_socket_timeout": self.connection_kwargs.get( 

2153 "socket_timeout", None 

2154 ), 

2155 "orig_socket_connect_timeout": self.connection_kwargs.get( 

2156 "socket_connect_timeout", None 

2157 ), 

2158 } 

2159 ) 

2160 

2161 def _update_maint_notifications_configs_for_connections( 

2162 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

2163 ): 

2164 """Update the maintenance notifications config for all connections in the pool.""" 

2165 with self._get_pool_lock(): 

2166 for conn in self._get_free_connections(): 

2167 conn.set_maint_notifications_pool_handler_for_connection( 

2168 maint_notifications_pool_handler 

2169 ) 

2170 conn.maint_notifications_config = ( 

2171 maint_notifications_pool_handler.config 

2172 ) 

2173 conn.disconnect() 

2174 for conn in self._get_in_use_connections(): 

2175 conn.set_maint_notifications_pool_handler_for_connection( 

2176 maint_notifications_pool_handler 

2177 ) 

2178 conn.maint_notifications_config = ( 

2179 maint_notifications_pool_handler.config 

2180 ) 

2181 conn.mark_for_reconnect() 

2182 

2183 def _should_update_connection( 

2184 self, 

2185 conn: "MaintNotificationsAbstractConnection", 

2186 matching_pattern: Literal[ 

2187 "connected_address", "configured_address", "notification_hash" 

2188 ] = "connected_address", 

2189 matching_address: Optional[str] = None, 

2190 matching_notification_hash: Optional[int] = None, 

2191 ) -> bool: 

2192 """ 

2193 Check if the connection should be updated based on the matching criteria. 

2194 """ 

2195 if matching_pattern == "connected_address": 

2196 if matching_address and conn.getpeername() != matching_address: 

2197 return False 

2198 elif matching_pattern == "configured_address": 

2199 if matching_address and conn.host != matching_address: 

2200 return False 

2201 elif matching_pattern == "notification_hash": 

2202 if ( 

2203 matching_notification_hash 

2204 and conn.maintenance_notification_hash != matching_notification_hash 

2205 ): 

2206 return False 

2207 return True 

2208 

2209 def update_connection_settings( 

2210 self, 

2211 conn: "MaintNotificationsAbstractConnection", 

2212 state: Optional["MaintenanceState"] = None, 

2213 maintenance_notification_hash: Optional[int] = None, 

2214 host_address: Optional[str] = None, 

2215 relaxed_timeout: Optional[float] = None, 

2216 update_notification_hash: bool = False, 

2217 reset_host_address: bool = False, 

2218 reset_relaxed_timeout: bool = False, 

2219 ): 

2220 """ 

2221 Update the settings for a single connection. 

2222 """ 

2223 if state: 

2224 conn.maintenance_state = state 

2225 

2226 if update_notification_hash: 

2227 # update the notification hash only if requested 

2228 conn.maintenance_notification_hash = maintenance_notification_hash 

2229 

2230 if host_address is not None: 

2231 conn.set_tmp_settings(tmp_host_address=host_address) 

2232 

2233 if relaxed_timeout is not None: 

2234 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout) 

2235 

2236 if reset_relaxed_timeout or reset_host_address: 

2237 conn.reset_tmp_settings( 

2238 reset_host_address=reset_host_address, 

2239 reset_relaxed_timeout=reset_relaxed_timeout, 

2240 ) 

2241 

2242 conn.update_current_socket_timeout(relaxed_timeout) 

2243 

2244 def update_connections_settings( 

2245 self, 

2246 state: Optional["MaintenanceState"] = None, 

2247 maintenance_notification_hash: Optional[int] = None, 

2248 host_address: Optional[str] = None, 

2249 relaxed_timeout: Optional[float] = None, 

2250 matching_address: Optional[str] = None, 

2251 matching_notification_hash: Optional[int] = None, 

2252 matching_pattern: Literal[ 

2253 "connected_address", "configured_address", "notification_hash" 

2254 ] = "connected_address", 

2255 update_notification_hash: bool = False, 

2256 reset_host_address: bool = False, 

2257 reset_relaxed_timeout: bool = False, 

2258 include_free_connections: bool = True, 

2259 ): 

2260 """ 

2261 Update the settings for all matching connections in the pool. 

2262 

2263 This method does not create new connections. 

2264 This method does not affect the connection kwargs. 

2265 

2266 :param state: The maintenance state to set for the connection. 

2267 :param maintenance_notification_hash: The hash of the maintenance notification 

2268 to set for the connection. 

2269 :param host_address: The host address to set for the connection. 

2270 :param relaxed_timeout: The relaxed timeout to set for the connection. 

2271 :param matching_address: The address to match for the connection. 

2272 :param matching_notification_hash: The notification hash to match for the connection. 

2273 :param matching_pattern: The pattern to match for the connection. 

2274 :param update_notification_hash: Whether to update the notification hash for the connection. 

2275 :param reset_host_address: Whether to reset the host address to the original address. 

2276 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout. 

2277 :param include_free_connections: Whether to include free/available connections. 

2278 """ 

2279 with self._get_pool_lock(): 

2280 for conn in self._get_in_use_connections(): 

2281 if self._should_update_connection( 

2282 conn, 

2283 matching_pattern, 

2284 matching_address, 

2285 matching_notification_hash, 

2286 ): 

2287 self.update_connection_settings( 

2288 conn, 

2289 state=state, 

2290 maintenance_notification_hash=maintenance_notification_hash, 

2291 host_address=host_address, 

2292 relaxed_timeout=relaxed_timeout, 

2293 update_notification_hash=update_notification_hash, 

2294 reset_host_address=reset_host_address, 

2295 reset_relaxed_timeout=reset_relaxed_timeout, 

2296 ) 

2297 

2298 if include_free_connections: 

2299 for conn in self._get_free_connections(): 

2300 if self._should_update_connection( 

2301 conn, 

2302 matching_pattern, 

2303 matching_address, 

2304 matching_notification_hash, 

2305 ): 

2306 self.update_connection_settings( 

2307 conn, 

2308 state=state, 

2309 maintenance_notification_hash=maintenance_notification_hash, 

2310 host_address=host_address, 

2311 relaxed_timeout=relaxed_timeout, 

2312 update_notification_hash=update_notification_hash, 

2313 reset_host_address=reset_host_address, 

2314 reset_relaxed_timeout=reset_relaxed_timeout, 

2315 ) 

2316 

2317 def update_connection_kwargs( 

2318 self, 

2319 **kwargs, 

2320 ): 

2321 """ 

2322 Update the connection kwargs for all future connections. 

2323 

2324 This method updates the connection kwargs for all future connections created by the pool. 

2325 Existing connections are not affected. 

2326 """ 

2327 self.connection_kwargs.update(kwargs) 

2328 

2329 def update_active_connections_for_reconnect( 

2330 self, 

2331 moving_address_src: Optional[str] = None, 

2332 ): 

2333 """ 

2334 Mark all active connections for reconnect. 

2335 This is used when a cluster node is migrated to a different address. 

2336 

2337 :param moving_address_src: The address of the node that is being moved. 

2338 """ 

2339 with self._get_pool_lock(): 

2340 for conn in self._get_in_use_connections(): 

2341 if self._should_update_connection( 

2342 conn, "connected_address", moving_address_src 

2343 ): 

2344 conn.mark_for_reconnect() 

2345 

2346 def disconnect_free_connections( 

2347 self, 

2348 moving_address_src: Optional[str] = None, 

2349 ): 

2350 """ 

2351 Disconnect all free/available connections. 

2352 This is used when a cluster node is migrated to a different address. 

2353 

2354 :param moving_address_src: The address of the node that is being moved. 

2355 """ 

2356 with self._get_pool_lock(): 

2357 for conn in self._get_free_connections(): 

2358 if self._should_update_connection( 

2359 conn, "connected_address", moving_address_src 

2360 ): 

2361 conn.disconnect() 

2362 

2363 

2364class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface): 

2365 """ 

2366 Create a connection pool. ``If max_connections`` is set, then this 

2367 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's 

2368 limit is reached. 

2369 

2370 By default, TCP connections are created unless ``connection_class`` 

2371 is specified. Use class:`.UnixDomainSocketConnection` for 

2372 unix sockets. 

2373 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

2374 

2375 If ``maint_notifications_config`` is provided, the connection pool will support 

2376 maintenance notifications. 

2377 Maintenance notifications are supported only with RESP3. 

2378 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3, 

2379 the maintenance notifications will be enabled by default. 

2380 

2381 Any additional keyword arguments are passed to the constructor of 

2382 ``connection_class``. 

2383 """ 

2384 

2385 @classmethod 

2386 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

2387 """ 

2388 Return a connection pool configured from the given URL. 

2389 

2390 For example:: 

2391 

2392 redis://[[username]:[password]]@localhost:6379/0 

2393 rediss://[[username]:[password]]@localhost:6379/0 

2394 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

2395 

2396 Three URL schemes are supported: 

2397 

2398 - `redis://` creates a TCP socket connection. See more at: 

2399 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

2400 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

2401 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

2402 - ``unix://``: creates a Unix Domain Socket connection. 

2403 

2404 The username, password, hostname, path and all querystring values 

2405 are passed through urllib.parse.unquote in order to replace any 

2406 percent-encoded values with their corresponding characters. 

2407 

2408 There are several ways to specify a database number. The first value 

2409 found will be used: 

2410 

2411 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

2412 2. If using the redis:// or rediss:// schemes, the path argument 

2413 of the url, e.g. redis://localhost/0 

2414 3. A ``db`` keyword argument to this function. 

2415 

2416 If none of these options are specified, the default db=0 is used. 

2417 

2418 All querystring options are cast to their appropriate Python types. 

2419 Boolean arguments can be specified with string values "True"/"False" 

2420 or "Yes"/"No". Values that cannot be properly cast cause a 

2421 ``ValueError`` to be raised. Once parsed, the querystring arguments 

2422 and keyword arguments are passed to the ``ConnectionPool``'s 

2423 class initializer. In the case of conflicting arguments, querystring 

2424 arguments always win. 

2425 """ 

2426 url_options = parse_url(url) 

2427 

2428 if "connection_class" in kwargs: 

2429 url_options["connection_class"] = kwargs["connection_class"] 

2430 

2431 kwargs.update(url_options) 

2432 return cls(**kwargs) 

2433 

2434 def __init__( 

2435 self, 

2436 connection_class=Connection, 

2437 max_connections: Optional[int] = None, 

2438 cache_factory: Optional[CacheFactoryInterface] = None, 

2439 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

2440 **connection_kwargs, 

2441 ): 

2442 max_connections = max_connections or 2**31 

2443 if not isinstance(max_connections, int) or max_connections < 0: 

2444 raise ValueError('"max_connections" must be a positive integer') 

2445 

2446 self.connection_class = connection_class 

2447 self._connection_kwargs = connection_kwargs 

2448 self.max_connections = max_connections 

2449 self.cache = None 

2450 self._cache_factory = cache_factory 

2451 

2452 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): 

2453 if self._connection_kwargs.get("protocol") not in [3, "3"]: 

2454 raise RedisError("Client caching is only supported with RESP version 3") 

2455 

2456 cache = self._connection_kwargs.get("cache") 

2457 

2458 if cache is not None: 

2459 if not isinstance(cache, CacheInterface): 

2460 raise ValueError("Cache must implement CacheInterface") 

2461 

2462 self.cache = cache 

2463 else: 

2464 if self._cache_factory is not None: 

2465 self.cache = self._cache_factory.get_cache() 

2466 else: 

2467 self.cache = CacheFactory( 

2468 self._connection_kwargs.get("cache_config") 

2469 ).get_cache() 

2470 

2471 connection_kwargs.pop("cache", None) 

2472 connection_kwargs.pop("cache_config", None) 

2473 

2474 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None) 

2475 if self._event_dispatcher is None: 

2476 self._event_dispatcher = EventDispatcher() 

2477 

2478 # a lock to protect the critical section in _checkpid(). 

2479 # this lock is acquired when the process id changes, such as 

2480 # after a fork. during this time, multiple threads in the child 

2481 # process could attempt to acquire this lock. the first thread 

2482 # to acquire the lock will reset the data structures and lock 

2483 # object of this pool. subsequent threads acquiring this lock 

2484 # will notice the first thread already did the work and simply 

2485 # release the lock. 

2486 

2487 self._fork_lock = threading.RLock() 

2488 self._lock = threading.RLock() 

2489 

2490 MaintNotificationsAbstractConnectionPool.__init__( 

2491 self, 

2492 maint_notifications_config=maint_notifications_config, 

2493 **connection_kwargs, 

2494 ) 

2495 

2496 self.reset() 

2497 

2498 def __repr__(self) -> str: 

2499 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

2500 return ( 

2501 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

2502 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

2503 f"({conn_kwargs})>)>" 

2504 ) 

2505 

2506 @property 

2507 def connection_kwargs(self) -> Dict[str, Any]: 

2508 return self._connection_kwargs 

2509 

2510 @connection_kwargs.setter 

2511 def connection_kwargs(self, value: Dict[str, Any]): 

2512 self._connection_kwargs = value 

2513 

2514 def get_protocol(self): 

2515 """ 

2516 Returns: 

2517 The RESP protocol version, or ``None`` if the protocol is not specified, 

2518 in which case the server default will be used. 

2519 """ 

2520 return self.connection_kwargs.get("protocol", None) 

2521 

2522 def reset(self) -> None: 

2523 self._created_connections = 0 

2524 self._available_connections = [] 

2525 self._in_use_connections = set() 

2526 

2527 # this must be the last operation in this method. while reset() is 

2528 # called when holding _fork_lock, other threads in this process 

2529 # can call _checkpid() which compares self.pid and os.getpid() without 

2530 # holding any lock (for performance reasons). keeping this assignment 

2531 # as the last operation ensures that those other threads will also 

2532 # notice a pid difference and block waiting for the first thread to 

2533 # release _fork_lock. when each of these threads eventually acquire 

2534 # _fork_lock, they will notice that another thread already called 

2535 # reset() and they will immediately release _fork_lock and continue on. 

2536 self.pid = os.getpid() 

2537 

2538 def _checkpid(self) -> None: 

2539 # _checkpid() attempts to keep ConnectionPool fork-safe on modern 

2540 # systems. this is called by all ConnectionPool methods that 

2541 # manipulate the pool's state such as get_connection() and release(). 

2542 # 

2543 # _checkpid() determines whether the process has forked by comparing 

2544 # the current process id to the process id saved on the ConnectionPool 

2545 # instance. if these values are the same, _checkpid() simply returns. 

2546 # 

2547 # when the process ids differ, _checkpid() assumes that the process 

2548 # has forked and that we're now running in the child process. the child 

2549 # process cannot use the parent's file descriptors (e.g., sockets). 

2550 # therefore, when _checkpid() sees the process id change, it calls 

2551 # reset() in order to reinitialize the child's ConnectionPool. this 

2552 # will cause the child to make all new connection objects. 

2553 # 

2554 # _checkpid() is protected by self._fork_lock to ensure that multiple 

2555 # threads in the child process do not call reset() multiple times. 

2556 # 

2557 # there is an extremely small chance this could fail in the following 

2558 # scenario: 

2559 # 1. process A calls _checkpid() for the first time and acquires 

2560 # self._fork_lock. 

2561 # 2. while holding self._fork_lock, process A forks (the fork() 

2562 # could happen in a different thread owned by process A) 

2563 # 3. process B (the forked child process) inherits the 

2564 # ConnectionPool's state from the parent. that state includes 

2565 # a locked _fork_lock. process B will not be notified when 

2566 # process A releases the _fork_lock and will thus never be 

2567 # able to acquire the _fork_lock. 

2568 # 

2569 # to mitigate this possible deadlock, _checkpid() will only wait 5 

2570 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in 

2571 # that time it is assumed that the child is deadlocked and a 

2572 # redis.ChildDeadlockedError error is raised. 

2573 if self.pid != os.getpid(): 

2574 acquired = self._fork_lock.acquire(timeout=5) 

2575 if not acquired: 

2576 raise ChildDeadlockedError 

2577 # reset() the instance for the new process if another thread 

2578 # hasn't already done so 

2579 try: 

2580 if self.pid != os.getpid(): 

2581 self.reset() 

2582 finally: 

2583 self._fork_lock.release() 

2584 

2585 @deprecated_args( 

2586 args_to_warn=["*"], 

2587 reason="Use get_connection() without args instead", 

2588 version="5.3.0", 

2589 ) 

2590 def get_connection(self, command_name=None, *keys, **options) -> "Connection": 

2591 "Get a connection from the pool" 

2592 

2593 self._checkpid() 

2594 with self._lock: 

2595 try: 

2596 connection = self._available_connections.pop() 

2597 except IndexError: 

2598 connection = self.make_connection() 

2599 self._in_use_connections.add(connection) 

2600 

2601 try: 

2602 # ensure this connection is connected to Redis 

2603 connection.connect() 

2604 # connections that the pool provides should be ready to send 

2605 # a command. if not, the connection was either returned to the 

2606 # pool before all data has been read or the socket has been 

2607 # closed. either way, reconnect and verify everything is good. 

2608 try: 

2609 if ( 

2610 connection.can_read() 

2611 and self.cache is None 

2612 and not self.maint_notifications_enabled() 

2613 ): 

2614 raise ConnectionError("Connection has data") 

2615 except (ConnectionError, TimeoutError, OSError): 

2616 connection.disconnect() 

2617 connection.connect() 

2618 if connection.can_read(): 

2619 raise ConnectionError("Connection not ready") 

2620 except BaseException: 

2621 # release the connection back to the pool so that we don't 

2622 # leak it 

2623 self.release(connection) 

2624 raise 

2625 return connection 

2626 

2627 def get_encoder(self) -> Encoder: 

2628 "Return an encoder based on encoding settings" 

2629 kwargs = self.connection_kwargs 

2630 return Encoder( 

2631 encoding=kwargs.get("encoding", "utf-8"), 

2632 encoding_errors=kwargs.get("encoding_errors", "strict"), 

2633 decode_responses=kwargs.get("decode_responses", False), 

2634 ) 

2635 

2636 def make_connection(self) -> "ConnectionInterface": 

2637 "Create a new connection" 

2638 if self._created_connections >= self.max_connections: 

2639 raise MaxConnectionsError("Too many connections") 

2640 self._created_connections += 1 

2641 

2642 kwargs = dict(self.connection_kwargs) 

2643 

2644 if self.cache is not None: 

2645 return CacheProxyConnection( 

2646 self.connection_class(**kwargs), self.cache, self._lock 

2647 ) 

2648 return self.connection_class(**kwargs) 

2649 

2650 def release(self, connection: "Connection") -> None: 

2651 "Releases the connection back to the pool" 

2652 self._checkpid() 

2653 with self._lock: 

2654 try: 

2655 self._in_use_connections.remove(connection) 

2656 except KeyError: 

2657 # Gracefully fail when a connection is returned to this pool 

2658 # that the pool doesn't actually own 

2659 return 

2660 

2661 if self.owns_connection(connection): 

2662 if connection.should_reconnect(): 

2663 connection.disconnect() 

2664 self._available_connections.append(connection) 

2665 self._event_dispatcher.dispatch( 

2666 AfterConnectionReleasedEvent(connection) 

2667 ) 

2668 else: 

2669 # Pool doesn't own this connection, do not add it back 

2670 # to the pool. 

2671 # The created connections count should not be changed, 

2672 # because the connection was not created by the pool. 

2673 connection.disconnect() 

2674 return 

2675 

2676 def owns_connection(self, connection: "Connection") -> int: 

2677 return connection.pid == self.pid 

2678 

2679 def disconnect(self, inuse_connections: bool = True) -> None: 

2680 """ 

2681 Disconnects connections in the pool 

2682 

2683 If ``inuse_connections`` is True, disconnect connections that are 

2684 currently in use, potentially by other threads. Otherwise only disconnect 

2685 connections that are idle in the pool. 

2686 """ 

2687 self._checkpid() 

2688 with self._lock: 

2689 if inuse_connections: 

2690 connections = chain( 

2691 self._available_connections, self._in_use_connections 

2692 ) 

2693 else: 

2694 connections = self._available_connections 

2695 

2696 for connection in connections: 

2697 connection.disconnect() 

2698 

2699 def close(self) -> None: 

2700 """Close the pool, disconnecting all connections""" 

2701 self.disconnect() 

2702 

2703 def set_retry(self, retry: Retry) -> None: 

2704 self.connection_kwargs.update({"retry": retry}) 

2705 for conn in self._available_connections: 

2706 conn.retry = retry 

2707 for conn in self._in_use_connections: 

2708 conn.retry = retry 

2709 

2710 def re_auth_callback(self, token: TokenInterface): 

2711 with self._lock: 

2712 for conn in self._available_connections: 

2713 conn.retry.call_with_retry( 

2714 lambda: conn.send_command( 

2715 "AUTH", token.try_get("oid"), token.get_value() 

2716 ), 

2717 lambda error: self._mock(error), 

2718 ) 

2719 conn.retry.call_with_retry( 

2720 lambda: conn.read_response(), lambda error: self._mock(error) 

2721 ) 

2722 for conn in self._in_use_connections: 

2723 conn.set_re_auth_token(token) 

2724 

2725 def _get_pool_lock(self): 

2726 return self._lock 

2727 

2728 def _get_free_connections(self): 

2729 with self._lock: 

2730 return self._available_connections 

2731 

2732 def _get_in_use_connections(self): 

2733 with self._lock: 

2734 return self._in_use_connections 

2735 

2736 async def _mock(self, error: RedisError): 

2737 """ 

2738 Dummy functions, needs to be passed as error callback to retry object. 

2739 :param error: 

2740 :return: 

2741 """ 

2742 pass 

2743 

2744 

2745class BlockingConnectionPool(ConnectionPool): 

2746 """ 

2747 Thread-safe blocking connection pool:: 

2748 

2749 >>> from redis.client import Redis 

2750 >>> client = Redis(connection_pool=BlockingConnectionPool()) 

2751 

2752 It performs the same function as the default 

2753 :py:class:`~redis.ConnectionPool` implementation, in that, 

2754 it maintains a pool of reusable connections that can be shared by 

2755 multiple redis clients (safely across threads if required). 

2756 

2757 The difference is that, in the event that a client tries to get a 

2758 connection from the pool when all of connections are in use, rather than 

2759 raising a :py:class:`~redis.ConnectionError` (as the default 

2760 :py:class:`~redis.ConnectionPool` implementation does), it 

2761 makes the client wait ("blocks") for a specified number of seconds until 

2762 a connection becomes available. 

2763 

2764 Use ``max_connections`` to increase / decrease the pool size:: 

2765 

2766 >>> pool = BlockingConnectionPool(max_connections=10) 

2767 

2768 Use ``timeout`` to tell it either how many seconds to wait for a connection 

2769 to become available, or to block forever: 

2770 

2771 >>> # Block forever. 

2772 >>> pool = BlockingConnectionPool(timeout=None) 

2773 

2774 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

2775 >>> # not available. 

2776 >>> pool = BlockingConnectionPool(timeout=5) 

2777 """ 

2778 

2779 def __init__( 

2780 self, 

2781 max_connections=50, 

2782 timeout=20, 

2783 connection_class=Connection, 

2784 queue_class=LifoQueue, 

2785 **connection_kwargs, 

2786 ): 

2787 self.queue_class = queue_class 

2788 self.timeout = timeout 

2789 self._in_maintenance = False 

2790 self._locked = False 

2791 super().__init__( 

2792 connection_class=connection_class, 

2793 max_connections=max_connections, 

2794 **connection_kwargs, 

2795 ) 

2796 

2797 def reset(self): 

2798 # Create and fill up a thread safe queue with ``None`` values. 

2799 try: 

2800 if self._in_maintenance: 

2801 self._lock.acquire() 

2802 self._locked = True 

2803 self.pool = self.queue_class(self.max_connections) 

2804 while True: 

2805 try: 

2806 self.pool.put_nowait(None) 

2807 except Full: 

2808 break 

2809 

2810 # Keep a list of actual connection instances so that we can 

2811 # disconnect them later. 

2812 self._connections = [] 

2813 finally: 

2814 if self._locked: 

2815 try: 

2816 self._lock.release() 

2817 except Exception: 

2818 pass 

2819 self._locked = False 

2820 

2821 # this must be the last operation in this method. while reset() is 

2822 # called when holding _fork_lock, other threads in this process 

2823 # can call _checkpid() which compares self.pid and os.getpid() without 

2824 # holding any lock (for performance reasons). keeping this assignment 

2825 # as the last operation ensures that those other threads will also 

2826 # notice a pid difference and block waiting for the first thread to 

2827 # release _fork_lock. when each of these threads eventually acquire 

2828 # _fork_lock, they will notice that another thread already called 

2829 # reset() and they will immediately release _fork_lock and continue on. 

2830 self.pid = os.getpid() 

2831 

2832 def make_connection(self): 

2833 "Make a fresh connection." 

2834 try: 

2835 if self._in_maintenance: 

2836 self._lock.acquire() 

2837 self._locked = True 

2838 

2839 if self.cache is not None: 

2840 connection = CacheProxyConnection( 

2841 self.connection_class(**self.connection_kwargs), 

2842 self.cache, 

2843 self._lock, 

2844 ) 

2845 else: 

2846 connection = self.connection_class(**self.connection_kwargs) 

2847 self._connections.append(connection) 

2848 return connection 

2849 finally: 

2850 if self._locked: 

2851 try: 

2852 self._lock.release() 

2853 except Exception: 

2854 pass 

2855 self._locked = False 

2856 

2857 @deprecated_args( 

2858 args_to_warn=["*"], 

2859 reason="Use get_connection() without args instead", 

2860 version="5.3.0", 

2861 ) 

2862 def get_connection(self, command_name=None, *keys, **options): 

2863 """ 

2864 Get a connection, blocking for ``self.timeout`` until a connection 

2865 is available from the pool. 

2866 

2867 If the connection returned is ``None`` then creates a new connection. 

2868 Because we use a last-in first-out queue, the existing connections 

2869 (having been returned to the pool after the initial ``None`` values 

2870 were added) will be returned before ``None`` values. This means we only 

2871 create new connections when we need to, i.e.: the actual number of 

2872 connections will only increase in response to demand. 

2873 """ 

2874 # Make sure we haven't changed process. 

2875 self._checkpid() 

2876 

2877 # Try and get a connection from the pool. If one isn't available within 

2878 # self.timeout then raise a ``ConnectionError``. 

2879 connection = None 

2880 try: 

2881 if self._in_maintenance: 

2882 self._lock.acquire() 

2883 self._locked = True 

2884 try: 

2885 connection = self.pool.get(block=True, timeout=self.timeout) 

2886 except Empty: 

2887 # Note that this is not caught by the redis client and will be 

2888 # raised unless handled by application code. If you want never to 

2889 raise ConnectionError("No connection available.") 

2890 

2891 # If the ``connection`` is actually ``None`` then that's a cue to make 

2892 # a new connection to add to the pool. 

2893 if connection is None: 

2894 connection = self.make_connection() 

2895 finally: 

2896 if self._locked: 

2897 try: 

2898 self._lock.release() 

2899 except Exception: 

2900 pass 

2901 self._locked = False 

2902 

2903 try: 

2904 # ensure this connection is connected to Redis 

2905 connection.connect() 

2906 # connections that the pool provides should be ready to send 

2907 # a command. if not, the connection was either returned to the 

2908 # pool before all data has been read or the socket has been 

2909 # closed. either way, reconnect and verify everything is good. 

2910 try: 

2911 if connection.can_read(): 

2912 raise ConnectionError("Connection has data") 

2913 except (ConnectionError, TimeoutError, OSError): 

2914 connection.disconnect() 

2915 connection.connect() 

2916 if connection.can_read(): 

2917 raise ConnectionError("Connection not ready") 

2918 except BaseException: 

2919 # release the connection back to the pool so that we don't leak it 

2920 self.release(connection) 

2921 raise 

2922 

2923 return connection 

2924 

2925 def release(self, connection): 

2926 "Releases the connection back to the pool." 

2927 # Make sure we haven't changed process. 

2928 self._checkpid() 

2929 

2930 try: 

2931 if self._in_maintenance: 

2932 self._lock.acquire() 

2933 self._locked = True 

2934 if not self.owns_connection(connection): 

2935 # pool doesn't own this connection. do not add it back 

2936 # to the pool. instead add a None value which is a placeholder 

2937 # that will cause the pool to recreate the connection if 

2938 # its needed. 

2939 connection.disconnect() 

2940 self.pool.put_nowait(None) 

2941 return 

2942 if connection.should_reconnect(): 

2943 connection.disconnect() 

2944 # Put the connection back into the pool. 

2945 try: 

2946 self.pool.put_nowait(connection) 

2947 except Full: 

2948 # perhaps the pool has been reset() after a fork? regardless, 

2949 # we don't want this connection 

2950 pass 

2951 finally: 

2952 if self._locked: 

2953 try: 

2954 self._lock.release() 

2955 except Exception: 

2956 pass 

2957 self._locked = False 

2958 

2959 def disconnect(self, inuse_connections: bool = True): 

2960 "Disconnects either all connections in the pool or just the free connections." 

2961 self._checkpid() 

2962 try: 

2963 if self._in_maintenance: 

2964 self._lock.acquire() 

2965 self._locked = True 

2966 if inuse_connections: 

2967 connections = self._connections 

2968 else: 

2969 connections = self._get_free_connections() 

2970 for connection in connections: 

2971 connection.disconnect() 

2972 finally: 

2973 if self._locked: 

2974 try: 

2975 self._lock.release() 

2976 except Exception: 

2977 pass 

2978 self._locked = False 

2979 

2980 def _get_free_connections(self): 

2981 with self._lock: 

2982 return {conn for conn in self.pool.queue if conn} 

2983 

2984 def _get_in_use_connections(self): 

2985 with self._lock: 

2986 # free connections 

2987 connections_in_queue = {conn for conn in self.pool.queue if conn} 

2988 # in self._connections we keep all created connections 

2989 # so the ones that are not in the queue are the in use ones 

2990 return { 

2991 conn for conn in self._connections if conn not in connections_in_queue 

2992 } 

2993 

2994 def set_in_maintenance(self, in_maintenance: bool): 

2995 """ 

2996 Sets a flag that this Blocking ConnectionPool is in maintenance mode. 

2997 

2998 This is used to prevent new connections from being created while we are in maintenance mode. 

2999 The pool will be in maintenance mode only when we are processing a MOVING notification. 

3000 """ 

3001 self._in_maintenance = in_maintenance