Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1395 statements  

1import copy 

2import os 

3import socket 

4import sys 

5import threading 

6import time 

7import weakref 

8from abc import ABC, abstractmethod 

9from itertools import chain 

10from queue import Empty, Full, LifoQueue 

11from typing import ( 

12 Any, 

13 Callable, 

14 Dict, 

15 Iterable, 

16 List, 

17 Literal, 

18 Optional, 

19 Type, 

20 TypeVar, 

21 Union, 

22) 

23from urllib.parse import parse_qs, unquote, urlparse 

24 

25from redis.cache import ( 

26 CacheEntry, 

27 CacheEntryStatus, 

28 CacheFactory, 

29 CacheFactoryInterface, 

30 CacheInterface, 

31 CacheKey, 

32) 

33 

34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser 

35from .auth.token import TokenInterface 

36from .backoff import NoBackoff 

37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider 

38from .event import AfterConnectionReleasedEvent, EventDispatcher 

39from .exceptions import ( 

40 AuthenticationError, 

41 AuthenticationWrongNumberOfArgsError, 

42 ChildDeadlockedError, 

43 ConnectionError, 

44 DataError, 

45 MaxConnectionsError, 

46 RedisError, 

47 ResponseError, 

48 TimeoutError, 

49) 

50from .maint_notifications import ( 

51 MaintenanceState, 

52 MaintNotificationsConfig, 

53 MaintNotificationsConnectionHandler, 

54 MaintNotificationsPoolHandler, 

55) 

56from .retry import Retry 

57from .utils import ( 

58 CRYPTOGRAPHY_AVAILABLE, 

59 HIREDIS_AVAILABLE, 

60 SSL_AVAILABLE, 

61 compare_versions, 

62 deprecated_args, 

63 ensure_string, 

64 format_error_message, 

65 get_lib_version, 

66 str_if_bytes, 

67) 

68 

69if SSL_AVAILABLE: 

70 import ssl 

71 from ssl import VerifyFlags 

72else: 

73 ssl = None 

74 VerifyFlags = None 

75 

76if HIREDIS_AVAILABLE: 

77 import hiredis 

78 

79SYM_STAR = b"*" 

80SYM_DOLLAR = b"$" 

81SYM_CRLF = b"\r\n" 

82SYM_EMPTY = b"" 

83 

84DEFAULT_RESP_VERSION = 2 

85 

86SENTINEL = object() 

87 

88DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] 

89if HIREDIS_AVAILABLE: 

90 DefaultParser = _HiredisParser 

91else: 

92 DefaultParser = _RESP2Parser 

93 

94 

95class HiredisRespSerializer: 

96 def pack(self, *args: List): 

97 """Pack a series of arguments into the Redis protocol""" 

98 output = [] 

99 

100 if isinstance(args[0], str): 

101 args = tuple(args[0].encode().split()) + args[1:] 

102 elif b" " in args[0]: 

103 args = tuple(args[0].split()) + args[1:] 

104 try: 

105 output.append(hiredis.pack_command(args)) 

106 except TypeError: 

107 _, value, traceback = sys.exc_info() 

108 raise DataError(value).with_traceback(traceback) 

109 

110 return output 

111 

112 

113class PythonRespSerializer: 

114 def __init__(self, buffer_cutoff, encode) -> None: 

115 self._buffer_cutoff = buffer_cutoff 

116 self.encode = encode 

117 

118 def pack(self, *args): 

119 """Pack a series of arguments into the Redis protocol""" 

120 output = [] 

121 # the client might have included 1 or more literal arguments in 

122 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

123 # arguments to be sent separately, so split the first argument 

124 # manually. These arguments should be bytestrings so that they are 

125 # not encoded. 

126 if isinstance(args[0], str): 

127 args = tuple(args[0].encode().split()) + args[1:] 

128 elif b" " in args[0]: 

129 args = tuple(args[0].split()) + args[1:] 

130 

131 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

132 

133 buffer_cutoff = self._buffer_cutoff 

134 for arg in map(self.encode, args): 

135 # to avoid large string mallocs, chunk the command into the 

136 # output list if we're sending large values or memoryviews 

137 arg_length = len(arg) 

138 if ( 

139 len(buff) > buffer_cutoff 

140 or arg_length > buffer_cutoff 

141 or isinstance(arg, memoryview) 

142 ): 

143 buff = SYM_EMPTY.join( 

144 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

145 ) 

146 output.append(buff) 

147 output.append(arg) 

148 buff = SYM_CRLF 

149 else: 

150 buff = SYM_EMPTY.join( 

151 ( 

152 buff, 

153 SYM_DOLLAR, 

154 str(arg_length).encode(), 

155 SYM_CRLF, 

156 arg, 

157 SYM_CRLF, 

158 ) 

159 ) 

160 output.append(buff) 

161 return output 

162 

163 

164class ConnectionInterface: 

165 @abstractmethod 

166 def repr_pieces(self): 

167 pass 

168 

169 @abstractmethod 

170 def register_connect_callback(self, callback): 

171 pass 

172 

173 @abstractmethod 

174 def deregister_connect_callback(self, callback): 

175 pass 

176 

177 @abstractmethod 

178 def set_parser(self, parser_class): 

179 pass 

180 

181 @abstractmethod 

182 def get_protocol(self): 

183 pass 

184 

185 @abstractmethod 

186 def connect(self): 

187 pass 

188 

189 @abstractmethod 

190 def on_connect(self): 

191 pass 

192 

193 @abstractmethod 

194 def disconnect(self, *args): 

195 pass 

196 

197 @abstractmethod 

198 def check_health(self): 

199 pass 

200 

201 @abstractmethod 

202 def send_packed_command(self, command, check_health=True): 

203 pass 

204 

205 @abstractmethod 

206 def send_command(self, *args, **kwargs): 

207 pass 

208 

209 @abstractmethod 

210 def can_read(self, timeout=0): 

211 pass 

212 

213 @abstractmethod 

214 def read_response( 

215 self, 

216 disable_decoding=False, 

217 *, 

218 disconnect_on_error=True, 

219 push_request=False, 

220 ): 

221 pass 

222 

223 @abstractmethod 

224 def pack_command(self, *args): 

225 pass 

226 

227 @abstractmethod 

228 def pack_commands(self, commands): 

229 pass 

230 

231 @property 

232 @abstractmethod 

233 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

234 pass 

235 

236 @abstractmethod 

237 def set_re_auth_token(self, token: TokenInterface): 

238 pass 

239 

240 @abstractmethod 

241 def re_auth(self): 

242 pass 

243 

244 @abstractmethod 

245 def mark_for_reconnect(self): 

246 """ 

247 Mark the connection to be reconnected on the next command. 

248 This is useful when a connection is moved to a different node. 

249 """ 

250 pass 

251 

252 @abstractmethod 

253 def should_reconnect(self): 

254 """ 

255 Returns True if the connection should be reconnected. 

256 """ 

257 pass 

258 

259 @abstractmethod 

260 def reset_should_reconnect(self): 

261 """ 

262 Reset the internal flag to False. 

263 """ 

264 pass 

265 

266 

267class MaintNotificationsAbstractConnection: 

268 """ 

269 Abstract class for handling maintenance notifications logic. 

270 This class is expected to be used as base class together with ConnectionInterface. 

271 

272 This class is intended to be used with multiple inheritance! 

273 

274 All logic related to maintenance notifications is encapsulated in this class. 

275 """ 

276 

277 def __init__( 

278 self, 

279 maint_notifications_config: Optional[MaintNotificationsConfig], 

280 maint_notifications_pool_handler: Optional[ 

281 MaintNotificationsPoolHandler 

282 ] = None, 

283 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

284 maintenance_notification_hash: Optional[int] = None, 

285 orig_host_address: Optional[str] = None, 

286 orig_socket_timeout: Optional[float] = None, 

287 orig_socket_connect_timeout: Optional[float] = None, 

288 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, 

289 ): 

290 """ 

291 Initialize the maintenance notifications for the connection. 

292 

293 Args: 

294 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications. 

295 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications. 

296 maintenance_state (MaintenanceState): The current maintenance state of the connection. 

297 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection. 

298 orig_host_address (Optional[str]): The original host address of the connection. 

299 orig_socket_timeout (Optional[float]): The original socket timeout of the connection. 

300 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection. 

301 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications. 

302 If not provided, the parser from the connection is used. 

303 This is useful when the parser is created after this object. 

304 """ 

305 self.maint_notifications_config = maint_notifications_config 

306 self.maintenance_state = maintenance_state 

307 self.maintenance_notification_hash = maintenance_notification_hash 

308 self._configure_maintenance_notifications( 

309 maint_notifications_pool_handler, 

310 orig_host_address, 

311 orig_socket_timeout, 

312 orig_socket_connect_timeout, 

313 parser, 

314 ) 

315 

316 @abstractmethod 

317 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]: 

318 pass 

319 

320 @abstractmethod 

321 def _get_socket(self) -> Optional[socket.socket]: 

322 pass 

323 

324 @abstractmethod 

325 def get_protocol(self) -> Union[int, str]: 

326 """ 

327 Returns: 

328 The RESP protocol version, or ``None`` if the protocol is not specified, 

329 in which case the server default will be used. 

330 """ 

331 pass 

332 

333 @property 

334 @abstractmethod 

335 def host(self) -> str: 

336 pass 

337 

338 @host.setter 

339 @abstractmethod 

340 def host(self, value: str): 

341 pass 

342 

343 @property 

344 @abstractmethod 

345 def socket_timeout(self) -> Optional[Union[float, int]]: 

346 pass 

347 

348 @socket_timeout.setter 

349 @abstractmethod 

350 def socket_timeout(self, value: Optional[Union[float, int]]): 

351 pass 

352 

353 @property 

354 @abstractmethod 

355 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

356 pass 

357 

358 @socket_connect_timeout.setter 

359 @abstractmethod 

360 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

361 pass 

362 

363 @abstractmethod 

364 def send_command(self, *args, **kwargs): 

365 pass 

366 

367 @abstractmethod 

368 def read_response( 

369 self, 

370 disable_decoding=False, 

371 *, 

372 disconnect_on_error=True, 

373 push_request=False, 

374 ): 

375 pass 

376 

377 @abstractmethod 

378 def disconnect(self, *args): 

379 pass 

380 

381 def _configure_maintenance_notifications( 

382 self, 

383 maint_notifications_pool_handler: Optional[ 

384 MaintNotificationsPoolHandler 

385 ] = None, 

386 orig_host_address=None, 

387 orig_socket_timeout=None, 

388 orig_socket_connect_timeout=None, 

389 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, 

390 ): 

391 """ 

392 Enable maintenance notifications by setting up 

393 handlers and storing original connection parameters. 

394 

395 Should be used ONLY with parsers that support push notifications. 

396 """ 

397 if ( 

398 not self.maint_notifications_config 

399 or not self.maint_notifications_config.enabled 

400 ): 

401 self._maint_notifications_pool_handler = None 

402 self._maint_notifications_connection_handler = None 

403 return 

404 

405 if not parser: 

406 raise RedisError( 

407 "To configure maintenance notifications, a parser must be provided!" 

408 ) 

409 

410 if not isinstance(parser, _HiredisParser) and not isinstance( 

411 parser, _RESP3Parser 

412 ): 

413 raise RedisError( 

414 "Maintenance notifications are only supported with hiredis and RESP3 parsers!" 

415 ) 

416 

417 if maint_notifications_pool_handler: 

418 # Extract a reference to a new pool handler that copies all properties 

419 # of the original one and has a different connection reference 

420 # This is needed because when we attach the handler to the parser 

421 # we need to make sure that the handler has a reference to the 

422 # connection that the parser is attached to. 

423 self._maint_notifications_pool_handler = ( 

424 maint_notifications_pool_handler.get_handler_for_connection() 

425 ) 

426 self._maint_notifications_pool_handler.set_connection(self) 

427 else: 

428 self._maint_notifications_pool_handler = None 

429 

430 self._maint_notifications_connection_handler = ( 

431 MaintNotificationsConnectionHandler(self, self.maint_notifications_config) 

432 ) 

433 

434 # Set up pool handler if available 

435 if self._maint_notifications_pool_handler: 

436 parser.set_node_moving_push_handler( 

437 self._maint_notifications_pool_handler.handle_notification 

438 ) 

439 

440 # Set up connection handler 

441 parser.set_maintenance_push_handler( 

442 self._maint_notifications_connection_handler.handle_notification 

443 ) 

444 

445 # Store original connection parameters 

446 self.orig_host_address = orig_host_address if orig_host_address else self.host 

447 self.orig_socket_timeout = ( 

448 orig_socket_timeout if orig_socket_timeout else self.socket_timeout 

449 ) 

