Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1396 statements  

1import copy 

2import os 

3import socket 

4import sys 

5import threading 

6import time 

7import weakref 

8from abc import ABC, abstractmethod 

9from itertools import chain 

10from queue import Empty, Full, LifoQueue 

11from typing import ( 

12 Any, 

13 Callable, 

14 Dict, 

15 Iterable, 

16 List, 

17 Literal, 

18 Optional, 

19 Type, 

20 TypeVar, 

21 Union, 

22) 

23from urllib.parse import parse_qs, unquote, urlparse 

24 

25from redis.cache import ( 

26 CacheEntry, 

27 CacheEntryStatus, 

28 CacheFactory, 

29 CacheFactoryInterface, 

30 CacheInterface, 

31 CacheKey, 

32) 

33 

34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser 

35from .auth.token import TokenInterface 

36from .backoff import NoBackoff 

37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider 

38from .driver_info import DriverInfo, resolve_driver_info 

39from .event import AfterConnectionReleasedEvent, EventDispatcher 

40from .exceptions import ( 

41 AuthenticationError, 

42 AuthenticationWrongNumberOfArgsError, 

43 ChildDeadlockedError, 

44 ConnectionError, 

45 DataError, 

46 MaxConnectionsError, 

47 RedisError, 

48 ResponseError, 

49 TimeoutError, 

50) 

51from .maint_notifications import ( 

52 MaintenanceState, 

53 MaintNotificationsConfig, 

54 MaintNotificationsConnectionHandler, 

55 MaintNotificationsPoolHandler, 

56) 

57from .retry import Retry 

58from .utils import ( 

59 CRYPTOGRAPHY_AVAILABLE, 

60 HIREDIS_AVAILABLE, 

61 SSL_AVAILABLE, 

62 compare_versions, 

63 deprecated_args, 

64 ensure_string, 

65 format_error_message, 

66 str_if_bytes, 

67) 

68 

69if SSL_AVAILABLE: 

70 import ssl 

71 from ssl import VerifyFlags 

72else: 

73 ssl = None 

74 VerifyFlags = None 

75 

76if HIREDIS_AVAILABLE: 

77 import hiredis 

78 

79SYM_STAR = b"*" 

80SYM_DOLLAR = b"$" 

81SYM_CRLF = b"\r\n" 

82SYM_EMPTY = b"" 

83 

84DEFAULT_RESP_VERSION = 2 

85 

86SENTINEL = object() 

87 

88DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] 

89if HIREDIS_AVAILABLE: 

90 DefaultParser = _HiredisParser 

91else: 

92 DefaultParser = _RESP2Parser 

93 

94 

95class HiredisRespSerializer: 

96 def pack(self, *args: List): 

97 """Pack a series of arguments into the Redis protocol""" 

98 output = [] 

99 

100 if isinstance(args[0], str): 

101 args = tuple(args[0].encode().split()) + args[1:] 

102 elif b" " in args[0]: 

103 args = tuple(args[0].split()) + args[1:] 

104 try: 

105 output.append(hiredis.pack_command(args)) 

106 except TypeError: 

107 _, value, traceback = sys.exc_info() 

108 raise DataError(value).with_traceback(traceback) 

109 

110 return output 

111 

112 

113class PythonRespSerializer: 

114 def __init__(self, buffer_cutoff, encode) -> None: 

115 self._buffer_cutoff = buffer_cutoff 

116 self.encode = encode 

117 

118 def pack(self, *args): 

119 """Pack a series of arguments into the Redis protocol""" 

120 output = [] 

121 # the client might have included 1 or more literal arguments in 

122 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

123 # arguments to be sent separately, so split the first argument 

124 # manually. These arguments should be bytestrings so that they are 

125 # not encoded. 

126 if isinstance(args[0], str): 

127 args = tuple(args[0].encode().split()) + args[1:] 

128 elif b" " in args[0]: 

129 args = tuple(args[0].split()) + args[1:] 

130 

131 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

132 

133 buffer_cutoff = self._buffer_cutoff 

134 for arg in map(self.encode, args): 

135 # to avoid large string mallocs, chunk the command into the 

136 # output list if we're sending large values or memoryviews 

137 arg_length = len(arg) 

138 if ( 

139 len(buff) > buffer_cutoff 

140 or arg_length > buffer_cutoff 

141 or isinstance(arg, memoryview) 

142 ): 

143 buff = SYM_EMPTY.join( 

144 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

145 ) 

146 output.append(buff) 

147 output.append(arg) 

148 buff = SYM_CRLF 

149 else: 

150 buff = SYM_EMPTY.join( 

151 ( 

152 buff, 

153 SYM_DOLLAR, 

154 str(arg_length).encode(), 

155 SYM_CRLF, 

156 arg, 

157 SYM_CRLF, 

158 ) 

159 ) 

160 output.append(buff) 

161 return output 

162 

163 

164class ConnectionInterface: 

165 @abstractmethod 

166 def repr_pieces(self): 

167 pass 

168 

169 @abstractmethod 

170 def register_connect_callback(self, callback): 

171 pass 

172 

173 @abstractmethod 

174 def deregister_connect_callback(self, callback): 

175 pass 

176 

177 @abstractmethod 

178 def set_parser(self, parser_class): 

179 pass 

180 

181 @abstractmethod 

182 def get_protocol(self): 

183 pass 

184 

185 @abstractmethod 

186 def connect(self): 

187 pass 

188 

189 @abstractmethod 

190 def on_connect(self): 

191 pass 

192 

193 @abstractmethod 

194 def disconnect(self, *args): 

195 pass 

196 

197 @abstractmethod 

198 def check_health(self): 

199 pass 

200 

201 @abstractmethod 

202 def send_packed_command(self, command, check_health=True): 

203 pass 

204 

205 @abstractmethod 

206 def send_command(self, *args, **kwargs): 

207 pass 

208 

209 @abstractmethod 

210 def can_read(self, timeout=0): 

211 pass 

212 

213 @abstractmethod 

214 def read_response( 

215 self, 

216 disable_decoding=False, 

217 *, 

218 disconnect_on_error=True, 

219 push_request=False, 

220 ): 

221 pass 

222 

223 @abstractmethod 

224 def pack_command(self, *args): 

225 pass 

226 

227 @abstractmethod 

228 def pack_commands(self, commands): 

229 pass 

230 

231 @property 

232 @abstractmethod 

233 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

234 pass 

235 

236 @abstractmethod 

237 def set_re_auth_token(self, token: TokenInterface): 

238 pass 

239 

240 @abstractmethod 

241 def re_auth(self): 

242 pass 

243 

244 @abstractmethod 

245 def mark_for_reconnect(self): 

246 """ 

247 Mark the connection to be reconnected on the next command. 

248 This is useful when a connection is moved to a different node. 

249 """ 

250 pass 

251 

252 @abstractmethod 

253 def should_reconnect(self): 

254 """ 

255 Returns True if the connection should be reconnected. 

256 """ 

257 pass 

258 

259 @abstractmethod 

260 def reset_should_reconnect(self): 

261 """ 

262 Reset the internal flag to False. 

263 """ 

264 pass 

265 

266 

267class MaintNotificationsAbstractConnection: 

268 """ 

269 Abstract class for handling maintenance notifications logic. 

270 This class is expected to be used as base class together with ConnectionInterface. 

271 

272 This class is intended to be used with multiple inheritance! 

273 

274 All logic related to maintenance notifications is encapsulated in this class. 

275 """ 

276 

277 def __init__( 

278 self, 

279 maint_notifications_config: Optional[MaintNotificationsConfig], 

280 maint_notifications_pool_handler: Optional[ 

281 MaintNotificationsPoolHandler 

282 ] = None, 

283 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

284 maintenance_notification_hash: Optional[int] = None, 

285 orig_host_address: Optional[str] = None, 

286 orig_socket_timeout: Optional[float] = None, 

287 orig_socket_connect_timeout: Optional[float] = None, 

288 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, 

289 ): 

290 """ 

291 Initialize the maintenance notifications for the connection. 

292 

293 Args: 

294 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications. 

295 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications. 

296 maintenance_state (MaintenanceState): The current maintenance state of the connection. 

297 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection. 

298 orig_host_address (Optional[str]): The original host address of the connection. 

299 orig_socket_timeout (Optional[float]): The original socket timeout of the connection. 

300 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection. 

301 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications. 

302 If not provided, the parser from the connection is used. 

303 This is useful when the parser is created after this object. 

304 """ 

305 self.maint_notifications_config = maint_notifications_config 

306 self.maintenance_state = maintenance_state 

307 self.maintenance_notification_hash = maintenance_notification_hash 

308 self._configure_maintenance_notifications( 

309 maint_notifications_pool_handler, 

310 orig_host_address, 

311 orig_socket_timeout, 

312 orig_socket_connect_timeout, 

313 parser, 

314 ) 

315 

316 @abstractmethod 

317 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]: 

318 pass 

319 

320 @abstractmethod 

321 def _get_socket(self) -> Optional[socket.socket]: 

322 pass 

323 

324 @abstractmethod 

325 def get_protocol(self) -> Union[int, str]: 

326 """ 

327 Returns: 

328 The RESP protocol version, or ``None`` if the protocol is not specified, 

329 in which case the server default will be used. 

330 """ 

331 pass 

332 

333 @property 

334 @abstractmethod 

335 def host(self) -> str: 

336 pass 

337 

338 @host.setter 

339 @abstractmethod 

340 def host(self, value: str): 

341 pass 

342 

343 @property 

344 @abstractmethod 

345 def socket_timeout(self) -> Optional[Union[float, int]]: 

346 pass 

347 

348 @socket_timeout.setter 

349 @abstractmethod 

350 def socket_timeout(self, value: Optional[Union[float, int]]): 

351 pass 

352 

353 @property 

354 @abstractmethod 

355 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

356 pass 

357 

358 @socket_connect_timeout.setter 

359 @abstractmethod 

360 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

361 pass 

362 

363 @abstractmethod 

364 def send_command(self, *args, **kwargs): 

365 pass 

366 

367 @abstractmethod 

368 def read_response( 

369 self, 

370 disable_decoding=False, 

371 *, 

372 disconnect_on_error=True, 

373 push_request=False, 

374 ): 

375 pass 

376 

377 @abstractmethod 

378 def disconnect(self, *args): 

379 pass 

380 

381 def _configure_maintenance_notifications( 

382 self, 

383 maint_notifications_pool_handler: Optional[ 

384 MaintNotificationsPoolHandler 

385 ] = None, 

386 orig_host_address=None, 

387 orig_socket_timeout=None, 

388 orig_socket_connect_timeout=None, 

389 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, 

390 ): 

391 """ 

392 Enable maintenance notifications by setting up 

393 handlers and storing original connection parameters. 

394 

395 Should be used ONLY with parsers that support push notifications. 

396 """ 

397 if ( 

398 not self.maint_notifications_config 

399 or not self.maint_notifications_config.enabled 

400 ): 

401 self._maint_notifications_pool_handler = None 

402 self._maint_notifications_connection_handler = None 

403 return 

404 

405 if not parser: 

406 raise RedisError( 

407 "To configure maintenance notifications, a parser must be provided!" 

408 ) 

409 

410 if not isinstance(parser, _HiredisParser) and not isinstance( 

411 parser, _RESP3Parser 

412 ): 

413 raise RedisError( 

414 "Maintenance notifications are only supported with hiredis and RESP3 parsers!" 

415 ) 

416 

417 if maint_notifications_pool_handler: 

418 # Extract a reference to a new pool handler that copies all properties 

419 # of the original one and has a different connection reference 

420 # This is needed because when we attach the handler to the parser 

421 # we need to make sure that the handler has a reference to the 

422 # connection that the parser is attached to. 

423 self._maint_notifications_pool_handler = ( 

424 maint_notifications_pool_handler.get_handler_for_connection() 

425 ) 

426 self._maint_notifications_pool_handler.set_connection(self) 

427 else: 

428 self._maint_notifications_pool_handler = None 

429 

430 self._maint_notifications_connection_handler = ( 

431 MaintNotificationsConnectionHandler(self, self.maint_notifications_config) 

432 ) 

433 

434 # Set up pool handler if available 

435 if self._maint_notifications_pool_handler: 

436 parser.set_node_moving_push_handler( 

437 self._maint_notifications_pool_handler.handle_notification 

438 ) 

439 

440 # Set up connection handler 

441 parser.set_maintenance_push_handler( 

442 self._maint_notifications_connection_handler.handle_notification 

443 ) 

444 

445 # Store original connection parameters 

446 self.orig_host_address = orig_host_address if orig_host_address else self.host 

447 self.orig_socket_timeout = ( 

448 orig_socket_timeout if orig_socket_timeout else self.socket_timeout 

449 ) 

450 self.orig_socket_connect_timeout = ( 

451 orig_socket_connect_timeout 

452 if orig_socket_connect_timeout 

453 else self.socket_connect_timeout 

454 ) 

