Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/connection.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1395 statements  

1import copy 

2import os 

3import socket 

4import sys 

5import threading 

6import time 

7import weakref 

8from abc import ABC, abstractmethod 

9from itertools import chain 

10from queue import Empty, Full, LifoQueue 

11from typing import ( 

12 Any, 

13 Callable, 

14 Dict, 

15 Iterable, 

16 List, 

17 Literal, 

18 Optional, 

19 Type, 

20 TypeVar, 

21 Union, 

22) 

23from urllib.parse import parse_qs, unquote, urlparse 

24 

25from redis.cache import ( 

26 CacheEntry, 

27 CacheEntryStatus, 

28 CacheFactory, 

29 CacheFactoryInterface, 

30 CacheInterface, 

31 CacheKey, 

32) 

33 

34from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser 

35from .auth.token import TokenInterface 

36from .backoff import NoBackoff 

37from .credentials import CredentialProvider, UsernamePasswordCredentialProvider 

38from .event import AfterConnectionReleasedEvent, EventDispatcher 

39from .exceptions import ( 

40 AuthenticationError, 

41 AuthenticationWrongNumberOfArgsError, 

42 ChildDeadlockedError, 

43 ConnectionError, 

44 DataError, 

45 MaxConnectionsError, 

46 RedisError, 

47 ResponseError, 

48 TimeoutError, 

49) 

50from .maint_notifications import ( 

51 MaintenanceState, 

52 MaintNotificationsConfig, 

53 MaintNotificationsConnectionHandler, 

54 MaintNotificationsPoolHandler, 

55) 

56from .retry import Retry 

57from .utils import ( 

58 CRYPTOGRAPHY_AVAILABLE, 

59 HIREDIS_AVAILABLE, 

60 SSL_AVAILABLE, 

61 compare_versions, 

62 deprecated_args, 

63 ensure_string, 

64 format_error_message, 

65 get_lib_version, 

66 str_if_bytes, 

67) 

68 

69if SSL_AVAILABLE: 

70 import ssl 

71 from ssl import VerifyFlags 

72else: 

73 ssl = None 

74 VerifyFlags = None 

75 

76if HIREDIS_AVAILABLE: 

77 import hiredis 

78 

79SYM_STAR = b"*" 

80SYM_DOLLAR = b"$" 

81SYM_CRLF = b"\r\n" 

82SYM_EMPTY = b"" 

83 

84DEFAULT_RESP_VERSION = 2 

85 

86SENTINEL = object() 

87 

88DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] 

89if HIREDIS_AVAILABLE: 

90 DefaultParser = _HiredisParser 

91else: 

92 DefaultParser = _RESP2Parser 

93 

94 

95class HiredisRespSerializer: 

96 def pack(self, *args: List): 

97 """Pack a series of arguments into the Redis protocol""" 

98 output = [] 

99 

100 if isinstance(args[0], str): 

101 args = tuple(args[0].encode().split()) + args[1:] 

102 elif b" " in args[0]: 

103 args = tuple(args[0].split()) + args[1:] 

104 try: 

105 output.append(hiredis.pack_command(args)) 

106 except TypeError: 

107 _, value, traceback = sys.exc_info() 

108 raise DataError(value).with_traceback(traceback) 

109 

110 return output 

111 

112 

113class PythonRespSerializer: 

114 def __init__(self, buffer_cutoff, encode) -> None: 

115 self._buffer_cutoff = buffer_cutoff 

116 self.encode = encode 

117 

118 def pack(self, *args): 

119 """Pack a series of arguments into the Redis protocol""" 

120 output = [] 

121 # the client might have included 1 or more literal arguments in 

122 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

123 # arguments to be sent separately, so split the first argument 

124 # manually. These arguments should be bytestrings so that they are 

125 # not encoded. 

126 if isinstance(args[0], str): 

127 args = tuple(args[0].encode().split()) + args[1:] 

128 elif b" " in args[0]: 

129 args = tuple(args[0].split()) + args[1:] 

130 

131 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

132 

133 buffer_cutoff = self._buffer_cutoff 

134 for arg in map(self.encode, args): 

135 # to avoid large string mallocs, chunk the command into the 

136 # output list if we're sending large values or memoryviews 

137 arg_length = len(arg) 

138 if ( 

139 len(buff) > buffer_cutoff 

140 or arg_length > buffer_cutoff 

141 or isinstance(arg, memoryview) 

142 ): 

143 buff = SYM_EMPTY.join( 

144 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

145 ) 

146 output.append(buff) 

147 output.append(arg) 

148 buff = SYM_CRLF 

149 else: 

150 buff = SYM_EMPTY.join( 

151 ( 

152 buff, 

153 SYM_DOLLAR, 

154 str(arg_length).encode(), 

155 SYM_CRLF, 

156 arg, 

157 SYM_CRLF, 

158 ) 

159 ) 

160 output.append(buff) 

161 return output 

162 

163 

164class ConnectionInterface: 

165 @abstractmethod 

166 def repr_pieces(self): 

167 pass 

168 

169 @abstractmethod 

170 def register_connect_callback(self, callback): 

171 pass 

172 

173 @abstractmethod 

174 def deregister_connect_callback(self, callback): 

175 pass 

176 

177 @abstractmethod 

178 def set_parser(self, parser_class): 

179 pass 

180 

181 @abstractmethod 

182 def get_protocol(self): 

183 pass 

184 

185 @abstractmethod 

186 def connect(self): 

187 pass 

188 

189 @abstractmethod 

190 def on_connect(self): 

191 pass 

192 

193 @abstractmethod 

194 def disconnect(self, *args): 

195 pass 

196 

197 @abstractmethod 

198 def check_health(self): 

199 pass 

200 

201 @abstractmethod 

202 def send_packed_command(self, command, check_health=True): 

203 pass 

204 

205 @abstractmethod 

206 def send_command(self, *args, **kwargs): 

207 pass 

208 

209 @abstractmethod 

210 def can_read(self, timeout=0): 

211 pass 

212 

213 @abstractmethod 

214 def read_response( 

215 self, 

216 disable_decoding=False, 

217 *, 

218 disconnect_on_error=True, 

219 push_request=False, 

220 ): 

221 pass 

222 

223 @abstractmethod 

224 def pack_command(self, *args): 

225 pass 

226 

227 @abstractmethod 

228 def pack_commands(self, commands): 

229 pass 

230 

231 @property 

232 @abstractmethod 

233 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

234 pass 

235 

236 @abstractmethod 

237 def set_re_auth_token(self, token: TokenInterface): 

238 pass 

239 

240 @abstractmethod 

241 def re_auth(self): 

242 pass 

243 

244 @abstractmethod 

245 def mark_for_reconnect(self): 

246 """ 

247 Mark the connection to be reconnected on the next command. 

248 This is useful when a connection is moved to a different node. 

249 """ 

250 pass 

251 

252 @abstractmethod 

253 def should_reconnect(self): 

254 """ 

255 Returns True if the connection should be reconnected. 

256 """ 

257 pass 

258 

259 @abstractmethod 

260 def reset_should_reconnect(self): 

261 """ 

262 Reset the internal flag to False. 

263 """ 

264 pass 

265 

266 

267class MaintNotificationsAbstractConnection: 

268 """ 

269 Abstract class for handling maintenance notifications logic. 

270 This class is expected to be used as base class together with ConnectionInterface. 

271 

272 This class is intended to be used with multiple inheritance! 

273 

274 All logic related to maintenance notifications is encapsulated in this class. 

275 """ 

276 

277 def __init__( 

278 self, 

279 maint_notifications_config: Optional[MaintNotificationsConfig], 

280 maint_notifications_pool_handler: Optional[ 

281 MaintNotificationsPoolHandler 

282 ] = None, 

283 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

284 maintenance_notification_hash: Optional[int] = None, 

285 orig_host_address: Optional[str] = None, 

286 orig_socket_timeout: Optional[float] = None, 

287 orig_socket_connect_timeout: Optional[float] = None, 

288 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, 

289 ): 

290 """ 

291 Initialize the maintenance notifications for the connection. 

292 

293 Args: 

294 maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications. 

295 maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications. 

296 maintenance_state (MaintenanceState): The current maintenance state of the connection. 

297 maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection. 

298 orig_host_address (Optional[str]): The original host address of the connection. 

299 orig_socket_timeout (Optional[float]): The original socket timeout of the connection. 

300 orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection. 

301 parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications. 

302 If not provided, the parser from the connection is used. 

303 This is useful when the parser is created after this object. 

304 """ 

305 self.maint_notifications_config = maint_notifications_config 

306 self.maintenance_state = maintenance_state 

307 self.maintenance_notification_hash = maintenance_notification_hash 

308 self._configure_maintenance_notifications( 

309 maint_notifications_pool_handler, 

310 orig_host_address, 

311 orig_socket_timeout, 

312 orig_socket_connect_timeout, 

313 parser, 

314 ) 

315 

316 @abstractmethod 

317 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]: 

318 pass 

319 

320 @abstractmethod 

321 def _get_socket(self) -> Optional[socket.socket]: 

322 pass 

323 

324 @abstractmethod 

325 def get_protocol(self) -> Union[int, str]: 

326 """ 

327 Returns: 

328 The RESP protocol version, or ``None`` if the protocol is not specified, 

329 in which case the server default will be used. 

330 """ 

331 pass 

332 

333 @property 

334 @abstractmethod 

335 def host(self) -> str: 

336 pass 

337 

338 @host.setter 

339 @abstractmethod 

340 def host(self, value: str): 

341 pass 

342 

343 @property 

344 @abstractmethod 

345 def socket_timeout(self) -> Optional[Union[float, int]]: 

346 pass 

347 

348 @socket_timeout.setter 

349 @abstractmethod 

350 def socket_timeout(self, value: Optional[Union[float, int]]): 

351 pass 

352 

353 @property 

354 @abstractmethod 

355 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

356 pass 

357 

358 @socket_connect_timeout.setter 

359 @abstractmethod 

360 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

361 pass 

362 

363 @abstractmethod 

364 def send_command(self, *args, **kwargs): 

365 pass 

366 

367 @abstractmethod 

368 def read_response( 

369 self, 

370 disable_decoding=False, 

371 *, 

372 disconnect_on_error=True, 

373 push_request=False, 

374 ): 

375 pass 

376 

377 @abstractmethod 

378 def disconnect(self, *args): 

379 pass 

380 

381 def _configure_maintenance_notifications( 

382 self, 

383 maint_notifications_pool_handler: Optional[ 

384 MaintNotificationsPoolHandler 

385 ] = None, 

386 orig_host_address=None, 

387 orig_socket_timeout=None, 

388 orig_socket_connect_timeout=None, 

389 parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, 

390 ): 

391 """ 

392 Enable maintenance notifications by setting up 

393 handlers and storing original connection parameters. 

394 

395 Should be used ONLY with parsers that support push notifications. 

396 """ 

397 if ( 

398 not self.maint_notifications_config 

399 or not self.maint_notifications_config.enabled 

400 ): 

401 self._maint_notifications_pool_handler = None 

402 self._maint_notifications_connection_handler = None 

403 return 

404 

405 if not parser: 

406 raise RedisError( 

407 "To configure maintenance notifications, a parser must be provided!" 

408 ) 

409 

410 if not isinstance(parser, _HiredisParser) and not isinstance( 

411 parser, _RESP3Parser 

412 ): 

413 raise RedisError( 

414 "Maintenance notifications are only supported with hiredis and RESP3 parsers!" 

415 ) 

416 

417 if maint_notifications_pool_handler: 

418 # Extract a reference to a new pool handler that copies all properties 

419 # of the original one and has a different connection reference 

420 # This is needed because when we attach the handler to the parser 

421 # we need to make sure that the handler has a reference to the 

422 # connection that the parser is attached to. 

423 self._maint_notifications_pool_handler = ( 

424 maint_notifications_pool_handler.get_handler_for_connection() 

425 ) 

426 self._maint_notifications_pool_handler.set_connection(self) 

427 else: 

428 self._maint_notifications_pool_handler = None 

429 

430 self._maint_notifications_connection_handler = ( 

431 MaintNotificationsConnectionHandler(self, self.maint_notifications_config) 

432 ) 

433 

434 # Set up pool handler if available 

435 if self._maint_notifications_pool_handler: 

436 parser.set_node_moving_push_handler( 

437 self._maint_notifications_pool_handler.handle_notification 

438 ) 

439 

440 # Set up connection handler 

441 parser.set_maintenance_push_handler( 

442 self._maint_notifications_connection_handler.handle_notification 

443 ) 

444 

445 # Store original connection parameters 

446 self.orig_host_address = orig_host_address if orig_host_address else self.host 

447 self.orig_socket_timeout = ( 

448 orig_socket_timeout if orig_socket_timeout else self.socket_timeout 

449 ) 