450 self.orig_socket_connect_timeout = ( 

451 orig_socket_connect_timeout 

452 if orig_socket_connect_timeout 

453 else self.socket_connect_timeout 

454 ) 

455 

456 def set_maint_notifications_pool_handler_for_connection( 

457 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

458 ): 

459 # Deep copy the pool handler to avoid sharing the same pool handler 

460 # between multiple connections, because otherwise each connection will override 

461 # the connection reference and the pool handler will only hold a reference 

462 # to the last connection that was set. 

463 maint_notifications_pool_handler_copy = ( 

464 maint_notifications_pool_handler.get_handler_for_connection() 

465 ) 

466 

467 maint_notifications_pool_handler_copy.set_connection(self) 

468 self._get_parser().set_node_moving_push_handler( 

469 maint_notifications_pool_handler_copy.handle_notification 

470 ) 

471 

472 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy 

473 

474 # Update maintenance notification connection handler if it doesn't exist 

475 if not self._maint_notifications_connection_handler: 

476 self._maint_notifications_connection_handler = ( 

477 MaintNotificationsConnectionHandler( 

478 self, maint_notifications_pool_handler.config 

479 ) 

480 ) 

481 self._get_parser().set_maintenance_push_handler( 

482 self._maint_notifications_connection_handler.handle_notification 

483 ) 

484 else: 

485 self._maint_notifications_connection_handler.config = ( 

486 maint_notifications_pool_handler.config 

487 ) 

488 

489 def activate_maint_notifications_handling_if_enabled(self, check_health=True): 

490 # Send maintenance notifications handshake if RESP3 is active 

491 # and maintenance notifications are enabled 

492 # and we have a host to determine the endpoint type from 

493 # When the maint_notifications_config enabled mode is "auto", 

494 # we just log a warning if the handshake fails 

495 # When the mode is enabled=True, we raise an exception in case of failure 

496 if ( 

497 self.get_protocol() not in [2, "2"] 

498 and self.maint_notifications_config 

499 and self.maint_notifications_config.enabled 

500 and self._maint_notifications_connection_handler 

501 and hasattr(self, "host") 

502 ): 

503 self._enable_maintenance_notifications( 

504 maint_notifications_config=self.maint_notifications_config, 

505 check_health=check_health, 

506 ) 

507 

508 def _enable_maintenance_notifications( 

509 self, maint_notifications_config: MaintNotificationsConfig, check_health=True 

510 ): 

511 try: 

512 host = getattr(self, "host", None) 

513 if host is None: 

514 raise ValueError( 

515 "Cannot enable maintenance notifications for connection" 

516 " object that doesn't have a host attribute." 

517 ) 

518 else: 

519 endpoint_type = maint_notifications_config.get_endpoint_type(host, self) 

520 self.send_command( 

521 "CLIENT", 

522 "MAINT_NOTIFICATIONS", 

523 "ON", 

524 "moving-endpoint-type", 

525 endpoint_type.value, 

526 check_health=check_health, 

527 ) 

528 response = self.read_response() 

529 if not response or str_if_bytes(response) != "OK": 

530 raise ResponseError( 

531 "The server doesn't support maintenance notifications" 

532 ) 

533 except Exception as e: 

534 if ( 

535 isinstance(e, ResponseError) 

536 and maint_notifications_config.enabled == "auto" 

537 ): 

538 # Log warning but don't fail the connection 

539 import logging 

540 

541 logger = logging.getLogger(__name__) 

542 logger.warning(f"Failed to enable maintenance notifications: {e}") 

543 else: 

544 raise 

545 

546 def get_resolved_ip(self) -> Optional[str]: 

547 """ 

548 Extract the resolved IP address from an 

549 established connection or resolve it from the host. 

550 

551 First tries to get the actual IP from the socket (most accurate), 

552 then falls back to DNS resolution if needed. 

553 

554 Args: 

555 connection: The connection object to extract the IP from 

556 

557 Returns: 

558 str: The resolved IP address, or None if it cannot be determined 

559 """ 

560 

561 # Method 1: Try to get the actual IP from the established socket connection 

562 # This is most accurate as it shows the exact IP being used 

563 try: 

564 conn_socket = self._get_socket() 

565 if conn_socket is not None: 

566 peer_addr = conn_socket.getpeername() 

567 if peer_addr and len(peer_addr) >= 1: 

568 # For TCP sockets, peer_addr is typically (host, port) tuple 

569 # Return just the host part 

570 return peer_addr[0] 

571 except (AttributeError, OSError): 

572 # Socket might not be connected or getpeername() might fail 

573 pass 

574 

575 # Method 2: Fallback to DNS resolution of the host 

576 # This is less accurate but works when socket is not available 

577 try: 

578 host = getattr(self, "host", "localhost") 

579 port = getattr(self, "port", 6379) 

580 if host: 

581 # Use getaddrinfo to resolve the hostname to IP 

582 # This mimics what the connection would do during _connect() 

583 addr_info = socket.getaddrinfo( 

584 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM 

585 ) 

586 if addr_info: 

587 # Return the IP from the first result 

588 # addr_info[0] is (family, socktype, proto, canonname, sockaddr) 

589 # sockaddr[0] is the IP address 

590 return str(addr_info[0][4][0]) 

591 except (AttributeError, OSError, socket.gaierror): 

592 # DNS resolution might fail 

593 pass 

594 

595 return None 

596 

597 @property 

598 def maintenance_state(self) -> MaintenanceState: 

599 return self._maintenance_state 

600 

601 @maintenance_state.setter 

602 def maintenance_state(self, state: "MaintenanceState"): 

603 self._maintenance_state = state 

604 

605 def getpeername(self): 

606 """ 

607 Returns the peer name of the connection. 

608 """ 

609 conn_socket = self._get_socket() 

610 if conn_socket: 

611 return conn_socket.getpeername()[0] 

612 return None 

613 

614 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

615 conn_socket = self._get_socket() 

616 if conn_socket: 

617 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout 

618 conn_socket.settimeout(timeout) 

619 self.update_parser_timeout(timeout) 

620 

621 def update_parser_timeout(self, timeout: Optional[float] = None): 

622 parser = self._get_parser() 

623 if parser and parser._buffer: 

624 if isinstance(parser, _RESP3Parser) and timeout: 

625 parser._buffer.socket_timeout = timeout 

626 elif isinstance(parser, _HiredisParser): 

627 parser._socket_timeout = timeout 

628 

629 def set_tmp_settings( 

630 self, 

631 tmp_host_address: Optional[Union[str, object]] = SENTINEL, 

632 tmp_relaxed_timeout: Optional[float] = None, 

633 ): 

634 """ 

635 The value of SENTINEL is used to indicate that the property should not be updated. 

636 """ 

637 if tmp_host_address and tmp_host_address != SENTINEL: 

638 self.host = str(tmp_host_address) 

639 if tmp_relaxed_timeout != -1: 

640 self.socket_timeout = tmp_relaxed_timeout 

641 self.socket_connect_timeout = tmp_relaxed_timeout 

642 

643 def reset_tmp_settings( 

644 self, 

645 reset_host_address: bool = False, 

646 reset_relaxed_timeout: bool = False, 

647 ): 

648 if reset_host_address: 

649 self.host = self.orig_host_address 

650 if reset_relaxed_timeout: 

651 self.socket_timeout = self.orig_socket_timeout 

652 self.socket_connect_timeout = self.orig_socket_connect_timeout 

653 

654 

655class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface): 

656 "Manages communication to and from a Redis server" 

657 

658 def __init__( 

659 self, 

660 db: int = 0, 

661 password: Optional[str] = None, 

662 socket_timeout: Optional[float] = None, 

663 socket_connect_timeout: Optional[float] = None, 

664 retry_on_timeout: bool = False, 

665 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL, 

666 encoding: str = "utf-8", 

667 encoding_errors: str = "strict", 

668 decode_responses: bool = False, 

669 parser_class=DefaultParser, 

670 socket_read_size: int = 65536, 

671 health_check_interval: int = 0, 

672 client_name: Optional[str] = None, 

673 lib_name: Optional[str] = "redis-py", 

674 lib_version: Optional[str] = get_lib_version(), 

675 username: Optional[str] = None, 

676 retry: Union[Any, None] = None, 

677 redis_connect_func: Optional[Callable[[], None]] = None, 

678 credential_provider: Optional[CredentialProvider] = None, 

679 protocol: Optional[int] = 2, 

680 command_packer: Optional[Callable[[], None]] = None, 

681 event_dispatcher: Optional[EventDispatcher] = None, 

682 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

683 maint_notifications_pool_handler: Optional[ 

684 MaintNotificationsPoolHandler 

685 ] = None, 

686 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

687 maintenance_notification_hash: Optional[int] = None, 

688 orig_host_address: Optional[str] = None, 

689 orig_socket_timeout: Optional[float] = None, 

690 orig_socket_connect_timeout: Optional[float] = None, 

691 ): 

692 """ 

693 Initialize a new Connection. 

694 To specify a retry policy for specific errors, first set 

695 `retry_on_error` to a list of the error/s to retry on, then set 

696 `retry` to a valid `Retry` object. 

697 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

698 """ 

699 if (username or password) and credential_provider is not None: 

700 raise DataError( 

701 "'username' and 'password' cannot be passed along with 'credential_" 

702 "provider'. Please provide only one of the following arguments: \n" 

703 "1. 'password' and (optional) 'username'\n" 

704 "2. 'credential_provider'" 

705 ) 

706 if event_dispatcher is None: 

707 self._event_dispatcher = EventDispatcher() 

708 else: 

709 self._event_dispatcher = event_dispatcher 

710 self.pid = os.getpid() 

711 self.db = db 

712 self.client_name = client_name 

713 self.lib_name = lib_name 

714 self.lib_version = lib_version 

715 self.credential_provider = credential_provider 

716 self.password = password 

717 self.username = username 

718 self._socket_timeout = socket_timeout 

719 if socket_connect_timeout is None: 

720 socket_connect_timeout = socket_timeout 

721 self._socket_connect_timeout = socket_connect_timeout 

722 self.retry_on_timeout = retry_on_timeout 

723 if retry_on_error is SENTINEL: 

724 retry_on_errors_list = [] 

725 else: 

726 retry_on_errors_list = list(retry_on_error) 

727 if retry_on_timeout: 

728 # Add TimeoutError to the errors list to retry on 

729 retry_on_errors_list.append(TimeoutError) 

730 self.retry_on_error = retry_on_errors_list 

731 if retry or self.retry_on_error: 

732 if retry is None: 

733 self.retry = Retry(NoBackoff(), 1) 

734 else: 

735 # deep-copy the Retry object as it is mutable 

736 self.retry = copy.deepcopy(retry) 

737 if self.retry_on_error: 

738 # Update the retry's supported errors with the specified errors 

739 self.retry.update_supported_errors(self.retry_on_error) 

740 else: 

741 self.retry = Retry(NoBackoff(), 0) 

742 self.health_check_interval = health_check_interval 

743 self.next_health_check = 0 

744 self.redis_connect_func = redis_connect_func 

745 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

746 self.handshake_metadata = None 

747 self._sock = None 

748 self._socket_read_size = socket_read_size 

749 self._connect_callbacks = [] 

750 self._buffer_cutoff = 6000 

751 self._re_auth_token: Optional[TokenInterface] = None 

752 try: 

753 p = int(protocol) 

754 except TypeError: 

755 p = DEFAULT_RESP_VERSION 

756 except ValueError: 

757 raise ConnectionError("protocol must be an integer") 

758 finally: 

759 if p < 2 or p > 3: 

760 raise ConnectionError("protocol must be either 2 or 3") 

761 # p = DEFAULT_RESP_VERSION 

762 self.protocol = p 

763 if self.protocol == 3 and parser_class == _RESP2Parser: 

764 # If the protocol is 3 but the parser is RESP2, change it to RESP3 

765 # This is needed because the parser might be set before the protocol 

766 # or might be provided as a kwarg to the constructor 

767 # We need to react on discrepancy only for RESP2 and RESP3 

768 # as hiredis supports both 

769 parser_class = _RESP3Parser 

770 self.set_parser(parser_class) 

771 

772 self._command_packer = self._construct_command_packer(command_packer) 

773 self._should_reconnect = False 

774 

775 # Set up maintenance notifications 

776 MaintNotificationsAbstractConnection.__init__( 

777 self, 

778 maint_notifications_config, 

779 maint_notifications_pool_handler, 

780 maintenance_state, 

781 maintenance_notification_hash, 

782 orig_host_address, 

783 orig_socket_timeout, 

784 orig_socket_connect_timeout, 

785 self._parser, 

786 ) 

787 

788 def __repr__(self): 

789 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) 

790 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