455 

456 def set_maint_notifications_pool_handler_for_connection( 

457 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

458 ): 

459 # Deep copy the pool handler to avoid sharing the same pool handler 

460 # between multiple connections, because otherwise each connection will override 

461 # the connection reference and the pool handler will only hold a reference 

462 # to the last connection that was set. 

463 maint_notifications_pool_handler_copy = ( 

464 maint_notifications_pool_handler.get_handler_for_connection() 

465 ) 

466 

467 maint_notifications_pool_handler_copy.set_connection(self) 

468 self._get_parser().set_node_moving_push_handler( 

469 maint_notifications_pool_handler_copy.handle_notification 

470 ) 

471 

472 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy 

473 

474 # Update maintenance notification connection handler if it doesn't exist 

475 if not self._maint_notifications_connection_handler: 

476 self._maint_notifications_connection_handler = ( 

477 MaintNotificationsConnectionHandler( 

478 self, maint_notifications_pool_handler.config 

479 ) 

480 ) 

481 self._get_parser().set_maintenance_push_handler( 

482 self._maint_notifications_connection_handler.handle_notification 

483 ) 

484 else: 

485 self._maint_notifications_connection_handler.config = ( 

486 maint_notifications_pool_handler.config 

487 ) 

488 

489 def activate_maint_notifications_handling_if_enabled(self, check_health=True): 

490 # Send maintenance notifications handshake if RESP3 is active 

491 # and maintenance notifications are enabled 

492 # and we have a host to determine the endpoint type from 

493 # When the maint_notifications_config enabled mode is "auto", 

494 # we just log a warning if the handshake fails 

495 # When the mode is enabled=True, we raise an exception in case of failure 

496 if ( 

497 self.get_protocol() not in [2, "2"] 

498 and self.maint_notifications_config 

499 and self.maint_notifications_config.enabled 

500 and self._maint_notifications_connection_handler 

501 and hasattr(self, "host") 

502 ): 

503 self._enable_maintenance_notifications( 

504 maint_notifications_config=self.maint_notifications_config, 

505 check_health=check_health, 

506 ) 

507 

508 def _enable_maintenance_notifications( 

509 self, maint_notifications_config: MaintNotificationsConfig, check_health=True 

510 ): 

511 try: 

512 host = getattr(self, "host", None) 

513 if host is None: 

514 raise ValueError( 

515 "Cannot enable maintenance notifications for connection" 

516 " object that doesn't have a host attribute." 

517 ) 

518 else: 

519 endpoint_type = maint_notifications_config.get_endpoint_type(host, self) 

520 self.send_command( 

521 "CLIENT", 

522 "MAINT_NOTIFICATIONS", 

523 "ON", 

524 "moving-endpoint-type", 

525 endpoint_type.value, 

526 check_health=check_health, 

527 ) 

528 response = self.read_response() 

529 if not response or str_if_bytes(response) != "OK": 

530 raise ResponseError( 

531 "The server doesn't support maintenance notifications" 

532 ) 

533 except Exception as e: 

534 if ( 

535 isinstance(e, ResponseError) 

536 and maint_notifications_config.enabled == "auto" 

537 ): 

538 # Log warning but don't fail the connection 

539 import logging 

540 

541 logger = logging.getLogger(__name__) 

542 logger.debug(f"Failed to enable maintenance notifications: {e}") 

543 else: 

544 raise 

545 

546 def get_resolved_ip(self) -> Optional[str]: 

547 """ 

548 Extract the resolved IP address from an 

549 established connection or resolve it from the host. 

550 

551 First tries to get the actual IP from the socket (most accurate), 

552 then falls back to DNS resolution if needed. 

553 

554 Args: 

555 connection: The connection object to extract the IP from 

556 

557 Returns: 

558 str: The resolved IP address, or None if it cannot be determined 

559 """ 

560 

561 # Method 1: Try to get the actual IP from the established socket connection 

562 # This is most accurate as it shows the exact IP being used 

563 try: 

564 conn_socket = self._get_socket() 

565 if conn_socket is not None: 

566 peer_addr = conn_socket.getpeername() 

567 if peer_addr and len(peer_addr) >= 1: 

568 # For TCP sockets, peer_addr is typically (host, port) tuple 

569 # Return just the host part 

570 return peer_addr[0] 

571 except (AttributeError, OSError): 

572 # Socket might not be connected or getpeername() might fail 

573 pass 

574 

575 # Method 2: Fallback to DNS resolution of the host 

576 # This is less accurate but works when socket is not available 

577 try: 

578 host = getattr(self, "host", "localhost") 

579 port = getattr(self, "port", 6379) 

580 if host: 

581 # Use getaddrinfo to resolve the hostname to IP 

582 # This mimics what the connection would do during _connect() 

583 addr_info = socket.getaddrinfo( 

584 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM 

585 ) 

586 if addr_info: 

587 # Return the IP from the first result 

588 # addr_info[0] is (family, socktype, proto, canonname, sockaddr) 

589 # sockaddr[0] is the IP address 

590 return str(addr_info[0][4][0]) 

591 except (AttributeError, OSError, socket.gaierror): 

592 # DNS resolution might fail 

593 pass 

594 

595 return None 

596 

597 @property 

598 def maintenance_state(self) -> MaintenanceState: 

599 return self._maintenance_state 

600 

601 @maintenance_state.setter 

602 def maintenance_state(self, state: "MaintenanceState"): 

603 self._maintenance_state = state 

604 

605 def getpeername(self): 

606 """ 

607 Returns the peer name of the connection. 

608 """ 

609 conn_socket = self._get_socket() 

610 if conn_socket: 

611 return conn_socket.getpeername()[0] 

612 return None 

613 

614 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

615 conn_socket = self._get_socket() 

616 if conn_socket: 

617 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout 

618 conn_socket.settimeout(timeout) 

619 self.update_parser_timeout(timeout) 

620 

621 def update_parser_timeout(self, timeout: Optional[float] = None): 

622 parser = self._get_parser() 

623 if parser and parser._buffer: 

624 if isinstance(parser, _RESP3Parser) and timeout: 

625 parser._buffer.socket_timeout = timeout 

626 elif isinstance(parser, _HiredisParser): 

627 parser._socket_timeout = timeout 

628 

629 def set_tmp_settings( 

630 self, 

631 tmp_host_address: Optional[Union[str, object]] = SENTINEL, 

632 tmp_relaxed_timeout: Optional[float] = None, 

633 ): 

634 """ 

635 The value of SENTINEL is used to indicate that the property should not be updated. 

636 """ 

637 if tmp_host_address and tmp_host_address != SENTINEL: 

638 self.host = str(tmp_host_address) 

639 if tmp_relaxed_timeout != -1: 

640 self.socket_timeout = tmp_relaxed_timeout 

641 self.socket_connect_timeout = tmp_relaxed_timeout 

642 

643 def reset_tmp_settings( 

644 self, 

645 reset_host_address: bool = False, 

646 reset_relaxed_timeout: bool = False, 

647 ): 

648 if reset_host_address: 

649 self.host = self.orig_host_address 

650 if reset_relaxed_timeout: 

651 self.socket_timeout = self.orig_socket_timeout 

652 self.socket_connect_timeout = self.orig_socket_connect_timeout 

653 

654 

655class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface): 

656 "Manages communication to and from a Redis server" 

657 

658 @deprecated_args( 

659 args_to_warn=["lib_name", "lib_version"], 

660 reason="Use 'driver_info' parameter instead. " 

661 "lib_name and lib_version will be removed in a future version.", 

662 ) 

663 def __init__( 

664 self, 

665 db: int = 0, 

666 password: Optional[str] = None, 

667 socket_timeout: Optional[float] = None, 

668 socket_connect_timeout: Optional[float] = None, 

669 retry_on_timeout: bool = False, 

670 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL, 

671 encoding: str = "utf-8", 

672 encoding_errors: str = "strict", 

673 decode_responses: bool = False, 

674 parser_class=DefaultParser, 

675 socket_read_size: int = 65536, 

676 health_check_interval: int = 0, 

677 client_name: Optional[str] = None, 

678 lib_name: Optional[str] = None, 

679 lib_version: Optional[str] = None, 

680 driver_info: Optional[DriverInfo] = None, 

681 username: Optional[str] = None, 

682 retry: Union[Any, None] = None, 

683 redis_connect_func: Optional[Callable[[], None]] = None, 

684 credential_provider: Optional[CredentialProvider] = None, 

685 protocol: Optional[int] = 2, 

686 command_packer: Optional[Callable[[], None]] = None, 

687 event_dispatcher: Optional[EventDispatcher] = None, 

688 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

689 maint_notifications_pool_handler: Optional[ 

690 MaintNotificationsPoolHandler 

691 ] = None, 

692 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

693 maintenance_notification_hash: Optional[int] = None, 

694 orig_host_address: Optional[str] = None, 

695 orig_socket_timeout: Optional[float] = None, 

696 orig_socket_connect_timeout: Optional[float] = None, 

697 ): 

698 """ 

699 Initialize a new Connection. 

700 

701 To specify a retry policy for specific errors, first set 

702 `retry_on_error` to a list of the error/s to retry on, then set 

703 `retry` to a valid `Retry` object. 

704 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

705 

706 Parameters 

707 ---------- 

708 driver_info : DriverInfo, optional 

709 Driver metadata for CLIENT SETINFO. If provided, lib_name and lib_version 

710 are ignored. If not provided, a DriverInfo will be created from lib_name 

711 and lib_version (or defaults if those are also None). 

712 lib_name : str, optional 

713 **Deprecated.** Use driver_info instead. Library name for CLIENT SETINFO. 

714 lib_version : str, optional 

715 **Deprecated.** Use driver_info instead. Library version for CLIENT SETINFO. 

716 """ 

717 if (username or password) and credential_provider is not None: 

718 raise DataError( 

719 "'username' and 'password' cannot be passed along with 'credential_" 

720 "provider'. Please provide only one of the following arguments: \n" 

721 "1. 'password' and (optional) 'username'\n" 

722 "2. 'credential_provider'" 

723 ) 

724 if event_dispatcher is None: 

725 self._event_dispatcher = EventDispatcher() 

726 else: 

727 self._event_dispatcher = event_dispatcher 

728 self.pid = os.getpid() 

729 self.db = db 

730 self.client_name = client_name 

731 

732 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version 

733 self.driver_info = resolve_driver_info(driver_info, lib_name, lib_version) 

734 

735 self.credential_provider = credential_provider 

736 self.password = password 

737 self.username = username 

738 self._socket_timeout = socket_timeout 

739 if socket_connect_timeout is None: 

740 socket_connect_timeout = socket_timeout 

741 self._socket_connect_timeout = socket_connect_timeout 

742 self.retry_on_timeout = retry_on_timeout 

743 if retry_on_error is SENTINEL: 

744 retry_on_errors_list = [] 

745 else: 

746 retry_on_errors_list = list(retry_on_error) 

747 if retry_on_timeout: 

748 # Add TimeoutError to the errors list to retry on 

749 retry_on_errors_list.append(TimeoutError) 

750 self.retry_on_error = retry_on_errors_list 

751 if retry or self.retry_on_error: 

752 if retry is None: 

753 self.retry = Retry(NoBackoff(), 1) 

754 else: 

755 # deep-copy the Retry object as it is mutable 

756 self.retry = copy.deepcopy(retry) 

757 if self.retry_on_error: 

758 # Update the retry's supported errors with the specified errors 

759 self.retry.update_supported_errors(self.retry_on_error) 

760 else: 

761 self.retry = Retry(NoBackoff(), 0) 

762 self.health_check_interval = health_check_interval 

763 self.next_health_check = 0 

764 self.redis_connect_func = redis_connect_func 

765 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

766 self.handshake_metadata = None 

767 self._sock = None 

768 self._socket_read_size = socket_read_size 

769 self._connect_callbacks = [] 

770 self._buffer_cutoff = 6000 

771 self._re_auth_token: Optional[TokenInterface] = None 

772 try: 

773 p = int(protocol) 

774 except TypeError: 

775 p = DEFAULT_RESP_VERSION 

776 except ValueError: 

777 raise ConnectionError("protocol must be an integer") 

778 finally: 

779 if p < 2 or p > 3: 

780 raise ConnectionError("protocol must be either 2 or 3") 

781 # p = DEFAULT_RESP_VERSION 

782 self.protocol = p 

783 if self.protocol == 3 and parser_class == _RESP2Parser: 

784 # If the protocol is 3 but the parser is RESP2, change it to RESP3 

785 # This is needed because the parser might be set before the protocol 

786 # or might be provided as a kwarg to the constructor 

787 # We need to react on discrepancy only for RESP2 and RESP3 

788 # as hiredis supports both 

789 parser_class = _RESP3Parser 

790 self.set_parser(parser_class) 

791 

