Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/redis/connection.py: 16%

716 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-23 06:16 +0000

1import copy 

2import os 

3import socket 

4import ssl 

5import sys 

6import threading 

7import weakref 

8from abc import abstractmethod 

9from itertools import chain 

10from queue import Empty, Full, LifoQueue 

11from time import time 

12from typing import Any, Callable, List, Optional, Tuple, Type, Union 

13from urllib.parse import parse_qs, unquote, urlparse 

14 

15from ._cache import ( 

16 DEFAULT_BLACKLIST, 

17 DEFAULT_EVICTION_POLICY, 

18 DEFAULT_WHITELIST, 

19 AbstractCache, 

20 _LocalCache, 

21) 

22from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser 

23from .backoff import NoBackoff 

24from .credentials import CredentialProvider, UsernamePasswordCredentialProvider 

25from .exceptions import ( 

26 AuthenticationError, 

27 AuthenticationWrongNumberOfArgsError, 

28 ChildDeadlockedError, 

29 ConnectionError, 

30 DataError, 

31 RedisError, 

32 ResponseError, 

33 TimeoutError, 

34) 

35from .retry import Retry 

36from .typing import KeysT, ResponseT 

37from .utils import ( 

38 CRYPTOGRAPHY_AVAILABLE, 

39 HIREDIS_AVAILABLE, 

40 HIREDIS_PACK_AVAILABLE, 

41 SSL_AVAILABLE, 

42 get_lib_version, 

43 str_if_bytes, 

44) 

45 

46if HIREDIS_AVAILABLE: 

47 import hiredis 

48 

49SYM_STAR = b"*" 

50SYM_DOLLAR = b"$" 

51SYM_CRLF = b"\r\n" 

52SYM_EMPTY = b"" 

53 

54DEFAULT_RESP_VERSION = 2 

55 

56SENTINEL = object() 

57 

58DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] 

59if HIREDIS_AVAILABLE: 

60 DefaultParser = _HiredisParser 

61else: 

62 DefaultParser = _RESP2Parser 

63 

64 

65class HiredisRespSerializer: 

66 def pack(self, *args: List): 

67 """Pack a series of arguments into the Redis protocol""" 

68 output = [] 

69 

70 if isinstance(args[0], str): 

71 args = tuple(args[0].encode().split()) + args[1:] 

72 elif b" " in args[0]: 

73 args = tuple(args[0].split()) + args[1:] 

74 try: 

75 output.append(hiredis.pack_command(args)) 

76 except TypeError: 

77 _, value, traceback = sys.exc_info() 

78 raise DataError(value).with_traceback(traceback) 

79 

80 return output 

81 

82 

83class PythonRespSerializer: 

84 def __init__(self, buffer_cutoff, encode) -> None: 

85 self._buffer_cutoff = buffer_cutoff 

86 self.encode = encode 

87 

88 def pack(self, *args): 

89 """Pack a series of arguments into the Redis protocol""" 

90 output = [] 

91 # the client might have included 1 or more literal arguments in 

92 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

93 # arguments to be sent separately, so split the first argument 

94 # manually. These arguments should be bytestrings so that they are 

95 # not encoded. 

96 if isinstance(args[0], str): 

97 args = tuple(args[0].encode().split()) + args[1:] 

98 elif b" " in args[0]: 

99 args = tuple(args[0].split()) + args[1:] 

100 

101 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

102 

103 buffer_cutoff = self._buffer_cutoff 

104 for arg in map(self.encode, args): 

105 # to avoid large string mallocs, chunk the command into the 

106 # output list if we're sending large values or memoryviews 

107 arg_length = len(arg) 

108 if ( 

109 len(buff) > buffer_cutoff 

110 or arg_length > buffer_cutoff 

111 or isinstance(arg, memoryview) 

112 ): 

113 buff = SYM_EMPTY.join( 

114 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

115 ) 

116 output.append(buff) 

117 output.append(arg) 

118 buff = SYM_CRLF 

119 else: 

120 buff = SYM_EMPTY.join( 

121 ( 

122 buff, 

123 SYM_DOLLAR, 

124 str(arg_length).encode(), 

125 SYM_CRLF, 

126 arg, 

127 SYM_CRLF, 

128 ) 

129 ) 

130 output.append(buff) 

131 return output 

132 

133 

134class AbstractConnection: 

135 "Manages communication to and from a Redis server" 

136 

137 def __init__( 

138 self, 

139 db: int = 0, 

140 password: Optional[str] = None, 

141 socket_timeout: Optional[float] = None, 

142 socket_connect_timeout: Optional[float] = None, 

143 retry_on_timeout: bool = False, 

144 retry_on_error=SENTINEL, 

145 encoding: str = "utf-8", 

146 encoding_errors: str = "strict", 

147 decode_responses: bool = False, 

148 parser_class=DefaultParser, 

149 socket_read_size: int = 65536, 

150 health_check_interval: int = 0, 

151 client_name: Optional[str] = None, 

152 lib_name: Optional[str] = "redis-py", 

153 lib_version: Optional[str] = get_lib_version(), 

154 username: Optional[str] = None, 

155 retry: Union[Any, None] = None, 

156 redis_connect_func: Optional[Callable[[], None]] = None, 

157 credential_provider: Optional[CredentialProvider] = None, 

158 protocol: Optional[int] = 2, 

159 command_packer: Optional[Callable[[], None]] = None, 

160 cache_enabled: bool = False, 

161 client_cache: Optional[AbstractCache] = None, 

162 cache_max_size: int = 10000, 

163 cache_ttl: int = 0, 

164 cache_policy: str = DEFAULT_EVICTION_POLICY, 

165 cache_blacklist: List[str] = DEFAULT_BLACKLIST, 

166 cache_whitelist: List[str] = DEFAULT_WHITELIST, 

167 ): 

168 """ 

169 Initialize a new Connection. 

170 To specify a retry policy for specific errors, first set 

171 `retry_on_error` to a list of the error/s to retry on, then set 

172 `retry` to a valid `Retry` object. 

173 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

174 """ 

175 if (username or password) and credential_provider is not None: 

176 raise DataError( 

177 "'username' and 'password' cannot be passed along with 'credential_" 

178 "provider'. Please provide only one of the following arguments: \n" 

179 "1. 'password' and (optional) 'username'\n" 

180 "2. 'credential_provider'" 

181 ) 

182 self.pid = os.getpid() 

183 self.db = db 