450 self.orig_socket_connect_timeout = ( 

451 orig_socket_connect_timeout 

452 if orig_socket_connect_timeout 

453 else self.socket_connect_timeout 

454 ) 

455 

456 def set_maint_notifications_pool_handler_for_connection( 

457 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

458 ): 

459 # Deep copy the pool handler to avoid sharing the same pool handler 

460 # between multiple connections, because otherwise each connection will override 

461 # the connection reference and the pool handler will only hold a reference 

462 # to the last connection that was set. 

463 maint_notifications_pool_handler_copy = ( 

464 maint_notifications_pool_handler.get_handler_for_connection() 

465 ) 

466 

467 maint_notifications_pool_handler_copy.set_connection(self) 

468 self._get_parser().set_node_moving_push_handler( 

469 maint_notifications_pool_handler_copy.handle_notification 

470 ) 

471 

472 self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy 

473 

474 # Update maintenance notification connection handler if it doesn't exist 

475 if not self._maint_notifications_connection_handler: 

476 self._maint_notifications_connection_handler = ( 

477 MaintNotificationsConnectionHandler( 

478 self, maint_notifications_pool_handler.config 

479 ) 

480 ) 

481 self._get_parser().set_maintenance_push_handler( 

482 self._maint_notifications_connection_handler.handle_notification 

483 ) 

484 else: 

485 self._maint_notifications_connection_handler.config = ( 

486 maint_notifications_pool_handler.config 

487 ) 

488 

489 def activate_maint_notifications_handling_if_enabled(self, check_health=True): 

490 # Send maintenance notifications handshake if RESP3 is active 

491 # and maintenance notifications are enabled 

492 # and we have a host to determine the endpoint type from 

493 # When the maint_notifications_config enabled mode is "auto", 

494 # we just log a warning if the handshake fails 

495 # When the mode is enabled=True, we raise an exception in case of failure 

496 if ( 

497 self.get_protocol() not in [2, "2"] 

498 and self.maint_notifications_config 

499 and self.maint_notifications_config.enabled 

500 and self._maint_notifications_connection_handler 

501 and hasattr(self, "host") 

502 ): 

503 self._enable_maintenance_notifications( 

504 maint_notifications_config=self.maint_notifications_config, 

505 check_health=check_health, 

506 ) 

507 

508 def _enable_maintenance_notifications( 

509 self, maint_notifications_config: MaintNotificationsConfig, check_health=True 

510 ): 

511 try: 

512 host = getattr(self, "host", None) 

513 if host is None: 

514 raise ValueError( 

515 "Cannot enable maintenance notifications for connection" 

516 " object that doesn't have a host attribute." 

517 ) 

518 else: 

519 endpoint_type = maint_notifications_config.get_endpoint_type(host, self) 

520 self.send_command( 

521 "CLIENT", 

522 "MAINT_NOTIFICATIONS", 

523 "ON", 

524 "moving-endpoint-type", 

525 endpoint_type.value, 

526 check_health=check_health, 

527 ) 

528 response = self.read_response() 

529 if not response or str_if_bytes(response) != "OK": 

530 raise ResponseError( 

531 "The server doesn't support maintenance notifications" 

532 ) 

533 except Exception as e: 

534 if ( 

535 isinstance(e, ResponseError) 

536 and maint_notifications_config.enabled == "auto" 

537 ): 

538 # Log warning but don't fail the connection 

539 import logging 

540 

541 logger = logging.getLogger(__name__) 

542 logger.debug(f"Failed to enable maintenance notifications: {e}") 

543 else: 

544 raise 

545 

546 def get_resolved_ip(self) -> Optional[str]: 

547 """ 

548 Extract the resolved IP address from an 

549 established connection or resolve it from the host. 

550 

551 First tries to get the actual IP from the socket (most accurate), 

552 then falls back to DNS resolution if needed. 

553 

554 Args: 

555 connection: The connection object to extract the IP from 

556 

557 Returns: 

558 str: The resolved IP address, or None if it cannot be determined 

559 """ 

560 

561 # Method 1: Try to get the actual IP from the established socket connection 

562 # This is most accurate as it shows the exact IP being used 

563 try: 

564 conn_socket = self._get_socket() 

565 if conn_socket is not None: 

566 peer_addr = conn_socket.getpeername() 

567 if peer_addr and len(peer_addr) >= 1: 

568 # For TCP sockets, peer_addr is typically (host, port) tuple 

569 # Return just the host part 

570 return peer_addr[0] 

571 except (AttributeError, OSError): 

572 # Socket might not be connected or getpeername() might fail 

573 pass 

574 

575 # Method 2: Fallback to DNS resolution of the host 

576 # This is less accurate but works when socket is not available 

577 try: 

578 host = getattr(self, "host", "localhost") 

579 port = getattr(self, "port", 6379) 

580 if host: 

581 # Use getaddrinfo to resolve the hostname to IP 

582 # This mimics what the connection would do during _connect() 

583 addr_info = socket.getaddrinfo( 

584 host, port, socket.AF_UNSPEC, socket.SOCK_STREAM 

585 ) 

586 if addr_info: 

587 # Return the IP from the first result 

588 # addr_info[0] is (family, socktype, proto, canonname, sockaddr) 

589 # sockaddr[0] is the IP address 

590 return str(addr_info[0][4][0]) 

591 except (AttributeError, OSError, socket.gaierror): 

592 # DNS resolution might fail 

593 pass 

594 

595 return None 

596 

597 @property 

598 def maintenance_state(self) -> MaintenanceState: 

599 return self._maintenance_state 

600 

601 @maintenance_state.setter 

602 def maintenance_state(self, state: "MaintenanceState"): 

603 self._maintenance_state = state 

604 

605 def getpeername(self): 

606 """ 

607 Returns the peer name of the connection. 

608 """ 

609 conn_socket = self._get_socket() 

610 if conn_socket: 

611 return conn_socket.getpeername()[0] 

612 return None 

613 

614 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

615 conn_socket = self._get_socket() 

616 if conn_socket: 

617 timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout 

618 conn_socket.settimeout(timeout) 

619 self.update_parser_timeout(timeout) 

620 

621 def update_parser_timeout(self, timeout: Optional[float] = None): 

622 parser = self._get_parser() 

623 if parser and parser._buffer: 

624 if isinstance(parser, _RESP3Parser) and timeout: 

625 parser._buffer.socket_timeout = timeout 

626 elif isinstance(parser, _HiredisParser): 

627 parser._socket_timeout = timeout 

628 

629 def set_tmp_settings( 

630 self, 

631 tmp_host_address: Optional[Union[str, object]] = SENTINEL, 

632 tmp_relaxed_timeout: Optional[float] = None, 

633 ): 

634 """ 

635 The value of SENTINEL is used to indicate that the property should not be updated. 

636 """ 

637 if tmp_host_address and tmp_host_address != SENTINEL: 

638 self.host = str(tmp_host_address) 

639 if tmp_relaxed_timeout != -1: 

640 self.socket_timeout = tmp_relaxed_timeout 

641 self.socket_connect_timeout = tmp_relaxed_timeout 

642 

643 def reset_tmp_settings( 

644 self, 

645 reset_host_address: bool = False, 

646 reset_relaxed_timeout: bool = False, 

647 ): 

648 if reset_host_address: 

649 self.host = self.orig_host_address 

650 if reset_relaxed_timeout: 

651 self.socket_timeout = self.orig_socket_timeout 

652 self.socket_connect_timeout = self.orig_socket_connect_timeout 

653 

654 

655class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface): 

656 "Manages communication to and from a Redis server" 

657 

658 def __init__( 

659 self, 

660 db: int = 0, 

661 password: Optional[str] = None, 

662 socket_timeout: Optional[float] = None, 

663 socket_connect_timeout: Optional[float] = None, 

664 retry_on_timeout: bool = False, 

665 retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL, 

666 encoding: str = "utf-8", 

667 encoding_errors: str = "strict", 

668 decode_responses: bool = False, 

669 parser_class=DefaultParser, 

670 socket_read_size: int = 65536, 

671 health_check_interval: int = 0, 

672 client_name: Optional[str] = None, 

673 lib_name: Optional[str] = "redis-py", 

674 lib_version: Optional[str] = get_lib_version(), 

675 username: Optional[str] = None, 

676 retry: Union[Any, None] = None, 

677 redis_connect_func: Optional[Callable[[], None]] = None, 

678 credential_provider: Optional[CredentialProvider] = None, 

679 protocol: Optional[int] = 2, 

680 command_packer: Optional[Callable[[], None]] = None, 

681 event_dispatcher: Optional[EventDispatcher] = None, 

682 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

683 maint_notifications_pool_handler: Optional[ 

684 MaintNotificationsPoolHandler 

685 ] = None, 

686 maintenance_state: "MaintenanceState" = MaintenanceState.NONE, 

687 maintenance_notification_hash: Optional[int] = None, 

688 orig_host_address: Optional[str] = None, 

689 orig_socket_timeout: Optional[float] = None, 

690 orig_socket_connect_timeout: Optional[float] = None, 

691 ): 

692 """ 

693 Initialize a new Connection. 

694 To specify a retry policy for specific errors, first set 

695 `retry_on_error` to a list of the error/s to retry on, then set 

696 `retry` to a valid `Retry` object. 

697 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

698 """ 

699 if (username or password) and credential_provider is not None: 

700 raise DataError( 

701 "'username' and 'password' cannot be passed along with 'credential_" 

702 "provider'. Please provide only one of the following arguments: \n" 

703 "1. 'password' and (optional) 'username'\n" 

704 "2. 'credential_provider'" 

705 ) 

706 if event_dispatcher is None: 

707 self._event_dispatcher = EventDispatcher() 

708 else: 

709 self._event_dispatcher = event_dispatcher 

710 self.pid = os.getpid() 

711 self.db = db 

712 self.client_name = client_name 

713 self.lib_name = lib_name 

714 self.lib_version = lib_version 

715 self.credential_provider = credential_provider 

716 self.password = password 

717 self.username = username 

718 self._socket_timeout = socket_timeout 

719 if socket_connect_timeout is None: 

720 socket_connect_timeout = socket_timeout 

721 self._socket_connect_timeout = socket_connect_timeout 

722 self.retry_on_timeout = retry_on_timeout 

723 if retry_on_error is SENTINEL: 

724 retry_on_errors_list = [] 

725 else: 

726 retry_on_errors_list = list(retry_on_error) 

727 if retry_on_timeout: 

728 # Add TimeoutError to the errors list to retry on 

729 retry_on_errors_list.append(TimeoutError) 

730 self.retry_on_error = retry_on_errors_list 

731 if retry or self.retry_on_error: 

732 if retry is None: 

733 self.retry = Retry(NoBackoff(), 1) 

734 else: 

735 # deep-copy the Retry object as it is mutable 

736 self.retry = copy.deepcopy(retry) 

737 if self.retry_on_error: 

738 # Update the retry's supported errors with the specified errors 

739 self.retry.update_supported_errors(self.retry_on_error) 

740 else: 

741 self.retry = Retry(NoBackoff(), 0) 

742 self.health_check_interval = health_check_interval 

743 self.next_health_check = 0 

744 self.redis_connect_func = redis_connect_func 

745 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

746 self.handshake_metadata = None 

747 self._sock = None 

748 self._socket_read_size = socket_read_size 

749 self._connect_callbacks = [] 

750 self._buffer_cutoff = 6000 

751 self._re_auth_token: Optional[TokenInterface] = None 

752 try: 

753 p = int(protocol) 

754 except TypeError: 

755 p = DEFAULT_RESP_VERSION 

756 except ValueError: 

757 raise ConnectionError("protocol must be an integer") 

758 finally: 

759 if p < 2 or p > 3: 

760 raise ConnectionError("protocol must be either 2 or 3") 

761 # p = DEFAULT_RESP_VERSION 

762 self.protocol = p 

763 if self.protocol == 3 and parser_class == _RESP2Parser: 

764 # If the protocol is 3 but the parser is RESP2, change it to RESP3 

765 # This is needed because the parser might be set before the protocol 

766 # or might be provided as a kwarg to the constructor 

767 # We need to react on discrepancy only for RESP2 and RESP3 

768 # as hiredis supports both 

769 parser_class = _RESP3Parser 

770 self.set_parser(parser_class) 

771 

772 self._command_packer = self._construct_command_packer(command_packer) 

773 self._should_reconnect = False 

774 

775 # Set up maintenance notifications 

776 MaintNotificationsAbstractConnection.__init__( 

777 self, 

778 maint_notifications_config, 

779 maint_notifications_pool_handler, 

780 maintenance_state, 

781 maintenance_notification_hash, 

782 orig_host_address, 

783 orig_socket_timeout, 

784 orig_socket_connect_timeout, 

785 self._parser, 

786 ) 

787 

788 def __repr__(self): 

789 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) 

790 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

791 

792 @abstractmethod 

793 def repr_pieces(self): 

794 pass 

795 

796 def __del__(self): 

797 try: 

798 self.disconnect() 