792 self._command_packer = self._construct_command_packer(command_packer) 

793 self._should_reconnect = False 

794 

795 # Set up maintenance notifications 

796 MaintNotificationsAbstractConnection.__init__( 

797 self, 

798 maint_notifications_config, 

799 maint_notifications_pool_handler, 

800 maintenance_state, 

801 maintenance_notification_hash, 

802 orig_host_address, 

803 orig_socket_timeout, 

804 orig_socket_connect_timeout, 

805 self._parser, 

806 ) 

807 

808 def __repr__(self): 

809 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) 

810 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

811 

812 @abstractmethod 

813 def repr_pieces(self): 

814 pass 

815 

816 def __del__(self): 

817 try: 

818 self.disconnect() 

819 except Exception: 

820 pass 

821 

822 def _construct_command_packer(self, packer): 

823 if packer is not None: 

824 return packer 

825 elif HIREDIS_AVAILABLE: 

826 return HiredisRespSerializer() 

827 else: 

828 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) 

829 

830 def register_connect_callback(self, callback): 

831 """ 

832 Register a callback to be called when the connection is established either 

833 initially or reconnected. This allows listeners to issue commands that 

834 are ephemeral to the connection, for example pub/sub subscription or 

835 key tracking. The callback must be a _method_ and will be kept as 

836 a weak reference. 

837 """ 

838 wm = weakref.WeakMethod(callback) 

839 if wm not in self._connect_callbacks: 

840 self._connect_callbacks.append(wm) 

841 

842 def deregister_connect_callback(self, callback): 

843 """ 

844 De-register a previously registered callback. It will no-longer receive 

845 notifications on connection events. Calling this is not required when the 

846 listener goes away, since the callbacks are kept as weak methods. 

847 """ 

848 try: 

849 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

850 except ValueError: 

851 pass 

852 

853 def set_parser(self, parser_class): 

854 """ 

855 Creates a new instance of parser_class with socket size: 

856 _socket_read_size and assigns it to the parser for the connection 

857 :param parser_class: The required parser class 

858 """ 

859 self._parser = parser_class(socket_read_size=self._socket_read_size) 

860 

861 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]: 

862 return self._parser 

863 

864 def connect(self): 

865 "Connects to the Redis server if not already connected" 

866 # try once the socket connect with the handshake, retry the whole 

867 # connect/handshake flow based on retry policy 

868 self.retry.call_with_retry( 

869 lambda: self.connect_check_health( 

870 check_health=True, retry_socket_connect=False 

871 ), 

872 lambda error: self.disconnect(error), 

873 ) 

874 

875 def connect_check_health( 

876 self, check_health: bool = True, retry_socket_connect: bool = True 

877 ): 

878 if self._sock: 

879 return 

880 try: 

881 if retry_socket_connect: 

882 sock = self.retry.call_with_retry( 

883 lambda: self._connect(), lambda error: self.disconnect(error) 

884 ) 

885 else: 

886 sock = self._connect() 

887 except socket.timeout: 

888 raise TimeoutError("Timeout connecting to server") 

889 except OSError as e: 

890 raise ConnectionError(self._error_message(e)) 

891 

892 self._sock = sock 

893 try: 

894 if self.redis_connect_func is None: 

895 # Use the default on_connect function 

896 self.on_connect_check_health(check_health=check_health) 

897 else: 

898 # Use the passed function redis_connect_func 

899 self.redis_connect_func(self) 

900 except RedisError: 

901 # clean up after any error in on_connect 

902 self.disconnect() 

903 raise 

904 

905 # run any user callbacks. right now the only internal callback 

906 # is for pubsub channel/pattern resubscription 

907 # first, remove any dead weakrefs 

908 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

909 for ref in self._connect_callbacks: 

910 callback = ref() 

911 if callback: 

912 callback(self) 

913 

914 @abstractmethod 

915 def _connect(self): 

916 pass 

917 

918 @abstractmethod 

919 def _host_error(self): 

920 pass 

921 

922 def _error_message(self, exception): 

923 return format_error_message(self._host_error(), exception) 

924 

925 def on_connect(self): 

926 self.on_connect_check_health(check_health=True) 

927 

928 def on_connect_check_health(self, check_health: bool = True): 

929 "Initialize the connection, authenticate and select a database" 

930 self._parser.on_connect(self) 

931 parser = self._parser 

932 

933 auth_args = None 

934 # if credential provider or username and/or password are set, authenticate 

935 if self.credential_provider or (self.username or self.password): 

936 cred_provider = ( 

937 self.credential_provider 

938 or UsernamePasswordCredentialProvider(self.username, self.password) 

939 ) 

940 auth_args = cred_provider.get_credentials() 

941 

942 # if resp version is specified and we have auth args, 

943 # we need to send them via HELLO 

944 if auth_args and self.protocol not in [2, "2"]: 

945 if isinstance(self._parser, _RESP2Parser): 

946 self.set_parser(_RESP3Parser) 

947 # update cluster exception classes 

948 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

949 self._parser.on_connect(self) 

950 if len(auth_args) == 1: 

951 auth_args = ["default", auth_args[0]] 

952 # avoid checking health here -- PING will fail if we try 

953 # to check the health prior to the AUTH 

954 self.send_command( 

955 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

956 ) 

957 self.handshake_metadata = self.read_response() 

958 # if response.get(b"proto") != self.protocol and response.get( 

959 # "proto" 

960 # ) != self.protocol: 

961 # raise ConnectionError("Invalid RESP version") 

962 elif auth_args: 

963 # avoid checking health here -- PING will fail if we try 

964 # to check the health prior to the AUTH 

965 self.send_command("AUTH", *auth_args, check_health=False) 

966 

967 try: 

968 auth_response = self.read_response() 

969 except AuthenticationWrongNumberOfArgsError: 

970 # a username and password were specified but the Redis 

971 # server seems to be < 6.0.0 which expects a single password 

972 # arg. retry auth with just the password. 

973 # https://github.com/andymccurdy/redis-py/issues/1274 

974 self.send_command("AUTH", auth_args[-1], check_health=False) 

975 auth_response = self.read_response() 

976 

977 if str_if_bytes(auth_response) != "OK": 

978 raise AuthenticationError("Invalid Username or Password") 

979 

980 # if resp version is specified, switch to it 

981 elif self.protocol not in [2, "2"]: 

982 if isinstance(self._parser, _RESP2Parser): 

983 self.set_parser(_RESP3Parser) 

984 # update cluster exception classes 

985 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

986 self._parser.on_connect(self) 

987 self.send_command("HELLO", self.protocol, check_health=check_health) 

988 self.handshake_metadata = self.read_response() 

989 if ( 

990 self.handshake_metadata.get(b"proto") != self.protocol 

991 and self.handshake_metadata.get("proto") != self.protocol 

992 ): 

993 raise ConnectionError("Invalid RESP version") 

994 

995 # Activate maintenance notifications for this connection 

996 # if enabled in the configuration 

997 # This is a no-op if maintenance notifications are not enabled 

998 self.activate_maint_notifications_handling_if_enabled(check_health=check_health) 

999 

1000 # if a client_name is given, set it 

1001 if self.client_name: 

1002 self.send_command( 

1003 "CLIENT", 

1004 "SETNAME", 

1005 self.client_name, 

1006 check_health=check_health, 

1007 ) 

1008 if str_if_bytes(self.read_response()) != "OK": 

1009 raise ConnectionError("Error setting client name") 

1010 

1011 # Set the library name and version from driver_info 

1012 try: 

1013 if self.driver_info and self.driver_info.formatted_name: 

1014 self.send_command( 

1015 "CLIENT", 

1016 "SETINFO", 

1017 "LIB-NAME", 

1018 self.driver_info.formatted_name, 

1019 check_health=check_health, 

1020 ) 

1021 self.read_response() 

1022 except ResponseError: 

1023 pass 

1024 

1025 try: 

1026 if self.driver_info and self.driver_info.lib_version: 

1027 self.send_command( 

1028 "CLIENT", 

1029 "SETINFO", 

1030 "LIB-VER", 

1031 self.driver_info.lib_version, 

1032 check_health=check_health, 

1033 ) 

1034 self.read_response() 

1035 except ResponseError: 

1036 pass 

1037 

1038 # if a database is specified, switch to it 

1039 if self.db: 

1040 self.send_command("SELECT", self.db, check_health=check_health) 

1041 if str_if_bytes(self.read_response()) != "OK": 

1042 raise ConnectionError("Invalid Database") 

1043 

1044 def disconnect(self, *args): 

1045 "Disconnects from the Redis server" 

1046 self._parser.on_disconnect() 

1047 

1048 conn_sock = self._sock 

1049 self._sock = None 

1050 # reset the reconnect flag 

1051 self.reset_should_reconnect() 

1052 if conn_sock is None: 

1053 return 

1054 

1055 if os.getpid() == self.pid: 

1056 try: 

1057 conn_sock.shutdown(socket.SHUT_RDWR) 

1058 except (OSError, TypeError): 

1059 pass 

1060 

1061 try: 

1062 conn_sock.close() 

1063 except OSError: 

1064 pass 

1065 

1066 def mark_for_reconnect(self): 

1067 self._should_reconnect = True 

1068 

1069 def should_reconnect(self): 

1070 return self._should_reconnect 

1071 

1072 def reset_should_reconnect(self): 

1073 self._should_reconnect = False 

1074 

1075 def _send_ping(self): 

1076 """Send PING, expect PONG in return""" 

1077 self.send_command("PING", check_health=False) 

1078 if str_if_bytes(self.read_response()) != "PONG": 

1079 raise ConnectionError("Bad response from PING health check") 

1080 

1081 def _ping_failed(self, error): 

1082 """Function to call when PING fails""" 

1083 self.disconnect() 

1084 

1085 def check_health(self): 

1086 """Check the health of the connection with a PING/PONG""" 

1087 if self.health_check_interval and time.monotonic() > self.next_health_check: 

1088 self.retry.call_with_retry(self._send_ping, self._ping_failed) 

1089 

1090 def send_packed_command(self, command, check_health=True): 

1091 """Send an already packed command to the Redis server""" 

1092 if not self._sock: 

1093 self.connect_check_health(check_health=False) 

1094 # guard against health check recursion 

1095 if check_health: 

1096 self.check_health() 

1097 try: 

1098 if isinstance(command, str): 

1099 command = [command] 

1100 for item in command: 

1101 self._sock.sendall(item) 

1102 except socket.timeout: 

1103 self.disconnect() 

1104 raise TimeoutError("Timeout writing to socket") 

1105 except OSError as e: 

1106 self.disconnect() 

1107 if len(e.args) == 1: 

1108 errno, errmsg = "UNKNOWN", e.args[0] 

1109 else: 

1110 errno = e.args[0] 

1111 errmsg = e.args[1] 

1112 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") 

1113 except BaseException: 

1114 # BaseExceptions can be raised when a socket send operation is not 

1115 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

1116 # to send un-sent data. However, the send_packed_command() API 

1117 # does not support it so there is no point in keeping the connection open. 

1118 self.disconnect() 

1119 raise 

1120 

1121 def send_command(self, *args, **kwargs): 

1122 """Pack and send a command to the Redis server""" 

1123 self.send_packed_command( 

1124 self._command_packer.pack(*args), 

1125 check_health=kwargs.get("check_health", True), 

1126 ) 

1127 

1128 def can_read(self, timeout=0): 

1129 """Poll the socket to see if there's data that can be read.""" 

1130 sock = self._sock 

1131 if not sock: 

1132 self.connect() 

1133 

1134 host_error = self._host_error() 

1135 

1136 try: 

1137 return self._parser.can_read(timeout) 

1138 

1139 except OSError as e: 

1140 self.disconnect() 

1141 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

1142 

1143 def read_response( 

1144 self, 

1145 disable_decoding=False, 

1146 *, 

1147 disconnect_on_error=True, 

1148 push_request=False, 

1149 ): 

1150 """Read the response from a previously sent command""" 

1151 

1152 host_error = self._host_error() 

1153 

1154 try: 

1155 if self.protocol in ["3", 3]: 

1156 response = self._parser.read_response( 

1157 disable_decoding=disable_decoding, push_request=push_request 

1158 ) 

1159 else: 

1160 response = self._parser.read_response(disable_decoding=disable_decoding) 

1161 except socket.timeout: 

1162 if disconnect_on_error: 

1163 self.disconnect() 

1164 raise TimeoutError(f"Timeout reading from {host_error}") 

1165 except OSError as e: 

1166 if disconnect_on_error: 

1167 self.disconnect() 

1168 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

1169 except BaseException: 

1170 # Also by default close in case of BaseException. A lot of code 

1171 # relies on this behaviour when doing Command/Response pairs. 

1172 # See #1128. 

1173 if disconnect_on_error: 

1174 self.disconnect() 

1175 raise 

1176 

1177 if self.health_check_interval: 