184 self.client_name = client_name 

185 self.lib_name = lib_name 

186 self.lib_version = lib_version 

187 self.credential_provider = credential_provider 

188 self.password = password 

189 self.username = username 

190 self.socket_timeout = socket_timeout 

191 if socket_connect_timeout is None: 

192 socket_connect_timeout = socket_timeout 

193 self.socket_connect_timeout = socket_connect_timeout 

194 self.retry_on_timeout = retry_on_timeout 

195 if retry_on_error is SENTINEL: 

196 retry_on_error = [] 

197 if retry_on_timeout: 

198 # Add TimeoutError to the errors list to retry on 

199 retry_on_error.append(TimeoutError) 

200 self.retry_on_error = retry_on_error 

201 if retry or retry_on_error: 

202 if retry is None: 

203 self.retry = Retry(NoBackoff(), 1) 

204 else: 

205 # deep-copy the Retry object as it is mutable 

206 self.retry = copy.deepcopy(retry) 

207 # Update the retry's supported errors with the specified errors 

208 self.retry.update_supported_errors(retry_on_error) 

209 else: 

210 self.retry = Retry(NoBackoff(), 0) 

211 self.health_check_interval = health_check_interval 

212 self.next_health_check = 0 

213 self.redis_connect_func = redis_connect_func 

214 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

215 self._sock = None 

216 self._socket_read_size = socket_read_size 

217 self.set_parser(parser_class) 

218 self._connect_callbacks = [] 

219 self._buffer_cutoff = 6000 

220 try: 

221 p = int(protocol) 

222 except TypeError: 

223 p = DEFAULT_RESP_VERSION 

224 except ValueError: 

225 raise ConnectionError("protocol must be an integer") 

226 finally: 

227 if p < 2 or p > 3: 

228 raise ConnectionError("protocol must be either 2 or 3") 

229 # p = DEFAULT_RESP_VERSION 

230 self.protocol = p 

231 self._command_packer = self._construct_command_packer(command_packer) 

232 if cache_enabled: 

233 _cache = _LocalCache(cache_max_size, cache_ttl, cache_policy) 

234 else: 

235 _cache = None 

236 self.client_cache = client_cache if client_cache is not None else _cache 

237 if self.client_cache is not None: 

238 if self.protocol not in [3, "3"]: 

239 raise RedisError( 

240 "client caching is only supported with protocol version 3 or higher" 

241 ) 

242 self.cache_blacklist = cache_blacklist 

243 self.cache_whitelist = cache_whitelist 

244 

245 def __repr__(self): 

246 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) 

247 return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" 

248 

249 @abstractmethod 

250 def repr_pieces(self): 

251 pass 

252 

253 def __del__(self): 

254 try: 

255 self.disconnect() 

256 except Exception: 

257 pass 

258 

259 def _construct_command_packer(self, packer): 

260 if packer is not None: 

261 return packer 

262 elif HIREDIS_PACK_AVAILABLE: 

263 return HiredisRespSerializer() 

264 else: 

265 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) 

266 

267 def register_connect_callback(self, callback): 

268 """ 

269 Register a callback to be called when the connection is established either 

270 initially or reconnected. This allows listeners to issue commands that 

271 are ephemeral to the connection, for example pub/sub subscription or 

272 key tracking. The callback must be a _method_ and will be kept as 

273 a weak reference. 

274 """ 

275 wm = weakref.WeakMethod(callback) 

276 if wm not in self._connect_callbacks: 

277 self._connect_callbacks.append(wm) 

278 

279 def deregister_connect_callback(self, callback): 

280 """ 

281 De-register a previously registered callback. It will no-longer receive 

282 notifications on connection events. Calling this is not required when the 

283 listener goes away, since the callbacks are kept as weak methods. 

284 """ 

285 try: 

286 self._connect_callbacks.remove(weakref.WeakMethod(callback)) 

287 except ValueError: 

288 pass 

289 

290 def set_parser(self, parser_class): 

291 """ 

292 Creates a new instance of parser_class with socket size: 

293 _socket_read_size and assigns it to the parser for the connection 

294 :param parser_class: The required parser class 

295 """ 

296 self._parser = parser_class(socket_read_size=self._socket_read_size) 

297 

298 def connect(self): 

299 "Connects to the Redis server if not already connected" 

300 if self._sock: 

301 return 

302 try: 

303 sock = self.retry.call_with_retry( 

304 lambda: self._connect(), lambda error: self.disconnect(error) 

305 ) 

306 except socket.timeout: 

307 raise TimeoutError("Timeout connecting to server") 

308 except OSError as e: 

309 raise ConnectionError(self._error_message(e)) 

310 

311 self._sock = sock 

312 try: 

313 if self.redis_connect_func is None: 

314 # Use the default on_connect function 

315 self.on_connect() 

316 else: 

317 # Use the passed function redis_connect_func 

318 self.redis_connect_func(self) 

319 except RedisError: 

320 # clean up after any error in on_connect 

321 self.disconnect() 

322 raise 

323 

324 # run any user callbacks. right now the only internal callback 

325 # is for pubsub channel/pattern resubscription 

326 # first, remove any dead weakrefs 

327 self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] 

328 for ref in self._connect_callbacks: 

329 callback = ref() 

330 if callback: 

331 callback(self) 

332 

333 @abstractmethod 

334 def _connect(self): 

335 pass 

336 

337 @abstractmethod 

338 def _host_error(self): 

339 pass 

340 

341 @abstractmethod 

342 def _error_message(self, exception): 

343 pass 

344 

345 def on_connect(self): 

346 "Initialize the connection, authenticate and select a database" 

347 self._parser.on_connect(self) 

348 parser = self._parser 

349 

350 auth_args = None 

351 # if credential provider or username and/or password are set, authenticate 

352 if self.credential_provider or (self.username or self.password): 

353 cred_provider = ( 

354 self.credential_provider 

355 or UsernamePasswordCredentialProvider(self.username, self.password) 

356 ) 

357 auth_args = cred_provider.get_credentials() 

358 

359 # if resp version is specified and we have auth args, 

360 # we need to send them via HELLO 

361 if auth_args and self.protocol not in [2, "2"]: 

362 if isinstance(self._parser, _RESP2Parser): 

363 self.set_parser(_RESP3Parser) 

364 # update cluster exception classes 

365 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

366 self._parser.on_connect(self) 

367 if len(auth_args) == 1: 

368 auth_args = ["default", auth_args[0]] 