799 except Exception: 

800 pass 

801 

802 def _construct_command_packer(self, packer): 

803 if packer is not None: 

804 return packer 

805 elif HIREDIS_AVAILABLE: 

806 return HiredisRespSerializer() 

807 else: 

808 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) 

809 

810 def register_connect_callback(self, callback): 

811 """ 

812 Register a callback to be called when the connection is established either 

813 initially or reconnected. This allows listeners to issue commands that 

814 are ephemeral to the connection, for example pub/sub subscription or 

815 key tracking. The callback must be a _method_ and will be kept as 

816 a weak reference. 

817 """ 

818 wm = weakref.WeakMethod(callback) 

819 if wm not in self._connect_callbacks: 

820 self._connect_callbacks.append(wm) 

821 

822 def deregister_connect_callback(self, callback): 

823 """ 

824 De-register a previously registered callback. It will no-longer receive 

825 notifications on connection events. Calling this is not required when the 

826 listener goes away, since the callbacks are kept as weak methods. 

827 """ 

828 try: 

829 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

830 except ValueError: 

831 pass 

832 

833 def set_parser(self, parser_class): 

834 """ 

835 Creates a new instance of parser_class with socket size: 

836 _socket_read_size and assigns it to the parser for the connection 

837 :param parser_class: The required parser class 

838 """ 

839 self._parser = parser_class(socket_read_size=self._socket_read_size) 

840 

841 def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]: 

842 return self._parser 

843 

844 def connect(self): 

845 "Connects to the Redis server if not already connected" 

846 # try once the socket connect with the handshake, retry the whole 

847 # connect/handshake flow based on retry policy 

848 self.retry.call_with_retry( 

849 lambda: self.connect_check_health( 

850 check_health=True, retry_socket_connect=False 

851 ), 

852 lambda error: self.disconnect(error), 

853 ) 

854 

855 def connect_check_health( 

856 self, check_health: bool = True, retry_socket_connect: bool = True 

857 ): 

858 if self._sock: 

859 return 

860 try: 

861 if retry_socket_connect: 

862 sock = self.retry.call_with_retry( 

863 lambda: self._connect(), lambda error: self.disconnect(error) 

864 ) 

865 else: 

866 sock = self._connect() 

867 except socket.timeout: 

868 raise TimeoutError("Timeout connecting to server") 

869 except OSError as e: 

870 raise ConnectionError(self._error_message(e)) 

871 

872 self._sock = sock 

873 try: 

874 if self.redis_connect_func is None: 

875 # Use the default on_connect function 

876 self.on_connect_check_health(check_health=check_health) 

877 else: 

878 # Use the passed function redis_connect_func 

879 self.redis_connect_func(self) 

880 except RedisError: 

881 # clean up after any error in on_connect 

882 self.disconnect() 

883 raise 

884 

885 # run any user callbacks. right now the only internal callback 

886 # is for pubsub channel/pattern resubscription 

887 # first, remove any dead weakrefs 

888 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

889 for ref in self._connect_callbacks: 

890 callback = ref() 

891 if callback: 

892 callback(self) 

893 

894 @abstractmethod 

895 def _connect(self): 

896 pass 

897 

898 @abstractmethod 

899 def _host_error(self): 

900 pass 

901 

902 def _error_message(self, exception): 

903 return format_error_message(self._host_error(), exception) 

904 

905 def on_connect(self): 

906 self.on_connect_check_health(check_health=True) 

907 

908 def on_connect_check_health(self, check_health: bool = True): 

909 "Initialize the connection, authenticate and select a database" 

910 self._parser.on_connect(self) 

911 parser = self._parser 

912 

913 auth_args = None 

914 # if credential provider or username and/or password are set, authenticate 

915 if self.credential_provider or (self.username or self.password): 

916 cred_provider = ( 

917 self.credential_provider 

918 or UsernamePasswordCredentialProvider(self.username, self.password) 

919 ) 

920 auth_args = cred_provider.get_credentials() 

921 

922 # if resp version is specified and we have auth args, 

923 # we need to send them via HELLO 

924 if auth_args and self.protocol not in [2, "2"]: 

925 if isinstance(self._parser, _RESP2Parser): 

926 self.set_parser(_RESP3Parser) 

927 # update cluster exception classes 

928 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

929 self._parser.on_connect(self) 

930 if len(auth_args) == 1: 

931 auth_args = ["default", auth_args[0]] 

932 # avoid checking health here -- PING will fail if we try 

933 # to check the health prior to the AUTH 

934 self.send_command( 

935 "HELLO", self.protocol, "AUTH", *auth_args, check_health=False 

936 ) 

937 self.handshake_metadata = self.read_response() 

938 # if response.get(b"proto") != self.protocol and response.get( 

939 # "proto" 

940 # ) != self.protocol: 

941 # raise ConnectionError("Invalid RESP version") 

942 elif auth_args: 

943 # avoid checking health here -- PING will fail if we try 

944 # to check the health prior to the AUTH 

945 self.send_command("AUTH", *auth_args, check_health=False) 

946 

947 try: 

948 auth_response = self.read_response() 

949 except AuthenticationWrongNumberOfArgsError: 

950 # a username and password were specified but the Redis 

951 # server seems to be < 6.0.0 which expects a single password 

952 # arg. retry auth with just the password. 

953 # https://github.com/andymccurdy/redis-py/issues/1274 

954 self.send_command("AUTH", auth_args[-1], check_health=False) 

955 auth_response = self.read_response() 

956 

957 if str_if_bytes(auth_response) != "OK": 

958 raise AuthenticationError("Invalid Username or Password") 

959 

960 # if resp version is specified, switch to it 

961 elif self.protocol not in [2, "2"]: 

962 if isinstance(self._parser, _RESP2Parser): 

963 self.set_parser(_RESP3Parser) 

964 # update cluster exception classes 

965 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

966 self._parser.on_connect(self) 

967 self.send_command("HELLO", self.protocol, check_health=check_health) 

968 self.handshake_metadata = self.read_response() 

969 if ( 

970 self.handshake_metadata.get(b"proto") != self.protocol 

971 and self.handshake_metadata.get("proto") != self.protocol 

972 ): 

973 raise ConnectionError("Invalid RESP version") 

974 

975 # Activate maintenance notifications for this connection 

976 # if enabled in the configuration 

977 # This is a no-op if maintenance notifications are not enabled 

978 self.activate_maint_notifications_handling_if_enabled(check_health=check_health) 

979 

980 # if a client_name is given, set it 

981 if self.client_name: 

982 self.send_command( 

983 "CLIENT", 

984 "SETNAME", 

985 self.client_name, 

986 check_health=check_health, 

987 ) 

988 if str_if_bytes(self.read_response()) != "OK": 

989 raise ConnectionError("Error setting client name") 

990 

991 try: 

992 # set the library name and version 

993 if self.lib_name: 

994 self.send_command( 

995 "CLIENT", 

996 "SETINFO", 

997 "LIB-NAME", 

998 self.lib_name, 

999 check_health=check_health, 

1000 ) 

1001 self.read_response() 

1002 except ResponseError: 

1003 pass 

1004 

1005 try: 

1006 if self.lib_version: 

1007 self.send_command( 

1008 "CLIENT", 

1009 "SETINFO", 

1010 "LIB-VER", 

1011 self.lib_version, 

1012 check_health=check_health, 

1013 ) 

1014 self.read_response() 

1015 except ResponseError: 

1016 pass 

1017 

1018 # if a database is specified, switch to it 

1019 if self.db: 

1020 self.send_command("SELECT", self.db, check_health=check_health) 

1021 if str_if_bytes(self.read_response()) != "OK": 

1022 raise ConnectionError("Invalid Database") 

1023 

1024 def disconnect(self, *args): 

1025 "Disconnects from the Redis server" 

1026 self._parser.on_disconnect() 

1027 

1028 conn_sock = self._sock 

1029 self._sock = None 

1030 # reset the reconnect flag 

1031 self.reset_should_reconnect() 

1032 if conn_sock is None: 

1033 return 

1034 

1035 if os.getpid() == self.pid: 

1036 try: 

1037 conn_sock.shutdown(socket.SHUT_RDWR) 

1038 except (OSError, TypeError): 

1039 pass 

1040 

1041 try: 

1042 conn_sock.close() 

1043 except OSError: 

1044 pass 

1045 

1046 def mark_for_reconnect(self): 

1047 self._should_reconnect = True 

1048 

1049 def should_reconnect(self): 

1050 return self._should_reconnect 

1051 

1052 def reset_should_reconnect(self): 

1053 self._should_reconnect = False 

1054 

1055 def _send_ping(self): 

1056 """Send PING, expect PONG in return""" 

1057 self.send_command("PING", check_health=False) 

1058 if str_if_bytes(self.read_response()) != "PONG": 

1059 raise ConnectionError("Bad response from PING health check") 

1060 

1061 def _ping_failed(self, error): 

1062 """Function to call when PING fails""" 

1063 self.disconnect() 

1064 

1065 def check_health(self): 

1066 """Check the health of the connection with a PING/PONG""" 

1067 if self.health_check_interval and time.monotonic() > self.next_health_check: 

1068 self.retry.call_with_retry(self._send_ping, self._ping_failed) 

1069 

1070 def send_packed_command(self, command, check_health=True): 

1071 """Send an already packed command to the Redis server""" 

1072 if not self._sock: 

1073 self.connect_check_health(check_health=False) 

1074 # guard against health check recursion 

1075 if check_health: 

1076 self.check_health() 

1077 try: 

1078 if isinstance(command, str): 

1079 command = [command] 

1080 for item in command: 

1081 self._sock.sendall(item) 

1082 except socket.timeout: 

1083 self.disconnect() 

1084 raise TimeoutError("Timeout writing to socket") 

1085 except OSError as e: 

1086 self.disconnect() 

1087 if len(e.args) == 1: 

1088 errno, errmsg = "UNKNOWN", e.args[0] 

1089 else: 

1090 errno = e.args[0] 

1091 errmsg = e.args[1] 

1092 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") 

1093 except BaseException: 

1094 # BaseExceptions can be raised when a socket send operation is not 

1095 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

1096 # to send un-sent data. However, the send_packed_command() API 

1097 # does not support it so there is no point in keeping the connection open. 

1098 self.disconnect() 

1099 raise 

1100 

1101 def send_command(self, *args, **kwargs): 

1102 """Pack and send a command to the Redis server""" 

1103 self.send_packed_command( 

1104 self._command_packer.pack(*args), 

1105 check_health=kwargs.get("check_health", True), 

1106 ) 

1107 

1108 def can_read(self, timeout=0): 

1109 """Poll the socket to see if there's data that can be read.""" 

1110 sock = self._sock 

1111 if not sock: 

1112 self.connect() 

1113 

1114 host_error = self._host_error() 

1115 

1116 try: 

1117 return self._parser.can_read(timeout) 

1118 

1119 except OSError as e: 

1120 self.disconnect() 

1121 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

1122 

1123 def read_response( 

1124 self, 

1125 disable_decoding=False, 

1126 *, 

1127 disconnect_on_error=True, 

1128 push_request=False, 

1129 ): 

1130 """Read the response from a previously sent command""" 

1131 

1132 host_error = self._host_error() 

1133 

1134 try: 

1135 if self.protocol in ["3", 3]: 

1136 response = self._parser.read_response( 

1137 disable_decoding=disable_decoding, push_request=push_request 

1138 ) 

1139 else: 

1140 response = self._parser.read_response(disable_decoding=disable_decoding) 

1141 except socket.timeout: 

1142 if disconnect_on_error: 

1143 self.disconnect() 

1144 raise TimeoutError(f"Timeout reading from {host_error}") 

1145 except OSError as e: 

1146 if disconnect_on_error: 

1147 self.disconnect() 

1148 raise ConnectionError(f"Error while reading from {host_error} : {e.args}") 

1149 except BaseException: 

1150 # Also by default close in case of BaseException. A lot of code 

1151 # relies on this behaviour when doing Command/Response pairs. 

1152 # See #1128. 

1153 if disconnect_on_error: 

1154 self.disconnect() 

1155 raise 

1156 

1157 if self.health_check_interval: 

1158 self.next_health_check = time.monotonic() + self.health_check_interval 

1159 

1160 if isinstance(response, ResponseError): 

1161 try: 

1162 raise response 

1163 finally: 

1164 del response # avoid creating ref cycles 

1165 return response 

1166 

1167 def pack_command(self, *args): 

1168 """Pack a series of arguments into the Redis protocol""" 

1169 return self._command_packer.pack(*args) 

1170 

1171 def pack_commands(self, commands): 

1172 """Pack multiple commands into the Redis protocol""" 

1173 output = [] 

1174 pieces = [] 

1175 buffer_length = 0 

1176 buffer_cutoff = self._buffer_cutoff 

1177 

1178 for cmd in commands: 

1179 for chunk in self._command_packer.pack(*cmd): 

1180 chunklen = len(chunk) 