1178 self.next_health_check = time.monotonic() + self.health_check_interval 

1179 

1180 if isinstance(response, ResponseError): 

1181 try: 

1182 raise response 

1183 finally: 

1184 del response # avoid creating ref cycles 

1185 return response 

1186 

1187 def pack_command(self, *args): 

1188 """Pack a series of arguments into the Redis protocol""" 

1189 return self._command_packer.pack(*args) 

1190 

1191 def pack_commands(self, commands): 

1192 """Pack multiple commands into the Redis protocol""" 

1193 output = [] 

1194 pieces = [] 

1195 buffer_length = 0 

1196 buffer_cutoff = self._buffer_cutoff 

1197 

1198 for cmd in commands: 

1199 for chunk in self._command_packer.pack(*cmd): 

1200 chunklen = len(chunk) 

1201 if ( 

1202 buffer_length > buffer_cutoff 

1203 or chunklen > buffer_cutoff 

1204 or isinstance(chunk, memoryview) 

1205 ): 

1206 if pieces: 

1207 output.append(SYM_EMPTY.join(pieces)) 

1208 buffer_length = 0 

1209 pieces = [] 

1210 

1211 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

1212 output.append(chunk) 

1213 else: 

1214 pieces.append(chunk) 

1215 buffer_length += chunklen 

1216 

1217 if pieces: 

1218 output.append(SYM_EMPTY.join(pieces)) 

1219 return output 

1220 

1221 def get_protocol(self) -> Union[int, str]: 

1222 return self.protocol 

1223 

1224 @property 

1225 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1226 return self._handshake_metadata 

1227 

1228 @handshake_metadata.setter 

1229 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): 

1230 self._handshake_metadata = value 

1231 

1232 def set_re_auth_token(self, token: TokenInterface): 

1233 self._re_auth_token = token 

1234 

1235 def re_auth(self): 

1236 if self._re_auth_token is not None: 

1237 self.send_command( 

1238 "AUTH", 

1239 self._re_auth_token.try_get("oid"), 

1240 self._re_auth_token.get_value(), 

1241 ) 

1242 self.read_response() 

1243 self._re_auth_token = None 

1244 

1245 def _get_socket(self) -> Optional[socket.socket]: 

1246 return self._sock 

1247 

1248 @property 

1249 def socket_timeout(self) -> Optional[Union[float, int]]: 

1250 return self._socket_timeout 

1251 

1252 @socket_timeout.setter 

1253 def socket_timeout(self, value: Optional[Union[float, int]]): 

1254 self._socket_timeout = value 

1255 

1256 @property 

1257 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

1258 return self._socket_connect_timeout 

1259 

1260 @socket_connect_timeout.setter 

1261 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

1262 self._socket_connect_timeout = value 

1263 

1264 

1265class Connection(AbstractConnection): 

1266 "Manages TCP communication to and from a Redis server" 

1267 

1268 def __init__( 

1269 self, 

1270 host="localhost", 

1271 port=6379, 

1272 socket_keepalive=False, 

1273 socket_keepalive_options=None, 

1274 socket_type=0, 

1275 **kwargs, 

1276 ): 

1277 self._host = host 

1278 self.port = int(port) 

1279 self.socket_keepalive = socket_keepalive 

1280 self.socket_keepalive_options = socket_keepalive_options or {} 

1281 self.socket_type = socket_type 

1282 super().__init__(**kwargs) 

1283 

1284 def repr_pieces(self): 

1285 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

1286 if self.client_name: 

1287 pieces.append(("client_name", self.client_name)) 

1288 return pieces 

1289 

1290 def _connect(self): 

1291 "Create a TCP socket connection" 

1292 # we want to mimic what socket.create_connection does to support 

1293 # ipv4/ipv6, but we want to set options prior to calling 

1294 # socket.connect() 

1295 err = None 

1296 

1297 for res in socket.getaddrinfo( 

1298 self.host, self.port, self.socket_type, socket.SOCK_STREAM 

1299 ): 

1300 family, socktype, proto, canonname, socket_address = res 

1301 sock = None 

1302 try: 

1303 sock = socket.socket(family, socktype, proto) 

1304 # TCP_NODELAY 

1305 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

1306 

1307 # TCP_KEEPALIVE 

1308 if self.socket_keepalive: 

1309 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

1310 for k, v in self.socket_keepalive_options.items(): 

1311 sock.setsockopt(socket.IPPROTO_TCP, k, v) 

1312 

1313 # set the socket_connect_timeout before we connect 

1314 sock.settimeout(self.socket_connect_timeout) 

1315 

1316 # connect 

1317 sock.connect(socket_address) 

1318 

1319 # set the socket_timeout now that we're connected 

1320 sock.settimeout(self.socket_timeout) 

1321 return sock 

1322 

1323 except OSError as _: 

1324 err = _ 

1325 if sock is not None: 

1326 try: 

1327 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1328 except OSError: 

1329 pass 

1330 sock.close() 

1331 

1332 if err is not None: 

1333 raise err 

1334 raise OSError("socket.getaddrinfo returned an empty list") 

1335 

1336 def _host_error(self): 

1337 return f"{self.host}:{self.port}" 

1338 

1339 @property 

1340 def host(self) -> str: 

1341 return self._host 

1342 

1343 @host.setter 

1344 def host(self, value: str): 

1345 self._host = value 

1346 

1347 

1348class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface): 

1349 DUMMY_CACHE_VALUE = b"foo" 

1350 MIN_ALLOWED_VERSION = "7.4.0" 

1351 DEFAULT_SERVER_NAME = "redis" 

1352 

1353 def __init__( 

1354 self, 

1355 conn: ConnectionInterface, 

1356 cache: CacheInterface, 

1357 pool_lock: threading.RLock, 

1358 ): 

1359 self.pid = os.getpid() 

1360 self._conn = conn 

1361 self.retry = self._conn.retry 

1362 self.host = self._conn.host 

1363 self.port = self._conn.port 

1364 self.credential_provider = conn.credential_provider 

1365 self._pool_lock = pool_lock 

1366 self._cache = cache 

1367 self._cache_lock = threading.RLock() 

1368 self._current_command_cache_key = None 

1369 self._current_options = None 

1370 self.register_connect_callback(self._enable_tracking_callback) 

1371 

1372 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1373 MaintNotificationsAbstractConnection.__init__( 

1374 self, 

1375 self._conn.maint_notifications_config, 

1376 self._conn._maint_notifications_pool_handler, 

1377 self._conn.maintenance_state, 

1378 self._conn.maintenance_notification_hash, 

1379 self._conn.host, 

1380 self._conn.socket_timeout, 

1381 self._conn.socket_connect_timeout, 

1382 self._conn._get_parser(), 

1383 ) 

1384 

1385 def repr_pieces(self): 

1386 return self._conn.repr_pieces() 

1387 

1388 def register_connect_callback(self, callback): 

1389 self._conn.register_connect_callback(callback) 

1390 

1391 def deregister_connect_callback(self, callback): 

1392 self._conn.deregister_connect_callback(callback) 

1393 

1394 def set_parser(self, parser_class): 

1395 self._conn.set_parser(parser_class) 

1396 

1397 def set_maint_notifications_pool_handler_for_connection( 

1398 self, maint_notifications_pool_handler 

1399 ): 

1400 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1401 self._conn.set_maint_notifications_pool_handler_for_connection( 

1402 maint_notifications_pool_handler 

1403 ) 

1404 

1405 def get_protocol(self): 

1406 return self._conn.get_protocol() 

1407 

1408 def connect(self): 

1409 self._conn.connect() 

1410 

1411 server_name = self._conn.handshake_metadata.get(b"server", None) 

1412 if server_name is None: 

1413 server_name = self._conn.handshake_metadata.get("server", None) 

1414 server_ver = self._conn.handshake_metadata.get(b"version", None) 

1415 if server_ver is None: 

1416 server_ver = self._conn.handshake_metadata.get("version", None) 

1417 if server_ver is None or server_ver is None: 

1418 raise ConnectionError("Cannot retrieve information about server version") 

1419 

1420 server_ver = ensure_string(server_ver) 

1421 server_name = ensure_string(server_name) 

1422 

1423 if ( 

1424 server_name != self.DEFAULT_SERVER_NAME 

1425 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 

1426 ): 

1427 raise ConnectionError( 

1428 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 

1429 ) 

1430 

1431 def on_connect(self): 

1432 self._conn.on_connect() 

1433 

1434 def disconnect(self, *args): 

1435 with self._cache_lock: 

1436 self._cache.flush() 

1437 self._conn.disconnect(*args) 

1438 

1439 def check_health(self): 

1440 self._conn.check_health() 

1441 

1442 def send_packed_command(self, command, check_health=True): 

1443 # TODO: Investigate if it's possible to unpack command 

1444 # or extract keys from packed command 

1445 self._conn.send_packed_command(command) 

1446 

1447 def send_command(self, *args, **kwargs): 

1448 self._process_pending_invalidations() 

1449 

1450 with self._cache_lock: 

1451 # Command is write command or not allowed 

1452 # to be cached. 

1453 if not self._cache.is_cachable( 

1454 CacheKey(command=args[0], redis_keys=(), redis_args=()) 

1455 ): 

1456 self._current_command_cache_key = None 

1457 self._conn.send_command(*args, **kwargs) 

1458 return 

1459 

1460 if kwargs.get("keys") is None: 

1461 raise ValueError("Cannot create cache key.") 

1462 

1463 # Creates cache key. 

1464 self._current_command_cache_key = CacheKey( 

1465 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args 

1466 ) 

1467 

1468 with self._cache_lock: 

1469 # We have to trigger invalidation processing in case if 

1470 # it was cached by another connection to avoid 

1471 # queueing invalidations in stale connections. 

1472 if self._cache.get(self._current_command_cache_key): 

1473 entry = self._cache.get(self._current_command_cache_key) 

1474 

1475 if entry.connection_ref != self._conn: 

1476 with self._pool_lock: 

1477 while entry.connection_ref.can_read(): 

1478 entry.connection_ref.read_response(push_request=True) 

1479 

1480 return 

1481 

1482 # Set temporary entry value to prevent 

1483 # race condition from another connection. 

1484 self._cache.set( 

1485 CacheEntry( 

1486 cache_key=self._current_command_cache_key, 

1487 cache_value=self.DUMMY_CACHE_VALUE, 

1488 status=CacheEntryStatus.IN_PROGRESS, 

1489 connection_ref=self._conn, 

1490 ) 

1491 ) 

1492 

1493 # Send command over socket only if it's allowed 

1494 # read-only command that not yet cached. 

1495 self._conn.send_command(*args, **kwargs) 

1496 

1497 def can_read(self, timeout=0): 

1498 return self._conn.can_read(timeout) 

1499 

1500 def read_response( 

1501 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False 

1502 ): 

1503 with self._cache_lock: 

1504 # Check if command response exists in a cache and it's not in progress. 

1505 if ( 

1506 self._current_command_cache_key is not None 

1507 and self._cache.get(self._current_command_cache_key) is not None 

1508 and self._cache.get(self._current_command_cache_key).status 

1509 != CacheEntryStatus.IN_PROGRESS 

1510 ): 

1511 res = copy.deepcopy( 

1512 self._cache.get(self._current_command_cache_key).cache_value 

1513 ) 

1514 self._current_command_cache_key = None 

1515 return res 

1516 

1517 response = self._conn.read_response( 

1518 disable_decoding=disable_decoding, 

1519 disconnect_on_error=disconnect_on_error, 

1520 push_request=push_request, 

1521 ) 

1522 

1523 with self._cache_lock: 

1524 # Prevent not-allowed command from caching. 

1525 if self._current_command_cache_key is None: 

1526 return response 

1527 # If response is None prevent from caching. 

1528 if response is None: 

1529 self._cache.delete_by_cache_keys([self._current_command_cache_key]) 

1530 return response 

1531 

1532 cache_entry = self._cache.get(self._current_command_cache_key) 

1533 

1534 # Cache only responses that still valid 

1535 # and wasn't invalidated by another connection in meantime. 

1536 if cache_entry is not None: 

1537 cache_entry.status = CacheEntryStatus.VALID 

1538 cache_entry.cache_value = response 

1539 self._cache.set(cache_entry) 

1540 

1541 self._current_command_cache_key = None 

1542 

1543 return response 

1544 

1545 def pack_command(self, *args): 

1546 return self._conn.pack_command(*args) 

1547 

1548 def pack_commands(self, commands): 

1549 return self._conn.pack_commands(commands) 

1550 

1551 @property 

1552 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1553 return self._conn.handshake_metadata 

1554 

1555 def set_re_auth_token(self, token: TokenInterface): 

1556 self._conn.set_re_auth_token(token) 

1557 

1558 def re_auth(self): 

1559 self._conn.re_auth() 

1560 

1561 def mark_for_reconnect(self): 

1562 self._conn.mark_for_reconnect() 

1563 