791 

792 @abstractmethod 

793 def repr_pieces(self): 

794 pass 

795 

796 def __del__(self): 

797 try: 

798 self.disconnect() 

799 except Exception: 

800 pass 

801 

802 def _construct_command_packer(self, packer): 

803 if packer is not None: 

804 return packer 

805 elif HIREDIS_AVAILABLE: 

806 return HiredisRespSerializer() 

807 else: 

808 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) 

809 

810 def register_connect_callback(self, callback): 

811 """ 

812 Register a callback to be called when the connection is established either 

813 initially or reconnected. This allows listeners to issue commands that 

814 are ephemeral to the connection, for example pub/sub subscription or 

815 key tracking. The callback must be a _method_ and will be kept as 

816 a weak reference. 

817 """ 

818 wm = weakref.WeakMethod(callback) 

819 if wm not in self._connect_callbacks: 

820 self._connect_callbacks.append(wm) 

821 

822 def deregister_connect_callback(self, callback): 

823 """ 

824 De-register a previously registered callback. It will no-longer receive 

825 notifications on connection events. Calling this is not required when the 

826 listener goes away, since the callbacks are kept as weak methods. 

827 """ 

828 try: 

829 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

830 except ValueError: 

831 pass 

832 

833 def set_parser(self, parser_class): 

834 """ 

835 Creates a new instance of parser_class with socket size: 

836 _socket_read_size and assigns it to the parser for the connection 

837 :param parser_class: The required parser class 

838 """ 

839 self._parser = parser_class(socket_read_size=self._socket_read_size) 

840 

841 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]: 

842 return self._parser 

843 

844 def connect(self): 

845 "Connects to the Redis server if not already connected" 

846 self.connect_check_health(check_health=True) 

847 

848 def connect_check_health( 

849 self, check_health: bool = True, retry_socket_connect: bool = True 

850 ): 

851 if self._sock: 

852 return 

853 try: 

854 if retry_socket_connect: 

855 sock = self.retry.call_with_retry( 

856 lambda: self._connect(), lambda error: self.disconnect(error) 

857 ) 

858 else: 

859 sock = self._connect() 

860 except socket.timeout: 

861 raise TimeoutError("Timeout connecting to server") 

862 except OSError as e: 

863 raise ConnectionError(self._error_message(e)) 

864 

865 self._sock = sock 

866 try: 

867 if self.redis_connect_func is None: 

868 # Use the default on_connect function 

869 self.on_connect_check_health(check_health=check_health) 

870 else: 

871 # Use the passed function redis_connect_func 

872 self.redis_connect_func(self) 

873 except RedisError: 

874 # clean up after any error in on_connect 

875 self.disconnect() 

876 raise 

877 

878 # run any user callbacks. right now the only internal callback 

879 # is for pubsub channel/pattern resubscription 

880 # first, remove any dead weakrefs 

881 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

882 for ref in self._connect_callbacks: 

883 callback = ref() 

884 if callback: 

885 callback(self) 

886 

887 @abstractmethod 

888 def _connect(self): 

889 pass 

890 

891 @abstractmethod 

892 def _host_error(self): 

893 pass 

894 

895 def _error_message(self, exception): 

896 return format_error_message(self._host_error(), exception) 

897 

898 def on_connect(self): 

899 self.on_connect_check_health(check_health=True) 

900 

901 def on_connect_check_health(self, check_health: bool = True): 

902 "Initialize the connection, authenticate and select a database" 

903 self._parser.on_connect(self) 

904 parser = self._parser 

905 

906 auth_args = None 

907 # if credential provider or username and/or password are set, authenticate 

908 if self.credential_provider or (self.username or self.password): 

909 cred_provider = ( 

910 self.credential_provider 

911 or UsernamePasswordCredentialProvider(self.username, self.password) 

912 ) 

913 auth_args = cred_provider.get_credentials() 

914 

915 # if resp version is specified and we have auth args, 

916 # we need to send them via HELLO 

917 if auth_args and self.protocol not in [2, "2"]: 

918 if isinstance(self._parser, _RESP2Parser): 

919 self.set_parser(_RESP3Parser) 

920 # update cluster exception classes 

921 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

922 self._parser.on_connect(self) 

923 if len(auth_args) == 1: 

924 auth_args = ["default", auth_args[0]] 

925 # avoid checking health here -- PING will fail if we try 

926 # to check the health prior to the AUTH 

927 self.send_command( 

928 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

929 ) 

930 self.handshake_metadata = self.read_response() 

931 # if response.get(b"proto") != self.protocol and response.get( 

932 # "proto" 

933 # ) != self.protocol: 

934 # raise ConnectionError("Invalid RESP version") 

935 elif auth_args: 

936 # avoid checking health here -- PING will fail if we try 

937 # to check the health prior to the AUTH 

938 self.send_command("AUTH", *auth_args, check_health=False) 

939 

940 try: 

941 auth_response = self.read_response() 

942 except AuthenticationWrongNumberOfArgsError: 

943 # a username and password were specified but the Redis 

944 # server seems to be < 6.0.0 which expects a single password 

945 # arg. retry auth with just the password. 

946 # https://github.com/andymccurdy/redis-py/issues/1274 

947 self.send_command("AUTH", auth_args[-1], check_health=False) 

948 auth_response = self.read_response() 

949 

950 if str_if_bytes(auth_response) != "OK": 

951 raise AuthenticationError("Invalid Username or Password") 

952 

953 # if resp version is specified, switch to it 

954 elif self.protocol not in [2, "2"]: 

955 if isinstance(self._parser, _RESP2Parser): 

956 self.set_parser(_RESP3Parser) 

957 # update cluster exception classes 

958 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

959 self._parser.on_connect(self) 

960 self.send_command("HELLO", self.protocol, check_health=check_health) 

961 self.handshake_metadata = self.read_response() 

962 if ( 

963 self.handshake_metadata.get(b"proto") != self.protocol 

964 and self.handshake_metadata.get("proto") != self.protocol 

965 ): 

966 raise ConnectionError("Invalid RESP version") 

967 

968 # Activate maintenance notifications for this connection 

969 # if enabled in the configuration 

970 # This is a no-op if maintenance notifications are not enabled 

971 self.activate_maint_notifications_handling_if_enabled(check_health=check_health) 

972 

973 # if a client_name is given, set it 

974 if self.client_name: 

975 self.send_command( 

976 "CLIENT", 

977 "SETNAME", 

978 self.client_name, 

979 check_health=check_health, 

980 ) 

981 if str_if_bytes(self.read_response()) != "OK": 

982 raise ConnectionError("Error setting client name") 

983 

984 try: 

985 # set the library name and version 

986 if self.lib_name: 

987 self.send_command( 

988 "CLIENT", 

989 "SETINFO", 

990 "LIB-NAME", 

991 self.lib_name, 

992 check_health=check_health, 

993 ) 

994 self.read_response() 

995 except ResponseError: 

996 pass 

997 

998 try: 

999 if self.lib_version: 

1000 self.send_command( 

1001 "CLIENT", 

1002 "SETINFO", 

1003 "LIB-VER", 

1004 self.lib_version, 

1005 check_health=check_health, 

1006 ) 

1007 self.read_response() 

1008 except ResponseError: 

1009 pass 

1010 

1011 # if a database is specified, switch to it 

1012 if self.db: 

1013 self.send_command("SELECT", self.db, check_health=check_health) 

1014 if str_if_bytes(self.read_response()) != "OK": 

1015 raise ConnectionError("Invalid Database") 

1016 

1017 def disconnect(self, *args): 

1018 "Disconnects from the Redis server" 

1019 self._parser.on_disconnect() 

1020 

1021 conn_sock = self._sock 

1022 self._sock = None 

1023 # reset the reconnect flag 

1024 self.reset_should_reconnect() 

1025 if conn_sock is None: 

1026 return 

1027 

1028 if os.getpid() == self.pid: 

1029 try: 

1030 conn_sock.shutdown(socket.SHUT_RDWR) 

1031 except (OSError, TypeError): 

1032 pass 

1033 

1034 try: 

1035 conn_sock.close() 

1036 except OSError: 

1037 pass 

1038 

1039 def mark_for_reconnect(self): 

1040 self._should_reconnect = True 

1041 

1042 def should_reconnect(self): 

1043 return self._should_reconnect 

1044 

1045 def reset_should_reconnect(self): 

1046 self._should_reconnect = False 

1047 

1048 def _send_ping(self): 

1049 """Send PING, expect PONG in return""" 

1050 self.send_command("PING", check_health=False) 

1051 if str_if_bytes(self.read_response()) != "PONG": 

1052 raise ConnectionError("Bad response from PING health check") 

1053 

1054 def _ping_failed(self, error): 

1055 """Function to call when PING fails""" 

1056 self.disconnect() 

1057 

1058 def check_health(self): 

1059 """Check the health of the connection with a PING/PONG""" 

1060 if self.health_check_interval and time.monotonic() > self.next_health_check: 

1061 self.retry.call_with_retry(self._send_ping, self._ping_failed) 

1062 

1063 def send_packed_command(self, command, check_health=True): 

1064 """Send an already packed command to the Redis server""" 

1065 if not self._sock: 

1066 self.connect_check_health(check_health=False) 

1067 # guard against health check recursion 

1068 if check_health: 

1069 self.check_health() 

1070 try: 

1071 if isinstance(command, str): 

1072 command = [command] 

1073 for item in command: 

1074 self._sock.sendall(item) 

1075 except socket.timeout: 

1076 self.disconnect() 

1077 raise TimeoutError("Timeout writing to socket") 

1078 except OSError as e: 

1079 self.disconnect() 

1080 if len(e.args) == 1: 

1081 errno, errmsg = "UNKNOWN", e.args[0] 

1082 else: 

1083 errno = e.args[0] 

1084 errmsg = e.args[1] 

1085 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") 

1086 except BaseException: 

1087 # BaseExceptions can be raised when a socket send operation is not 

1088 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

1089 # to send un-sent data. However, the send_packed_command() API 

1090 # does not support it so there is no point in keeping the connection open. 

1091 self.disconnect() 

1092 raise 

1093 

1094 def send_command(self, *args, **kwargs): 

1095 """Pack and send a command to the Redis server""" 

1096 self.send_packed_command( 

1097 self._command_packer.pack(*args), 

1098 check_health=kwargs.get("check_health", True), 

1099 ) 

1100 

1101 def can_read(self, timeout=0): 

1102 """Poll the socket to see if there's data that can be read.""" 

1103 sock = self._sock 

1104 if not sock: 

1105 self.connect() 

1106 

1107 host_error = self._host_error() 

1108 

1109 try: 

1110 return self._parser.can_read(timeout) 

1111 

1112 except OSError as e: 

1113 self.disconnect() 

1114 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

1115 

1116 def read_response( 

1117 self, 

1118 disable_decoding=False, 

1119 *, 

1120 disconnect_on_error=True, 

1121 push_request=False, 

1122 ): 

1123 """Read the response from a previously sent command""" 

1124 

1125 host_error = self._host_error() 

1126 

1127 try: 

1128 if self.protocol in ["3", 3]: 

1129 response = self._parser.read_response( 

1130 disable_decoding=disable_decoding, push_request=push_request 

1131 ) 

1132 else: 

1133 response = self._parser.read_response(disable_decoding=disable_decoding) 

1134 except socket.timeout: 

1135 if disconnect_on_error: 

1136 self.disconnect() 

1137 raise TimeoutError(f"Timeout reading from {host_error}") 

1138 except OSError as e: 

1139 if disconnect_on_error: 

1140 self.disconnect() 

1141 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

1142 except BaseException: 

1143 # Also by default close in case of BaseException. A lot of code 

1144 # relies on this behaviour when doing Command/Response pairs. 

1145 # See #1128. 

1146 if disconnect_on_error: 

1147 self.disconnect() 

1148 raise 

1149 

1150 if self.health_check_interval: 

1151 self.next_health_check = time.monotonic() + self.health_check_interval 

1152 

1153 if isinstance(response, ResponseError): 

1154 try: 

1155 raise response 

1156 finally: 

1157 del response # avoid creating ref cycles 

1158 return response 

1159 

1160 def pack_command(self, *args): 

1161 """Pack a series of arguments into the Redis protocol""" 

1162 return self._command_packer.pack(*args) 

1163 

1164 def pack_commands(self, commands): 

1165 """Pack multiple commands into the Redis protocol""" 

1166 output = [] 

1167 pieces = [] 

1168 buffer_length = 0 

1169 buffer_cutoff = self._buffer_cutoff 

1170 

1171 for cmd in commands: 

1172 for chunk in self._command_packer.pack(*cmd): 

1173 chunklen = len(chunk) 