1181 if ( 

1182 buffer_length > buffer_cutoff 

1183 or chunklen > buffer_cutoff 

1184 or isinstance(chunk, memoryview) 

1185 ): 

1186 if pieces: 

1187 output.append(SYM_EMPTY.join(pieces)) 

1188 buffer_length = 0 

1189 pieces = [] 

1190 

1191 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

1192 output.append(chunk) 

1193 else: 

1194 pieces.append(chunk) 

1195 buffer_length += chunklen 

1196 

1197 if pieces: 

1198 output.append(SYM_EMPTY.join(pieces)) 

1199 return output 

1200 

1201 def get_protocol(self) -> Union[int, str]: 

1202 return self.protocol 

1203 

1204 @property 

1205 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1206 return self._handshake_metadata 

1207 

1208 @handshake_metadata.setter 

1209 def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): 

1210 self._handshake_metadata = value 

1211 

1212 def set_re_auth_token(self, token: TokenInterface): 

1213 self._re_auth_token = token 

1214 

1215 def re_auth(self): 

1216 if self._re_auth_token is not None: 

1217 self.send_command( 

1218 "AUTH", 

1219 self._re_auth_token.try_get("oid"), 

1220 self._re_auth_token.get_value(), 

1221 ) 

1222 self.read_response() 

1223 self._re_auth_token = None 

1224 

1225 def _get_socket(self) -> Optional[socket.socket]: 

1226 return self._sock 

1227 

1228 @property 

1229 def socket_timeout(self) -> Optional[Union[float, int]]: 

1230 return self._socket_timeout 

1231 

1232 @socket_timeout.setter 

1233 def socket_timeout(self, value: Optional[Union[float, int]]): 

1234 self._socket_timeout = value 

1235 

1236 @property 

1237 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

1238 return self._socket_connect_timeout 

1239 

1240 @socket_connect_timeout.setter 

1241 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

1242 self._socket_connect_timeout = value 

1243 

1244 

1245class Connection(AbstractConnection): 

1246 "Manages TCP communication to and from a Redis server" 

1247 

1248 def __init__( 

1249 self, 

1250 host="localhost", 

1251 port=6379, 

1252 socket_keepalive=False, 

1253 socket_keepalive_options=None, 

1254 socket_type=0, 

1255 **kwargs, 

1256 ): 

1257 self._host = host 

1258 self.port = int(port) 

1259 self.socket_keepalive = socket_keepalive 

1260 self.socket_keepalive_options = socket_keepalive_options or {} 

1261 self.socket_type = socket_type 

1262 super().__init__(**kwargs) 

1263 

1264 def repr_pieces(self): 

1265 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

1266 if self.client_name: 

1267 pieces.append(("client_name", self.client_name)) 

1268 return pieces 

1269 

1270 def _connect(self): 

1271 "Create a TCP socket connection" 

1272 # we want to mimic what socket.create_connection does to support 

1273 # ipv4/ipv6, but we want to set options prior to calling 

1274 # socket.connect() 

1275 err = None 

1276 

1277 for res in socket.getaddrinfo( 

1278 self.host, self.port, self.socket_type, socket.SOCK_STREAM 

1279 ): 

1280 family, socktype, proto, canonname, socket_address = res 

1281 sock = None 

1282 try: 

1283 sock = socket.socket(family, socktype, proto) 

1284 # TCP_NODELAY 

1285 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

1286 

1287 # TCP_KEEPALIVE 

1288 if self.socket_keepalive: 

1289 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

1290 for k, v in self.socket_keepalive_options.items(): 

1291 sock.setsockopt(socket.IPPROTO_TCP, k, v) 

1292 

1293 # set the socket_connect_timeout before we connect 

1294 sock.settimeout(self.socket_connect_timeout) 

1295 

1296 # connect 

1297 sock.connect(socket_address) 

1298 

1299 # set the socket_timeout now that we're connected 

1300 sock.settimeout(self.socket_timeout) 

1301 return sock 

1302 

1303 except OSError as _: 

1304 err = _ 

1305 if sock is not None: 

1306 try: 

1307 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1308 except OSError: 

1309 pass 

1310 sock.close() 

1311 

1312 if err is not None: 

1313 raise err 

1314 raise OSError("socket.getaddrinfo returned an empty list") 

1315 

1316 def _host_error(self): 

1317 return f"{self.host}:{self.port}" 

1318 

1319 @property 

1320 def host(self) -> str: 

1321 return self._host 

1322 

1323 @host.setter 

1324 def host(self, value: str): 

1325 self._host = value 

1326 

1327 

1328class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface): 

1329 DUMMY_CACHE_VALUE = b"foo" 

1330 MIN_ALLOWED_VERSION = "7.4.0" 

1331 DEFAULT_SERVER_NAME = "redis" 

1332 

1333 def __init__( 

1334 self, 

1335 conn: ConnectionInterface, 

1336 cache: CacheInterface, 

1337 pool_lock: threading.RLock, 

1338 ): 

1339 self.pid = os.getpid() 

1340 self._conn = conn 

1341 self.retry = self._conn.retry 

1342 self.host = self._conn.host 

1343 self.port = self._conn.port 

1344 self.credential_provider = conn.credential_provider 

1345 self._pool_lock = pool_lock 

1346 self._cache = cache 

1347 self._cache_lock = threading.RLock() 

1348 self._current_command_cache_key = None 

1349 self._current_options = None 

1350 self.register_connect_callback(self._enable_tracking_callback) 

1351 

1352 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1353 MaintNotificationsAbstractConnection.__init__( 

1354 self, 

1355 self._conn.maint_notifications_config, 

1356 self._conn._maint_notifications_pool_handler, 

1357 self._conn.maintenance_state, 

1358 self._conn.maintenance_notification_hash, 

1359 self._conn.host, 

1360 self._conn.socket_timeout, 

1361 self._conn.socket_connect_timeout, 

1362 self._conn._get_parser(), 

1363 ) 

1364 

1365 def repr_pieces(self): 

1366 return self._conn.repr_pieces() 

1367 

1368 def register_connect_callback(self, callback): 

1369 self._conn.register_connect_callback(callback) 

1370 

1371 def deregister_connect_callback(self, callback): 

1372 self._conn.deregister_connect_callback(callback) 

1373 

1374 def set_parser(self, parser_class): 

1375 self._conn.set_parser(parser_class) 

1376 

1377 def set_maint_notifications_pool_handler_for_connection( 

1378 self, maint_notifications_pool_handler 

1379 ): 

1380 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1381 self._conn.set_maint_notifications_pool_handler_for_connection( 

1382 maint_notifications_pool_handler 

1383 ) 

1384 

1385 def get_protocol(self): 

1386 return self._conn.get_protocol() 

1387 

1388 def connect(self): 

1389 self._conn.connect() 

1390 

1391 server_name = self._conn.handshake_metadata.get(b"server", None) 

1392 if server_name is None: 

1393 server_name = self._conn.handshake_metadata.get("server", None) 

1394 server_ver = self._conn.handshake_metadata.get(b"version", None) 

1395 if server_ver is None: 

1396 server_ver = self._conn.handshake_metadata.get("version", None) 

1397 if server_ver is None or server_ver is None: 

1398 raise ConnectionError("Cannot retrieve information about server version") 

1399 

1400 server_ver = ensure_string(server_ver) 

1401 server_name = ensure_string(server_name) 

1402 

1403 if ( 

1404 server_name != self.DEFAULT_SERVER_NAME 

1405 or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 

1406 ): 

1407 raise ConnectionError( 

1408 "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 

1409 ) 

1410 

1411 def on_connect(self): 

1412 self._conn.on_connect() 

1413 

1414 def disconnect(self, *args): 

1415 with self._cache_lock: 

1416 self._cache.flush() 

1417 self._conn.disconnect(*args) 

1418 

1419 def check_health(self): 

1420 self._conn.check_health() 

1421 

1422 def send_packed_command(self, command, check_health=True): 

1423 # TODO: Investigate if it's possible to unpack command 

1424 # or extract keys from packed command 

1425 self._conn.send_packed_command(command) 

1426 

1427 def send_command(self, *args, **kwargs): 

1428 self._process_pending_invalidations() 

1429 

1430 with self._cache_lock: 

1431 # Command is write command or not allowed 

1432 # to be cached. 

1433 if not self._cache.is_cachable( 

1434 CacheKey(command=args[0], redis_keys=(), redis_args=()) 

1435 ): 

1436 self._current_command_cache_key = None 

1437 self._conn.send_command(*args, **kwargs) 

1438 return 

1439 

1440 if kwargs.get("keys") is None: 

1441 raise ValueError("Cannot create cache key.") 

1442 

1443 # Creates cache key. 

1444 self._current_command_cache_key = CacheKey( 

1445 command=args[0], redis_keys=tuple(kwargs.get("keys")), redis_args=args 

1446 ) 

1447 

1448 with self._cache_lock: 

1449 # We have to trigger invalidation processing in case if 

1450 # it was cached by another connection to avoid 

1451 # queueing invalidations in stale connections. 

1452 if self._cache.get(self._current_command_cache_key): 

1453 entry = self._cache.get(self._current_command_cache_key) 

1454 

1455 if entry.connection_ref != self._conn: 

1456 with self._pool_lock: 

1457 while entry.connection_ref.can_read(): 

1458 entry.connection_ref.read_response(push_request=True) 

1459 

1460 return 

1461 

1462 # Set temporary entry value to prevent 

1463 # race condition from another connection. 

1464 self._cache.set( 

1465 CacheEntry( 

1466 cache_key=self._current_command_cache_key, 

1467 cache_value=self.DUMMY_CACHE_VALUE, 

1468 status=CacheEntryStatus.IN_PROGRESS, 

1469 connection_ref=self._conn, 

1470 ) 

1471 ) 

1472 

1473 # Send command over socket only if it's allowed 

1474 # read-only command that not yet cached. 

1475 self._conn.send_command(*args, **kwargs) 

1476 

1477 def can_read(self, timeout=0): 

1478 return self._conn.can_read(timeout) 

1479 

1480 def read_response( 

1481 self, disable_decoding=False, *, disconnect_on_error=True, push_request=False 

1482 ): 

1483 with self._cache_lock: 

1484 # Check if command response exists in a cache and it's not in progress. 

1485 if ( 

1486 self._current_command_cache_key is not None 

1487 and self._cache.get(self._current_command_cache_key) is not None 

1488 and self._cache.get(self._current_command_cache_key).status 

1489 != CacheEntryStatus.IN_PROGRESS 

1490 ): 

1491 res = copy.deepcopy( 

1492 self._cache.get(self._current_command_cache_key).cache_value 

1493 ) 

1494 self._current_command_cache_key = None 

1495 return res 

1496 

1497 response = self._conn.read_response( 

1498 disable_decoding=disable_decoding, 

1499 disconnect_on_error=disconnect_on_error, 

1500 push_request=push_request, 

1501 ) 

1502 

1503 with self._cache_lock: 

1504 # Prevent not-allowed command from caching. 

1505 if self._current_command_cache_key is None: 

1506 return response 

1507 # If response is None prevent from caching. 

1508 if response is None: 

1509 self._cache.delete_by_cache_keys([self._current_command_cache_key]) 

1510 return response 

1511 

1512 cache_entry = self._cache.get(self._current_command_cache_key) 

1513 

1514 # Cache only responses that still valid 

1515 # and wasn't invalidated by another connection in meantime. 

1516 if cache_entry is not None: 

1517 cache_entry.status = CacheEntryStatus.VALID 

1518 cache_entry.cache_value = response 

1519 self._cache.set(cache_entry) 

1520 

1521 self._current_command_cache_key = None 

1522 

1523 return response 

1524 

1525 def pack_command(self, *args): 

1526 return self._conn.pack_command(*args) 

1527 

1528 def pack_commands(self, commands): 

1529 return self._conn.pack_commands(commands) 

1530 

1531 @property 

1532 def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: 

1533 return self._conn.handshake_metadata 

1534 

1535 def set_re_auth_token(self, token: TokenInterface): 

1536 self._conn.set_re_auth_token(token) 

1537 

1538 def re_auth(self): 

1539 self._conn.re_auth() 

1540 

1541 def mark_for_reconnect(self): 

1542 self._conn.mark_for_reconnect() 

1543 

1544 def should_reconnect(self): 

1545 return self._conn.should_reconnect() 

1546 

1547 def reset_should_reconnect(self): 

1548 self._conn.reset_should_reconnect() 

1549 

1550 @property 

1551 def host(self) -> str: 

1552 return self._conn.host 

1553 

1554 @host.setter 

1555 def host(self, value: str): 

1556 self._conn.host = value 

1557 

1558 @property 

1559 def socket_timeout(self) -> Optional[Union[float, int]]: 

1560 return self._conn.socket_timeout 

1561 

1562 @socket_timeout.setter 

1563 def socket_timeout(self, value: Optional[Union[float, int]]): 