1564 def should_reconnect(self): 

1565 return self._conn.should_reconnect() 

1566 

1567 def reset_should_reconnect(self): 

1568 self._conn.reset_should_reconnect() 

1569 

1570 @property 

1571 def host(self) -> str: 

1572 return self._conn.host 

1573 

1574 @host.setter 

1575 def host(self, value: str): 

1576 self._conn.host = value 

1577 

1578 @property 

1579 def socket_timeout(self) -> Optional[Union[float, int]]: 

1580 return self._conn.socket_timeout 

1581 

1582 @socket_timeout.setter 

1583 def socket_timeout(self, value: Optional[Union[float, int]]): 

1584 self._conn.socket_timeout = value 

1585 

1586 @property 

1587 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

1588 return self._conn.socket_connect_timeout 

1589 

1590 @socket_connect_timeout.setter 

1591 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

1592 self._conn.socket_connect_timeout = value 

1593 

1594 def _get_socket(self) -> Optional[socket.socket]: 

1595 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1596 return self._conn._get_socket() 

1597 else: 

1598 raise NotImplementedError( 

1599 "Maintenance notifications are not supported by this connection type" 

1600 ) 

1601 

1602 def _get_maint_notifications_connection_instance( 

1603 self, connection 

1604 ) -> MaintNotificationsAbstractConnection: 

1605 """ 

1606 Validate that connection instance supports maintenance notifications. 

1607 With this helper method we ensure that we are working 

1608 with the correct connection type. 

1609 After twe validate that connection instance supports maintenance notifications 

1610 we can safely return the connection instance 

1611 as MaintNotificationsAbstractConnection. 

1612 """ 

1613 if not isinstance(connection, MaintNotificationsAbstractConnection): 

1614 raise NotImplementedError( 

1615 "Maintenance notifications are not supported by this connection type" 

1616 ) 

1617 else: 

1618 return connection 

1619 

1620 @property 

1621 def maintenance_state(self) -> MaintenanceState: 

1622 con = self._get_maint_notifications_connection_instance(self._conn) 

1623 return con.maintenance_state 

1624 

1625 @maintenance_state.setter 

1626 def maintenance_state(self, state: MaintenanceState): 

1627 con = self._get_maint_notifications_connection_instance(self._conn) 

1628 con.maintenance_state = state 

1629 

1630 def getpeername(self): 

1631 con = self._get_maint_notifications_connection_instance(self._conn) 

1632 return con.getpeername() 

1633 

1634 def get_resolved_ip(self): 

1635 con = self._get_maint_notifications_connection_instance(self._conn) 

1636 return con.get_resolved_ip() 

1637 

1638 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

1639 con = self._get_maint_notifications_connection_instance(self._conn) 

1640 con.update_current_socket_timeout(relaxed_timeout) 

1641 

1642 def set_tmp_settings( 

1643 self, 

1644 tmp_host_address: Optional[str] = None, 

1645 tmp_relaxed_timeout: Optional[float] = None, 

1646 ): 

1647 con = self._get_maint_notifications_connection_instance(self._conn) 

1648 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout) 

1649 

1650 def reset_tmp_settings( 

1651 self, 

1652 reset_host_address: bool = False, 

1653 reset_relaxed_timeout: bool = False, 

1654 ): 

1655 con = self._get_maint_notifications_connection_instance(self._conn) 

1656 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout) 

1657 

1658 def _connect(self): 

1659 self._conn._connect() 

1660 

1661 def _host_error(self): 

1662 self._conn._host_error() 

1663 

1664 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: 

1665 conn.send_command("CLIENT", "TRACKING", "ON") 

1666 conn.read_response() 

1667 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) 

1668 

1669 def _process_pending_invalidations(self): 

1670 while self.can_read(): 

1671 self._conn.read_response(push_request=True) 

1672 

1673 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): 

1674 with self._cache_lock: 

1675 # Flush cache when DB flushed on server-side 

1676 if data[1] is None: 

1677 self._cache.flush() 

1678 else: 

1679 self._cache.delete_by_redis_keys(data[1]) 

1680 

1681 

1682class SSLConnection(Connection): 

1683 """Manages SSL connections to and from the Redis server(s). 

1684 This class extends the Connection class, adding SSL functionality, and making 

1685 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

1686 """ # noqa 

1687 

1688 def __init__( 

1689 self, 

1690 ssl_keyfile=None, 

1691 ssl_certfile=None, 

1692 ssl_cert_reqs="required", 

1693 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None, 

1694 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None, 

1695 ssl_ca_certs=None, 

1696 ssl_ca_data=None, 

1697 ssl_check_hostname=True, 

1698 ssl_ca_path=None, 

1699 ssl_password=None, 

1700 ssl_validate_ocsp=False, 

1701 ssl_validate_ocsp_stapled=False, 

1702 ssl_ocsp_context=None, 

1703 ssl_ocsp_expected_cert=None, 

1704 ssl_min_version=None, 

1705 ssl_ciphers=None, 

1706 **kwargs, 

1707 ): 

1708 """Constructor 

1709 

1710 Args: 

1711 ssl_keyfile: Path to an ssl private key. Defaults to None. 

1712 ssl_certfile: Path to an ssl certificate. Defaults to None. 

1713 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), 

1714 or an ssl.VerifyMode. Defaults to "required". 

1715 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None. 

1716 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None. 

1717 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. 

1718 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. 

1719 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True. 

1720 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. 

1721 ssl_password: Password for unlocking an encrypted private key. Defaults to None. 

1722 

1723 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) 

1724 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response 

1725 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert 

1726 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. 

1727 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. 

1728 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information. 

1729 

1730 Raises: 

1731 RedisError 

1732 """ # noqa 

1733 if not SSL_AVAILABLE: 

1734 raise RedisError("Python wasn't built with SSL support") 

1735 

1736 self.keyfile = ssl_keyfile 

1737 self.certfile = ssl_certfile 

1738 if ssl_cert_reqs is None: 

1739 ssl_cert_reqs = ssl.CERT_NONE 

1740 elif isinstance(ssl_cert_reqs, str): 

1741 CERT_REQS = { # noqa: N806 

1742 "none": ssl.CERT_NONE, 

1743 "optional": ssl.CERT_OPTIONAL, 

1744 "required": ssl.CERT_REQUIRED, 

1745 } 

1746 if ssl_cert_reqs not in CERT_REQS: 

1747 raise RedisError( 

1748 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" 

1749 ) 

1750 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 

1751 self.cert_reqs = ssl_cert_reqs 

1752 self.ssl_include_verify_flags = ssl_include_verify_flags 

1753 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags 

1754 self.ca_certs = ssl_ca_certs 

1755 self.ca_data = ssl_ca_data 

1756 self.ca_path = ssl_ca_path 

1757 self.check_hostname = ( 

1758 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1759 ) 

1760 self.certificate_password = ssl_password 

1761 self.ssl_validate_ocsp = ssl_validate_ocsp 

1762 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled 

1763 self.ssl_ocsp_context = ssl_ocsp_context 

1764 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert 

1765 self.ssl_min_version = ssl_min_version 

1766 self.ssl_ciphers = ssl_ciphers 

1767 super().__init__(**kwargs) 

1768 

1769 def _connect(self): 

1770 """ 

1771 Wrap the socket with SSL support, handling potential errors. 

1772 """ 

1773 sock = super()._connect() 

1774 try: 

1775 return self._wrap_socket_with_ssl(sock) 

1776 except (OSError, RedisError): 

1777 sock.close() 

1778 raise 

1779 

1780 def _wrap_socket_with_ssl(self, sock): 

1781 """ 

1782 Wraps the socket with SSL support. 

1783 

1784 Args: 

1785 sock: The plain socket to wrap with SSL. 

1786 

1787 Returns: 

1788 An SSL wrapped socket. 

1789 """ 

1790 context = ssl.create_default_context() 

1791 context.check_hostname = self.check_hostname 

1792 context.verify_mode = self.cert_reqs 

1793 if self.ssl_include_verify_flags: 

1794 for flag in self.ssl_include_verify_flags: 

1795 context.verify_flags |= flag 

1796 if self.ssl_exclude_verify_flags: 

1797 for flag in self.ssl_exclude_verify_flags: 

1798 context.verify_flags &= ~flag 

1799 if self.certfile or self.keyfile: 

1800 context.load_cert_chain( 

1801 certfile=self.certfile, 

1802 keyfile=self.keyfile, 

1803 password=self.certificate_password, 

1804 ) 

1805 if ( 

1806 self.ca_certs is not None 

1807 or self.ca_path is not None 

1808 or self.ca_data is not None 

1809 ): 

1810 context.load_verify_locations( 

1811 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1812 ) 

1813 if self.ssl_min_version is not None: 

1814 context.minimum_version = self.ssl_min_version 

1815 if self.ssl_ciphers: 

1816 context.set_ciphers(self.ssl_ciphers) 

1817 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: 

1818 raise RedisError("cryptography is not installed.") 

1819 

1820 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: 

1821 raise RedisError( 

1822 "Either an OCSP staple or pure OCSP connection must be validated " 

1823 "- not both." 

1824 ) 

1825 

1826 sslsock = context.wrap_socket(sock, server_hostname=self.host) 

1827 

1828 # validation for the stapled case 

1829 if self.ssl_validate_ocsp_stapled: 

1830 import OpenSSL 

1831 

1832 from .ocsp import ocsp_staple_verifier 

1833 

1834 # if a context is provided use it - otherwise, a basic context 

1835 if self.ssl_ocsp_context is None: 

1836 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) 

1837 staple_ctx.use_certificate_file(self.certfile) 

1838 staple_ctx.use_privatekey_file(self.keyfile) 

1839 else: 

1840 staple_ctx = self.ssl_ocsp_context 

1841 

1842 staple_ctx.set_ocsp_client_callback( 

1843 ocsp_staple_verifier, self.ssl_ocsp_expected_cert 

1844 ) 

1845 

1846 # need another socket 

1847 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) 

1848 con.request_ocsp() 

1849 con.connect((self.host, self.port)) 

1850 con.do_handshake() 

1851 con.shutdown() 

1852 return sslsock 

1853 

1854 # pure ocsp validation 

1855 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: 

1856 from .ocsp import OCSPVerifier 

1857 

1858 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) 

1859 if o.is_valid(): 

1860 return sslsock 

1861 else: 

1862 raise ConnectionError("ocsp validation error") 

1863 return sslsock 

1864 

1865 

1866class UnixDomainSocketConnection(AbstractConnection): 

1867 "Manages UDS communication to and from a Redis server" 

1868 

1869 def __init__(self, path="", socket_timeout=None, **kwargs): 

1870 super().__init__(**kwargs) 

1871 self.path = path 

1872 self.socket_timeout = socket_timeout 

1873 

1874 def repr_pieces(self): 

1875 pieces = [("path", self.path), ("db", self.db)] 

1876 if self.client_name: 

1877 pieces.append(("client_name", self.client_name)) 

1878 return pieces 

1879 

1880 def _connect(self): 

1881 "Create a Unix domain socket connection" 

1882 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

1883 sock.settimeout(self.socket_connect_timeout) 

1884 try: 

1885 sock.connect(self.path) 

1886 except OSError: 

1887 # Prevent ResourceWarnings for unclosed sockets. 

1888 try: 

1889 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1890 except OSError: 

1891 pass 

1892 sock.close() 

1893 raise 

1894 sock.settimeout(self.socket_timeout) 

1895 return sock 

1896 

1897 def _host_error(self): 

1898 return self.path 

1899 

1900 

1901FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1902 

1903 

1904def to_bool(value): 

1905 if value is None or value == "": 

1906 return None 

1907 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1908 return False 

1909 return bool(value) 

1910 

1911 

1912def parse_ssl_verify_flags(value): 

1913 # flags are passed in as a string representation of a list, 

1914 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

1915 verify_flags_str = value.replace("[", "").replace("]", "") 

1916 

1917 verify_flags = [] 

1918 for flag in verify_flags_str.split(","): 

1919 flag = flag.strip() 

1920 if not hasattr(VerifyFlags, flag): 

1921 raise ValueError(f"Invalid ssl verify flag: {flag}") 

1922 verify_flags.append(getattr(VerifyFlags, flag)) 

1923 return verify_flags 

1924 

1925 

1926URL_QUERY_ARGUMENT_PARSERS = { 

1927 "db": int, 

1928 "socket_timeout": float, 

1929 "socket_connect_timeout": float, 

1930 "socket_keepalive": to_bool, 

1931 "retry_on_timeout": to_bool, 

1932 "retry_on_error": list, 

1933 "max_connections": int, 

1934 "health_check_interval": int, 

1935 "ssl_check_hostname": to_bool, 

1936 "ssl_include_verify_flags": parse_ssl_verify_flags, 

1937 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

1938 "timeout": float, 

1939} 

1940 

1941 

1942def parse_url(url): 