369 self.send_command("HELLO", self.protocol, "AUTH", *auth_args) 

370 response = self.read_response() 

371 # if response.get(b"proto") != self.protocol and response.get( 

372 # "proto" 

373 # ) != self.protocol: 

374 # raise ConnectionError("Invalid RESP version") 

375 elif auth_args: 

376 # avoid checking health here -- PING will fail if we try 

377 # to check the health prior to the AUTH 

378 self.send_command("AUTH", *auth_args, check_health=False) 

379 

380 try: 

381 auth_response = self.read_response() 

382 except AuthenticationWrongNumberOfArgsError: 

383 # a username and password were specified but the Redis 

384 # server seems to be < 6.0.0 which expects a single password 

385 # arg. retry auth with just the password. 

386 # https://github.com/andymccurdy/redis-py/issues/1274 

387 self.send_command("AUTH", auth_args[-1], check_health=False) 

388 auth_response = self.read_response() 

389 

390 if str_if_bytes(auth_response) != "OK": 

391 raise AuthenticationError("Invalid Username or Password") 

392 

393 # if resp version is specified, switch to it 

394 elif self.protocol not in [2, "2"]: 

395 if isinstance(self._parser, _RESP2Parser): 

396 self.set_parser(_RESP3Parser) 

397 # update cluster exception classes 

398 self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES 

399 self._parser.on_connect(self) 

400 self.send_command("HELLO", self.protocol) 

401 response = self.read_response() 

402 if ( 

403 response.get(b"proto") != self.protocol 

404 and response.get("proto") != self.protocol 

405 ): 

406 raise ConnectionError("Invalid RESP version") 

407 

408 # if a client_name is given, set it 

409 if self.client_name: 

410 self.send_command("CLIENT", "SETNAME", self.client_name) 

411 if str_if_bytes(self.read_response()) != "OK": 

412 raise ConnectionError("Error setting client name") 

413 

414 try: 

415 # set the library name and version 

416 if self.lib_name: 

417 self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) 

418 self.read_response() 

419 except ResponseError: 

420 pass 

421 

422 try: 

423 if self.lib_version: 

424 self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) 

425 self.read_response() 

426 except ResponseError: 

427 pass 

428 

429 # if a database is specified, switch to it 

430 if self.db: 

431 self.send_command("SELECT", self.db) 

432 if str_if_bytes(self.read_response()) != "OK": 

433 raise ConnectionError("Invalid Database") 

434 

435 # if client caching is enabled, start tracking 

436 if self.client_cache: 

437 self.send_command("CLIENT", "TRACKING", "ON") 

438 self.read_response() 

439 self._parser.set_invalidation_push_handler(self._cache_invalidation_process) 

440 

441 def disconnect(self, *args): 

442 "Disconnects from the Redis server" 

443 self._parser.on_disconnect() 

444 

445 conn_sock = self._sock 

446 self._sock = None 

447 if conn_sock is None: 

448 return 

449 

450 if os.getpid() == self.pid: 

451 try: 

452 conn_sock.shutdown(socket.SHUT_RDWR) 

453 except (OSError, TypeError): 

454 pass 

455 

456 try: 

457 conn_sock.close() 

458 except OSError: 

459 pass 

460 

461 if self.client_cache: 

462 self.client_cache.flush() 

463 

464 def _send_ping(self): 

465 """Send PING, expect PONG in return""" 

466 self.send_command("PING", check_health=False) 

467 if str_if_bytes(self.read_response()) != "PONG": 

468 raise ConnectionError("Bad response from PING health check") 

469 

470 def _ping_failed(self, error): 

471 """Function to call when PING fails""" 

472 self.disconnect() 

473 

474 def check_health(self): 

475 """Check the health of the connection with a PING/PONG""" 

476 if self.health_check_interval and time() > self.next_health_check: 

477 self.retry.call_with_retry(self._send_ping, self._ping_failed) 

478 

479 def send_packed_command(self, command, check_health=True): 

480 """Send an already packed command to the Redis server""" 

481 if not self._sock: 

482 self.connect() 

483 # guard against health check recursion 

484 if check_health: 

485 self.check_health() 

486 try: 

487 if isinstance(command, str): 

488 command = [command] 

489 for item in command: 

490 self._sock.sendall(item) 

491 except socket.timeout: 

492 self.disconnect() 

493 raise TimeoutError("Timeout writing to socket") 

494 except OSError as e: 

495 self.disconnect() 

496 if len(e.args) == 1: 

497 errno, errmsg = "UNKNOWN", e.args[0] 

498 else: 

499 errno = e.args[0] 

500 errmsg = e.args[1] 

501 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") 

502 except BaseException: 

503 # BaseExceptions can be raised when a socket send operation is not 

504 # finished, e.g. due to a timeout. Ideally, a caller could then re-try 

505 # to send un-sent data. However, the send_packed_command() API 

506 # does not support it so there is no point in keeping the connection open. 

507 self.disconnect() 

508 raise 

509 

510 def send_command(self, *args, **kwargs): 

511 """Pack and send a command to the Redis server""" 

512 self.send_packed_command( 

513 self._command_packer.pack(*args), 

514 check_health=kwargs.get("check_health", True), 

515 ) 

516 

517 def can_read(self, timeout=0): 

518 """Poll the socket to see if there's data that can be read.""" 

519 sock = self._sock 

520 if not sock: 

521 self.connect() 

522 

523 host_error = self._host_error() 

524 

525 try: 

526 return self._parser.can_read(timeout) 

527 except OSError as e: 

528 self.disconnect() 

529 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

530 

531 def read_response( 

532 self, 

533 disable_decoding=False, 

534 *, 

535 disconnect_on_error=True, 

536 push_request=False, 

537 ): 

538 """Read the response from a previously sent command""" 

539 

540 host_error = self._host_error() 

541 

542 try: 

543 if self.protocol in ["3", 3] and not HIREDIS_AVAILABLE: 

544 response = self._parser.read_response( 

545 disable_decoding=disable_decoding, push_request=push_request 

546 ) 

547 else: 

548 response = self._parser.read_response(disable_decoding=disable_decoding) 

549 except socket.timeout: 

550 if disconnect_on_error: 

551 self.disconnect() 

552 raise TimeoutError(f"Timeout reading from {host_error}") 

553 except OSError as e: 

554 if disconnect_on_error: 

555 self.disconnect() 