1564 self._conn.socket_timeout = value 

1565 

1566 @property 

1567 def socket_connect_timeout(self) -> Optional[Union[float, int]]: 

1568 return self._conn.socket_connect_timeout 

1569 

1570 @socket_connect_timeout.setter 

1571 def socket_connect_timeout(self, value: Optional[Union[float, int]]): 

1572 self._conn.socket_connect_timeout = value 

1573 

1574 def _get_socket(self) -> Optional[socket.socket]: 

1575 if isinstance(self._conn, MaintNotificationsAbstractConnection): 

1576 return self._conn._get_socket() 

1577 else: 

1578 raise NotImplementedError( 

1579 "Maintenance notifications are not supported by this connection type" 

1580 ) 

1581 

1582 def _get_maint_notifications_connection_instance( 

1583 self, connection 

1584 ) -> MaintNotificationsAbstractConnection: 

1585 """ 

1586 Validate that connection instance supports maintenance notifications. 

1587 With this helper method we ensure that we are working 

1588 with the correct connection type. 

1589 After twe validate that connection instance supports maintenance notifications 

1590 we can safely return the connection instance 

1591 as MaintNotificationsAbstractConnection. 

1592 """ 

1593 if not isinstance(connection, MaintNotificationsAbstractConnection): 

1594 raise NotImplementedError( 

1595 "Maintenance notifications are not supported by this connection type" 

1596 ) 

1597 else: 

1598 return connection 

1599 

1600 @property 

1601 def maintenance_state(self) -> MaintenanceState: 

1602 con = self._get_maint_notifications_connection_instance(self._conn) 

1603 return con.maintenance_state 

1604 

1605 @maintenance_state.setter 

1606 def maintenance_state(self, state: MaintenanceState): 

1607 con = self._get_maint_notifications_connection_instance(self._conn) 

1608 con.maintenance_state = state 

1609 

1610 def getpeername(self): 

1611 con = self._get_maint_notifications_connection_instance(self._conn) 

1612 return con.getpeername() 

1613 

1614 def get_resolved_ip(self): 

1615 con = self._get_maint_notifications_connection_instance(self._conn) 

1616 return con.get_resolved_ip() 

1617 

1618 def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): 

1619 con = self._get_maint_notifications_connection_instance(self._conn) 

1620 con.update_current_socket_timeout(relaxed_timeout) 

1621 

1622 def set_tmp_settings( 

1623 self, 

1624 tmp_host_address: Optional[str] = None, 

1625 tmp_relaxed_timeout: Optional[float] = None, 

1626 ): 

1627 con = self._get_maint_notifications_connection_instance(self._conn) 

1628 con.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout) 

1629 

1630 def reset_tmp_settings( 

1631 self, 

1632 reset_host_address: bool = False, 

1633 reset_relaxed_timeout: bool = False, 

1634 ): 

1635 con = self._get_maint_notifications_connection_instance(self._conn) 

1636 con.reset_tmp_settings(reset_host_address, reset_relaxed_timeout) 

1637 

1638 def _connect(self): 

1639 self._conn._connect() 

1640 

1641 def _host_error(self): 

1642 self._conn._host_error() 

1643 

1644 def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: 

1645 conn.send_command("CLIENT", "TRACKING", "ON") 

1646 conn.read_response() 

1647 conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) 

1648 

1649 def _process_pending_invalidations(self): 

1650 while self.can_read(): 

1651 self._conn.read_response(push_request=True) 

1652 

1653 def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): 

1654 with self._cache_lock: 

1655 # Flush cache when DB flushed on server-side 

1656 if data[1] is None: 

1657 self._cache.flush() 

1658 else: 

1659 self._cache.delete_by_redis_keys(data[1]) 

1660 

1661 

1662class SSLConnection(Connection): 

1663 """Manages SSL connections to and from the Redis server(s). 

1664 This class extends the Connection class, adding SSL functionality, and making 

1665 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

1666 """ # noqa 

1667 

1668 def __init__( 

1669 self, 

1670 ssl_keyfile=None, 

1671 ssl_certfile=None, 

1672 ssl_cert_reqs="required", 

1673 ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None, 

1674 ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None, 

1675 ssl_ca_certs=None, 

1676 ssl_ca_data=None, 

1677 ssl_check_hostname=True, 

1678 ssl_ca_path=None, 

1679 ssl_password=None, 

1680 ssl_validate_ocsp=False, 

1681 ssl_validate_ocsp_stapled=False, 

1682 ssl_ocsp_context=None, 

1683 ssl_ocsp_expected_cert=None, 

1684 ssl_min_version=None, 

1685 ssl_ciphers=None, 

1686 **kwargs, 

1687 ): 

1688 """Constructor 

1689 

1690 Args: 

1691 ssl_keyfile: Path to an ssl private key. Defaults to None. 

1692 ssl_certfile: Path to an ssl certificate. Defaults to None. 

1693 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), 

1694 or an ssl.VerifyMode. Defaults to "required". 

1695 ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None. 

1696 ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None. 

1697 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. 

1698 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. 

1699 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True. 

1700 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. 

1701 ssl_password: Password for unlocking an encrypted private key. Defaults to None. 

1702 

1703 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) 

1704 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response 

1705 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert 

1706 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. 

1707 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. 

1708 ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information. 

1709 

1710 Raises: 

1711 RedisError 

1712 """ # noqa 

1713 if not SSL_AVAILABLE: 

1714 raise RedisError("Python wasn't built with SSL support") 

1715 

1716 self.keyfile = ssl_keyfile 

1717 self.certfile = ssl_certfile 

1718 if ssl_cert_reqs is None: 

1719 ssl_cert_reqs = ssl.CERT_NONE 

1720 elif isinstance(ssl_cert_reqs, str): 

1721 CERT_REQS = { # noqa: N806 

1722 "none": ssl.CERT_NONE, 

1723 "optional": ssl.CERT_OPTIONAL, 

1724 "required": ssl.CERT_REQUIRED, 

1725 } 

1726 if ssl_cert_reqs not in CERT_REQS: 

1727 raise RedisError( 

1728 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" 

1729 ) 

1730 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 

1731 self.cert_reqs = ssl_cert_reqs 

1732 self.ssl_include_verify_flags = ssl_include_verify_flags 

1733 self.ssl_exclude_verify_flags = ssl_exclude_verify_flags 

1734 self.ca_certs = ssl_ca_certs 

1735 self.ca_data = ssl_ca_data 

1736 self.ca_path = ssl_ca_path 

1737 self.check_hostname = ( 

1738 ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False 

1739 ) 

1740 self.certificate_password = ssl_password 

1741 self.ssl_validate_ocsp = ssl_validate_ocsp 

1742 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled 

1743 self.ssl_ocsp_context = ssl_ocsp_context 

1744 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert 

1745 self.ssl_min_version = ssl_min_version 

1746 self.ssl_ciphers = ssl_ciphers 

1747 super().__init__(**kwargs) 

1748 

1749 def _connect(self): 

1750 """ 

1751 Wrap the socket with SSL support, handling potential errors. 

1752 """ 

1753 sock = super()._connect() 

1754 try: 

1755 return self._wrap_socket_with_ssl(sock) 

1756 except (OSError, RedisError): 

1757 sock.close() 

1758 raise 

1759 

1760 def _wrap_socket_with_ssl(self, sock): 

1761 """ 

1762 Wraps the socket with SSL support. 

1763 

1764 Args: 

1765 sock: The plain socket to wrap with SSL. 

1766 

1767 Returns: 

1768 An SSL wrapped socket. 

1769 """ 

1770 context = ssl.create_default_context() 

1771 context.check_hostname = self.check_hostname 

1772 context.verify_mode = self.cert_reqs 

1773 if self.ssl_include_verify_flags: 

1774 for flag in self.ssl_include_verify_flags: 

1775 context.verify_flags |= flag 

1776 if self.ssl_exclude_verify_flags: 

1777 for flag in self.ssl_exclude_verify_flags: 

1778 context.verify_flags &= ~flag 

1779 if self.certfile or self.keyfile: 

1780 context.load_cert_chain( 

1781 certfile=self.certfile, 

1782 keyfile=self.keyfile, 

1783 password=self.certificate_password, 

1784 ) 

1785 if ( 

1786 self.ca_certs is not None 

1787 or self.ca_path is not None 

1788 or self.ca_data is not None 

1789 ): 

1790 context.load_verify_locations( 

1791 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1792 ) 

1793 if self.ssl_min_version is not None: 

1794 context.minimum_version = self.ssl_min_version 

1795 if self.ssl_ciphers: 

1796 context.set_ciphers(self.ssl_ciphers) 

1797 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: 

1798 raise RedisError("cryptography is not installed.") 

1799 

1800 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: 

1801 raise RedisError( 

1802 "Either an OCSP staple or pure OCSP connection must be validated " 

1803 "- not both." 

1804 ) 

1805 

1806 sslsock = context.wrap_socket(sock, server_hostname=self.host) 

1807 

1808 # validation for the stapled case 

1809 if self.ssl_validate_ocsp_stapled: 

1810 import OpenSSL 

1811 

1812 from .ocsp import ocsp_staple_verifier 

1813 

1814 # if a context is provided use it - otherwise, a basic context 

1815 if self.ssl_ocsp_context is None: 

1816 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) 

1817 staple_ctx.use_certificate_file(self.certfile) 

1818 staple_ctx.use_privatekey_file(self.keyfile) 

1819 else: 

1820 staple_ctx = self.ssl_ocsp_context 

1821 

1822 staple_ctx.set_ocsp_client_callback( 

1823 ocsp_staple_verifier, self.ssl_ocsp_expected_cert 

1824 ) 

1825 

1826 # need another socket 

1827 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) 

1828 con.request_ocsp() 

1829 con.connect((self.host, self.port)) 

1830 con.do_handshake() 

1831 con.shutdown() 

1832 return sslsock 

1833 

1834 # pure ocsp validation 

1835 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: 

1836 from .ocsp import OCSPVerifier 

1837 

1838 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) 

1839 if o.is_valid(): 

1840 return sslsock 

1841 else: 

1842 raise ConnectionError("ocsp validation error") 

1843 return sslsock 

1844 

1845 

1846class UnixDomainSocketConnection(AbstractConnection): 

1847 "Manages UDS communication to and from a Redis server" 

1848 

1849 def __init__(self, path="", socket_timeout=None, **kwargs): 

1850 super().__init__(**kwargs) 

1851 self.path = path 

1852 self.socket_timeout = socket_timeout 

1853 

1854 def repr_pieces(self): 

1855 pieces = [("path", self.path), ("db", self.db)] 

1856 if self.client_name: 

1857 pieces.append(("client_name", self.client_name)) 

1858 return pieces 

1859 

1860 def _connect(self): 

1861 "Create a Unix domain socket connection" 

1862 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

1863 sock.settimeout(self.socket_connect_timeout) 

1864 try: 

1865 sock.connect(self.path) 

1866 except OSError: 

1867 # Prevent ResourceWarnings for unclosed sockets. 

1868 try: 

1869 sock.shutdown(socket.SHUT_RDWR) # ensure a clean close 

1870 except OSError: 

1871 pass 

1872 sock.close() 

1873 raise 

1874 sock.settimeout(self.socket_timeout) 

1875 return sock 

1876 

1877 def _host_error(self): 

1878 return self.path 

1879 

1880 

1881FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1882 

1883 

1884def to_bool(value): 

1885 if value is None or value == "": 

1886 return None 

1887 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1888 return False 

1889 return bool(value) 

1890 

1891 

1892def parse_ssl_verify_flags(value): 

1893 # flags are passed in as a string representation of a list, 

1894 # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN 

1895 verify_flags_str = value.replace("[", "").replace("]", "") 

1896 

1897 verify_flags = [] 

1898 for flag in verify_flags_str.split(","): 

1899 flag = flag.strip() 

1900 if not hasattr(VerifyFlags, flag): 

1901 raise ValueError(f"Invalid ssl verify flag: {flag}") 

1902 verify_flags.append(getattr(VerifyFlags, flag)) 

1903 return verify_flags 

1904 

1905 

1906URL_QUERY_ARGUMENT_PARSERS = { 

1907 "db": int, 

1908 "socket_timeout": float, 

1909 "socket_connect_timeout": float, 

1910 "socket_keepalive": to_bool, 

1911 "retry_on_timeout": to_bool, 

1912 "retry_on_error": list, 

1913 "max_connections": int, 

1914 "health_check_interval": int, 

1915 "ssl_check_hostname": to_bool, 

1916 "ssl_include_verify_flags": parse_ssl_verify_flags, 

1917 "ssl_exclude_verify_flags": parse_ssl_verify_flags, 

1918 "timeout": float, 

1919} 

1920 

1921 

1922def parse_url(url): 

1923 if not ( 

1924 url.startswith("redis://") 

1925 or url.startswith("rediss://") 

1926 or url.startswith("unix://") 

1927 ): 