1943 if not ( 

1944 url.startswith("redis://") 

1945 or url.startswith("rediss://") 

1946 or url.startswith("unix://") 

1947 ): 

1948 raise ValueError( 

1949 "Redis URL must specify one of the following " 

1950 "schemes (redis://, rediss://, unix://)" 

1951 ) 

1952 

1953 url = urlparse(url) 

1954 kwargs = {} 

1955 

1956 for name, value in parse_qs(url.query).items(): 

1957 if value and len(value) > 0: 

1958 value = unquote(value[0]) 

1959 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1960 if parser: 

1961 try: 

1962 kwargs[name] = parser(value) 

1963 except (TypeError, ValueError): 

1964 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1965 else: 

1966 kwargs[name] = value 

1967 

1968 if url.username: 

1969 kwargs["username"] = unquote(url.username) 

1970 if url.password: 

1971 kwargs["password"] = unquote(url.password) 

1972 

1973 # We only support redis://, rediss:// and unix:// schemes. 

1974 if url.scheme == "unix": 

1975 if url.path: 

1976 kwargs["path"] = unquote(url.path) 

1977 kwargs["connection_class"] = UnixDomainSocketConnection 

1978 

1979 else: # implied: url.scheme in ("redis", "rediss"): 

1980 if url.hostname: 

1981 kwargs["host"] = unquote(url.hostname) 

1982 if url.port: 

1983 kwargs["port"] = int(url.port) 

1984 

1985 # If there's a path argument, use it as the db argument if a 

1986 # querystring value wasn't specified 

1987 if url.path and "db" not in kwargs: 

1988 try: 

1989 kwargs["db"] = int(unquote(url.path).replace("/", "")) 

1990 except (AttributeError, ValueError): 

1991 pass 

1992 

1993 if url.scheme == "rediss": 

1994 kwargs["connection_class"] = SSLConnection 

1995 

1996 return kwargs 

1997 

1998 

1999_CP = TypeVar("_CP", bound="ConnectionPool") 

2000 

2001 

2002class ConnectionPoolInterface(ABC): 

2003 @abstractmethod 

2004 def get_protocol(self): 

2005 pass 

2006 

2007 @abstractmethod 

2008 def reset(self): 

2009 pass 

2010 

2011 @abstractmethod 

2012 @deprecated_args( 

2013 args_to_warn=["*"], 

2014 reason="Use get_connection() without args instead", 

2015 version="5.3.0", 

2016 ) 

2017 def get_connection( 

2018 self, command_name: Optional[str], *keys, **options 

2019 ) -> ConnectionInterface: 

2020 pass 

2021 

2022 @abstractmethod 

2023 def get_encoder(self): 

2024 pass 

2025 

2026 @abstractmethod 

2027 def release(self, connection: ConnectionInterface): 

2028 pass 

2029 

2030 @abstractmethod 

2031 def disconnect(self, inuse_connections: bool = True): 

2032 pass 

2033 

2034 @abstractmethod 

2035 def close(self): 

2036 pass 

2037 

2038 @abstractmethod 

2039 def set_retry(self, retry: Retry): 

2040 pass 

2041 

2042 @abstractmethod 

2043 def re_auth_callback(self, token: TokenInterface): 

2044 pass 

2045 

2046 

2047class MaintNotificationsAbstractConnectionPool: 

2048 """ 

2049 Abstract class for handling maintenance notifications logic. 

2050 This class is mixed into the ConnectionPool classes. 

2051 

2052 This class is not intended to be used directly! 

2053 

2054 All logic related to maintenance notifications and 

2055 connection pool handling is encapsulated in this class. 

2056 """ 

2057 

2058 def __init__( 

2059 self, 

2060 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

2061 **kwargs, 

2062 ): 

2063 # Initialize maintenance notifications 

2064 is_protocol_supported = kwargs.get("protocol") in [3, "3"] 

2065 if maint_notifications_config is None and is_protocol_supported: 

2066 maint_notifications_config = MaintNotificationsConfig() 

2067 

2068 if maint_notifications_config and maint_notifications_config.enabled: 

2069 if not is_protocol_supported: 

2070 raise RedisError( 

2071 "Maintenance notifications handlers on connection are only supported with RESP version 3" 

2072 ) 

2073 

2074 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2075 self, maint_notifications_config 

2076 ) 

2077 

2078 self._update_connection_kwargs_for_maint_notifications( 

2079 self._maint_notifications_pool_handler 

2080 ) 

2081 else: 

2082 self._maint_notifications_pool_handler = None 

2083 

2084 @property 

2085 @abstractmethod 

2086 def connection_kwargs(self) -> Dict[str, Any]: 

2087 pass 

2088 

2089 @connection_kwargs.setter 

2090 @abstractmethod 

2091 def connection_kwargs(self, value: Dict[str, Any]): 

2092 pass 

2093 

2094 @abstractmethod 

2095 def _get_pool_lock(self) -> threading.RLock: 

2096 pass 

2097 

2098 @abstractmethod 

2099 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]: 

2100 pass 

2101 

2102 @abstractmethod 

2103 def _get_in_use_connections( 

2104 self, 

2105 ) -> Iterable["MaintNotificationsAbstractConnection"]: 

2106 pass 

2107 

2108 def maint_notifications_enabled(self): 

2109 """ 

2110 Returns: 

2111 True if the maintenance notifications are enabled, False otherwise. 

2112 The maintenance notifications config is stored in the pool handler. 

2113 If the pool handler is not set, the maintenance notifications are not enabled. 

2114 """ 

2115 maint_notifications_config = ( 

2116 self._maint_notifications_pool_handler.config 

2117 if self._maint_notifications_pool_handler 

2118 else None 

2119 ) 

2120 

2121 return maint_notifications_config and maint_notifications_config.enabled 

2122 

2123 def update_maint_notifications_config( 

2124 self, maint_notifications_config: MaintNotificationsConfig 

2125 ): 

2126 """ 

2127 Updates the maintenance notifications configuration. 

2128 This method should be called only if the pool was created 

2129 without enabling the maintenance notifications and 

2130 in a later point in time maintenance notifications 

2131 are requested to be enabled. 

2132 """ 

2133 if ( 

2134 self.maint_notifications_enabled() 

2135 and not maint_notifications_config.enabled 

2136 ): 

2137 raise ValueError( 

2138 "Cannot disable maintenance notifications after enabling them" 

2139 ) 

2140 # first update pool settings 

2141 if not self._maint_notifications_pool_handler: 

2142 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2143 self, maint_notifications_config 

2144 ) 

2145 else: 

2146 self._maint_notifications_pool_handler.config = maint_notifications_config 

2147 

2148 # then update connection kwargs and existing connections 

2149 self._update_connection_kwargs_for_maint_notifications( 

2150 self._maint_notifications_pool_handler 

2151 ) 

2152 self._update_maint_notifications_configs_for_connections( 

2153 self._maint_notifications_pool_handler 

2154 ) 

2155 

2156 def _update_connection_kwargs_for_maint_notifications( 

2157 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

2158 ): 

2159 """ 

2160 Update the connection kwargs for all future connections. 

2161 """ 

2162 if not self.maint_notifications_enabled(): 

2163 return 

2164 

2165 self.connection_kwargs.update( 

2166 { 

2167 "maint_notifications_pool_handler": maint_notifications_pool_handler, 

2168 "maint_notifications_config": maint_notifications_pool_handler.config, 

2169 } 

2170 ) 

2171 

2172 # Store original connection parameters for maintenance notifications. 

2173 if self.connection_kwargs.get("orig_host_address", None) is None: 

2174 # If orig_host_address is None it means we haven't 

2175 # configured the original values yet 

2176 self.connection_kwargs.update( 

2177 { 

2178 "orig_host_address": self.connection_kwargs.get("host"), 

2179 "orig_socket_timeout": self.connection_kwargs.get( 

2180 "socket_timeout", None 

2181 ), 

2182 "orig_socket_connect_timeout": self.connection_kwargs.get( 

2183 "socket_connect_timeout", None 

2184 ), 

2185 } 

2186 ) 

2187 

2188 def _update_maint_notifications_configs_for_connections( 

2189 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

2190 ): 

2191 """Update the maintenance notifications config for all connections in the pool.""" 

2192 with self._get_pool_lock(): 

2193 for conn in self._get_free_connections(): 

2194 conn.set_maint_notifications_pool_handler_for_connection( 

2195 maint_notifications_pool_handler 

2196 ) 

2197 conn.maint_notifications_config = ( 

2198 maint_notifications_pool_handler.config 

2199 ) 

2200 conn.disconnect() 

2201 for conn in self._get_in_use_connections(): 

2202 conn.set_maint_notifications_pool_handler_for_connection( 

2203 maint_notifications_pool_handler 

2204 ) 

2205 conn.maint_notifications_config = ( 

2206 maint_notifications_pool_handler.config 

2207 ) 

2208 conn.mark_for_reconnect() 

2209 

2210 def _should_update_connection( 

2211 self, 

2212 conn: "MaintNotificationsAbstractConnection", 

2213 matching_pattern: Literal[ 

2214 "connected_address", "configured_address", "notification_hash" 

2215 ] = "connected_address", 

2216 matching_address: Optional[str] = None, 

2217 matching_notification_hash: Optional[int] = None, 

2218 ) -> bool: 

2219 """ 

2220 Check if the connection should be updated based on the matching criteria. 

2221 """ 

2222 if matching_pattern == "connected_address": 

2223 if matching_address and conn.getpeername() != matching_address: 

2224 return False 

2225 elif matching_pattern == "configured_address": 

2226 if matching_address and conn.host != matching_address: 

2227 return False 

2228 elif matching_pattern == "notification_hash": 

2229 if ( 

2230 matching_notification_hash 

2231 and conn.maintenance_notification_hash != matching_notification_hash 

2232 ): 

2233 return False 

2234 return True 

2235 

2236 def update_connection_settings( 

2237 self, 

2238 conn: "MaintNotificationsAbstractConnection", 

2239 state: Optional["MaintenanceState"] = None, 

2240 maintenance_notification_hash: Optional[int] = None, 

2241 host_address: Optional[str] = None, 

2242 relaxed_timeout: Optional[float] = None, 

2243 update_notification_hash: bool = False, 

2244 reset_host_address: bool = False, 

2245 reset_relaxed_timeout: bool = False, 

2246 ): 

2247 """ 

2248 Update the settings for a single connection. 

2249 """ 

2250 if state: 

2251 conn.maintenance_state = state 

2252 

2253 if update_notification_hash: 

2254 # update the notification hash only if requested 

2255 conn.maintenance_notification_hash = maintenance_notification_hash 

2256 

2257 if host_address is not None: 

2258 conn.set_tmp_settings(tmp_host_address=host_address) 

2259 

2260 if relaxed_timeout is not None: 

2261 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout) 

2262 

2263 if reset_relaxed_timeout or reset_host_address: 

2264 conn.reset_tmp_settings( 

2265 reset_host_address=reset_host_address, 

2266 reset_relaxed_timeout=reset_relaxed_timeout, 

2267 ) 

2268 

2269 conn.update_current_socket_timeout(relaxed_timeout) 

2270 

2271 def update_connections_settings( 

2272 self, 

2273 state: Optional["MaintenanceState"] = None, 

2274 maintenance_notification_hash: Optional[int] = None, 

2275 host_address: Optional[str] = None, 

2276 relaxed_timeout: Optional[float] = None, 

2277 matching_address: Optional[str] = None, 

2278 matching_notification_hash: Optional[int] = None, 

2279 matching_pattern: Literal[ 

2280 "connected_address", "configured_address", "notification_hash" 

2281 ] = "connected_address", 

2282 update_notification_hash: bool = False, 

2283 reset_host_address: bool = False, 

2284 reset_relaxed_timeout: bool = False, 

2285 include_free_connections: bool = True, 

2286 ): 

2287 """ 

2288 Update the settings for all matching connections in the pool. 

2289 

2290 This method does not create new connections. 

2291 This method does not affect the connection kwargs. 

2292 

2293 :param state: The maintenance state to set for the connection. 

2294 :param maintenance_notification_hash: The hash of the maintenance notification 

2295 to set for the connection. 

2296 :param host_address: The host address to set for the connection. 

2297 :param relaxed_timeout: The relaxed timeout to set for the connection. 

2298 :param matching_address: The address to match for the connection. 

2299 :param matching_notification_hash: The notification hash to match for the connection. 

2300 :param matching_pattern: The pattern to match for the connection. 

2301 :param update_notification_hash: Whether to update the notification hash for the connection. 

2302 :param reset_host_address: Whether to reset the host address to the original address. 

2303 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout. 

2304 :param include_free_connections: Whether to include free/available connections. 

2305 """ 

2306 with self._get_pool_lock(): 

2307 for conn in self._get_in_use_connections(): 

2308 if self._should_update_connection( 

2309 conn, 

2310 matching_pattern, 

2311 matching_address, 

2312 matching_notification_hash, 

2313 ): 