1174 if ( 

1175 buffer_length > buffer_cutoff 

1176 or chunklen > buffer_cutoff 

1177 or isinstance(chunk, memoryview) 

1178 ): 

1179 if pieces: 

1180 output.append(SYM_EMPTY.join(pieces)) 

1181 buffer_length = 0 

1182 pieces = [] 

1183 

1184 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

1185 output.append(chunk) 

1186 else: 

1187 pieces.append(chunk) 

1188 buffer_length += chunklen 

1189 

1190 if pieces: 

1191 output.append(SYM_EMPTY.join(pieces)) 

1192 return output 

1193 

1194 def get_protocol(self) -> Union[int, str]: 

1195 return self.protocol 

1196 

1197 @property 

1198 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1199 return self._handshake_metadata 

1200 

1201 @handshake_metadata.setter 

1202 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): 

1203 self._handshake_metadata = value 

1204 

1205 def set_re_auth_token(self, token: TokenInterface): 

1206 self._re_auth_token = token 

1207 

1208 def re_auth(self): 

1209 if self._re_auth_token is not None: 

1210 self.send_command( 

1211 "AUTH", 

1212 self._re_auth_token.try_get("oid"), 

1213 self._re_auth_token.get_value(), 

1214 ) 

1215 self.read_response() 

1216 self._re_auth_token = None 

1217 

1218 def _get_socket(self) -> Optional[socket.socket]: 

1219 return self._sock 

1220 

1221 @property 

1222 def socket_timeout(self) -> Optional[Union[float, int]]: 

1223 return self._socket_timeout 

1224 

1225 @socket_timeout.setter 

1226 def socket_timeout(self, value: Optional[Union[float, int]]): 

1227 self._socket_timeout = value 

1228 

1229 @property 

1230 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

1231 return self._socket_connect_timeout 

1232 

1233 @socket_connect_timeout.setter 

1234 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

1235 self._socket_connect_timeout = value 

1236 

1237 

1238class Connection(AbstractConnection): 

1239 "Manages TCP communication to and from a Redis server" 

1240 

1241 def __init__( 

1242 self, 

1243 host="localhost", 

1244 port=6379, 

1245 socket_keepalive=False, 

1246 socket_keepalive_options=None, 

1247 socket_type=0, 

1248 **kwargs, 

1249 ): 

1250 self._host = host 

1251 self.port = int(port) 

1252 self.socket_keepalive = socket_keepalive 

1253 self.socket_keepalive_options = socket_keepalive_options or {} 

1254 self.socket_type = socket_type 

1255 super().__init__(**kwargs) 

1256 

1257 def repr_pieces(self): 

1258 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

1259 if self.client_name: 

1260 pieces.append(("client_name", self.client_name)) 

1261 return pieces 

1262 

1263 def _connect(self): 

1264 "Create a TCP socket connection" 

1265 # we want to mimic what socket.create_connection does to support 

1266 # ipv4/ipv6, but we want to set options prior to calling 

1267 # socket.connect() 

1268 err = None 

1269 

1270 for res in socket.getaddrinfo( 

1271 self.host, self.port, self.socket_type, socket.SOCK_STREAM 

1272 ): 

1273 family, socktype, proto, canonname, socket_address = res 

1274 sock = None 

1275 try: 

1276 sock = socket.socket(family, socktype, proto) 

1277 # TCP_NODELAY 

1278 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

1279 

1280 # TCP_KEEPALIVE 

1281 if self.socket_keepalive: 

1282 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

1283 for k, v in self.socket_keepalive_options.items(): 

1284 sock.setsockopt(socket.IPPROTO_TCP, k, v) 

1285 

1286 # set the socket_connect_timeout before we connect 

1287 sock.settimeout(self.socket_connect_timeout) 

1288 

1289 # connect 

1290 sock.connect(socket_address) 

1291 

1292 # set the socket_timeout now that we're connected 

1293 sock.settimeout(self.socket_timeout) 

1294 return sock 

1295 

1296 except OSError as _: 

1297 err = _ 

1298 if sock is not None: 

1299 try: 

1300 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1301 except OSError: 

1302 pass 

1303 sock.close() 

1304 

1305 if err is not None: 

1306 raise err 

1307 raise OSError("socket.getaddrinfo returned an empty list") 

1308 

1309 def _host_error(self): 

1310 return f"{self.host}:{self.port}" 

1311 

1312 @property 

1313 def host(self) -> str: 

1314 return self._host 

1315 

1316 @host.setter 

1317 def host(self, value: str): 

1318 self._host = value 

1319 

1320 

1321class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface): 

1322 DUMMY_CACHE_VALUE = b"foo" 

1323 MIN_ALLOWED_VERSION = "7.4.0" 

1324 DEFAULT_SERVER_NAME = "redis" 

1325 

1326 def __init__( 

1327 self, 

1328 conn: ConnectionInterface, 

1329 cache: CacheInterface, 

1330 pool_lock: threading.RLock, 

1331 ): 

1332 self.pid = os.getpid() 

1333 self._conn = conn 

1334 self.retry = self._conn.retry 

1335 self.host = self._conn.host 

1336 self.port = self._conn.port 

1337 self.credential_provider = conn.credential_provider 

1338 self._pool_lock = pool_lock 

1339 self._cache = cache 

1340 self._cache_lock = threading.RLock() 

1341 self._current_command_cache_key = None 

1342 self._current_options = None 

1343 self.register_connect_callback(self._enable_tracking_callback) 

1344 

1345 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1346 MaintNotificationsAbstractConnection.__init__( 

1347 self, 

1348 self._conn.maint_notifications_config, 

1349 self._conn._maint_notifications_pool_handler, 

1350 self._conn.maintenance_state, 

1351 self._conn.maintenance_notification_hash, 

1352 self._conn.host, 

1353 self._conn.socket_timeout, 

1354 self._conn.socket_connect_timeout, 

1355 self._conn._get_parser(), 

1356 ) 

1357 

1358 def repr_pieces(self): 

1359 return self._conn.repr_pieces() 

1360 

1361 def register_connect_callback(self, callback): 

1362 self._conn.register_connect_callback(callback) 

1363 

1364 def deregister_connect_callback(self, callback): 

1365 self._conn.deregister_connect_callback(callback) 

1366 

1367 def set_parser(self, parser_class): 

1368 self._conn.set_parser(parser_class) 

1369 

1370 def set_maint_notifications_pool_handler_for_connection( 

1371 self, maint_notifications_pool_handler 

1372 ): 

1373 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1374 self._conn.set_maint_notifications_pool_handler_for_connection( 

1375 maint_notifications_pool_handler 

1376 ) 

1377 

1378 def get_protocol(self): 

1379 return self._conn.get_protocol() 

1380 

1381 def connect(self): 

1382 self._conn.connect() 

1383 

1384 server_name = self._conn.handshake_metadata.get(b"server", None) 

1385 if server_name is None: 

1386 server_name = self._conn.handshake_metadata.get("server", None) 

1387 server_ver = self._conn.handshake_metadata.get(b"version", None) 

1388 if server_ver is None: 

1389 server_ver = self._conn.handshake_metadata.get("version", None) 

1390 if server_ver is None or server_ver is None: 

1391 raise ConnectionError("Cannot retrieve information about server version") 

1392 

1393 server_ver = ensure_string(server_ver) 

1394 server_name = ensure_string(server_name) 

1395 

1396 if ( 

1397 server_name != self.DEFAULT_SERVER_NAME 

1398 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 

1399 ): 

1400 raise ConnectionError( 

1401 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 

1402 ) 

1403 

1404 def on_connect(self): 

1405 self._conn.on_connect() 

1406 

1407 def disconnect(self, *args): 

1408 with self._cache_lock: 

1409 self._cache.flush() 

1410 self._conn.disconnect(*args) 

1411 

1412 def check_health(self): 

1413 self._conn.check_health() 

1414 

1415 def send_packed_command(self, command, check_health=True): 

1416 # TODO: Investigate if it's possible to unpack command 

1417 # or extract keys from packed command 

1418 self._conn.send_packed_command(command) 

1419 

1420 def send_command(self, *args, **kwargs): 

1421 self._process_pending_invalidations() 

1422 

1423 with self._cache_lock: 

1424 # Command is write command or not allowed 

1425 # to be cached. 

1426 if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())): 

1427 self._current_command_cache_key = None 

1428 self._conn.send_command(*args, **kwargs) 

1429 return 

1430 

1431 if kwargs.get("keys") is None: 

1432 raise ValueError("Cannot create cache key.") 

1433 

1434 # Creates cache key. 

1435 self._current_command_cache_key = CacheKey( 

1436 command=args[0], redis_keys=tuple(kwargs.get("keys")) 

1437 ) 

1438 

1439 with self._cache_lock: 

1440 # We have to trigger invalidation processing in case if 

1441 # it was cached by another connection to avoid 

1442 # queueing invalidations in stale connections. 

1443 if self._cache.get(self._current_command_cache_key): 

1444 entry = self._cache.get(self._current_command_cache_key) 

1445 

1446 if entry.connection_ref != self._conn: 

1447 with self._pool_lock: 

1448 while entry.connection_ref.can_read(): 

1449 entry.connection_ref.read_response(push_request=True) 

1450 

1451 return 

1452 

1453 # Set temporary entry value to prevent 

1454 # race condition from another connection. 

1455 self._cache.set( 

1456 CacheEntry( 

1457 cache_key=self._current_command_cache_key, 

1458 cache_value=self.DUMMY_CACHE_VALUE, 

1459 status=CacheEntryStatus.IN_PROGRESS, 

1460 connection_ref=self._conn, 

1461 ) 

1462 ) 

1463 

1464 # Send command over socket only if it's allowed 

1465 # read-only command that not yet cached. 

1466 self._conn.send_command(*args, **kwargs) 

1467 

1468 def can_read(self, timeout=0): 

1469 return self._conn.can_read(timeout) 

1470 

1471 def read_response( 

1472 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False 

1473 ): 

1474 with self._cache_lock: 

1475 # Check if command response exists in a cache and it's not in progress. 

1476 if ( 

1477 self._current_command_cache_key is not None 

1478 and self._cache.get(self._current_command_cache_key) is not None 

1479 and self._cache.get(self._current_command_cache_key).status 

1480 != CacheEntryStatus.IN_PROGRESS 

1481 ): 

1482 res = copy.deepcopy( 

1483 self._cache.get(self._current_command_cache_key).cache_value 

1484 ) 

1485 self._current_command_cache_key = None 

1486 return res 

1487 

1488 response = self._conn.read_response( 

1489 disable_decoding=disable_decoding, 

1490 disconnect_on_error=disconnect_on_error, 

1491 push_request=push_request, 

1492 ) 

1493 

1494 with self._cache_lock: 

1495 # Prevent not-allowed command from caching. 

1496 if self._current_command_cache_key is None: 

1497 return response 

1498 # If response is None prevent from caching. 

1499 if response is None: 

1500 self._cache.delete_by_cache_keys([self._current_command_cache_key]) 

1501 return response 

1502 

1503 cache_entry = self._cache.get(self._current_command_cache_key) 

1504 

1505 # Cache only responses that still valid 

1506 # and wasn't invalidated by another connection in meantime. 

1507 if cache_entry is not None: 

1508 cache_entry.status = CacheEntryStatus.VALID 

1509 cache_entry.cache_value = response 

1510 self._cache.set(cache_entry) 

1511 

1512 self._current_command_cache_key = None 

1513 

1514 return response 

1515 

1516 def pack_command(self, *args): 

1517 return self._conn.pack_command(*args) 

1518 

1519 def pack_commands(self, commands): 

1520 return self._conn.pack_commands(commands) 

1521 

1522 @property 

1523 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1524 return self._conn.handshake_metadata 

1525 

1526 def set_re_auth_token(self, token: TokenInterface): 

1527 self._conn.set_re_auth_token(token) 

1528 

1529 def re_auth(self): 

1530 self._conn.re_auth() 

1531 

1532 def mark_for_reconnect(self): 

1533 self._conn.mark_for_reconnect() 

1534 

1535 def should_reconnect(self): 

1536 return self._conn.should_reconnect() 

1537 

1538 def reset_should_reconnect(self): 

1539 self._conn.reset_should_reconnect() 

1540 

1541 @property 

1542 def host(self) -> str: 

1543 return self._conn.host 

1544 

1545 @host.setter 

1546 def host(self, value: str): 

1547 self._conn.host = value 

1548 

1549 @property 

1550 def socket_timeout(self) -> Optional[Union[float, int]]: 

1551 return self._conn.socket_timeout 

1552 

1553 @socket_timeout.setter 

1554 def socket_timeout(self, value: Optional[Union[float, int]]): 

1555 self._conn.socket_timeout = value 

1556 