556 raise ConnectionError( 

557 f"Error while reading from {host_error}" f" : {e.args}" 

558 ) 

559 except BaseException: 

560 # Also by default close in case of BaseException. A lot of code 

561 # relies on this behaviour when doing Command/Response pairs. 

562 # See #1128. 

563 if disconnect_on_error: 

564 self.disconnect() 

565 raise 

566 

567 if self.health_check_interval: 

568 self.next_health_check = time() + self.health_check_interval 

569 

570 if isinstance(response, ResponseError): 

571 try: 

572 raise response 

573 finally: 

574 del response # avoid creating ref cycles 

575 return response 

576 

577 def pack_command(self, *args): 

578 """Pack a series of arguments into the Redis protocol""" 

579 return self._command_packer.pack(*args) 

580 

581 def pack_commands(self, commands): 

582 """Pack multiple commands into the Redis protocol""" 

583 output = [] 

584 pieces = [] 

585 buffer_length = 0 

586 buffer_cutoff = self._buffer_cutoff 

587 

588 for cmd in commands: 

589 for chunk in self._command_packer.pack(*cmd): 

590 chunklen = len(chunk) 

591 if ( 

592 buffer_length > buffer_cutoff 

593 or chunklen > buffer_cutoff 

594 or isinstance(chunk, memoryview) 

595 ): 

596 if pieces: 

597 output.append(SYM_EMPTY.join(pieces)) 

598 buffer_length = 0 

599 pieces = [] 

600 

601 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

602 output.append(chunk) 

603 else: 

604 pieces.append(chunk) 

605 buffer_length += chunklen 

606 

607 if pieces: 

608 output.append(SYM_EMPTY.join(pieces)) 

609 return output 

610 

611 def _cache_invalidation_process( 

612 self, data: List[Union[str, Optional[List[str]]]] 

613 ) -> None: 

614 """ 

615 Invalidate (delete) all redis commands associated with a specific key. 

616 `data` is a list of strings, where the first string is the invalidation message 

617 and the second string is the list of keys to invalidate. 

618 (if the list of keys is None, then all keys are invalidated) 

619 """ 

620 if data[1] is None: 

621 self.client_cache.flush() 

622 else: 

623 for key in data[1]: 

624 self.client_cache.invalidate_key(str_if_bytes(key)) 

625 

626 def _get_from_local_cache(self, command: str): 

627 """ 

628 If the command is in the local cache, return the response 

629 """ 

630 if ( 

631 self.client_cache is None 

632 or command[0] in self.cache_blacklist 

633 or command[0] not in self.cache_whitelist 

634 ): 

635 return None 

636 while self.can_read(): 

637 self.read_response(push_request=True) 

638 return self.client_cache.get(command) 

639 

640 def _add_to_local_cache( 

641 self, command: Tuple[str], response: ResponseT, keys: List[KeysT] 

642 ): 

643 """ 

644 Add the command and response to the local cache if the command 

645 is allowed to be cached 

646 """ 

647 if ( 

648 self.client_cache is not None 

649 and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) 

650 and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) 

651 ): 

652 self.client_cache.set(command, response, keys) 

653 

654 

655class Connection(AbstractConnection): 

656 "Manages TCP communication to and from a Redis server" 

657 

658 def __init__( 

659 self, 

660 host="localhost", 

661 port=6379, 

662 socket_keepalive=False, 

663 socket_keepalive_options=None, 

664 socket_type=0, 

665 **kwargs, 

666 ): 

667 self.host = host 

668 self.port = int(port) 

669 self.socket_keepalive = socket_keepalive 

670 self.socket_keepalive_options = socket_keepalive_options or {} 

671 self.socket_type = socket_type 

672 super().__init__(**kwargs) 

673 

674 def repr_pieces(self): 

675 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

676 if self.client_name: 

677 pieces.append(("client_name", self.client_name)) 

678 return pieces 

679 

680 def _connect(self): 

681 "Create a TCP socket connection" 

682 # we want to mimic what socket.create_connection does to support 

683 # ipv4/ipv6, but we want to set options prior to calling 

684 # socket.connect() 

685 err = None 

686 for res in socket.getaddrinfo( 

687 self.host, self.port, self.socket_type, socket.SOCK_STREAM 

688 ): 

689 family, socktype, proto, canonname, socket_address = res 

690 sock = None 

691 try: 

692 sock = socket.socket(family, socktype, proto) 

693 # TCP_NODELAY 

694 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

695 

696 # TCP_KEEPALIVE 

697 if self.socket_keepalive: 

698 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

699 for k, v in self.socket_keepalive_options.items(): 

700 sock.setsockopt(socket.IPPROTO_TCP, k, v) 

701 

702 # set the socket_connect_timeout before we connect 

703 sock.settimeout(self.socket_connect_timeout) 

704 

705 # connect 

706 sock.connect(socket_address) 

707 

708 # set the socket_timeout now that we're connected 

709 sock.settimeout(self.socket_timeout) 

710 return sock 

711 

712 except OSError as _: 

713 err = _ 

714 if sock is not None: 

715 sock.close() 

716 

717 if err is not None: 

718 raise err 

719 raise OSError("socket.getaddrinfo returned an empty list") 

720 

721 def _host_error(self): 

722 return f"{self.host}:{self.port}" 

723 

724 def _error_message(self, exception): 

725 # args for socket.error can either be (errno, "message") 

726 # or just "message" 

727 

728 host_error = self._host_error() 

729 

730 if len(exception.args) == 1: 

731 try: 

732 return f"Error connecting to {host_error}. \ 

733 {exception.args[0]}." 

734 except AttributeError: 

735 return f"Connection Error: {exception.args[0]}" 

736 else: 

737 try: 

738 return ( 

739 f"Error {exception.args[0]} connecting to " 

740 f"{host_error}. {exception.args[1]}." 

741 ) 

742 except AttributeError: 

743 return f"Connection Error: {exception.args[0]}" 

744 

745 

746class SSLConnection(Connection): 

747 """Manages SSL connections to and from the Redis server(s). 

748 This class extends the Connection class, adding SSL functionality, and making 

749 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

750 """ # noqa 

751 

752 def __init__( 

753 self, 

754 ssl_keyfile=None, 

755 ssl_certfile=None, 

756 ssl_cert_reqs="required", 

757 ssl_ca_certs=None, 

758 ssl_ca_data=None, 

759 ssl_check_hostname=False, 

760 ssl_ca_path=None, 

761 ssl_password=None, 

762 ssl_validate_ocsp=False, 

763 ssl_validate_ocsp_stapled=False, 

764 ssl_ocsp_context=None, 

765 ssl_ocsp_expected_cert=None, 

766 ssl_min_version=None, 

767 **kwargs, 

768 ): 