1928 raise ValueError( 

1929 "Redis URL must specify one of the following " 

1930 "schemes (redis://, rediss://, unix://)" 

1931 ) 

1932 

1933 url = urlparse(url) 

1934 kwargs = {} 

1935 

1936 for name, value in parse_qs(url.query).items(): 

1937 if value and len(value) > 0: 

1938 value = unquote(value[0]) 

1939 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1940 if parser: 

1941 try: 

1942 kwargs[name] = parser(value) 

1943 except (TypeError, ValueError): 

1944 raise ValueError(f"Invalid value for '{name}' in connection URL.") 

1945 else: 

1946 kwargs[name] = value 

1947 

1948 if url.username: 

1949 kwargs["username"] = unquote(url.username) 

1950 if url.password: 

1951 kwargs["password"] = unquote(url.password) 

1952 

1953 # We only support redis://, rediss:// and unix:// schemes. 

1954 if url.scheme == "unix": 

1955 if url.path: 

1956 kwargs["path"] = unquote(url.path) 

1957 kwargs["connection_class"] = UnixDomainSocketConnection 

1958 

1959 else: # implied: url.scheme in ("redis", "rediss"): 

1960 if url.hostname: 

1961 kwargs["host"] = unquote(url.hostname) 

1962 if url.port: 

1963 kwargs["port"] = int(url.port) 

1964 

1965 # If there's a path argument, use it as the db argument if a 

1966 # querystring value wasn't specified 

1967 if url.path and "db" not in kwargs: 

1968 try: 

1969 kwargs["db"] = int(unquote(url.path).replace("/", "")) 

1970 except (AttributeError, ValueError): 

1971 pass 

1972 

1973 if url.scheme == "rediss": 

1974 kwargs["connection_class"] = SSLConnection 

1975 

1976 return kwargs 

1977 

1978 

1979_CP = TypeVar("_CP", bound="ConnectionPool") 

1980 

1981 

1982class ConnectionPoolInterface(ABC): 

1983 @abstractmethod 

1984 def get_protocol(self): 

1985 pass 

1986 

1987 @abstractmethod 

1988 def reset(self): 

1989 pass 

1990 

1991 @abstractmethod 

1992 @deprecated_args( 

1993 args_to_warn=["*"], 

1994 reason="Use get_connection() without args instead", 

1995 version="5.3.0", 

1996 ) 

1997 def get_connection( 

1998 self, command_name: Optional[str], *keys, **options 

1999 ) -> ConnectionInterface: 

2000 pass 

2001 

2002 @abstractmethod 

2003 def get_encoder(self): 

2004 pass 

2005 

2006 @abstractmethod 

2007 def release(self, connection: ConnectionInterface): 

2008 pass 

2009 

2010 @abstractmethod 

2011 def disconnect(self, inuse_connections: bool = True): 

2012 pass 

2013 

2014 @abstractmethod 

2015 def close(self): 

2016 pass 

2017 

2018 @abstractmethod 

2019 def set_retry(self, retry: Retry): 

2020 pass 

2021 

2022 @abstractmethod 

2023 def re_auth_callback(self, token: TokenInterface): 

2024 pass 

2025 

2026 

2027class MaintNotificationsAbstractConnectionPool: 

2028 """ 

2029 Abstract class for handling maintenance notifications logic. 

2030 This class is mixed into the ConnectionPool classes. 

2031 

2032 This class is not intended to be used directly! 

2033 

2034 All logic related to maintenance notifications and 

2035 connection pool handling is encapsulated in this class. 

2036 """ 

2037 

2038 def __init__( 

2039 self, 

2040 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

2041 **kwargs, 

2042 ): 

2043 # Initialize maintenance notifications 

2044 is_protocol_supported = kwargs.get("protocol") in [3, "3"] 

2045 if maint_notifications_config is None and is_protocol_supported: 

2046 maint_notifications_config = MaintNotificationsConfig() 

2047 

2048 if maint_notifications_config and maint_notifications_config.enabled: 

2049 if not is_protocol_supported: 

2050 raise RedisError( 

2051 "Maintenance notifications handlers on connection are only supported with RESP version 3" 

2052 ) 

2053 

2054 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2055 self, maint_notifications_config 

2056 ) 

2057 

2058 self._update_connection_kwargs_for_maint_notifications( 

2059 self._maint_notifications_pool_handler 

2060 ) 

2061 else: 

2062 self._maint_notifications_pool_handler = None 

2063 

2064 @property 

2065 @abstractmethod 

2066 def connection_kwargs(self) -> Dict[str, Any]: 

2067 pass 

2068 

2069 @connection_kwargs.setter 

2070 @abstractmethod 

2071 def connection_kwargs(self, value: Dict[str, Any]): 

2072 pass 

2073 

2074 @abstractmethod 

2075 def _get_pool_lock(self) -> threading.RLock: 

2076 pass 

2077 

2078 @abstractmethod 

2079 def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]: 

2080 pass 

2081 

2082 @abstractmethod 

2083 def _get_in_use_connections( 

2084 self, 

2085 ) -> Iterable["MaintNotificationsAbstractConnection"]: 

2086 pass 

2087 

2088 def maint_notifications_enabled(self): 

2089 """ 

2090 Returns: 

2091 True if the maintenance notifications are enabled, False otherwise. 

2092 The maintenance notifications config is stored in the pool handler. 

2093 If the pool handler is not set, the maintenance notifications are not enabled. 

2094 """ 

2095 maint_notifications_config = ( 

2096 self._maint_notifications_pool_handler.config 

2097 if self._maint_notifications_pool_handler 

2098 else None 

2099 ) 

2100 

2101 return maint_notifications_config and maint_notifications_config.enabled 

2102 

2103 def update_maint_notifications_config( 

2104 self, maint_notifications_config: MaintNotificationsConfig 

2105 ): 

2106 """ 

2107 Updates the maintenance notifications configuration. 

2108 This method should be called only if the pool was created 

2109 without enabling the maintenance notifications and 

2110 in a later point in time maintenance notifications 

2111 are requested to be enabled. 

2112 """ 

2113 if ( 

2114 self.maint_notifications_enabled() 

2115 and not maint_notifications_config.enabled 

2116 ): 

2117 raise ValueError( 

2118 "Cannot disable maintenance notifications after enabling them" 

2119 ) 

2120 # first update pool settings 

2121 if not self._maint_notifications_pool_handler: 

2122 self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( 

2123 self, maint_notifications_config 

2124 ) 

2125 else: 

2126 self._maint_notifications_pool_handler.config = maint_notifications_config 

2127 

2128 # then update connection kwargs and existing connections 

2129 self._update_connection_kwargs_for_maint_notifications( 

2130 self._maint_notifications_pool_handler 

2131 ) 

2132 self._update_maint_notifications_configs_for_connections( 

2133 self._maint_notifications_pool_handler 

2134 ) 

2135 

2136 def _update_connection_kwargs_for_maint_notifications( 

2137 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

2138 ): 

2139 """ 

2140 Update the connection kwargs for all future connections. 

2141 """ 

2142 if not self.maint_notifications_enabled(): 

2143 return 

2144 

2145 self.connection_kwargs.update( 

2146 { 

2147 "maint_notifications_pool_handler": maint_notifications_pool_handler, 

2148 "maint_notifications_config": maint_notifications_pool_handler.config, 

2149 } 

2150 ) 

2151 

2152 # Store original connection parameters for maintenance notifications. 

2153 if self.connection_kwargs.get("orig_host_address", None) is None: 

2154 # If orig_host_address is None it means we haven't 

2155 # configured the original values yet 

2156 self.connection_kwargs.update( 

2157 { 

2158 "orig_host_address": self.connection_kwargs.get("host"), 

2159 "orig_socket_timeout": self.connection_kwargs.get( 

2160 "socket_timeout", None 

2161 ), 

2162 "orig_socket_connect_timeout": self.connection_kwargs.get( 

2163 "socket_connect_timeout", None 

2164 ), 

2165 } 

2166 ) 

2167 

2168 def _update_maint_notifications_configs_for_connections( 

2169 self, maint_notifications_pool_handler: MaintNotificationsPoolHandler 

2170 ): 

2171 """Update the maintenance notifications config for all connections in the pool.""" 

2172 with self._get_pool_lock(): 

2173 for conn in self._get_free_connections(): 

2174 conn.set_maint_notifications_pool_handler_for_connection( 

2175 maint_notifications_pool_handler 

2176 ) 

2177 conn.maint_notifications_config = ( 

2178 maint_notifications_pool_handler.config 

2179 ) 

2180 conn.disconnect() 

2181 for conn in self._get_in_use_connections(): 

2182 conn.set_maint_notifications_pool_handler_for_connection( 

2183 maint_notifications_pool_handler 

2184 ) 

2185 conn.maint_notifications_config = ( 

2186 maint_notifications_pool_handler.config 

2187 ) 

2188 conn.mark_for_reconnect() 

2189 

2190 def _should_update_connection( 

2191 self, 

2192 conn: "MaintNotificationsAbstractConnection", 

2193 matching_pattern: Literal[ 

2194 "connected_address", "configured_address", "notification_hash" 

2195 ] = "connected_address", 

2196 matching_address: Optional[str] = None, 

2197 matching_notification_hash: Optional[int] = None, 

2198 ) -> bool: 

2199 """ 

2200 Check if the connection should be updated based on the matching criteria. 

2201 """ 

2202 if matching_pattern == "connected_address": 

2203 if matching_address and conn.getpeername() != matching_address: 

2204 return False 

2205 elif matching_pattern == "configured_address": 

2206 if matching_address and conn.host != matching_address: 

2207 return False 

2208 elif matching_pattern == "notification_hash": 

2209 if ( 

2210 matching_notification_hash 

2211 and conn.maintenance_notification_hash != matching_notification_hash 

2212 ): 

2213 return False 

2214 return True 

2215 

2216 def update_connection_settings( 

2217 self, 

2218 conn: "MaintNotificationsAbstractConnection", 

2219 state: Optional["MaintenanceState"] = None, 

2220 maintenance_notification_hash: Optional[int] = None, 

2221 host_address: Optional[str] = None, 

2222 relaxed_timeout: Optional[float] = None, 

2223 update_notification_hash: bool = False, 

2224 reset_host_address: bool = False, 

2225 reset_relaxed_timeout: bool = False, 

2226 ): 

2227 """ 

2228 Update the settings for a single connection. 

2229 """ 

2230 if state: 

2231 conn.maintenance_state = state 

2232 

2233 if update_notification_hash: 

2234 # update the notification hash only if requested 

2235 conn.maintenance_notification_hash = maintenance_notification_hash 

2236 

2237 if host_address is not None: 

2238 conn.set_tmp_settings(tmp_host_address=host_address) 

2239 

2240 if relaxed_timeout is not None: 

2241 conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout) 

2242 

2243 if reset_relaxed_timeout or reset_host_address: 

2244 conn.reset_tmp_settings( 

2245 reset_host_address=reset_host_address, 

2246 reset_relaxed_timeout=reset_relaxed_timeout, 

2247 ) 

2248 

2249 conn.update_current_socket_timeout(relaxed_timeout) 

2250 

2251 def update_connections_settings( 

2252 self, 

2253 state: Optional["MaintenanceState"] = None, 

2254 maintenance_notification_hash: Optional[int] = None, 

2255 host_address: Optional[str] = None, 

2256 relaxed_timeout: Optional[float] = None, 

2257 matching_address: Optional[str] = None, 

2258 matching_notification_hash: Optional[int] = None, 

2259 matching_pattern: Literal[ 

2260 "connected_address", "configured_address", "notification_hash" 

2261 ] = "connected_address", 

2262 update_notification_hash: bool = False, 

2263 reset_host_address: bool = False, 

2264 reset_relaxed_timeout: bool = False, 

2265 include_free_connections: bool = True, 

2266 ): 

2267 """ 

2268 Update the settings for all matching connections in the pool. 

2269 

2270 This method does not create new connections. 

2271 This method does not affect the connection kwargs. 

2272 

2273 :param state: The maintenance state to set for the connection. 

2274 :param maintenance_notification_hash: The hash of the maintenance notification 

2275 to set for the connection. 

2276 :param host_address: The host address to set for the connection. 

2277 :param relaxed_timeout: The relaxed timeout to set for the connection. 

2278 :param matching_address: The address to match for the connection. 

2279 :param matching_notification_hash: The notification hash to match for the connection. 

2280 :param matching_pattern: The pattern to match for the connection. 

2281 :param update_notification_hash: Whether to update the notification hash for the connection. 

2282 :param reset_host_address: Whether to reset the host address to the original address. 

2283 :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout. 

2284 :param include_free_connections: Whether to include free/available connections. 

2285 """ 

2286 with self._get_pool_lock(): 

2287 for conn in self._get_in_use_connections(): 

2288 if self._should_update_connection( 

2289 conn, 

2290 matching_pattern, 

2291 matching_address, 

2292 matching_notification_hash, 

2293 ): 