1557 @property 

1558 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

1559 return self._conn.socket_connect_timeout 

1560 

1561 @socket_connect_timeout.setter 

1562 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

1563 self._conn.socket_connect_timeout = value 

1564 

1565 def _get_socket(self) -> Optional[socket.socket]: 

1566 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1567 return self._conn._get_socket() 

1568 else: 

1569 raise NotImplementedError( 

1570 "Maintenance notifications are not supported by this connection type" 

1571 ) 

1572 

1573 def _get_maint_notifications_connection_instance( 

1574 self, connection 

1575 ) -> MaintNotificationsAbstractConnection: 

1576 """ 

1577 Validate that connection instance supports maintenance notifications. 

1578 With this helper method we ensure that we are working 

1579 with the correct connection type. 

1580 After twe validate that connection instance supports maintenance notifications 

1581 we can safely return the connection instance 

1582 as MaintNotificationsAbstractConnection. 

1583 """ 

1584 if not isinstance(connection, MaintNotificationsAbstractConnection): 

1585 raise NotImplementedError( 

1586 "Maintenance notifications are not supported by this connection type" 

1587 ) 

1588 else: 

1589 return connection 

1590 

1591 @property 

1592 def maintenance_state(self) -> MaintenanceState: 

1593 con = self._get_maint_notifications_connection_instance(self._conn) 

1594 return con.maintenance_state 

1595 

1596 @maintenance_state.setter 

1597 def maintenance_state(self, state: MaintenanceState): 

1598 con = self._get_maint_notifications_connection_instance(self._conn) 

1599 con.maintenance_state = state 

1600 

1601 def getpeername(self): 

1602 con = self._get_maint_notifications_connection_instance(self._conn) 

1603 return con.getpeername() 

1604 

1605 def get_resolved_ip(self): 

1606 con = self._get_maint_notifications_connection_instance(self._conn) 

1607 return con.get_resolved_ip() 

1608 

1609 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

1610 con = self._get_maint_notifications_connection_instance(self._conn) 

1611 con.update_current_socket_timeout(relaxed_timeout) 

1612 

1613 def set_tmp_settings( 

1614 self, 

1615 tmp_host_address: Optional[str] = None, 

1616 tmp_relaxed_timeout: Optional[float] = None, 

1617 ): 

1618 con = self._get_maint_notifications_connection_instance(self._conn) 

1619 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout) 

1620 

1621 def reset_tmp_settings( 

1622 self, 

1623 reset_host_address: bool = False, 

1624 reset_relaxed_timeout: bool = False, 

1625 ): 

1626 con = self._get_maint_notifications_connection_instance(self._conn) 

1627 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout) 

1628 

1629 def _connect(self): 

1630 self._conn._connect() 

1631 

1632 def _host_error(self): 

1633 self._conn._host_error() 

1634 

1635 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: 

1636 conn.send_command("CLIENT", "TRACKING", "ON") 

1637 conn.read_response() 

1638 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) 

1639 

1640 def _process_pending_invalidations(self): 

1641 while self.can_read(): 

1642 self._conn.read_response(push_request=True) 

1643 

1644 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): 

1645 with self._cache_lock: 

1646 # Flush cache when DB flushed on server-side 

1647 if data[1] is None: 

1648 self._cache.flush() 

1649 else: 

1650 self._cache.delete_by_redis_keys(data[1]) 

1651 

1652 

1653class SSLConnection(Connection): 

1654 """Manages SSL connections to and from the Redis server(s). 

1655 This class extends the Connection class, adding SSL functionality, and making 

1656 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

1657 """ # noqa 

1658 

1659 def __init__( 

1660 self, 

1661 ssl_keyfile=None, 

1662 ssl_certfile=None, 

1663 ssl_cert_reqs="required", 

1664 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None, 

1665 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None, 

1666 ssl_ca_certs=None, 

1667 ssl_ca_data=None, 

1668 ssl_check_hostname=True, 

1669 ssl_ca_path=None, 

1670 ssl_password=None, 

1671 ssl_validate_ocsp=False, 

1672 ssl_validate_ocsp_stapled=False, 

1673 ssl_ocsp_context=None, 

1674 ssl_ocsp_expected_cert=None, 

1675 ssl_min_version=None, 

1676 ssl_ciphers=None, 

1677 **kwargs, 

1678 ): 

1679 """Constructor 

1680 

1681 Args: 

1682 ssl_keyfile: Path to an ssl private key. Defaults to None. 

1683 ssl_certfile: Path to an ssl certificate. Defaults to None. 

1684 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), 

1685 or an ssl.VerifyMode. Defaults to "required". 

1686 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None. 

1687 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None. 

1688 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. 

1689 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. 

1690 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True. 

1691 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. 

1692 ssl_password: Password for unlocking an encrypted private key. Defaults to None. 

1693 

1694 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) 

1695 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response 

1696 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert 

1697 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. 

1698 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. 

1699 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information. 

1700 

1701 Raises: 

1702 RedisError 

1703 """ # noqa 

1704 if not SSL_AVAILABLE: 

1705 raise RedisError("Python wasn't built with SSL support") 

1706 

1707 self.keyfile = ssl_keyfile 

1708 self.certfile = ssl_certfile 

1709 if ssl_cert_reqs is None: 

1710 ssl_cert_reqs = ssl.CERT_NONE 

1711 elif isinstance(ssl_cert_reqs, str): 

1712 CERT_REQS = { # noqa: N806 

1713 "none": ssl.CERT_NONE, 

1714 "optional": ssl.CERT_OPTIONAL, 

1715 "required": ssl.CERT_REQUIRED, 

1716 } 

1717 if ssl_cert_reqs not in CERT_REQS: 

1718 raise RedisError( 

1719 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" 

1720 ) 

1721 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 

1722 self.cert_reqs = ssl_cert_reqs 

1723 self.ssl_include_verify_flags = ssl_include_verify_flags 

1724 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags 

1725 self.ca_certs = ssl_ca_certs 

1726 self.ca_data = ssl_ca_data 

1727 self.ca_path = ssl_ca_path 

1728 self.check_hostname = ( 

1729 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1730 ) 

1731 self.certificate_password = ssl_password 

1732 self.ssl_validate_ocsp = ssl_validate_ocsp 

1733 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled 

1734 self.ssl_ocsp_context = ssl_ocsp_context 

1735 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert 

1736 self.ssl_min_version = ssl_min_version 

1737 self.ssl_ciphers = ssl_ciphers 

1738 super().__init__(**kwargs) 

1739 

1740 def _connect(self): 

1741 """ 

1742 Wrap the socket with SSL support, handling potential errors. 

1743 """ 

1744 sock = super()._connect() 

1745 try: 

1746 return self._wrap_socket_with_ssl(sock) 

1747 except (OSError, RedisError): 

1748 sock.close() 

1749 raise 

1750 

1751 def _wrap_socket_with_ssl(self, sock): 

1752 """ 

1753 Wraps the socket with SSL support. 

1754 

1755 Args: 

1756 sock: The plain socket to wrap with SSL. 

1757 

1758 Returns: 

1759 An SSL wrapped socket. 

1760 """ 

1761 context = ssl.create_default_context() 

1762 context.check_hostname = self.check_hostname 

1763 context.verify_mode = self.cert_reqs 

1764 if self.ssl_include_verify_flags: 

1765 for flag in self.ssl_include_verify_flags: 

1766 context.verify_flags |= flag 

1767 if self.ssl_exclude_verify_flags: 

1768 for flag in self.ssl_exclude_verify_flags: 

1769 context.verify_flags &= ~flag 

1770 if self.certfile or self.keyfile: 

1771 context.load_cert_chain( 

1772 certfile=self.certfile, 

1773 keyfile=self.keyfile, 

1774 password=self.certificate_password, 

1775 ) 

1776 if ( 

1777 self.ca_certs is not None 

1778 or self.ca_path is not None 

1779 or self.ca_data is not None 

1780 ): 

1781 context.load_verify_locations( 

1782 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1783 ) 

1784 if self.ssl_min_version is not None: 

1785 context.minimum_version = self.ssl_min_version 

1786 if self.ssl_ciphers: 

1787 context.set_ciphers(self.ssl_ciphers) 

1788 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: 

1789 raise RedisError("cryptography is not installed.") 

1790 

1791 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: 

1792 raise RedisError( 

1793 "Either an OCSP staple or pure OCSP connection must be validated " 

1794 "- not both." 

1795 ) 

1796 

1797 sslsock = context.wrap_socket(sock, server_hostname=self.host) 

1798 

1799 # validation for the stapled case 

1800 if self.ssl_validate_ocsp_stapled: 

1801 import OpenSSL 

1802 

1803 from .ocsp import ocsp_staple_verifier 

1804 

1805 # if a context is provided use it - otherwise, a basic context 

1806 if self.ssl_ocsp_context is None: 

1807 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) 

1808 staple_ctx.use_certificate_file(self.certfile) 

1809 staple_ctx.use_privatekey_file(self.keyfile) 

1810 else: 

1811 staple_ctx = self.ssl_ocsp_context 

1812 

1813 staple_ctx.set_ocsp_client_callback( 

1814 ocsp_staple_verifier, self.ssl_ocsp_expected_cert 

1815 ) 

1816 

1817 # need another socket 

1818 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) 

1819 con.request_ocsp() 

1820 con.connect((self.host, self.port)) 

1821 con.do_handshake() 

1822 con.shutdown() 

1823 return sslsock 

1824 

1825 # pure ocsp validation 

1826 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: 

1827 from .ocsp import OCSPVerifier 

1828 

1829 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) 

1830 if o.is_valid(): 

1831 return sslsock 

1832 else: 

1833 raise ConnectionError("ocsp validation error") 

1834 return sslsock 

1835 

1836 

1837class UnixDomainSocketConnection(AbstractConnection): 

1838 "Manages UDS communication to and from a Redis server" 

1839 

1840 def __init__(self, path="", socket_timeout=None, **kwargs): 

1841 super().__init__(**kwargs) 

1842 self.path = path 

1843 self.socket_timeout = socket_timeout 

1844 

1845 def repr_pieces(self): 

1846 pieces = [("path", self.path), ("db", self.db)] 

1847 if self.client_name: 

1848 pieces.append(("client_name", self.client_name)) 

1849 return pieces 

1850 

1851 def _connect(self): 

1852 "Create a Unix domain socket connection" 

1853 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

1854 sock.settimeout(self.socket_connect_timeout) 

1855 try: 

1856 sock.connect(self.path) 

1857 except OSError: 

1858 # Prevent ResourceWarnings for unclosed sockets. 

1859 try: 

1860 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1861 except OSError: 

1862 pass 

1863 sock.close() 

1864 raise 

1865 sock.settimeout(self.socket_timeout) 

1866 return sock 

1867 

1868 def _host_error(self): 

1869 return self.path 

1870 

1871 

1872FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1873 

1874 

1875def to_bool(value): 

1876 if value is None or value == "": 

1877 return None 

1878 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1879 return False 

1880 return bool(value) 

1881 

1882 

1883def parse_ssl_verify_flags(value): 

1884 # flags are passed in as a string representation of a list, 

1885 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

1886 verify_flags_str = value.replace("[", "").replace("]", "") 

1887 

1888 verify_flags = [] 

1889 for flag in verify_flags_str.split(","): 

1890 flag = flag.strip() 

1891 if not hasattr(VerifyFlags, flag): 

1892 raise ValueError(f"Invalid ssl verify flag: {flag}") 

1893 verify_flags.append(getattr(VerifyFlags, flag)) 

1894 return verify_flags 

1895 

1896 

1897URL_QUERY_ARGUMENT_PARSERS = { 

1898 "db": int, 

1899 "socket_timeout": float, 

1900 "socket_connect_timeout": float, 

1901 "socket_keepalive": to_bool, 

1902 "retry_on_timeout": to_bool, 

1903 "retry_on_error": list, 

1904 "max_connections": int, 

1905 "health_check_interval": int, 

1906 "ssl_check_hostname": to_bool, 

1907 "ssl_include_verify_flags": parse_ssl_verify_flags, 

1908 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

1909 "timeout": float, 

1910} 

1911 

1912 

1913def parse_url(url): 

1914 if not ( 

1915 url.startswith("redis://") 

1916 or url.startswith("rediss://") 

1917 or url.startswith("unix://") 

1918 ): 

1919 raise ValueError( 

1920 "Redis URL must specify one of the following " 

1921 "schemes (redis://, rediss://, unix://)" 

1922 ) 

1923 

1924 url = urlparse(url) 

1925 kwargs = {} 

1926 

1927 for name, value in parse_qs(url.query).items(): 

1928 if value and len(value) > 0: 

1929 value = unquote(value[0]) 