769 """Constructor 

770 

771 Args: 

772 ssl_keyfile: Path to an ssl private key. Defaults to None. 

773 ssl_certfile: Path to an ssl certificate. Defaults to None. 

774 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required). Defaults to "required". 

775 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. 

776 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. 

777 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False. 

778 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. 

779 ssl_password: Password for unlocking an encrypted private key. Defaults to None. 

780 

781 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) 

782 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response 

783 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert 

784 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. 

785 ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. 

786 

787 Raises: 

788 RedisError 

789 """ # noqa 

790 if not SSL_AVAILABLE: 

791 raise RedisError("Python wasn't built with SSL support") 

792 

793 self.keyfile = ssl_keyfile 

794 self.certfile = ssl_certfile 

795 if ssl_cert_reqs is None: 

796 ssl_cert_reqs = ssl.CERT_NONE 

797 elif isinstance(ssl_cert_reqs, str): 

798 CERT_REQS = { 

799 "none": ssl.CERT_NONE, 

800 "optional": ssl.CERT_OPTIONAL, 

801 "required": ssl.CERT_REQUIRED, 

802 } 

803 if ssl_cert_reqs not in CERT_REQS: 

804 raise RedisError( 

805 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" 

806 ) 

807 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 

808 self.cert_reqs = ssl_cert_reqs 

809 self.ca_certs = ssl_ca_certs 

810 self.ca_data = ssl_ca_data 

811 self.ca_path = ssl_ca_path 

812 self.check_hostname = ssl_check_hostname 

813 self.certificate_password = ssl_password 

814 self.ssl_validate_ocsp = ssl_validate_ocsp 

815 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled 

816 self.ssl_ocsp_context = ssl_ocsp_context 

817 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert 

818 self.ssl_min_version = ssl_min_version 

819 super().__init__(**kwargs) 

820 

821 def _connect(self): 

822 "Wrap the socket with SSL support" 

823 sock = super()._connect() 

824 context = ssl.create_default_context() 

825 context.check_hostname = self.check_hostname 

826 context.verify_mode = self.cert_reqs 

827 if self.certfile or self.keyfile: 

828 context.load_cert_chain( 

829 certfile=self.certfile, 

830 keyfile=self.keyfile, 

831 password=self.certificate_password, 

832 ) 

833 if ( 

834 self.ca_certs is not None 

835 or self.ca_path is not None 

836 or self.ca_data is not None 

837 ): 

838 context.load_verify_locations( 

839 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

840 ) 

841 if self.ssl_min_version is not None: 

842 context.minimum_version = self.ssl_min_version 

843 sslsock = context.wrap_socket(sock, server_hostname=self.host) 

844 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: 

845 raise RedisError("cryptography is not installed.") 

846 

847 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: 

848 raise RedisError( 

849 "Either an OCSP staple or pure OCSP connection must be validated " 

850 "- not both." 

851 ) 

852 

853 # validation for the stapled case 

854 if self.ssl_validate_ocsp_stapled: 

855 import OpenSSL 

856 

857 from .ocsp import ocsp_staple_verifier 

858 

859 # if a context is provided use it - otherwise, a basic context 

860 if self.ssl_ocsp_context is None: 

861 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) 

862 staple_ctx.use_certificate_file(self.certfile) 

863 staple_ctx.use_privatekey_file(self.keyfile) 

864 else: 

865 staple_ctx = self.ssl_ocsp_context 

866 

867 staple_ctx.set_ocsp_client_callback( 

868 ocsp_staple_verifier, self.ssl_ocsp_expected_cert 

869 ) 

870 

871 # need another socket 

872 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) 

873 con.request_ocsp() 

874 con.connect((self.host, self.port)) 

875 con.do_handshake() 

876 con.shutdown() 

877 return sslsock 

878 

879 # pure ocsp validation 

880 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: 

881 from .ocsp import OCSPVerifier 

882 

883 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) 

884 if o.is_valid(): 

885 return sslsock 

886 else: 

887 raise ConnectionError("ocsp validation error") 

888 return sslsock 

889 

890 

891class UnixDomainSocketConnection(AbstractConnection): 

892 "Manages UDS communication to and from a Redis server" 

893 

894 def __init__(self, path="", socket_timeout=None, **kwargs): 

895 self.path = path 

896 self.socket_timeout = socket_timeout 

897 super().__init__(**kwargs) 

898 

899 def repr_pieces(self): 

900 pieces = [("path", self.path), ("db", self.db)] 

901 if self.client_name: 

902 pieces.append(("client_name", self.client_name)) 

903 return pieces 

904 

905 def _connect(self): 

906 "Create a Unix domain socket connection" 

907 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

908 sock.settimeout(self.socket_connect_timeout) 

909 sock.connect(self.path) 

910 sock.settimeout(self.socket_timeout) 

911 return sock 

912 

913 def _host_error(self): 

914 return self.path 

915 

916 def _error_message(self, exception): 

917 # args for socket.error can either be (errno, "message") 

918 # or just "message" 

919 host_error = self._host_error() 

920 if len(exception.args) == 1: 

921 return ( 

922 f"Error connecting to unix socket: {host_error}. {exception.args[0]}." 

923 ) 

924 else: 

925 return ( 

926 f"Error {exception.args[0]} connecting to unix socket: " 

927 f"{host_error}. {exception.args[1]}." 

928 ) 

929 

930 

931FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

932 

933 

934def to_bool(value): 

935 if value is None or value == "": 

936 return None 

937 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

938 return False 

939 return bool(value) 

940 

941 

942URL_QUERY_ARGUMENT_PARSERS = { 

943 "db": int, 

944 "socket_timeout": float, 

945 "socket_connect_timeout": float, 

946 "socket_keepalive": to_bool, 

947 "retry_on_timeout": to_bool, 

948 "retry_on_error": list, 

949 "max_connections": int, 

950 "health_check_interval": int, 

951 "ssl_check_hostname": to_bool, 

952 "timeout": float, 

953} 

954 

955 

956def parse_url(url): 

957 if not ( 

958 url.startswith("redis://") 

959 or url.startswith("rediss://") 

960 or url.startswith("unix://") 

961 ): 