2294 self.update_connection_settings( 

2295 conn, 

2296 state=state, 

2297 maintenance_notification_hash=maintenance_notification_hash, 

2298 host_address=host_address, 

2299 relaxed_timeout=relaxed_timeout, 

2300 update_notification_hash=update_notification_hash, 

2301 reset_host_address=reset_host_address, 

2302 reset_relaxed_timeout=reset_relaxed_timeout, 

2303 ) 

2304 

2305 if include_free_connections: 

2306 for conn in self._get_free_connections(): 

2307 if self._should_update_connection( 

2308 conn, 

2309 matching_pattern, 

2310 matching_address, 

2311 matching_notification_hash, 

2312 ): 

2313 self.update_connection_settings( 

2314 conn, 

2315 state=state, 

2316 maintenance_notification_hash=maintenance_notification_hash, 

2317 host_address=host_address, 

2318 relaxed_timeout=relaxed_timeout, 

2319 update_notification_hash=update_notification_hash, 

2320 reset_host_address=reset_host_address, 

2321 reset_relaxed_timeout=reset_relaxed_timeout, 

2322 ) 

2323 

2324 def update_connection_kwargs( 

2325 self, 

2326 **kwargs, 

2327 ): 

2328 """ 

2329 Update the connection kwargs for all future connections. 

2330 

2331 This method updates the connection kwargs for all future connections created by the pool. 

2332 Existing connections are not affected. 

2333 """ 

2334 self.connection_kwargs.update(kwargs) 

2335 

2336 def update_active_connections_for_reconnect( 

2337 self, 

2338 moving_address_src: Optional[str] = None, 

2339 ): 

2340 """ 

2341 Mark all active connections for reconnect. 

2342 This is used when a cluster node is migrated to a different address. 

2343 

2344 :param moving_address_src: The address of the node that is being moved. 

2345 """ 

2346 with self._get_pool_lock(): 

2347 for conn in self._get_in_use_connections(): 

2348 if self._should_update_connection( 

2349 conn, "connected_address", moving_address_src 

2350 ): 

2351 conn.mark_for_reconnect() 

2352 

2353 def disconnect_free_connections( 

2354 self, 

2355 moving_address_src: Optional[str] = None, 

2356 ): 

2357 """ 

2358 Disconnect all free/available connections. 

2359 This is used when a cluster node is migrated to a different address. 

2360 

2361 :param moving_address_src: The address of the node that is being moved. 

2362 """ 

2363 with self._get_pool_lock(): 

2364 for conn in self._get_free_connections(): 

2365 if self._should_update_connection( 

2366 conn, "connected_address", moving_address_src 

2367 ): 

2368 conn.disconnect() 

2369 

2370 

2371class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface): 

2372 """ 

2373 Create a connection pool. ``If max_connections`` is set, then this 

2374 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's 

2375 limit is reached. 

2376 

2377 By default, TCP connections are created unless ``connection_class`` 

2378 is specified. Use class:`.UnixDomainSocketConnection` for 

2379 unix sockets. 

2380 :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. 

2381 

2382 If ``maint_notifications_config`` is provided, the connection pool will support 

2383 maintenance notifications. 

2384 Maintenance notifications are supported only with RESP3. 

2385 If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3, 

2386 the maintenance notifications will be enabled by default. 

2387 

2388 Any additional keyword arguments are passed to the constructor of 

2389 ``connection_class``. 

2390 """ 

2391 

2392 @classmethod 

2393 def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: 

2394 """ 

2395 Return a connection pool configured from the given URL. 

2396 

2397 For example:: 

2398 

2399 redis://[[username]:[password]]@localhost:6379/0 

2400 rediss://[[username]:[password]]@localhost:6379/0 

2401 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

2402 

2403 Three URL schemes are supported: 

2404 

2405 - `redis://` creates a TCP socket connection. See more at: 

2406 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

2407 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

2408 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

2409 - ``unix://``: creates a Unix Domain Socket connection. 

2410 

2411 The username, password, hostname, path and all querystring values 

2412 are passed through urllib.parse.unquote in order to replace any 

2413 percent-encoded values with their corresponding characters. 

2414 

2415 There are several ways to specify a database number. The first value 

2416 found will be used: 

2417 

2418 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

2419 2. If using the redis:// or rediss:// schemes, the path argument 

2420 of the url, e.g. redis://localhost/0 

2421 3. A ``db`` keyword argument to this function. 

2422 

2423 If none of these options are specified, the default db=0 is used. 

2424 

2425 All querystring options are cast to their appropriate Python types. 

2426 Boolean arguments can be specified with string values "True"/"False" 

2427 or "Yes"/"No". Values that cannot be properly cast cause a 

2428 ``ValueError`` to be raised. Once parsed, the querystring arguments 

2429 and keyword arguments are passed to the ``ConnectionPool``'s 

2430 class initializer. In the case of conflicting arguments, querystring 

2431 arguments always win. 

2432 """ 

2433 url_options = parse_url(url) 

2434 

2435 if "connection_class" in kwargs: 

2436 url_options["connection_class"] = kwargs["connection_class"] 

2437 

2438 kwargs.update(url_options) 

2439 return cls(**kwargs) 

2440 

2441 def __init__( 

2442 self, 

2443 connection_class=Connection, 

2444 max_connections: Optional[int] = None, 

2445 cache_factory: Optional[CacheFactoryInterface] = None, 

2446 maint_notifications_config: Optional[MaintNotificationsConfig] = None, 

2447 **connection_kwargs, 

2448 ): 

2449 max_connections = max_connections or 2**31 

2450 if not isinstance(max_connections, int) or max_connections < 0: 

2451 raise ValueError('"max_connections" must be a positive integer') 

2452 

2453 self.connection_class = connection_class 

2454 self._connection_kwargs = connection_kwargs 

2455 self.max_connections = max_connections 

2456 self.cache = None 

2457 self._cache_factory = cache_factory 

2458 

2459 if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): 

2460 if self._connection_kwargs.get("protocol") not in [3, "3"]: 

2461 raise RedisError("Client caching is only supported with RESP version 3") 

2462 

2463 cache = self._connection_kwargs.get("cache") 

2464 

2465 if cache is not None: 

2466 if not isinstance(cache, CacheInterface): 

2467 raise ValueError("Cache must implement CacheInterface") 

2468 

2469 self.cache = cache 

2470 else: 

2471 if self._cache_factory is not None: 

2472 self.cache = self._cache_factory.get_cache() 

2473 else: 

2474 self.cache = CacheFactory( 

2475 self._connection_kwargs.get("cache_config") 

2476 ).get_cache() 

2477 

2478 connection_kwargs.pop("cache", None) 

2479 connection_kwargs.pop("cache_config", None) 

2480 

2481 self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None) 

2482 if self._event_dispatcher is None: 

2483 self._event_dispatcher = EventDispatcher() 

2484 

2485 # a lock to protect the critical section in _checkpid(). 

2486 # this lock is acquired when the process id changes, such as 

2487 # after a fork. during this time, multiple threads in the child 

2488 # process could attempt to acquire this lock. the first thread 

2489 # to acquire the lock will reset the data structures and lock 

2490 # object of this pool. subsequent threads acquiring this lock 

2491 # will notice the first thread already did the work and simply 

2492 # release the lock. 

2493 

2494 self._fork_lock = threading.RLock() 

2495 self._lock = threading.RLock() 

2496 

2497 MaintNotificationsAbstractConnectionPool.__init__( 

2498 self, 

2499 maint_notifications_config=maint_notifications_config, 

2500 **connection_kwargs, 

2501 ) 

2502 

2503 self.reset() 

2504 

2505 def __repr__(self) -> str: 

2506 conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) 

2507 return ( 

2508 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

2509 f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" 

2510 f"({conn_kwargs})>)>" 

2511 ) 

2512 

2513 @property 

2514 def connection_kwargs(self) -> Dict[str, Any]: 

2515 return self._connection_kwargs 

2516 

2517 @connection_kwargs.setter 

2518 def connection_kwargs(self, value: Dict[str, Any]): 

2519 self._connection_kwargs = value 

2520 

2521 def get_protocol(self): 

2522 """ 

2523 Returns: 

2524 The RESP protocol version, or ``None`` if the protocol is not specified, 

2525 in which case the server default will be used. 

2526 """ 

2527 return self.connection_kwargs.get("protocol", None) 

2528 

2529 def reset(self) -> None: 

2530 self._created_connections = 0 

2531 self._available_connections = [] 

2532 self._in_use_connections = set() 

2533 

2534 # this must be the last operation in this method. while reset() is 

2535 # called when holding _fork_lock, other threads in this process 

2536 # can call _checkpid() which compares self.pid and os.getpid() without 

2537 # holding any lock (for performance reasons). keeping this assignment 

2538 # as the last operation ensures that those other threads will also 

2539 # notice a pid difference and block waiting for the first thread to 

2540 # release _fork_lock. when each of these threads eventually acquire 

2541 # _fork_lock, they will notice that another thread already called 

2542 # reset() and they will immediately release _fork_lock and continue on. 

2543 self.pid = os.getpid() 

2544 

2545 def _checkpid(self) -> None: 

2546 # _checkpid() attempts to keep ConnectionPool fork-safe on modern 

2547 # systems. this is called by all ConnectionPool methods that 

2548 # manipulate the pool's state such as get_connection() and release(). 

2549 # 

2550 # _checkpid() determines whether the process has forked by comparing 

2551 # the current process id to the process id saved on the ConnectionPool 

2552 # instance. if these values are the same, _checkpid() simply returns. 

2553 # 

2554 # when the process ids differ, _checkpid() assumes that the process 

2555 # has forked and that we're now running in the child process. the child 

2556 # process cannot use the parent's file descriptors (e.g., sockets). 

2557 # therefore, when _checkpid() sees the process id change, it calls 

2558 # reset() in order to reinitialize the child's ConnectionPool. this 

2559 # will cause the child to make all new connection objects. 

2560 # 

2561 # _checkpid() is protected by self._fork_lock to ensure that multiple 

2562 # threads in the child process do not call reset() multiple times. 

2563 # 

2564 # there is an extremely small chance this could fail in the following 

2565 # scenario: 

2566 # 1. process A calls _checkpid() for the first time and acquires 

2567 # self._fork_lock. 

2568 # 2. while holding self._fork_lock, process A forks (the fork() 

2569 # could happen in a different thread owned by process A) 

2570 # 3. process B (the forked child process) inherits the 

2571 # ConnectionPool's state from the parent. that state includes 

2572 # a locked _fork_lock. process B will not be notified when 

2573 # process A releases the _fork_lock and will thus never be 

2574 # able to acquire the _fork_lock. 

2575 # 

2576 # to mitigate this possible deadlock, _checkpid() will only wait 5 

2577 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in 

2578 # that time it is assumed that the child is deadlocked and a 

2579 # redis.ChildDeadlockedError error is raised. 

2580 if self.pid != os.getpid(): 

2581 acquired = self._fork_lock.acquire(timeout=5) 

2582 if not acquired: 

2583 raise ChildDeadlockedError 

2584 # reset() the instance for the new process if another thread 

2585 # hasn't already done so 

2586 try: 

2587 if self.pid != os.getpid(): 

2588 self.reset() 

2589 finally: 

2590 self._fork_lock.release() 

2591 

2592 @deprecated_args( 

2593 args_to_warn=["*"], 

2594 reason="Use get_connection() without args instead", 

2595 version="5.3.0", 

2596 ) 

2597 def get_connection(self, command_name=None, *keys, **options) -> "Connection": 

2598 "Get a connection from the pool" 

2599 

2600 self._checkpid() 

2601 with self._lock: 

2602 try: 

2603 connection = self._available_connections.pop() 

2604 except IndexError: 

2605 connection = self.make_connection() 

2606 self._in_use_connections.add(connection) 

2607 

2608 try: 

2609 # ensure this connection is connected to Redis 

2610 connection.connect() 

2611 # connections that the pool provides should be ready to send 

2612 # a command. if not, the connection was either returned to the 

2613 # pool before all data has been read or the socket has been 

2614 # closed. either way, reconnect and verify everything is good. 

2615 try: 

2616 if ( 

2617 connection.can_read() 

2618 and self.cache is None 

2619 and not self.maint_notifications_enabled() 

2620 ): 

2621 raise ConnectionError("Connection has data") 

2622 except (ConnectionError, TimeoutError, OSError): 

2623 connection.disconnect() 

2624 connection.connect() 

2625 if connection.can_read(): 

2626 raise ConnectionError("Connection not ready") 

2627 except BaseException: 

2628 # release the connection back to the pool so that we don't 

2629 # leak it 

2630 self.release(connection) 

2631 raise 

2632 return connection 

2633 

2634 def get_encoder(self) -> Encoder: 

2635 "Return an encoder based on encoding settings" 

2636 kwargs = self.connection_kwargs 