1930 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1931 if parser: 

1932 try: 

1933 kwargs[name] = parser(value) 

1934 except (TypeError, ValueError): 

1935 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1936 else: 

1937 kwargs[name] = value 

1938 

1939 if url.username: 

1940 kwargs["username"] = unquote(url.username) 

1941 if url.password: 

1942 kwargs["password"] = unquote(url.password) 

1943 

1944 # We only support redis://, rediss:// and unix:// schemes. 

1945 if url.scheme == "unix": 

1946 if url.path: 

1947 kwargs["path"] = unquote(url.path) 

1948 kwargs["connection_class"] = UnixDomainSocketConnection 

1949 

1950 else: # implied: url.scheme in ("redis", "rediss"): 

1951 if url.hostname: 

1952 kwargs["host"] = unquote(url.hostname) 

1953 if url.port: 

1954 kwargs["port"] = int(url.port) 

1955 

1956 # If there's a path argument, use it as the db argument if a 

1957 # querystring value wasn't specified 

1958 if url.path and "db" not in kwargs: 

1959 try: 

1960 kwargs["db"] = int(unquote(url.path).replace("/", "")) 

1961 except (AttributeError, ValueError): 

1962 pass 

1963 

1964 if url.scheme == "rediss": 

1965 kwargs["connection_class"] = SSLConnection 

1966 

1967 return kwargs 

1968 

1969 

1970_CP = TypeVar("_CP", bound="ConnectionPool") 

1971 

1972 

1973class ConnectionPoolInterface(ABC): 

1974 @abstractmethod 

1975 def get_protocol(self): 

1976 pass 

1977 

1978 @abstractmethod 

1979 def reset(self): 

1980 pass 

1981 

1982 @abstractmethod 

1983 @deprecated_args( 

1984 args_to_warn=["*"], 

1985 reason="Use get_connection() without args instead", 

1986 version="5.3.0", 

1987 ) 

1988 def get_connection( 

1989 self, command_name: Optional[str], *keys, **options 

1990 ) -> ConnectionInterface: 

1991 pass 

1992 

1993 @abstractmethod 

1994 def get_encoder(self): 

1995 pass 

1996 

1997 @abstractmethod 

1998 def release(self, connection: ConnectionInterface): 

1999 pass 

2000 

2001 @abstractmethod 

2002 def disconnect(self, inuse_connections: bool = True): 

2003 pass 

2004 

2005 @abstractmethod 

2006 def close(self): 

2007 pass 

2008 

2009 @abstractmethod 

2010 def set_retry(self, retry: Retry): 

2011 pass 

2012 

2013 @abstractmethod 

2014 def re_auth_callback(self, token: TokenInterface): 

2015 pass 

2016 

2017 

2018class MaintNotificationsAbstractConnectionPool: 

2019 """ 

2020 Abstract class for handling maintenance notifications logic. 

2021 This class is mixed into the ConnectionPool classes. 

2022 

2023 This class is not intended to be used directly! 

2024 

2025 All logic related to maintenance notifications and 

2026 connection pool handling is encapsulated in this class. 

2027 """ 

2028 

2029 def __init__( 

2030 self, 

2031 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

2032 **kwargs, 

2033 ): 

2034 # Initialize maintenance notifications 

2035 is_protocol_supported = kwargs.get("protocol") in [3, "3"] 

2036 if maint_notifications_config is None and is_protocol_supported: 

2037 maint_notifications_config = MaintNotificationsConfig() 

2038 

2039 if maint_notifications_config and maint_notifications_config.enabled: 

2040 if not is_protocol_supported: 

2041 raise RedisError( 

2042 "Maintenance notifications handlers on connection are only supported with RESP version 3" 

2043 ) 

2044 

2045 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2046 self, maint_notifications_config 

2047 ) 

2048 

2049 self._update_connection_kwargs_for_maint_notifications( 

2050 self._maint_notifications_pool_handler 

2051 ) 

2052 else: 

2053 self._maint_notifications_pool_handler = None 

2054 

2055 @property 

2056 @abstractmethod 

2057 def connection_kwargs(self) -> Dict[str, Any]: 

2058 pass 

2059 

2060 @connection_kwargs.setter 

2061 @abstractmethod 

2062 def connection_kwargs(self, value: Dict[str, Any]): 

2063 pass 

2064 

2065 @abstractmethod 

2066 def _get_pool_lock(self) -> threading.RLock: 

2067 pass 

2068 

2069 @abstractmethod 

2070 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]: 

2071 pass 

2072 

2073 @abstractmethod 

2074 def _get_in_use_connections( 

2075 self, 

2076 ) -> Iterable["MaintNotificationsAbstractConnection"]: 

2077 pass 

2078 

2079 def maint_notifications_enabled(self): 

2080 """ 

2081 Returns: 

2082 True if the maintenance notifications are enabled, False otherwise. 

2083 The maintenance notifications config is stored in the pool handler. 

2084 If the pool handler is not set, the maintenance notifications are not enabled. 

2085 """ 

2086 maint_notifications_config = ( 

2087 self._maint_notifications_pool_handler.config 

2088 if self._maint_notifications_pool_handler 

2089 else None 

2090 ) 

2091 

2092 return maint_notifications_config and maint_notifications_config.enabled 

2093 

2094 def update_maint_notifications_config( 

2095 self, maint_notifications_config: MaintNotificationsConfig 

2096 ): 

2097 """ 

2098 Updates the maintenance notifications configuration. 

2099 This method should be called only if the pool was created 

2100 without enabling the maintenance notifications and 

2101 in a later point in time maintenance notifications 

2102 are requested to be enabled. 

2103 """ 

2104 if ( 

2105 self.maint_notifications_enabled() 

2106 and not maint_notifications_config.enabled 

2107 ): 

2108 raise ValueError( 

2109 "Cannot disable maintenance notifications after enabling them" 

2110 ) 

2111 # first update pool settings 

2112 if not self._maint_notifications_pool_handler: 

2113 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2114 self, maint_notifications_config 

2115 ) 

2116 else: 

2117 self._maint_notifications_pool_handler.config = maint_notifications_config 

2118 

2119 # then update connection kwargs and existing connections 

2120 self._update_connection_kwargs_for_maint_notifications( 

2121 self._maint_notifications_pool_handler 

2122 ) 

2123 self._update_maint_notifications_configs_for_connections( 

2124 self._maint_notifications_pool_handler 

2125 ) 

2126 

2127 def _update_connection_kwargs_for_maint_notifications( 

2128 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

2129 ): 

2130 """ 

2131 Update the connection kwargs for all future connections. 

2132 """ 

2133 if not self.maint_notifications_enabled(): 

2134 return 

2135 

2136 self.connection_kwargs.update( 

2137 { 

2138 "maint_notifications_pool_handler": maint_notifications_pool_handler, 

2139 "maint_notifications_config": maint_notifications_pool_handler.config, 

2140 } 

2141 ) 

2142 

2143 # Store original connection parameters for maintenance notifications. 

2144 if self.connection_kwargs.get("orig_host_address", None) is None: 

2145 # If orig_host_address is None it means we haven't 

2146 # configured the original values yet 

2147 self.connection_kwargs.update( 

2148 { 

2149 "orig_host_address": self.connection_kwargs.get("host"), 

2150 "orig_socket_timeout": self.connection_kwargs.get( 

2151 "socket_timeout", None 

2152 ), 

2153 "orig_socket_connect_timeout": self.connection_kwargs.get( 

2154 "socket_connect_timeout", None 

2155 ), 

2156 } 

2157 ) 

2158 

2159 def _update_maint_notifications_configs_for_connections( 

2160 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

2161 ): 

2162 """Update the maintenance notifications config for all connections in the pool.""" 

2163 with self._get_pool_lock(): 

2164 for conn in self._get_free_connections(): 

2165 conn.set_maint_notifications_pool_handler_for_connection( 

2166 maint_notifications_pool_handler 

2167 ) 

2168 conn.maint_notifications_config = ( 

2169 maint_notifications_pool_handler.config 

2170 ) 

2171 conn.disconnect() 

2172 for conn in self._get_in_use_connections(): 

2173 conn.set_maint_notifications_pool_handler_for_connection( 

2174 maint_notifications_pool_handler 

2175 ) 

2176 conn.maint_notifications_config = ( 

2177 maint_notifications_pool_handler.config 

2178 ) 

2179 conn.mark_for_reconnect() 

2180 

2181 def _should_update_connection( 

2182 self, 

2183 conn: "MaintNotificationsAbstractConnection", 

2184 matching_pattern: Literal[ 

2185 "connected_address", "configured_address", "notification_hash" 

2186 ] = "connected_address", 

2187 matching_address: Optional[str] = None, 

2188 matching_notification_hash: Optional[int] = None, 

2189 ) -> bool: 

2190 """ 

2191 Check if the connection should be updated based on the matching criteria. 

2192 """ 

2193 if matching_pattern == "connected_address": 

2194 if matching_address and conn.getpeername() != matching_address: 

2195 return False 

2196 elif matching_pattern == "configured_address": 

2197 if matching_address and conn.host != matching_address: 

2198 return False 

2199 elif matching_pattern == "notification_hash": 

2200 if ( 

2201 matching_notification_hash 

2202 and conn.maintenance_notification_hash != matching_notification_hash 

2203 ): 

2204 return False 

2205 return True 

2206 

2207 def update_connection_settings( 

2208 self, 

2209 conn: "MaintNotificationsAbstractConnection", 

2210 state: Optional["MaintenanceState"] = None, 

2211 maintenance_notification_hash: Optional[int] = None, 

2212 host_address: Optional[str] = None, 

2213 relaxed_timeout: Optional[float] = None, 

2214 update_notification_hash: bool = False, 

2215 reset_host_address: bool = False, 

2216 reset_relaxed_timeout: bool = False, 

2217 ): 

2218 """ 

2219 Update the settings for a single connection. 

2220 """ 

2221 if state: 

2222 conn.maintenance_state = state 

2223 

2224 if update_notification_hash: 

2225 # update the notification hash only if requested 

2226 conn.maintenance_notification_hash = maintenance_notification_hash 

2227 

2228 if host_address is not None: 

2229 conn.set_tmp_settings(tmp_host_address=host_address) 

2230 

2231 if relaxed_timeout is not None: 

2232 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout) 

2233 

2234 if reset_relaxed_timeout or reset_host_address: 

2235 conn.reset_tmp_settings( 

2236 reset_host_address=reset_host_address, 

2237 reset_relaxed_timeout=reset_relaxed_timeout, 

2238 ) 

2239 

2240 conn.update_current_socket_timeout(relaxed_timeout) 

2241 

2242 def update_connections_settings( 

2243 self, 

2244 state: Optional["MaintenanceState"] = None, 

2245 maintenance_notification_hash: Optional[int] = None, 

2246 host_address: Optional[str] = None, 

2247 relaxed_timeout: Optional[float] = None, 

2248 matching_address: Optional[str] = None, 

2249 matching_notification_hash: Optional[int] = None, 

2250 matching_pattern: Literal[ 

2251 "connected_address", "configured_address", "notification_hash" 

2252 ] = "connected_address", 

2253 update_notification_hash: bool = False, 

2254 reset_host_address: bool = False, 

2255 reset_relaxed_timeout: bool = False, 

2256 include_free_connections: bool = True, 

2257 ): 

2258 """ 

2259 Update the settings for all matching connections in the pool. 

2260 

2261 This method does not create new connections. 

2262 This method does not affect the connection kwargs. 

2263 

2264 :param state: The maintenance state to set for the connection. 

2265 :param maintenance_notification_hash: The hash of the maintenance notification 

2266 to set for the connection. 

2267 :param host_address: The host address to set for the connection. 

2268 :param relaxed_timeout: The relaxed timeout to set for the connection. 

2269 :param matching_address: The address to match for the connection. 

2270 :param matching_notification_hash: The notification hash to match for the connection. 

2271 :param matching_pattern: The pattern to match for the connection. 

2272 :param update_notification_hash: Whether to update the notification hash for the connection. 

2273 :param reset_host_address: Whether to reset the host address to the original address. 

2274 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout. 

2275 :param include_free_connections: Whether to include free/available connections. 

2276 """ 

2277 with self._get_pool_lock(): 

2278 for conn in self._get_in_use_connections(): 

2279 if self._should_update_connection( 

2280 conn, 

2281 matching_pattern, 

2282 matching_address, 

2283 matching_notification_hash, 

2284 ): 

2285 self.update_connection_settings( 

2286 conn, 

2287 state=state, 

2288 maintenance_notification_hash=maintenance_notification_hash, 

2289 host_address=host_address, 

2290 relaxed_timeout=relaxed_timeout, 

2291 update_notification_hash=update_notification_hash, 

2292 reset_host_address=reset_host_address, 

2293 reset_relaxed_timeout=reset_relaxed_timeout, 

2294 ) 