962 raise ValueError( 

963 "Redis URL must specify one of the following " 

964 "schemes (redis://, rediss://, unix://)" 

965 ) 

966 

967 url = urlparse(url) 

968 kwargs = {} 

969 

970 for name, value in parse_qs(url.query).items(): 

971 if value and len(value) > 0: 

972 value = unquote(value[0]) 

973 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

974 if parser: 

975 try: 

976 kwargs[name] = parser(value) 

977 except (TypeError, ValueError): 

978 raise ValueError(f"Invalid value for `{name}` in connection URL.") 

979 else: 

980 kwargs[name] = value 

981 

982 if url.username: 

983 kwargs["username"] = unquote(url.username) 

984 if url.password: 

985 kwargs["password"] = unquote(url.password) 

986 

987 # We only support redis://, rediss:// and unix:// schemes. 

988 if url.scheme == "unix": 

989 if url.path: 

990 kwargs["path"] = unquote(url.path) 

991 kwargs["connection_class"] = UnixDomainSocketConnection 

992 

993 else: # implied: url.scheme in ("redis", "rediss"): 

994 if url.hostname: 

995 kwargs["host"] = unquote(url.hostname) 

996 if url.port: 

997 kwargs["port"] = int(url.port) 

998 

999 # If there's a path argument, use it as the db argument if a 

1000 # querystring value wasn't specified 

1001 if url.path and "db" not in kwargs: 

1002 try: 

1003 kwargs["db"] = int(unquote(url.path).replace("/", "")) 

1004 except (AttributeError, ValueError): 

1005 pass 

1006 

1007 if url.scheme == "rediss": 

1008 kwargs["connection_class"] = SSLConnection 

1009 

1010 return kwargs 

1011 

1012 

1013class ConnectionPool: 

1014 """ 

1015 Create a connection pool. ``If max_connections`` is set, then this 

1016 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's 

1017 limit is reached. 

1018 

1019 By default, TCP connections are created unless ``connection_class`` 

1020 is specified. Use class:`.UnixDomainSocketConnection` for 

1021 unix sockets. 

1022 

1023 Any additional keyword arguments are passed to the constructor of 

1024 ``connection_class``. 

1025 """ 

1026 

1027 @classmethod 

1028 def from_url(cls, url, **kwargs): 

1029 """ 

1030 Return a connection pool configured from the given URL. 

1031 

1032 For example:: 

1033 

1034 redis://[[username]:[password]]@localhost:6379/0 

1035 rediss://[[username]:[password]]@localhost:6379/0 

1036 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1037 

1038 Three URL schemes are supported: 

1039 

1040 - `redis://` creates a TCP socket connection. See more at: 

1041 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1042 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1043 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1044 - ``unix://``: creates a Unix Domain Socket connection. 

1045 

1046 The username, password, hostname, path and all querystring values 

1047 are passed through urllib.parse.unquote in order to replace any 

1048 percent-encoded values with their corresponding characters. 

1049 

1050 There are several ways to specify a database number. The first value 

1051 found will be used: 

1052 

1053 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1054 2. If using the redis:// or rediss:// schemes, the path argument 

1055 of the url, e.g. redis://localhost/0 

1056 3. A ``db`` keyword argument to this function. 

1057 

1058 If none of these options are specified, the default db=0 is used. 

1059 

1060 All querystring options are cast to their appropriate Python types. 

1061 Boolean arguments can be specified with string values "True"/"False" 

1062 or "Yes"/"No". Values that cannot be properly cast cause a 

1063 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1064 and keyword arguments are passed to the ``ConnectionPool``'s 

1065 class initializer. In the case of conflicting arguments, querystring 

1066 arguments always win. 

1067 """ 

1068 url_options = parse_url(url) 

1069 

1070 if "connection_class" in kwargs: 

1071 url_options["connection_class"] = kwargs["connection_class"] 

1072 

1073 kwargs.update(url_options) 

1074 return cls(**kwargs) 

1075 

1076 def __init__( 

1077 self, 

1078 connection_class=Connection, 

1079 max_connections: Optional[int] = None, 

1080 **connection_kwargs, 

1081 ): 

1082 max_connections = max_connections or 2**31 

1083 if not isinstance(max_connections, int) or max_connections < 0: 

1084 raise ValueError('"max_connections" must be a positive integer') 

1085 

1086 self.connection_class = connection_class 

1087 self.connection_kwargs = connection_kwargs 

1088 self.max_connections = max_connections 

1089 

1090 # a lock to protect the critical section in _checkpid(). 

1091 # this lock is acquired when the process id changes, such as 

1092 # after a fork. during this time, multiple threads in the child 

1093 # process could attempt to acquire this lock. the first thread 

1094 # to acquire the lock will reset the data structures and lock 

1095 # object of this pool. subsequent threads acquiring this lock 

1096 # will notice the first thread already did the work and simply 

1097 # release the lock. 

1098 self._fork_lock = threading.Lock() 

1099 self.reset() 

1100 

1101 def __repr__(self) -> (str, str): 

1102 return ( 

1103 f"<{type(self).__module__}.{type(self).__name__}" 

1104 f"({repr(self.connection_class(**self.connection_kwargs))})>" 

1105 ) 

1106 

1107 def reset(self) -> None: 

1108 self._lock = threading.Lock() 

1109 self._created_connections = 0 

1110 self._available_connections = [] 

1111 self._in_use_connections = set() 

1112 

1113 # this must be the last operation in this method. while reset() is 

1114 # called when holding _fork_lock, other threads in this process 

1115 # can call _checkpid() which compares self.pid and os.getpid() without 

1116 # holding any lock (for performance reasons). keeping this assignment 

1117 # as the last operation ensures that those other threads will also 

1118 # notice a pid difference and block waiting for the first thread to 

1119 # release _fork_lock. when each of these threads eventually acquire 

1120 # _fork_lock, they will notice that another thread already called 

1121 # reset() and they will immediately release _fork_lock and continue on. 

1122 self.pid = os.getpid() 

1123 

1124 def _checkpid(self) -> None: 

1125 # _checkpid() attempts to keep ConnectionPool fork-safe on modern 

1126 # systems. this is called by all ConnectionPool methods that 

1127 # manipulate the pool's state such as get_connection() and release(). 

1128 # 

1129 # _checkpid() determines whether the process has forked by comparing 

1130 # the current process id to the process id saved on the ConnectionPool 

1131 # instance. if these values are the same, _checkpid() simply returns. 