2314 self.update_connection_settings( 

2315 conn, 

2316 state=state, 

2317 maintenance_notification_hash=maintenance_notification_hash, 

2318 host_address=host_address, 

2319 relaxed_timeout=relaxed_timeout, 

2320 update_notification_hash=update_notification_hash, 

2321 reset_host_address=reset_host_address, 

2322 reset_relaxed_timeout=reset_relaxed_timeout, 

2323 ) 

2324 

2325 if include_free_connections: 

2326 for conn in self._get_free_connections(): 

2327 if self._should_update_connection( 

2328 conn, 

2329 matching_pattern, 

2330 matching_address, 

2331 matching_notification_hash, 

2332 ): 

2333 self.update_connection_settings( 

2334 conn, 

2335 state=state, 

2336 maintenance_notification_hash=maintenance_notification_hash, 

2337 host_address=host_address, 

2338 relaxed_timeout=relaxed_timeout, 

2339 update_notification_hash=update_notification_hash, 

2340 reset_host_address=reset_host_address, 

2341 reset_relaxed_timeout=reset_relaxed_timeout, 

2342 ) 

2343 

2344 def update_connection_kwargs( 

2345 self, 

2346 **kwargs, 

2347 ): 

2348 """ 

2349 Update the connection kwargs for all future connections. 

2350 

2351 This method updates the connection kwargs for all future connections created by the pool. 

2352 Existing connections are not affected. 

2353 """ 

2354 self.connection_kwargs.update(kwargs) 

2355 

2356 def update_active_connections_for_reconnect( 

2357 self, 

2358 moving_address_src: Optional[str] = None, 

2359 ): 

2360 """ 

2361 Mark all active connections for reconnect. 

2362 This is used when a cluster node is migrated to a different address. 

2363 

2364 :param moving_address_src: The address of the node that is being moved. 

2365 """ 

2366 with self._get_pool_lock(): 

2367 for conn in self._get_in_use_connections(): 

2368 if self._should_update_connection( 

2369 conn, "connected_address", moving_address_src 

2370 ): 

2371 conn.mark_for_reconnect() 

2372 

2373 def disconnect_free_connections( 

2374 self, 

2375 moving_address_src: Optional[str] = None, 

2376 ): 

2377 """ 

2378 Disconnect all free/available connections. 

2379 This is used when a cluster node is migrated to a different address. 

2380 

2381 :param moving_address_src: The address of the node that is being moved. 

2382 """ 

2383 with self._get_pool_lock(): 

2384 for conn in self._get_free_connections(): 

2385 if self._should_update_connection( 

2386 conn, "connected_address", moving_address_src 

2387 ): 

2388 conn.disconnect() 

2389 

2390 

2391class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface): 

2392 """ 

2393 Create a connection pool. ``If max_connections`` is set, then this 

2394 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's 

2395 limit is reached. 

2396 

2397 By default, TCP connections are created unless ``connection_class`` 

2398 is specified. Use class:`.UnixDomainSocketConnection` for 

2399 unix sockets. 

2400 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

2401 

2402 If ``maint_notifications_config`` is provided, the connection pool will support 

2403 maintenance notifications. 

2404 Maintenance notifications are supported only with RESP3. 

2405 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3, 

2406 the maintenance notifications will be enabled by default. 

2407 

2408 Any additional keyword arguments are passed to the constructor of 

2409 ``connection_class``. 

2410 """ 

2411 

2412 @classmethod 

2413 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

2414 """ 

2415 Return a connection pool configured from the given URL. 

2416 

2417 For example:: 

2418 

2419 redis://[[username]:[password]]@localhost:6379/0 

2420 rediss://[[username]:[password]]@localhost:6379/0 

2421 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

2422 

2423 Three URL schemes are supported: 

2424 

2425 - `redis://` creates a TCP socket connection. See more at: 

2426 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

2427 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

2428 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

2429 - ``unix://``: creates a Unix Domain Socket connection. 

2430 

2431 The username, password, hostname, path and all querystring values 

2432 are passed through urllib.parse.unquote in order to replace any 

2433 percent-encoded values with their corresponding characters. 

2434 

2435 There are several ways to specify a database number. The first value 

2436 found will be used: 

2437 

2438 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

2439 2. If using the redis:// or rediss:// schemes, the path argument 

2440 of the url, e.g. redis://localhost/0 

2441 3. A ``db`` keyword argument to this function. 

2442 

2443 If none of these options are specified, the default db=0 is used. 

2444 

2445 All querystring options are cast to their appropriate Python types. 

2446 Boolean arguments can be specified with string values "True"/"False" 

2447 or "Yes"/"No". Values that cannot be properly cast cause a 

2448 ``ValueError`` to be raised. Once parsed, the querystring arguments 

2449 and keyword arguments are passed to the ``ConnectionPool``'s 

2450 class initializer. In the case of conflicting arguments, querystring 

2451 arguments always win. 

2452 """ 

2453 url_options = parse_url(url) 

2454 

2455 if "connection_class" in kwargs: 

2456 url_options["connection_class"] = kwargs["connection_class"] 

2457 

2458 kwargs.update(url_options) 

2459 return cls(**kwargs) 

2460 

2461 def __init__( 

2462 self, 

2463 connection_class=Connection, 

2464 max_connections: Optional[int] = None, 

2465 cache_factory: Optional[CacheFactoryInterface] = None, 

2466 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

2467 **connection_kwargs, 

2468 ): 

2469 max_connections = max_connections or 2**31 

2470 if not isinstance(max_connections, int) or max_connections < 0: 

2471 raise ValueError('"max_connections" must be a positive integer') 

2472 

2473 self.connection_class = connection_class 

2474 self._connection_kwargs = connection_kwargs 

2475 self.max_connections = max_connections 

2476 self.cache = None 

2477 self._cache_factory = cache_factory 

2478 

2479 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): 

2480 if self._connection_kwargs.get("protocol") not in [3, "3"]: 

2481 raise RedisError("Client caching is only supported with RESP version 3") 

2482 

2483 cache = self._connection_kwargs.get("cache") 

2484 

2485 if cache is not None: 

2486 if not isinstance(cache, CacheInterface): 

2487 raise ValueError("Cache must implement CacheInterface") 

2488 

2489 self.cache = cache 

2490 else: 

2491 if self._cache_factory is not None: 

2492 self.cache = self._cache_factory.get_cache() 

2493 else: 

2494 self.cache = CacheFactory( 

2495 self._connection_kwargs.get("cache_config") 

2496 ).get_cache() 

2497 

2498 connection_kwargs.pop("cache", None) 

2499 connection_kwargs.pop("cache_config", None) 

2500 

2501 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None) 

2502 if self._event_dispatcher is None: 

2503 self._event_dispatcher = EventDispatcher() 

2504 

2505 # a lock to protect the critical section in _checkpid(). 

2506 # this lock is acquired when the process id changes, such as 

2507 # after a fork. during this time, multiple threads in the child 

2508 # process could attempt to acquire this lock. the first thread 

2509 # to acquire the lock will reset the data structures and lock 

2510 # object of this pool. subsequent threads acquiring this lock 

2511 # will notice the first thread already did the work and simply 

2512 # release the lock. 

2513 

2514 self._fork_lock = threading.RLock() 

2515 self._lock = threading.RLock() 

2516 

2517 MaintNotificationsAbstractConnectionPool.__init__( 

2518 self, 

2519 maint_notifications_config=maint_notifications_config, 

2520 **connection_kwargs, 

2521 ) 

2522 

2523 self.reset() 

2524 

2525 def __repr__(self) -> str: 

2526 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

2527 return ( 

2528 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

2529 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

2530 f"({conn_kwargs})>)>" 

2531 ) 

2532 

2533 @property 

2534 def connection_kwargs(self) -> Dict[str, Any]: 

2535 return self._connection_kwargs 

2536 

2537 @connection_kwargs.setter 

2538 def connection_kwargs(self, value: Dict[str, Any]): 

2539 self._connection_kwargs = value 

2540 

2541 def get_protocol(self): 

2542 """ 

2543 Returns: 

2544 The RESP protocol version, or ``None`` if the protocol is not specified, 

2545 in which case the server default will be used. 

2546 """ 

2547 return self.connection_kwargs.get("protocol", None) 

2548 

2549 def reset(self) -> None: 

2550 self._created_connections = 0 

2551 self._available_connections = [] 

2552 self._in_use_connections = set() 

2553 

2554 # this must be the last operation in this method. while reset() is 

2555 # called when holding _fork_lock, other threads in this process 

2556 # can call _checkpid() which compares self.pid and os.getpid() without 

2557 # holding any lock (for performance reasons). keeping this assignment 

2558 # as the last operation ensures that those other threads will also 

2559 # notice a pid difference and block waiting for the first thread to 

2560 # release _fork_lock. when each of these threads eventually acquire 

2561 # _fork_lock, they will notice that another thread already called 

2562 # reset() and they will immediately release _fork_lock and continue on. 

2563 self.pid = os.getpid() 

2564 

2565 def _checkpid(self) -> None: 

2566 # _checkpid() attempts to keep ConnectionPool fork-safe on modern 

2567 # systems. this is called by all ConnectionPool methods that 

2568 # manipulate the pool's state such as get_connection() and release(). 

2569 # 

2570 # _checkpid() determines whether the process has forked by comparing 

2571 # the current process id to the process id saved on the ConnectionPool 

2572 # instance. if these values are the same, _checkpid() simply returns. 

2573 # 

2574 # when the process ids differ, _checkpid() assumes that the process 

2575 # has forked and that we're now running in the child process. the child 

2576 # process cannot use the parent's file descriptors (e.g., sockets). 

2577 # therefore, when _checkpid() sees the process id change, it calls 

2578 # reset() in order to reinitialize the child's ConnectionPool. this 

2579 # will cause the child to make all new connection objects. 

2580 # 

2581 # _checkpid() is protected by self._fork_lock to ensure that multiple 

2582 # threads in the child process do not call reset() multiple times. 

2583 # 

2584 # there is an extremely small chance this could fail in the following 

2585 # scenario: 

2586 # 1. process A calls _checkpid() for the first time and acquires 

2587 # self._fork_lock. 

2588 # 2. while holding self._fork_lock, process A forks (the fork() 

2589 # could happen in a different thread owned by process A) 

2590 # 3. process B (the forked child process) inherits the 

2591 # ConnectionPool's state from the parent. that state includes 

2592 # a locked _fork_lock. process B will not be notified when 

2593 # process A releases the _fork_lock and will thus never be 

2594 # able to acquire the _fork_lock. 

2595 # 

2596 # to mitigate this possible deadlock, _checkpid() will only wait 5 

2597 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in 

2598 # that time it is assumed that the child is deadlocked and a 

2599 # redis.ChildDeadlockedError error is raised. 

2600 if self.pid != os.getpid(): 

2601 acquired = self._fork_lock.acquire(timeout=5) 

2602 if not acquired: 

2603 raise ChildDeadlockedError 

2604 # reset() the instance for the new process if another thread 

2605 # hasn't already done so 

2606 try: 

2607 if self.pid != os.getpid(): 

2608 self.reset() 

2609 finally: 

2610 self._fork_lock.release() 

2611 

2612 @deprecated_args( 

2613 args_to_warn=["*"], 

2614 reason="Use get_connection() without args instead", 

2615 version="5.3.0", 

2616 ) 

2617 def get_connection(self, command_name=None, *keys, **options) -> "Connection": 

2618 "Get a connection from the pool" 

2619 

2620 self._checkpid() 

2621 with self._lock: 

2622 try: 

2623 connection = self._available_connections.pop() 

2624 except IndexError: 

2625 connection = self.make_connection() 

2626 self._in_use_connections.add(connection) 

2627 

2628 try: 

2629 # ensure this connection is connected to Redis 

2630 connection.connect() 

2631 # connections that the pool provides should be ready to send 

2632 # a command. if not, the connection was either returned to the 

2633 # pool before all data has been read or the socket has been 

2634 # closed. either way, reconnect and verify everything is good. 

2635 try: 

2636 if ( 

2637 connection.can_read() 

2638 and self.cache is None 

2639 and not self.maint_notifications_enabled() 

2640 ): 

2641 raise ConnectionError("Connection has data") 

2642 except (ConnectionError, TimeoutError, OSError): 

2643 connection.disconnect() 

2644 connection.connect() 

2645 if connection.can_read(): 

2646 raise ConnectionError("Connection not ready") 

2647 except BaseException: 

2648 # release the connection back to the pool so that we don't 

2649 # leak it 

2650 self.release(connection) 

2651 raise 

2652 return connection 

2653 

2654 def get_encoder(self) -> Encoder: 

2655 "Return an encoder based on encoding settings" 

2656 kwargs = self.connection_kwargs 