2295 

2296 if include_free_connections: 

2297 for conn in self._get_free_connections(): 

2298 if self._should_update_connection( 

2299 conn, 

2300 matching_pattern, 

2301 matching_address, 

2302 matching_notification_hash, 

2303 ): 

2304 self.update_connection_settings( 

2305 conn, 

2306 state=state, 

2307 maintenance_notification_hash=maintenance_notification_hash, 

2308 host_address=host_address, 

2309 relaxed_timeout=relaxed_timeout, 

2310 update_notification_hash=update_notification_hash, 

2311 reset_host_address=reset_host_address, 

2312 reset_relaxed_timeout=reset_relaxed_timeout, 

2313 ) 

2314 

2315 def update_connection_kwargs( 

2316 self, 

2317 **kwargs, 

2318 ): 

2319 """ 

2320 Update the connection kwargs for all future connections. 

2321 

2322 This method updates the connection kwargs for all future connections created by the pool. 

2323 Existing connections are not affected. 

2324 """ 

2325 self.connection_kwargs.update(kwargs) 

2326 

2327 def update_active_connections_for_reconnect( 

2328 self, 

2329 moving_address_src: Optional[str] = None, 

2330 ): 

2331 """ 

2332 Mark all active connections for reconnect. 

2333 This is used when a cluster node is migrated to a different address. 

2334 

2335 :param moving_address_src: The address of the node that is being moved. 

2336 """ 

2337 with self._get_pool_lock(): 

2338 for conn in self._get_in_use_connections(): 

2339 if self._should_update_connection( 

2340 conn, "connected_address", moving_address_src 

2341 ): 

2342 conn.mark_for_reconnect() 

2343 

2344 def disconnect_free_connections( 

2345 self, 

2346 moving_address_src: Optional[str] = None, 

2347 ): 

2348 """ 

2349 Disconnect all free/available connections. 

2350 This is used when a cluster node is migrated to a different address. 

2351 

2352 :param moving_address_src: The address of the node that is being moved. 

2353 """ 

2354 with self._get_pool_lock(): 

2355 for conn in self._get_free_connections(): 

2356 if self._should_update_connection( 

2357 conn, "connected_address", moving_address_src 

2358 ): 

2359 conn.disconnect() 

2360 

2361 

2362class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface): 

2363 """ 

2364 Create a connection pool. ``If max_connections`` is set, then this 

2365 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's 

2366 limit is reached. 

2367 

2368 By default, TCP connections are created unless ``connection_class`` 

2369 is specified. Use class:`.UnixDomainSocketConnection` for 

2370 unix sockets. 

2371 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

2372 

2373 If ``maint_notifications_config`` is provided, the connection pool will support 

2374 maintenance notifications. 

2375 Maintenance notifications are supported only with RESP3. 

2376 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3, 

2377 the maintenance notifications will be enabled by default. 

2378 

2379 Any additional keyword arguments are passed to the constructor of 

2380 ``connection_class``. 

2381 """ 

2382 

2383 @classmethod 

2384 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

2385 """ 

2386 Return a connection pool configured from the given URL. 

2387 

2388 For example:: 

2389 

2390 redis://[[username]:[password]]@localhost:6379/0 

2391 rediss://[[username]:[password]]@localhost:6379/0 

2392 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

2393 

2394 Three URL schemes are supported: 

2395 

2396 - `redis://` creates a TCP socket connection. See more at: 

2397 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

2398 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

2399 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

2400 - ``unix://``: creates a Unix Domain Socket connection. 

2401 

2402 The username, password, hostname, path and all querystring values 

2403 are passed through urllib.parse.unquote in order to replace any 

2404 percent-encoded values with their corresponding characters. 

2405 

2406 There are several ways to specify a database number. The first value 

2407 found will be used: 

2408 

2409 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

2410 2. If using the redis:// or rediss:// schemes, the path argument 

2411 of the url, e.g. redis://localhost/0 

2412 3. A ``db`` keyword argument to this function. 

2413 

2414 If none of these options are specified, the default db=0 is used. 

2415 

2416 All querystring options are cast to their appropriate Python types. 

2417 Boolean arguments can be specified with string values "True"/"False" 

2418 or "Yes"/"No". Values that cannot be properly cast cause a 

2419 ``ValueError`` to be raised. Once parsed, the querystring arguments 

2420 and keyword arguments are passed to the ``ConnectionPool``'s 

2421 class initializer. In the case of conflicting arguments, querystring 

2422 arguments always win. 

2423 """ 

2424 url_options = parse_url(url) 

2425 

2426 if "connection_class" in kwargs: 

2427 url_options["connection_class"] = kwargs["connection_class"] 

2428 

2429 kwargs.update(url_options) 

2430 return cls(**kwargs) 

2431 

2432 def __init__( 

2433 self, 

2434 connection_class=Connection, 

2435 max_connections: Optional[int] = None, 

2436 cache_factory: Optional[CacheFactoryInterface] = None, 

2437 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

2438 **connection_kwargs, 

2439 ): 

2440 max_connections = max_connections or 2**31 

2441 if not isinstance(max_connections, int) or max_connections < 0: 

2442 raise ValueError('"max_connections" must be a positive integer') 

2443 

2444 self.connection_class = connection_class 

2445 self._connection_kwargs = connection_kwargs 

2446 self.max_connections = max_connections 

2447 self.cache = None 

2448 self._cache_factory = cache_factory 

2449 

2450 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): 

2451 if self._connection_kwargs.get("protocol") not in [3, "3"]: 

2452 raise RedisError("Client caching is only supported with RESP version 3") 

2453 

2454 cache = self._connection_kwargs.get("cache") 

2455 

2456 if cache is not None: 

2457 if not isinstance(cache, CacheInterface): 

2458 raise ValueError("Cache must implement CacheInterface") 

2459 

2460 self.cache = cache 

2461 else: 

2462 if self._cache_factory is not None: 

2463 self.cache = self._cache_factory.get_cache() 

2464 else: 

2465 self.cache = CacheFactory( 

2466 self._connection_kwargs.get("cache_config") 

2467 ).get_cache() 

2468 

2469 connection_kwargs.pop("cache", None) 

2470 connection_kwargs.pop("cache_config", None) 

2471 

2472 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None) 

2473 if self._event_dispatcher is None: 

2474 self._event_dispatcher = EventDispatcher() 

2475 

2476 # a lock to protect the critical section in _checkpid(). 

2477 # this lock is acquired when the process id changes, such as 

2478 # after a fork. during this time, multiple threads in the child 

2479 # process could attempt to acquire this lock. the first thread 

2480 # to acquire the lock will reset the data structures and lock 

2481 # object of this pool. subsequent threads acquiring this lock 

2482 # will notice the first thread already did the work and simply 

2483 # release the lock. 

2484 

2485 self._fork_lock = threading.RLock() 

2486 self._lock = threading.RLock() 

2487 

2488 MaintNotificationsAbstractConnectionPool.__init__( 

2489 self, 

2490 maint_notifications_config=maint_notifications_config, 

2491 **connection_kwargs, 

2492 ) 

2493 

2494 self.reset() 

2495 

2496 def __repr__(self) -> str: 

2497 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

2498 return ( 

2499 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

2500 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

2501 f"({conn_kwargs})>)>" 

2502 ) 

2503 

2504 @property 

2505 def connection_kwargs(self) -> Dict[str, Any]: 

2506 return self._connection_kwargs 

2507 

2508 @connection_kwargs.setter 

2509 def connection_kwargs(self, value: Dict[str, Any]): 

2510 self._connection_kwargs = value 

2511 

2512 def get_protocol(self): 

2513 """ 

2514 Returns: 

2515 The RESP protocol version, or ``None`` if the protocol is not specified, 

2516 in which case the server default will be used. 

2517 """ 

2518 return self.connection_kwargs.get("protocol", None) 

2519 

2520 def reset(self) -> None: 

2521 self._created_connections = 0 

2522 self._available_connections = [] 

2523 self._in_use_connections = set() 

2524 

2525 # this must be the last operation in this method. while reset() is 

2526 # called when holding _fork_lock, other threads in this process 

2527 # can call _checkpid() which compares self.pid and os.getpid() without 

2528 # holding any lock (for performance reasons). keeping this assignment 

2529 # as the last operation ensures that those other threads will also 

2530 # notice a pid difference and block waiting for the first thread to 

2531 # release _fork_lock. when each of these threads eventually acquire 

2532 # _fork_lock, they will notice that another thread already called 

2533 # reset() and they will immediately release _fork_lock and continue on. 

2534 self.pid = os.getpid() 

2535 

2536 def _checkpid(self) -> None: 

2537 # _checkpid() attempts to keep ConnectionPool fork-safe on modern 

2538 # systems. this is called by all ConnectionPool methods that 

2539 # manipulate the pool's state such as get_connection() and release(). 

2540 # 

2541 # _checkpid() determines whether the process has forked by comparing 

2542 # the current process id to the process id saved on the ConnectionPool 

2543 # instance. if these values are the same, _checkpid() simply returns. 

2544 # 

2545 # when the process ids differ, _checkpid() assumes that the process 

2546 # has forked and that we're now running in the child process. the child 

2547 # process cannot use the parent's file descriptors (e.g., sockets). 

2548 # therefore, when _checkpid() sees the process id change, it calls 

2549 # reset() in order to reinitialize the child's ConnectionPool. this 

2550 # will cause the child to make all new connection objects. 

2551 # 

2552 # _checkpid() is protected by self._fork_lock to ensure that multiple 

2553 # threads in the child process do not call reset() multiple times. 

2554 # 

2555 # there is an extremely small chance this could fail in the following 

2556 # scenario: 

2557 # 1. process A calls _checkpid() for the first time and acquires 

2558 # self._fork_lock. 

2559 # 2. while holding self._fork_lock, process A forks (the fork() 

2560 # could happen in a different thread owned by process A) 

2561 # 3. process B (the forked child process) inherits the 

2562 # ConnectionPool's state from the parent. that state includes 

2563 # a locked _fork_lock. process B will not be notified when 

2564 # process A releases the _fork_lock and will thus never be 

2565 # able to acquire the _fork_lock. 

2566 # 

2567 # to mitigate this possible deadlock, _checkpid() will only wait 5 

2568 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in 

2569 # that time it is assumed that the child is deadlocked and a 

2570 # redis.ChildDeadlockedError error is raised. 

2571 if self.pid != os.getpid(): 

2572 acquired = self._fork_lock.acquire(timeout=5) 

2573 if not acquired: 

2574 raise ChildDeadlockedError 

2575 # reset() the instance for the new process if another thread 

2576 # hasn't already done so 

2577 try: 

2578 if self.pid != os.getpid(): 

2579 self.reset() 

2580 finally: 

2581 self._fork_lock.release() 

2582 

2583 @deprecated_args( 

2584 args_to_warn=["*"], 

2585 reason="Use get_connection() without args instead", 

2586 version="5.3.0", 

2587 ) 

2588 def get_connection(self, command_name=None, *keys, **options) -> "Connection": 

2589 "Get a connection from the pool" 

2590 

2591 self._checkpid() 

2592 with self._lock: 

2593 try: 

2594 connection = self._available_connections.pop() 

2595 except IndexError: 

2596 connection = self.make_connection() 

2597 self._in_use_connections.add(connection) 

2598 

2599 try: 

2600 # ensure this connection is connected to Redis 

2601 connection.connect() 

2602 # connections that the pool provides should be ready to send 

2603 # a command. if not, the connection was either returned to the 

2604 # pool before all data has been read or the socket has been 

2605 # closed. either way, reconnect and verify everything is good. 

2606 try: 

2607 if ( 

2608 connection.can_read() 

2609 and self.cache is None 

2610 and not self.maint_notifications_enabled() 

2611 ): 

2612 raise ConnectionError("Connection has data") 

2613 except (ConnectionError, TimeoutError, OSError): 

2614 connection.disconnect() 

2615 connection.connect() 

2616 if connection.can_read(): 

2617 raise ConnectionError("Connection not ready") 

2618 except BaseException: 

2619 # release the connection back to the pool so that we don't 

2620 # leak it 

2621 self.release(connection) 

2622 raise 

2623 return connection 

2624 

2625 def get_encoder(self) -> Encoder: 

2626 "Return an encoder based on encoding settings" 

2627 kwargs = self.connection_kwargs 

2628 return Encoder( 

2629 encoding=kwargs.get("encoding", "utf-8"), 

2630 encoding_errors=kwargs.get("encoding_errors", "strict"), 

2631 decode_responses=kwargs.get("decode_responses", False), 

2632 ) 