1132 # 

1133 # when the process ids differ, _checkpid() assumes that the process 

1134 # has forked and that we're now running in the child process. the child 

1135 # process cannot use the parent's file descriptors (e.g., sockets). 

1136 # therefore, when _checkpid() sees the process id change, it calls 

1137 # reset() in order to reinitialize the child's ConnectionPool. this 

1138 # will cause the child to make all new connection objects. 

1139 # 

1140 # _checkpid() is protected by self._fork_lock to ensure that multiple 

1141 # threads in the child process do not call reset() multiple times. 

1142 # 

1143 # there is an extremely small chance this could fail in the following 

1144 # scenario: 

1145 # 1. process A calls _checkpid() for the first time and acquires 

1146 # self._fork_lock. 

1147 # 2. while holding self._fork_lock, process A forks (the fork() 

1148 # could happen in a different thread owned by process A) 

1149 # 3. process B (the forked child process) inherits the 

1150 # ConnectionPool's state from the parent. that state includes 

1151 # a locked _fork_lock. process B will not be notified when 

1152 # process A releases the _fork_lock and will thus never be 

1153 # able to acquire the _fork_lock. 

1154 # 

1155 # to mitigate this possible deadlock, _checkpid() will only wait 5 

1156 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in 

1157 # that time it is assumed that the child is deadlocked and a 

1158 # redis.ChildDeadlockedError error is raised. 

1159 if self.pid != os.getpid(): 

1160 acquired = self._fork_lock.acquire(timeout=5) 

1161 if not acquired: 

1162 raise ChildDeadlockedError 

1163 # reset() the instance for the new process if another thread 

1164 # hasn't already done so 

1165 try: 

1166 if self.pid != os.getpid(): 

1167 self.reset() 

1168 finally: 

1169 self._fork_lock.release() 

1170 

1171 def get_connection(self, command_name: str, *keys, **options) -> "Connection": 

1172 "Get a connection from the pool" 

1173 self._checkpid() 

1174 with self._lock: 

1175 try: 

1176 connection = self._available_connections.pop() 

1177 except IndexError: 

1178 connection = self.make_connection() 

1179 self._in_use_connections.add(connection) 

1180 

1181 try: 

1182 # ensure this connection is connected to Redis 

1183 connection.connect() 

1184 # if client caching is not enabled connections that the pool 

1185 # provides should be ready to send a command. 

1186 # if not, the connection was either returned to the 

1187 # pool before all data has been read or the socket has been 

1188 # closed. either way, reconnect and verify everything is good. 

1189 # (if caching enabled the connection will not always be ready 

1190 # to send a command because it may contain invalidation messages) 

1191 try: 

1192 if connection.can_read() and connection.client_cache is None: 

1193 raise ConnectionError("Connection has data") 

1194 except (ConnectionError, OSError): 

1195 connection.disconnect() 

1196 connection.connect() 

1197 if connection.can_read(): 

1198 raise ConnectionError("Connection not ready") 

1199 except BaseException: 

1200 # release the connection back to the pool so that we don't 

1201 # leak it 

1202 self.release(connection) 

1203 raise 

1204 

1205 return connection 

1206 

1207 def get_encoder(self) -> Encoder: 

1208 "Return an encoder based on encoding settings" 

1209 kwargs = self.connection_kwargs 

1210 return Encoder( 

1211 encoding=kwargs.get("encoding", "utf-8"), 

1212 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1213 decode_responses=kwargs.get("decode_responses", False), 

1214 ) 

1215 

1216 def make_connection(self) -> "Connection": 

1217 "Create a new connection" 

1218 if self._created_connections >= self.max_connections: 

1219 raise ConnectionError("Too many connections") 

1220 self._created_connections += 1 

1221 return self.connection_class(**self.connection_kwargs) 

1222 

1223 def release(self, connection: "Connection") -> None: 

1224 "Releases the connection back to the pool" 

1225 self._checkpid() 

1226 with self._lock: 

1227 try: 

1228 self._in_use_connections.remove(connection) 

1229 except KeyError: 

1230 # Gracefully fail when a connection is returned to this pool 

1231 # that the pool doesn't actually own 

1232 pass 

1233 

1234 if self.owns_connection(connection): 

1235 self._available_connections.append(connection) 

1236 else: 

1237 # pool doesn't own this connection. do not add it back 

1238 # to the pool and decrement the count so that another 

1239 # connection can take its place if needed 

1240 self._created_connections -= 1 

1241 connection.disconnect() 

1242 return 

1243 

1244 def owns_connection(self, connection: "Connection") -> int: 

1245 return connection.pid == self.pid 

1246 

1247 def disconnect(self, inuse_connections: bool = True) -> None: 

1248 """ 

1249 Disconnects connections in the pool 

1250 

1251 If ``inuse_connections`` is True, disconnect connections that are 

1252 current in use, potentially by other threads. Otherwise only disconnect 

1253 connections that are idle in the pool. 

1254 """ 

1255 self._checkpid() 

1256 with self._lock: 

1257 if inuse_connections: 

1258 connections = chain( 

1259 self._available_connections, self._in_use_connections 

1260 ) 

1261 else: 

1262 connections = self._available_connections 

1263 

1264 for connection in connections: 

1265 connection.disconnect() 

1266 

1267 def close(self) -> None: 

1268 """Close the pool, disconnecting all connections""" 

1269 self.disconnect() 

1270 

1271 def set_retry(self, retry: "Retry") -> None: 

1272 self.connection_kwargs.update({"retry": retry}) 

1273 for conn in self._available_connections: 

1274 conn.retry = retry 

1275 for conn in self._in_use_connections: 

1276 conn.retry = retry 

1277 

1278 def flush_cache(self): 

1279 self._checkpid() 

1280 with self._lock: 

1281 connections = chain(self._available_connections, self._in_use_connections) 

1282 

1283 for connection in connections: 

1284 try: 

1285 connection.client_cache.flush() 

1286 except AttributeError: 

1287 # cache is not enabled 

1288 pass 

1289 

1290 def delete_command_from_cache(self, command: str): 

1291 self._checkpid() 

1292 with self._lock: 

1293 connections = chain(self._available_connections, self._in_use_connections) 

1294 

1295 for connection in connections: 

1296 try: 

1297 connection.client_cache.delete_command(command) 

1298 except AttributeError: 

1299 # cache is not enabled 

1300 pass 

1301 