2657 return Encoder( 

2658 encoding=kwargs.get("encoding", "utf-8"), 

2659 encoding_errors=kwargs.get("encoding_errors", "strict"), 

2660 decode_responses=kwargs.get("decode_responses", False), 

2661 ) 

2662 

2663 def make_connection(self) -> "ConnectionInterface": 

2664 "Create a new connection" 

2665 if self._created_connections >= self.max_connections: 

2666 raise MaxConnectionsError("Too many connections") 

2667 self._created_connections += 1 

2668 

2669 kwargs = dict(self.connection_kwargs) 

2670 

2671 if self.cache is not None: 

2672 return CacheProxyConnection( 

2673 self.connection_class(**kwargs), self.cache, self._lock 

2674 ) 

2675 return self.connection_class(**kwargs) 

2676 

2677 def release(self, connection: "Connection") -> None: 

2678 "Releases the connection back to the pool" 

2679 self._checkpid() 

2680 with self._lock: 

2681 try: 

2682 self._in_use_connections.remove(connection) 

2683 except KeyError: 

2684 # Gracefully fail when a connection is returned to this pool 

2685 # that the pool doesn't actually own 

2686 return 

2687 

2688 if self.owns_connection(connection): 

2689 if connection.should_reconnect(): 

2690 connection.disconnect() 

2691 self._available_connections.append(connection) 

2692 self._event_dispatcher.dispatch( 

2693 AfterConnectionReleasedEvent(connection) 

2694 ) 

2695 else: 

2696 # Pool doesn't own this connection, do not add it back 

2697 # to the pool. 

2698 # The created connections count should not be changed, 

2699 # because the connection was not created by the pool. 

2700 connection.disconnect() 

2701 return 

2702 

2703 def owns_connection(self, connection: "Connection") -> int: 

2704 return connection.pid == self.pid 

2705 

2706 def disconnect(self, inuse_connections: bool = True) -> None: 

2707 """ 

2708 Disconnects connections in the pool 

2709 

2710 If ``inuse_connections`` is True, disconnect connections that are 

2711 currently in use, potentially by other threads. Otherwise only disconnect 

2712 connections that are idle in the pool. 

2713 """ 

2714 self._checkpid() 

2715 with self._lock: 

2716 if inuse_connections: 

2717 connections = chain( 

2718 self._available_connections, self._in_use_connections 

2719 ) 

2720 else: 

2721 connections = self._available_connections 

2722 

2723 for connection in connections: 

2724 connection.disconnect() 

2725 

2726 def close(self) -> None: 

2727 """Close the pool, disconnecting all connections""" 

2728 self.disconnect() 

2729 

2730 def set_retry(self, retry: Retry) -> None: 

2731 self.connection_kwargs.update({"retry": retry}) 

2732 for conn in self._available_connections: 

2733 conn.retry = retry 

2734 for conn in self._in_use_connections: 

2735 conn.retry = retry 

2736 

2737 def re_auth_callback(self, token: TokenInterface): 

2738 with self._lock: 

2739 for conn in self._available_connections: 

2740 conn.retry.call_with_retry( 

2741 lambda: conn.send_command( 

2742 "AUTH", token.try_get("oid"), token.get_value() 

2743 ), 

2744 lambda error: self._mock(error), 

2745 ) 

2746 conn.retry.call_with_retry( 

2747 lambda: conn.read_response(), lambda error: self._mock(error) 

2748 ) 

2749 for conn in self._in_use_connections: 

2750 conn.set_re_auth_token(token) 

2751 

2752 def _get_pool_lock(self): 

2753 return self._lock 

2754 

2755 def _get_free_connections(self): 

2756 with self._lock: 

2757 return self._available_connections 

2758 

2759 def _get_in_use_connections(self): 

2760 with self._lock: 

2761 return self._in_use_connections 

2762 

2763 async def _mock(self, error: RedisError): 

2764 """ 

2765 Dummy functions, needs to be passed as error callback to retry object. 

2766 :param error: 

2767 :return: 

2768 """ 

2769 pass 

2770 

2771 

2772class BlockingConnectionPool(ConnectionPool): 

2773 """ 

2774 Thread-safe blocking connection pool:: 

2775 

2776 >>> from redis.client import Redis 

2777 >>> client = Redis(connection_pool=BlockingConnectionPool()) 

2778 

2779 It performs the same function as the default 

2780 :py:class:`~redis.ConnectionPool` implementation, in that, 

2781 it maintains a pool of reusable connections that can be shared by 

2782 multiple redis clients (safely across threads if required). 

2783 

2784 The difference is that, in the event that a client tries to get a 

2785 connection from the pool when all of connections are in use, rather than 

2786 raising a :py:class:`~redis.ConnectionError` (as the default 

2787 :py:class:`~redis.ConnectionPool` implementation does), it 

2788 makes the client wait ("blocks") for a specified number of seconds until 

2789 a connection becomes available. 

2790 

2791 Use ``max_connections`` to increase / decrease the pool size:: 

2792 

2793 >>> pool = BlockingConnectionPool(max_connections=10) 

2794 

2795 Use ``timeout`` to tell it either how many seconds to wait for a connection 

2796 to become available, or to block forever: 

2797 

2798 >>> # Block forever. 

2799 >>> pool = BlockingConnectionPool(timeout=None) 

2800 

2801 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

2802 >>> # not available. 

2803 >>> pool = BlockingConnectionPool(timeout=5) 

2804 """ 

2805 

2806 def __init__( 

2807 self, 

2808 max_connections=50, 

2809 timeout=20, 

2810 connection_class=Connection, 

2811 queue_class=LifoQueue, 

2812 **connection_kwargs, 

2813 ): 

2814 self.queue_class = queue_class 

2815 self.timeout = timeout 

2816 self._in_maintenance = False 

2817 self._locked = False 

2818 super().__init__( 

2819 connection_class=connection_class, 

2820 max_connections=max_connections, 

2821 **connection_kwargs, 

2822 ) 

2823 

2824 def reset(self): 

2825 # Create and fill up a thread safe queue with ``None`` values. 

2826 try: 

2827 if self._in_maintenance: 

2828 self._lock.acquire() 

2829 self._locked = True 

2830 self.pool = self.queue_class(self.max_connections) 

2831 while True: 

2832 try: 

2833 self.pool.put_nowait(None) 

2834 except Full: 

2835 break 

2836 

2837 # Keep a list of actual connection instances so that we can 

2838 # disconnect them later. 

2839 self._connections = [] 

2840 finally: 

2841 if self._locked: 

2842 try: 

2843 self._lock.release() 

2844 except Exception: 

2845 pass 

2846 self._locked = False 

2847 

2848 # this must be the last operation in this method. while reset() is 

2849 # called when holding _fork_lock, other threads in this process 

2850 # can call _checkpid() which compares self.pid and os.getpid() without 

2851 # holding any lock (for performance reasons). keeping this assignment 

2852 # as the last operation ensures that those other threads will also 

2853 # notice a pid difference and block waiting for the first thread to 

2854 # release _fork_lock. when each of these threads eventually acquire 

2855 # _fork_lock, they will notice that another thread already called 

2856 # reset() and they will immediately release _fork_lock and continue on. 

2857 self.pid = os.getpid() 

2858 

2859 def make_connection(self): 

2860 "Make a fresh connection." 

2861 try: 

2862 if self._in_maintenance: 

2863 self._lock.acquire() 

2864 self._locked = True 

2865 

2866 if self.cache is not None: 

2867 connection = CacheProxyConnection( 

2868 self.connection_class(**self.connection_kwargs), 

2869 self.cache, 

2870 self._lock, 

2871 ) 

2872 else: 

2873 connection = self.connection_class(**self.connection_kwargs) 

2874 self._connections.append(connection) 

2875 return connection 

2876 finally: 

2877 if self._locked: 

2878 try: 

2879 self._lock.release() 

2880 except Exception: 

2881 pass 

2882 self._locked = False 

2883 

2884 @deprecated_args( 

2885 args_to_warn=["*"], 

2886 reason="Use get_connection() without args instead", 

2887 version="5.3.0", 

2888 ) 

2889 def get_connection(self, command_name=None, *keys, **options): 

2890 """ 

2891 Get a connection, blocking for ``self.timeout`` until a connection 

2892 is available from the pool. 

2893 

2894 If the connection returned is ``None`` then creates a new connection. 

2895 Because we use a last-in first-out queue, the existing connections 

2896 (having been returned to the pool after the initial ``None`` values 

2897 were added) will be returned before ``None`` values. This means we only 

2898 create new connections when we need to, i.e.: the actual number of 

2899 connections will only increase in response to demand. 

2900 """ 

2901 # Make sure we haven't changed process. 

2902 self._checkpid() 

2903 

2904 # Try and get a connection from the pool. If one isn't available within 

2905 # self.timeout then raise a ``ConnectionError``. 

2906 connection = None 

2907 try: 

2908 if self._in_maintenance: 

2909 self._lock.acquire() 

2910 self._locked = True 

2911 try: 

2912 connection = self.pool.get(block=True, timeout=self.timeout) 

2913 except Empty: 

2914 # Note that this is not caught by the redis client and will be 

2915 # raised unless handled by application code. If you want never to 

2916 raise ConnectionError("No connection available.") 

2917 

2918 # If the ``connection`` is actually ``None`` then that's a cue to make 

2919 # a new connection to add to the pool. 

2920 if connection is None: 

2921 connection = self.make_connection() 

2922 finally: 

2923 if self._locked: 

2924 try: 

2925 self._lock.release() 

2926 except Exception: 

2927 pass 

2928 self._locked = False 

2929 

2930 try: 

2931 # ensure this connection is connected to Redis 

2932 connection.connect() 

2933 # connections that the pool provides should be ready to send 

2934 # a command. if not, the connection was either returned to the 

2935 # pool before all data has been read or the socket has been 

2936 # closed. either way, reconnect and verify everything is good. 

2937 try: 

2938 if connection.can_read(): 

2939 raise ConnectionError("Connection has data") 

2940 except (ConnectionError, TimeoutError, OSError): 

2941 connection.disconnect() 

2942 connection.connect() 

2943 if connection.can_read(): 

2944 raise ConnectionError("Connection not ready") 

2945 except BaseException: 

2946 # release the connection back to the pool so that we don't leak it 

2947 self.release(connection) 

2948 raise 

2949 

2950 return connection 

2951 

2952 def release(self, connection): 

2953 "Releases the connection back to the pool." 

2954 # Make sure we haven't changed process. 

2955 self._checkpid() 

2956 

2957 try: 

2958 if self._in_maintenance: 

2959 self._lock.acquire() 

2960 self._locked = True 

2961 if not self.owns_connection(connection): 

2962 # pool doesn't own this connection. do not add it back 

2963 # to the pool. instead add a None value which is a placeholder 

2964 # that will cause the pool to recreate the connection if 

2965 # its needed. 

2966 connection.disconnect() 

2967 self.pool.put_nowait(None) 

2968 return 

2969 if connection.should_reconnect(): 

2970 connection.disconnect() 

2971 # Put the connection back into the pool. 

2972 try: 

2973 self.pool.put_nowait(connection) 

2974 except Full: 

2975 # perhaps the pool has been reset() after a fork? regardless, 

2976 # we don't want this connection 

2977 pass 

2978 finally: 

2979 if self._locked: 

2980 try: 

2981 self._lock.release() 

2982 except Exception: 

2983 pass 

2984 self._locked = False 

2985 

2986 def disconnect(self, inuse_connections: bool = True): 

2987 "Disconnects either all connections in the pool or just the free connections." 

2988 self._checkpid() 

2989 try: 

2990 if self._in_maintenance: 

2991 self._lock.acquire() 

2992 self._locked = True 

2993 if inuse_connections: 

2994 connections = self._connections 

2995 else: 

2996 connections = self._get_free_connections() 

2997 for connection in connections: 

2998 connection.disconnect() 

2999 finally: 

3000 if self._locked: 

3001 try: 

3002 self._lock.release() 

3003 except Exception: 

3004 pass 

3005 self._locked = False 

3006 

3007 def _get_free_connections(self): 

3008 with self._lock: 

3009 return {conn for conn in self.pool.queue if conn} 

3010 

3011 def _get_in_use_connections(self): 

3012 with self._lock: 

3013 # free connections 

3014 connections_in_queue = {conn for conn in self.pool.queue if conn} 

3015 # in self._connections we keep all created connections 

3016 # so the ones that are not in the queue are the in use ones 

3017 return { 

3018 conn for conn in self._connections if conn not in connections_in_queue 

3019 } 

3020 

3021 def set_in_maintenance(self, in_maintenance: bool): 

3022 """ 

3023 Sets a flag that this Blocking ConnectionPool is in maintenance mode. 

3024 

3025 This is used to prevent new connections from being created while we are in maintenance mode. 

3026 The pool will be in maintenance mode only when we are processing a MOVING notification. 

3027 """ 

3028 self._in_maintenance = in_maintenance