2633 

2634 def make_connection(self) -> "ConnectionInterface": 

2635 "Create a new connection" 

2636 if self._created_connections >= self.max_connections: 

2637 raise MaxConnectionsError("Too many connections") 

2638 self._created_connections += 1 

2639 

2640 kwargs = dict(self.connection_kwargs) 

2641 

2642 if self.cache is not None: 

2643 return CacheProxyConnection( 

2644 self.connection_class(**kwargs), self.cache, self._lock 

2645 ) 

2646 return self.connection_class(**kwargs) 

2647 

2648 def release(self, connection: "Connection") -> None: 

2649 "Releases the connection back to the pool" 

2650 self._checkpid() 

2651 with self._lock: 

2652 try: 

2653 self._in_use_connections.remove(connection) 

2654 except KeyError: 

2655 # Gracefully fail when a connection is returned to this pool 

2656 # that the pool doesn't actually own 

2657 return 

2658 

2659 if self.owns_connection(connection): 

2660 if connection.should_reconnect(): 

2661 connection.disconnect() 

2662 self._available_connections.append(connection) 

2663 self._event_dispatcher.dispatch( 

2664 AfterConnectionReleasedEvent(connection) 

2665 ) 

2666 else: 

2667 # Pool doesn't own this connection, do not add it back 

2668 # to the pool. 

2669 # The created connections count should not be changed, 

2670 # because the connection was not created by the pool. 

2671 connection.disconnect() 

2672 return 

2673 

2674 def owns_connection(self, connection: "Connection") -> int: 

2675 return connection.pid == self.pid 

2676 

2677 def disconnect(self, inuse_connections: bool = True) -> None: 

2678 """ 

2679 Disconnects connections in the pool 

2680 

2681 If ``inuse_connections`` is True, disconnect connections that are 

2682 currently in use, potentially by other threads. Otherwise only disconnect 

2683 connections that are idle in the pool. 

2684 """ 

2685 self._checkpid() 

2686 with self._lock: 

2687 if inuse_connections: 

2688 connections = chain( 

2689 self._available_connections, self._in_use_connections 

2690 ) 

2691 else: 

2692 connections = self._available_connections 

2693 

2694 for connection in connections: 

2695 connection.disconnect() 

2696 

2697 def close(self) -> None: 

2698 """Close the pool, disconnecting all connections""" 

2699 self.disconnect() 

2700 

2701 def set_retry(self, retry: Retry) -> None: 

2702 self.connection_kwargs.update({"retry": retry}) 

2703 for conn in self._available_connections: 

2704 conn.retry = retry 

2705 for conn in self._in_use_connections: 

2706 conn.retry = retry 

2707 

2708 def re_auth_callback(self, token: TokenInterface): 

2709 with self._lock: 

2710 for conn in self._available_connections: 

2711 conn.retry.call_with_retry( 

2712 lambda: conn.send_command( 

2713 "AUTH", token.try_get("oid"), token.get_value() 

2714 ), 

2715 lambda error: self._mock(error), 

2716 ) 

2717 conn.retry.call_with_retry( 

2718 lambda: conn.read_response(), lambda error: self._mock(error) 

2719 ) 

2720 for conn in self._in_use_connections: 

2721 conn.set_re_auth_token(token) 

2722 

2723 def _get_pool_lock(self): 

2724 return self._lock 

2725 

2726 def _get_free_connections(self): 

2727 with self._lock: 

2728 return self._available_connections 

2729 

2730 def _get_in_use_connections(self): 

2731 with self._lock: 

2732 return self._in_use_connections 

2733 

2734 async def _mock(self, error: RedisError): 

2735 """ 

2736 Dummy functions, needs to be passed as error callback to retry object. 

2737 :param error: 

2738 :return: 

2739 """ 

2740 pass 

2741 

2742 

2743class BlockingConnectionPool(ConnectionPool): 

2744 """ 

2745 Thread-safe blocking connection pool:: 

2746 

2747 >>> from redis.client import Redis 

2748 >>> client = Redis(connection_pool=BlockingConnectionPool()) 

2749 

2750 It performs the same function as the default 

2751 :py:class:`~redis.ConnectionPool` implementation, in that, 

2752 it maintains a pool of reusable connections that can be shared by 

2753 multiple redis clients (safely across threads if required). 

2754 

2755 The difference is that, in the event that a client tries to get a 

2756 connection from the pool when all of connections are in use, rather than 

2757 raising a :py:class:`~redis.ConnectionError` (as the default 

2758 :py:class:`~redis.ConnectionPool` implementation does), it 

2759 makes the client wait ("blocks") for a specified number of seconds until 

2760 a connection becomes available. 

2761 

2762 Use ``max_connections`` to increase / decrease the pool size:: 

2763 

2764 >>> pool = BlockingConnectionPool(max_connections=10) 

2765 

2766 Use ``timeout`` to tell it either how many seconds to wait for a connection 

2767 to become available, or to block forever: 

2768 

2769 >>> # Block forever. 

2770 >>> pool = BlockingConnectionPool(timeout=None) 

2771 

2772 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

2773 >>> # not available. 

2774 >>> pool = BlockingConnectionPool(timeout=5) 

2775 """ 

2776 

2777 def __init__( 

2778 self, 

2779 max_connections=50, 

2780 timeout=20, 

2781 connection_class=Connection, 

2782 queue_class=LifoQueue, 

2783 **connection_kwargs, 

2784 ): 

2785 self.queue_class = queue_class 

2786 self.timeout = timeout 

2787 self._in_maintenance = False 

2788 self._locked = False 

2789 super().__init__( 

2790 connection_class=connection_class, 

2791 max_connections=max_connections, 

2792 **connection_kwargs, 

2793 ) 

2794 

2795 def reset(self): 

2796 # Create and fill up a thread safe queue with ``None`` values. 

2797 try: 

2798 if self._in_maintenance: 

2799 self._lock.acquire() 

2800 self._locked = True 

2801 self.pool = self.queue_class(self.max_connections) 

2802 while True: 

2803 try: 

2804 self.pool.put_nowait(None) 

2805 except Full: 

2806 break 

2807 

2808 # Keep a list of actual connection instances so that we can 

2809 # disconnect them later. 

2810 self._connections = [] 

2811 finally: 

2812 if self._locked: 

2813 try: 

2814 self._lock.release() 

2815 except Exception: 

2816 pass 

2817 self._locked = False 

2818 

2819 # this must be the last operation in this method. while reset() is 

2820 # called when holding _fork_lock, other threads in this process 

2821 # can call _checkpid() which compares self.pid and os.getpid() without 

2822 # holding any lock (for performance reasons). keeping this assignment 

2823 # as the last operation ensures that those other threads will also 

2824 # notice a pid difference and block waiting for the first thread to 

2825 # release _fork_lock. when each of these threads eventually acquire 

2826 # _fork_lock, they will notice that another thread already called 

2827 # reset() and they will immediately release _fork_lock and continue on. 

2828 self.pid = os.getpid() 

2829 

2830 def make_connection(self): 

2831 "Make a fresh connection." 

2832 try: 

2833 if self._in_maintenance: 

2834 self._lock.acquire() 

2835 self._locked = True 

2836 

2837 if self.cache is not None: 

2838 connection = CacheProxyConnection( 

2839 self.connection_class(**self.connection_kwargs), 

2840 self.cache, 

2841 self._lock, 

2842 ) 

2843 else: 

2844 connection = self.connection_class(**self.connection_kwargs) 

2845 self._connections.append(connection) 

2846 return connection 

2847 finally: 

2848 if self._locked: 

2849 try: 

2850 self._lock.release() 

2851 except Exception: 

2852 pass 

2853 self._locked = False 

2854 

2855 @deprecated_args( 

2856 args_to_warn=["*"], 

2857 reason="Use get_connection() without args instead", 

2858 version="5.3.0", 

2859 ) 

2860 def get_connection(self, command_name=None, *keys, **options): 

2861 """ 

2862 Get a connection, blocking for ``self.timeout`` until a connection 

2863 is available from the pool. 

2864 

2865 If the connection returned is ``None`` then creates a new connection. 

2866 Because we use a last-in first-out queue, the existing connections 

2867 (having been returned to the pool after the initial ``None`` values 

2868 were added) will be returned before ``None`` values. This means we only 

2869 create new connections when we need to, i.e.: the actual number of 

2870 connections will only increase in response to demand. 

2871 """ 

2872 # Make sure we haven't changed process. 

2873 self._checkpid() 

2874 

2875 # Try and get a connection from the pool. If one isn't available within 

2876 # self.timeout then raise a ``ConnectionError``. 

2877 connection = None 

2878 try: 

2879 if self._in_maintenance: 

2880 self._lock.acquire() 

2881 self._locked = True 

2882 try: 

2883 connection = self.pool.get(block=True, timeout=self.timeout) 

2884 except Empty: 

2885 # Note that this is not caught by the redis client and will be 

2886 # raised unless handled by application code. If you want never to 

2887 raise ConnectionError("No connection available.") 

2888 

2889 # If the ``connection`` is actually ``None`` then that's a cue to make 

2890 # a new connection to add to the pool. 

2891 if connection is None: 

2892 connection = self.make_connection() 

2893 finally: 

2894 if self._locked: 

2895 try: 

2896 self._lock.release() 

2897 except Exception: 

2898 pass 

2899 self._locked = False 

2900 

2901 try: 

2902 # ensure this connection is connected to Redis 

2903 connection.connect() 

2904 # connections that the pool provides should be ready to send 

2905 # a command. if not, the connection was either returned to the 

2906 # pool before all data has been read or the socket has been 

2907 # closed. either way, reconnect and verify everything is good. 

2908 try: 

2909 if connection.can_read(): 

2910 raise ConnectionError("Connection has data") 

2911 except (ConnectionError, TimeoutError, OSError): 

2912 connection.disconnect() 

2913 connection.connect() 

2914 if connection.can_read(): 

2915 raise ConnectionError("Connection not ready") 

2916 except BaseException: 

2917 # release the connection back to the pool so that we don't leak it 

2918 self.release(connection) 

2919 raise 

2920 

2921 return connection 

2922 

2923 def release(self, connection): 

2924 "Releases the connection back to the pool." 

2925 # Make sure we haven't changed process. 

2926 self._checkpid() 

2927 

2928 try: 

2929 if self._in_maintenance: 

2930 self._lock.acquire() 

2931 self._locked = True 

2932 if not self.owns_connection(connection): 

2933 # pool doesn't own this connection. do not add it back 

2934 # to the pool. instead add a None value which is a placeholder 

2935 # that will cause the pool to recreate the connection if 

2936 # its needed. 

2937 connection.disconnect() 

2938 self.pool.put_nowait(None) 

2939 return 

2940 if connection.should_reconnect(): 

2941 connection.disconnect() 

2942 # Put the connection back into the pool. 

2943 try: 

2944 self.pool.put_nowait(connection) 

2945 except Full: 

2946 # perhaps the pool has been reset() after a fork? regardless, 

2947 # we don't want this connection 

2948 pass 

2949 finally: 

2950 if self._locked: 

2951 try: 

2952 self._lock.release() 

2953 except Exception: 

2954 pass 

2955 self._locked = False 

2956 

2957 def disconnect(self, inuse_connections: bool = True): 

2958 "Disconnects either all connections in the pool or just the free connections." 

2959 self._checkpid() 

2960 try: 

2961 if self._in_maintenance: 

2962 self._lock.acquire() 

2963 self._locked = True 

2964 if inuse_connections: 

2965 connections = self._connections 

2966 else: 

2967 connections = self._get_free_connections() 

2968 for connection in connections: 

2969 connection.disconnect() 

2970 finally: 

2971 if self._locked: 

2972 try: 

2973 self._lock.release() 

2974 except Exception: 

2975 pass 

2976 self._locked = False 

2977 

2978 def _get_free_connections(self): 

2979 with self._lock: 

2980 return {conn for conn in self.pool.queue if conn} 

2981 

2982 def _get_in_use_connections(self): 

2983 with self._lock: 

2984 # free connections 

2985 connections_in_queue = {conn for conn in self.pool.queue if conn} 

2986 # in self._connections we keep all created connections 

2987 # so the ones that are not in the queue are the in use ones 

2988 return { 

2989 conn for conn in self._connections if conn not in connections_in_queue 

2990 } 

2991 

2992 def set_in_maintenance(self, in_maintenance: bool): 

2993 """ 

2994 Sets a flag that this Blocking ConnectionPool is in maintenance mode. 

2995 

2996 This is used to prevent new connections from being created while we are in maintenance mode. 

2997 The pool will be in maintenance mode only when we are processing a MOVING notification. 

2998 """ 

2999 self._in_maintenance = in_maintenance