2637 return Encoder( 

2638 encoding=kwargs.get("encoding", "utf-8"), 

2639 encoding_errors=kwargs.get("encoding_errors", "strict"), 

2640 decode_responses=kwargs.get("decode_responses", False), 

2641 ) 

2642 

2643 def make_connection(self) -> "ConnectionInterface": 

2644 "Create a new connection" 

2645 if self._created_connections >= self.max_connections: 

2646 raise MaxConnectionsError("Too many connections") 

2647 self._created_connections += 1 

2648 

2649 kwargs = dict(self.connection_kwargs) 

2650 

2651 if self.cache is not None: 

2652 return CacheProxyConnection( 

2653 self.connection_class(**kwargs), self.cache, self._lock 

2654 ) 

2655 return self.connection_class(**kwargs) 

2656 

2657 def release(self, connection: "Connection") -> None: 

2658 "Releases the connection back to the pool" 

2659 self._checkpid() 

2660 with self._lock: 

2661 try: 

2662 self._in_use_connections.remove(connection) 

2663 except KeyError: 

2664 # Gracefully fail when a connection is returned to this pool 

2665 # that the pool doesn't actually own 

2666 return 

2667 

2668 if self.owns_connection(connection): 

2669 if connection.should_reconnect(): 

2670 connection.disconnect() 

2671 self._available_connections.append(connection) 

2672 self._event_dispatcher.dispatch( 

2673 AfterConnectionReleasedEvent(connection) 

2674 ) 

2675 else: 

2676 # Pool doesn't own this connection, do not add it back 

2677 # to the pool. 

2678 # The created connections count should not be changed, 

2679 # because the connection was not created by the pool. 

2680 connection.disconnect() 

2681 return 

2682 

2683 def owns_connection(self, connection: "Connection") -> int: 

2684 return connection.pid == self.pid 

2685 

2686 def disconnect(self, inuse_connections: bool = True) -> None: 

2687 """ 

2688 Disconnects connections in the pool 

2689 

2690 If ``inuse_connections`` is True, disconnect connections that are 

2691 currently in use, potentially by other threads. Otherwise only disconnect 

2692 connections that are idle in the pool. 

2693 """ 

2694 self._checkpid() 

2695 with self._lock: 

2696 if inuse_connections: 

2697 connections = chain( 

2698 self._available_connections, self._in_use_connections 

2699 ) 

2700 else: 

2701 connections = self._available_connections 

2702 

2703 for connection in connections: 

2704 connection.disconnect() 

2705 

2706 def close(self) -> None: 

2707 """Close the pool, disconnecting all connections""" 

2708 self.disconnect() 

2709 

2710 def set_retry(self, retry: Retry) -> None: 

2711 self.connection_kwargs.update({"retry": retry}) 

2712 for conn in self._available_connections: 

2713 conn.retry = retry 

2714 for conn in self._in_use_connections: 

2715 conn.retry = retry 

2716 

2717 def re_auth_callback(self, token: TokenInterface): 

2718 with self._lock: 

2719 for conn in self._available_connections: 

2720 conn.retry.call_with_retry( 

2721 lambda: conn.send_command( 

2722 "AUTH", token.try_get("oid"), token.get_value() 

2723 ), 

2724 lambda error: self._mock(error), 

2725 ) 

2726 conn.retry.call_with_retry( 

2727 lambda: conn.read_response(), lambda error: self._mock(error) 

2728 ) 

2729 for conn in self._in_use_connections: 

2730 conn.set_re_auth_token(token) 

2731 

2732 def _get_pool_lock(self): 

2733 return self._lock 

2734 

2735 def _get_free_connections(self): 

2736 with self._lock: 

2737 return self._available_connections 

2738 

2739 def _get_in_use_connections(self): 

2740 with self._lock: 

2741 return self._in_use_connections 

2742 

2743 async def _mock(self, error: RedisError): 

2744 """ 

2745 Dummy functions, needs to be passed as error callback to retry object. 

2746 :param error: 

2747 :return: 

2748 """ 

2749 pass 

2750 

2751 

2752class BlockingConnectionPool(ConnectionPool): 

2753 """ 

2754 Thread-safe blocking connection pool:: 

2755 

2756 >>> from redis.client import Redis 

2757 >>> client = Redis(connection_pool=BlockingConnectionPool()) 

2758 

2759 It performs the same function as the default 

2760 :py:class:`~redis.ConnectionPool` implementation, in that, 

2761 it maintains a pool of reusable connections that can be shared by 

2762 multiple redis clients (safely across threads if required). 

2763 

2764 The difference is that, in the event that a client tries to get a 

2765 connection from the pool when all of connections are in use, rather than 

2766 raising a :py:class:`~redis.ConnectionError` (as the default 

2767 :py:class:`~redis.ConnectionPool` implementation does), it 

2768 makes the client wait ("blocks") for a specified number of seconds until 

2769 a connection becomes available. 

2770 

2771 Use ``max_connections`` to increase / decrease the pool size:: 

2772 

2773 >>> pool = BlockingConnectionPool(max_connections=10) 

2774 

2775 Use ``timeout`` to tell it either how many seconds to wait for a connection 

2776 to become available, or to block forever: 

2777 

2778 >>> # Block forever. 

2779 >>> pool = BlockingConnectionPool(timeout=None) 

2780 

2781 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

2782 >>> # not available. 

2783 >>> pool = BlockingConnectionPool(timeout=5) 

2784 """ 

2785 

2786 def __init__( 

2787 self, 

2788 max_connections=50, 

2789 timeout=20, 

2790 connection_class=Connection, 

2791 queue_class=LifoQueue, 

2792 **connection_kwargs, 

2793 ): 

2794 self.queue_class = queue_class 

2795 self.timeout = timeout 

2796 self._in_maintenance = False 

2797 self._locked = False 

2798 super().__init__( 

2799 connection_class=connection_class, 

2800 max_connections=max_connections, 

2801 **connection_kwargs, 

2802 ) 

2803 

2804 def reset(self): 

2805 # Create and fill up a thread safe queue with ``None`` values. 

2806 try: 

2807 if self._in_maintenance: 

2808 self._lock.acquire() 

2809 self._locked = True 

2810 self.pool = self.queue_class(self.max_connections) 

2811 while True: 

2812 try: 

2813 self.pool.put_nowait(None) 

2814 except Full: 

2815 break 

2816 

2817 # Keep a list of actual connection instances so that we can 

2818 # disconnect them later. 

2819 self._connections = [] 

2820 finally: 

2821 if self._locked: 

2822 try: 

2823 self._lock.release() 

2824 except Exception: 

2825 pass 

2826 self._locked = False 

2827 

2828 # this must be the last operation in this method. while reset() is 

2829 # called when holding _fork_lock, other threads in this process 

2830 # can call _checkpid() which compares self.pid and os.getpid() without 

2831 # holding any lock (for performance reasons). keeping this assignment 

2832 # as the last operation ensures that those other threads will also 

2833 # notice a pid difference and block waiting for the first thread to 

2834 # release _fork_lock. when each of these threads eventually acquire 

2835 # _fork_lock, they will notice that another thread already called 

2836 # reset() and they will immediately release _fork_lock and continue on. 

2837 self.pid = os.getpid() 

2838 

2839 def make_connection(self): 

2840 "Make a fresh connection." 

2841 try: 

2842 if self._in_maintenance: 

2843 self._lock.acquire() 

2844 self._locked = True 

2845 

2846 if self.cache is not None: 

2847 connection = CacheProxyConnection( 

2848 self.connection_class(**self.connection_kwargs), 

2849 self.cache, 

2850 self._lock, 

2851 ) 

2852 else: 

2853 connection = self.connection_class(**self.connection_kwargs) 

2854 self._connections.append(connection) 

2855 return connection 

2856 finally: 

2857 if self._locked: 

2858 try: 

2859 self._lock.release() 

2860 except Exception: 

2861 pass 

2862 self._locked = False 

2863 

2864 @deprecated_args( 

2865 args_to_warn=["*"], 

2866 reason="Use get_connection() without args instead", 

2867 version="5.3.0", 

2868 ) 

2869 def get_connection(self, command_name=None, *keys, **options): 

2870 """ 

2871 Get a connection, blocking for ``self.timeout`` until a connection 

2872 is available from the pool. 

2873 

2874 If the connection returned is ``None`` then creates a new connection. 

2875 Because we use a last-in first-out queue, the existing connections 

2876 (having been returned to the pool after the initial ``None`` values 

2877 were added) will be returned before ``None`` values. This means we only 

2878 create new connections when we need to, i.e.: the actual number of 

2879 connections will only increase in response to demand. 

2880 """ 

2881 # Make sure we haven't changed process. 

2882 self._checkpid() 

2883 

2884 # Try and get a connection from the pool. If one isn't available within 

2885 # self.timeout then raise a ``ConnectionError``. 

2886 connection = None 

2887 try: 

2888 if self._in_maintenance: 

2889 self._lock.acquire() 

2890 self._locked = True 

2891 try: 

2892 connection = self.pool.get(block=True, timeout=self.timeout) 

2893 except Empty: 

2894 # Note that this is not caught by the redis client and will be 

2895 # raised unless handled by application code. If you want never to 

2896 raise ConnectionError("No connection available.") 

2897 

2898 # If the ``connection`` is actually ``None`` then that's a cue to make 

2899 # a new connection to add to the pool. 

2900 if connection is None: 

2901 connection = self.make_connection() 

2902 finally: 

2903 if self._locked: 

2904 try: 

2905 self._lock.release() 

2906 except Exception: 

2907 pass 

2908 self._locked = False 

2909 

2910 try: 

2911 # ensure this connection is connected to Redis 

2912 connection.connect() 

2913 # connections that the pool provides should be ready to send 

2914 # a command. if not, the connection was either returned to the 

2915 # pool before all data has been read or the socket has been 

2916 # closed. either way, reconnect and verify everything is good. 

2917 try: 

2918 if connection.can_read(): 

2919 raise ConnectionError("Connection has data") 

2920 except (ConnectionError, TimeoutError, OSError): 

2921 connection.disconnect() 

2922 connection.connect() 

2923 if connection.can_read(): 

2924 raise ConnectionError("Connection not ready") 

2925 except BaseException: 

2926 # release the connection back to the pool so that we don't leak it 

2927 self.release(connection) 

2928 raise 

2929 

2930 return connection 

2931 

2932 def release(self, connection): 

2933 "Releases the connection back to the pool." 

2934 # Make sure we haven't changed process. 

2935 self._checkpid() 

2936 

2937 try: 

2938 if self._in_maintenance: 

2939 self._lock.acquire() 

2940 self._locked = True 

2941 if not self.owns_connection(connection): 

2942 # pool doesn't own this connection. do not add it back 

2943 # to the pool. instead add a None value which is a placeholder 

2944 # that will cause the pool to recreate the connection if 

2945 # its needed. 

2946 connection.disconnect() 

2947 self.pool.put_nowait(None) 

2948 return 

2949 if connection.should_reconnect(): 

2950 connection.disconnect() 

2951 # Put the connection back into the pool. 

2952 try: 

2953 self.pool.put_nowait(connection) 

2954 except Full: 

2955 # perhaps the pool has been reset() after a fork? regardless, 

2956 # we don't want this connection 

2957 pass 

2958 finally: 

2959 if self._locked: 

2960 try: 

2961 self._lock.release() 

2962 except Exception: 

2963 pass 

2964 self._locked = False 

2965 

2966 def disconnect(self, inuse_connections: bool = True): 

2967 "Disconnects either all connections in the pool or just the free connections." 

2968 self._checkpid() 

2969 try: 

2970 if self._in_maintenance: 

2971 self._lock.acquire() 

2972 self._locked = True 

2973 if inuse_connections: 

2974 connections = self._connections 

2975 else: 

2976 connections = self._get_free_connections() 

2977 for connection in connections: 

2978 connection.disconnect() 

2979 finally: 

2980 if self._locked: 

2981 try: 

2982 self._lock.release() 

2983 except Exception: 

2984 pass 

2985 self._locked = False 

2986 

2987 def _get_free_connections(self): 

2988 with self._lock: 

2989 return {conn for conn in self.pool.queue if conn} 

2990 

2991 def _get_in_use_connections(self): 

2992 with self._lock: 

2993 # free connections 

2994 connections_in_queue = {conn for conn in self.pool.queue if conn} 

2995 # in self._connections we keep all created connections 

2996 # so the ones that are not in the queue are the in use ones 

2997 return { 

2998 conn for conn in self._connections if conn not in connections_in_queue 

2999 } 

3000 

3001 def set_in_maintenance(self, in_maintenance: bool): 

3002 """ 

3003 Sets a flag that this Blocking ConnectionPool is in maintenance mode. 

3004 

3005 This is used to prevent new connections from being created while we are in maintenance mode. 

3006 The pool will be in maintenance mode only when we are processing a MOVING notification. 

3007 """ 

3008 self._in_maintenance = in_maintenance