1302 def invalidate_key_from_cache(self, key: str): 

1303 self._checkpid() 

1304 with self._lock: 

1305 connections = chain(self._available_connections, self._in_use_connections) 

1306 

1307 for connection in connections: 

1308 try: 

1309 connection.client_cache.invalidate_key(key) 

1310 except AttributeError: 

1311 # cache is not enabled 

1312 pass 

1313 

1314 

1315class BlockingConnectionPool(ConnectionPool): 

1316 """ 

1317 Thread-safe blocking connection pool:: 

1318 

1319 >>> from redis.client import Redis 

1320 >>> client = Redis(connection_pool=BlockingConnectionPool()) 

1321 

1322 It performs the same function as the default 

1323 :py:class:`~redis.ConnectionPool` implementation, in that, 

1324 it maintains a pool of reusable connections that can be shared by 

1325 multiple redis clients (safely across threads if required). 

1326 

1327 The difference is that, in the event that a client tries to get a 

1328 connection from the pool when all of connections are in use, rather than 

1329 raising a :py:class:`~redis.ConnectionError` (as the default 

1330 :py:class:`~redis.ConnectionPool` implementation does), it 

1331 makes the client wait ("blocks") for a specified number of seconds until 

1332 a connection becomes available. 

1333 

1334 Use ``max_connections`` to increase / decrease the pool size:: 

1335 

1336 >>> pool = BlockingConnectionPool(max_connections=10) 

1337 

1338 Use ``timeout`` to tell it either how many seconds to wait for a connection 

1339 to become available, or to block forever: 

1340 

1341 >>> # Block forever. 

1342 >>> pool = BlockingConnectionPool(timeout=None) 

1343 

1344 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

1345 >>> # not available. 

1346 >>> pool = BlockingConnectionPool(timeout=5) 

1347 """ 

1348 

1349 def __init__( 

1350 self, 

1351 max_connections=50, 

1352 timeout=20, 

1353 connection_class=Connection, 

1354 queue_class=LifoQueue, 

1355 **connection_kwargs, 

1356 ): 

1357 self.queue_class = queue_class 

1358 self.timeout = timeout 

1359 super().__init__( 

1360 connection_class=connection_class, 

1361 max_connections=max_connections, 

1362 **connection_kwargs, 

1363 ) 

1364 

1365 def reset(self): 

1366 # Create and fill up a thread safe queue with ``None`` values. 

1367 self.pool = self.queue_class(self.max_connections) 

1368 while True: 

1369 try: 

1370 self.pool.put_nowait(None) 

1371 except Full: 

1372 break 

1373 

1374 # Keep a list of actual connection instances so that we can 

1375 # disconnect them later. 

1376 self._connections = [] 

1377 

1378 # this must be the last operation in this method. while reset() is 

1379 # called when holding _fork_lock, other threads in this process 

1380 # can call _checkpid() which compares self.pid and os.getpid() without 

1381 # holding any lock (for performance reasons). keeping this assignment 

1382 # as the last operation ensures that those other threads will also 

1383 # notice a pid difference and block waiting for the first thread to 

1384 # release _fork_lock. when each of these threads eventually acquire 

1385 # _fork_lock, they will notice that another thread already called 

1386 # reset() and they will immediately release _fork_lock and continue on. 

1387 self.pid = os.getpid() 

1388 

1389 def make_connection(self): 

1390 "Make a fresh connection." 

1391 connection = self.connection_class(**self.connection_kwargs) 

1392 self._connections.append(connection) 

1393 return connection 

1394 

1395 def get_connection(self, command_name, *keys, **options): 

1396 """ 

1397 Get a connection, blocking for ``self.timeout`` until a connection 

1398 is available from the pool. 

1399 

1400 If the connection returned is ``None`` then creates a new connection. 

1401 Because we use a last-in first-out queue, the existing connections 

1402 (having been returned to the pool after the initial ``None`` values 

1403 were added) will be returned before ``None`` values. This means we only 

1404 create new connections when we need to, i.e.: the actual number of 

1405 connections will only increase in response to demand. 

1406 """ 

1407 # Make sure we haven't changed process. 

1408 self._checkpid() 

1409 

1410 # Try and get a connection from the pool. If one isn't available within 

1411 # self.timeout then raise a ``ConnectionError``. 

1412 connection = None 

1413 try: 

1414 connection = self.pool.get(block=True, timeout=self.timeout) 

1415 except Empty: 

1416 # Note that this is not caught by the redis client and will be 

1417 # raised unless handled by application code. If you want never to 

1418 raise ConnectionError("No connection available.") 

1419 

1420 # If the ``connection`` is actually ``None`` then that's a cue to make 

1421 # a new connection to add to the pool. 

1422 if connection is None: 

1423 connection = self.make_connection() 

1424 

1425 try: 

1426 # ensure this connection is connected to Redis 

1427 connection.connect() 

1428 # connections that the pool provides should be ready to send 

1429 # a command. if not, the connection was either returned to the 

1430 # pool before all data has been read or the socket has been 

1431 # closed. either way, reconnect and verify everything is good. 

1432 try: 

1433 if connection.can_read(): 

1434 raise ConnectionError("Connection has data") 

1435 except (ConnectionError, OSError): 

1436 connection.disconnect() 

1437 connection.connect() 

1438 if connection.can_read(): 

1439 raise ConnectionError("Connection not ready") 

1440 except BaseException: 

1441 # release the connection back to the pool so that we don't leak it 

1442 self.release(connection) 

1443 raise 

1444 

1445 return connection 

1446 

1447 def release(self, connection): 

1448 "Releases the connection back to the pool." 

1449 # Make sure we haven't changed process. 

1450 self._checkpid() 

1451 if not self.owns_connection(connection): 

1452 # pool doesn't own this connection. do not add it back 

1453 # to the pool. instead add a None value which is a placeholder 

1454 # that will cause the pool to recreate the connection if 

1455 # its needed. 

1456 connection.disconnect() 

1457 self.pool.put_nowait(None) 

1458 return 

1459 

1460 # Put the connection back into the pool. 

1461 try: 

1462 self.pool.put_nowait(connection) 

1463 except Full: 

1464 # perhaps the pool has been reset() after a fork? regardless, 

1465 # we don't want this connection 

1466 pass 

1467 

1468 def disconnect(self): 

1469 "Disconnects all connections in the pool." 

1470 self._checkpid() 

1471 for connection in self._connections: 

1472 connection.disconnect()