Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/redis/connection.py: 20%

858 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 07:09 +0000

1import copy 

2import errno 

3import io 

4import os 

5import socket 

6import sys 

7import threading 

8import weakref 

9from abc import abstractmethod 

10from io import SEEK_END 

11from itertools import chain 

12from queue import Empty, Full, LifoQueue 

13from time import time 

14from typing import Optional, Union 

15from urllib.parse import parse_qs, unquote, urlparse 

16 

17from redis.backoff import NoBackoff 

18from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider 

19from redis.exceptions import ( 

20 AuthenticationError, 

21 AuthenticationWrongNumberOfArgsError, 

22 BusyLoadingError, 

23 ChildDeadlockedError, 

24 ConnectionError, 

25 DataError, 

26 ExecAbortError, 

27 InvalidResponse, 

28 ModuleError, 

29 NoPermissionError, 

30 NoScriptError, 

31 ReadOnlyError, 

32 RedisError, 

33 ResponseError, 

34 TimeoutError, 

35) 

36from redis.retry import Retry 

37from redis.utils import ( 

38 CRYPTOGRAPHY_AVAILABLE, 

39 HIREDIS_AVAILABLE, 

40 HIREDIS_PACK_AVAILABLE, 

41 str_if_bytes, 

42) 

43 

44try: 

45 import ssl 

46 

47 ssl_available = True 

48except ImportError: 

49 ssl_available = False 

50 

51NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK} 

52 

53if ssl_available: 

54 if hasattr(ssl, "SSLWantReadError"): 

55 NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2 

56 NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2 

57 else: 

58 NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2 

59 

60NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) 

61 

62if HIREDIS_AVAILABLE: 

63 import hiredis 

64 

65SYM_STAR = b"*" 

66SYM_DOLLAR = b"$" 

67SYM_CRLF = b"\r\n" 

68SYM_EMPTY = b"" 

69 

70SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." 

71 

72SENTINEL = object() 

73MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." 

74NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" 

75MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." 

76MODULE_EXPORTS_DATA_TYPES_ERROR = ( 

77 "Error unloading module: the module " 

78 "exports one or more module-side data " 

79 "types, can't unload" 

80) 

81# user send an AUTH cmd to a server without authorization configured 

82NO_AUTH_SET_ERROR = { 

83 # Redis >= 6.0 

84 "AUTH <password> called without any password " 

85 "configured for the default user. Are you sure " 

86 "your configuration is correct?": AuthenticationError, 

87 # Redis < 6.0 

88 "Client sent AUTH, but no password is set": AuthenticationError, 

89} 

90 

91 

92class Encoder: 

93 "Encode strings to bytes-like and decode bytes-like to strings" 

94 

95 def __init__(self, encoding, encoding_errors, decode_responses): 

96 self.encoding = encoding 

97 self.encoding_errors = encoding_errors 

98 self.decode_responses = decode_responses 

99 

100 def encode(self, value): 

101 "Return a bytestring or bytes-like representation of the value" 

102 if isinstance(value, (bytes, memoryview)): 

103 return value 

104 elif isinstance(value, bool): 

105 # special case bool since it is a subclass of int 

106 raise DataError( 

107 "Invalid input of type: 'bool'. Convert to a " 

108 "bytes, string, int or float first." 

109 ) 

110 elif isinstance(value, (int, float)): 

111 value = repr(value).encode() 

112 elif not isinstance(value, str): 

113 # a value we don't know how to deal with. throw an error 

114 typename = type(value).__name__ 

115 raise DataError( 

116 f"Invalid input of type: '{typename}'. " 

117 f"Convert to a bytes, string, int or float first." 

118 ) 

119 if isinstance(value, str): 

120 value = value.encode(self.encoding, self.encoding_errors) 

121 return value 

122 

123 def decode(self, value, force=False): 

124 "Return a unicode string from the bytes-like representation" 

125 if self.decode_responses or force: 

126 if isinstance(value, memoryview): 

127 value = value.tobytes() 

128 if isinstance(value, bytes): 

129 value = value.decode(self.encoding, self.encoding_errors) 

130 return value 

131 

132 

133class BaseParser: 

134 EXCEPTION_CLASSES = { 

135 "ERR": { 

136 "max number of clients reached": ConnectionError, 

137 "invalid password": AuthenticationError, 

138 # some Redis server versions report invalid command syntax 

139 # in lowercase 

140 "wrong number of arguments " 

141 "for 'auth' command": AuthenticationWrongNumberOfArgsError, 

142 # some Redis server versions report invalid command syntax 

143 # in uppercase 

144 "wrong number of arguments " 

145 "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, 

146 MODULE_LOAD_ERROR: ModuleError, 

147 MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, 

148 NO_SUCH_MODULE_ERROR: ModuleError, 

149 MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, 

150 **NO_AUTH_SET_ERROR, 

151 }, 

152 "WRONGPASS": AuthenticationError, 

153 "EXECABORT": ExecAbortError, 

154 "LOADING": BusyLoadingError, 

155 "NOSCRIPT": NoScriptError, 

156 "READONLY": ReadOnlyError, 

157 "NOAUTH": AuthenticationError, 

158 "NOPERM": NoPermissionError, 

159 } 

160 

161 def parse_error(self, response): 

162 "Parse an error response" 

163 error_code = response.split(" ")[0] 

164 if error_code in self.EXCEPTION_CLASSES: 

165 response = response[len(error_code) + 1 :] 

166 exception_class = self.EXCEPTION_CLASSES[error_code] 

167 if isinstance(exception_class, dict): 

168 exception_class = exception_class.get(response, ResponseError) 

169 return exception_class(response) 

170 return ResponseError(response) 

171 

172 

173class SocketBuffer: 

174 def __init__( 

175 self, socket: socket.socket, socket_read_size: int, socket_timeout: float 

176 ): 

177 self._sock = socket 

178 self.socket_read_size = socket_read_size 

179 self.socket_timeout = socket_timeout 

180 self._buffer = io.BytesIO() 

181 

182 def unread_bytes(self) -> int: 

183 """ 

184 Remaining unread length of buffer 

185 """ 

186 pos = self._buffer.tell() 

187 end = self._buffer.seek(0, SEEK_END) 

188 self._buffer.seek(pos) 

189 return end - pos 

190 

191 def _read_from_socket( 

192 self, 

193 length: Optional[int] = None, 

194 timeout: Union[float, object] = SENTINEL, 

195 raise_on_timeout: Optional[bool] = True, 

196 ) -> bool: 

197 sock = self._sock 

198 socket_read_size = self.socket_read_size 

199 marker = 0 

200 custom_timeout = timeout is not SENTINEL 

201 

202 buf = self._buffer 

203 current_pos = buf.tell() 

204 buf.seek(0, SEEK_END) 

205 if custom_timeout: 

206 sock.settimeout(timeout) 

207 try: 

208 while True: 

209 data = self._sock.recv(socket_read_size) 

210 # an empty string indicates the server shutdown the socket 

211 if isinstance(data, bytes) and len(data) == 0: 

212 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) 

213 buf.write(data) 

214 data_length = len(data) 

215 marker += data_length 

216 

217 if length is not None and length > marker: 

218 continue 

219 return True 

220 except socket.timeout: 

221 if raise_on_timeout: 

222 raise TimeoutError("Timeout reading from socket") 

223 return False 

224 except NONBLOCKING_EXCEPTIONS as ex: 

225 # if we're in nonblocking mode and the recv raises a 

226 # blocking error, simply return False indicating that 

227 # there's no data to be read. otherwise raise the 

228 # original exception. 

229 allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) 

230 if not raise_on_timeout and ex.errno == allowed: 

231 return False 

232 raise ConnectionError(f"Error while reading from socket: {ex.args}") 

233 finally: 

234 buf.seek(current_pos) 

235 if custom_timeout: 

236 sock.settimeout(self.socket_timeout) 

237 

238 def can_read(self, timeout: float) -> bool: 

239 return bool(self.unread_bytes()) or self._read_from_socket( 

240 timeout=timeout, raise_on_timeout=False 

241 ) 

242 

243 def read(self, length: int) -> bytes: 

244 length = length + 2 # make sure to read the \r\n terminator 

245 # BufferIO will return less than requested if buffer is short 

246 data = self._buffer.read(length) 

247 missing = length - len(data) 

248 if missing: 

249 # fill up the buffer and read the remainder 

250 self._read_from_socket(missing) 

251 data += self._buffer.read(missing) 

252 return data[:-2] 

253 

254 def readline(self) -> bytes: 

255 buf = self._buffer 

256 data = buf.readline() 

257 while not data.endswith(SYM_CRLF): 

258 # there's more data in the socket that we need 

259 self._read_from_socket() 

260 data += buf.readline() 

261 

262 return data[:-2] 

263 

264 def get_pos(self) -> int: 

265 """ 

266 Get current read position 

267 """ 

268 return self._buffer.tell() 

269 

270 def rewind(self, pos: int) -> None: 

271 """ 

272 Rewind the buffer to a specific position, to re-start reading 

273 """ 

274 self._buffer.seek(pos) 

275 

276 def purge(self) -> None: 

277 """ 

278 After a successful read, purge the read part of buffer 

279 """ 

280 unread = self.unread_bytes() 

281 

282 # Only if we have read all of the buffer do we truncate, to 

283 # reduce the amount of memory thrashing. This heuristic 

284 # can be changed or removed later. 

285 if unread > 0: 

286 return 

287 

288 if unread > 0: 

289 # move unread data to the front 

290 view = self._buffer.getbuffer() 

291 view[:unread] = view[-unread:] 

292 self._buffer.truncate(unread) 

293 self._buffer.seek(0) 

294 

295 def close(self) -> None: 

296 try: 

297 self._buffer.close() 

298 except Exception: 

299 # issue #633 suggests the purge/close somehow raised a 

300 # BadFileDescriptor error. Perhaps the client ran out of 

301 # memory or something else? It's probably OK to ignore 

302 # any error being raised from purge/close since we're 

303 # removing the reference to the instance below. 

304 pass 

305 self._buffer = None 

306 self._sock = None 

307 

308 

309class PythonParser(BaseParser): 

310 "Plain Python parsing class" 

311 

312 def __init__(self, socket_read_size): 

313 self.socket_read_size = socket_read_size 

314 self.encoder = None 

315 self._sock = None 

316 self._buffer = None 

317 

318 def __del__(self): 

319 try: 

320 self.on_disconnect() 

321 except Exception: 

322 pass 

323 

324 def on_connect(self, connection): 

325 "Called when the socket connects" 

326 self._sock = connection._sock 

327 self._buffer = SocketBuffer( 

328 self._sock, self.socket_read_size, connection.socket_timeout 

329 ) 

330 self.encoder = connection.encoder 

331 

332 def on_disconnect(self): 

333 "Called when the socket disconnects" 

334 self._sock = None 

335 if self._buffer is not None: 

336 self._buffer.close() 

337 self._buffer = None 

338 self.encoder = None 

339 

340 def can_read(self, timeout): 

341 return self._buffer and self._buffer.can_read(timeout) 

342 

343 def read_response(self, disable_decoding=False): 

344 pos = self._buffer.get_pos() if self._buffer else None 

345 try: 

346 result = self._read_response(disable_decoding=disable_decoding) 

347 except BaseException: 

348 if self._buffer: 

349 self._buffer.rewind(pos) 

350 raise 

351 else: 

352 self._buffer.purge() 

353 return result 

354 

355 def _read_response(self, disable_decoding=False): 

356 raw = self._buffer.readline() 

357 if not raw: 

358 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) 

359 

360 byte, response = raw[:1], raw[1:] 

361 

362 # server returned an error 

363 if byte == b"-": 

364 response = response.decode("utf-8", errors="replace") 

365 error = self.parse_error(response) 

366 # if the error is a ConnectionError, raise immediately so the user 

367 # is notified 

368 if isinstance(error, ConnectionError): 

369 raise error 

370 # otherwise, we're dealing with a ResponseError that might belong 

371 # inside a pipeline response. the connection's read_response() 

372 # and/or the pipeline's execute() will raise this error if 

373 # necessary, so just return the exception instance here. 

374 return error 

375 # single value 

376 elif byte == b"+": 

377 pass 

378 # int value 

379 elif byte == b":": 

380 return int(response) 

381 # bulk response 

382 elif byte == b"$" and response == b"-1": 

383 return None 

384 elif byte == b"$": 

385 response = self._buffer.read(int(response)) 

386 # multi-bulk response 

387 elif byte == b"*" and response == b"-1": 

388 return None 

389 elif byte == b"*": 

390 response = [ 

391 self._read_response(disable_decoding=disable_decoding) 

392 for i in range(int(response)) 

393 ] 

394 else: 

395 raise InvalidResponse(f"Protocol Error: {raw!r}") 

396 

397 if disable_decoding is False: 

398 response = self.encoder.decode(response) 

399 return response 

400 

401 

402class HiredisParser(BaseParser): 

403 "Parser class for connections using Hiredis" 

404 

405 def __init__(self, socket_read_size): 

406 if not HIREDIS_AVAILABLE: 

407 raise RedisError("Hiredis is not installed") 

408 self.socket_read_size = socket_read_size 

409 self._buffer = bytearray(socket_read_size) 

410 

411 def __del__(self): 

412 try: 

413 self.on_disconnect() 

414 except Exception: 

415 pass 

416 

417 def on_connect(self, connection, **kwargs): 

418 self._sock = connection._sock 

419 self._socket_timeout = connection.socket_timeout 

420 kwargs = { 

421 "protocolError": InvalidResponse, 

422 "replyError": self.parse_error, 

423 "errors": connection.encoder.encoding_errors, 

424 } 

425 

426 if connection.encoder.decode_responses: 

427 kwargs["encoding"] = connection.encoder.encoding 

428 self._reader = hiredis.Reader(**kwargs) 

429 self._next_response = False 

430 

431 def on_disconnect(self): 

432 self._sock = None 

433 self._reader = None 

434 self._next_response = False 

435 

436 def can_read(self, timeout): 

437 if not self._reader: 

438 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) 

439 

440 if self._next_response is False: 

441 self._next_response = self._reader.gets() 

442 if self._next_response is False: 

443 return self.read_from_socket(timeout=timeout, raise_on_timeout=False) 

444 return True 

445 

446 def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): 

447 sock = self._sock 

448 custom_timeout = timeout is not SENTINEL 

449 try: 

450 if custom_timeout: 

451 sock.settimeout(timeout) 

452 bufflen = self._sock.recv_into(self._buffer) 

453 if bufflen == 0: 

454 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) 

455 self._reader.feed(self._buffer, 0, bufflen) 

456 # data was read from the socket and added to the buffer. 

457 # return True to indicate that data was read. 

458 return True 

459 except socket.timeout: 

460 if raise_on_timeout: 

461 raise TimeoutError("Timeout reading from socket") 

462 return False 

463 except NONBLOCKING_EXCEPTIONS as ex: 

464 # if we're in nonblocking mode and the recv raises a 

465 # blocking error, simply return False indicating that 

466 # there's no data to be read. otherwise raise the 

467 # original exception. 

468 allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) 

469 if not raise_on_timeout and ex.errno == allowed: 

470 return False 

471 raise ConnectionError(f"Error while reading from socket: {ex.args}") 

472 finally: 

473 if custom_timeout: 

474 sock.settimeout(self._socket_timeout) 

475 

476 def read_response(self, disable_decoding=False): 

477 if not self._reader: 

478 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) 

479 

480 # _next_response might be cached from a can_read() call 

481 if self._next_response is not False: 

482 response = self._next_response 

483 self._next_response = False 

484 return response 

485 

486 if disable_decoding: 

487 response = self._reader.gets(False) 

488 else: 

489 response = self._reader.gets() 

490 

491 while response is False: 

492 self.read_from_socket() 

493 if disable_decoding: 

494 response = self._reader.gets(False) 

495 else: 

496 response = self._reader.gets() 

497 # if the response is a ConnectionError or the response is a list and 

498 # the first item is a ConnectionError, raise it as something bad 

499 # happened 

500 if isinstance(response, ConnectionError): 

501 raise response 

502 elif ( 

503 isinstance(response, list) 

504 and response 

505 and isinstance(response[0], ConnectionError) 

506 ): 

507 raise response[0] 

508 return response 

509 

510 

511DefaultParser: BaseParser 

512if HIREDIS_AVAILABLE: 

513 DefaultParser = HiredisParser 

514else: 

515 DefaultParser = PythonParser 

516 

517 

518class HiredisRespSerializer: 

519 def pack(self, *args): 

520 """Pack a series of arguments into the Redis protocol""" 

521 output = [] 

522 

523 if isinstance(args[0], str): 

524 args = tuple(args[0].encode().split()) + args[1:] 

525 elif b" " in args[0]: 

526 args = tuple(args[0].split()) + args[1:] 

527 try: 

528 output.append(hiredis.pack_command(args)) 

529 except TypeError: 

530 _, value, traceback = sys.exc_info() 

531 raise DataError(value).with_traceback(traceback) 

532 

533 return output 

534 

535 

536class PythonRespSerializer: 

537 def __init__(self, buffer_cutoff, encode) -> None: 

538 self._buffer_cutoff = buffer_cutoff 

539 self.encode = encode 

540 

541 def pack(self, *args): 

542 """Pack a series of arguments into the Redis protocol""" 

543 output = [] 

544 # the client might have included 1 or more literal arguments in 

545 # the command name, e.g., 'CONFIG GET'. The Redis server expects these 

546 # arguments to be sent separately, so split the first argument 

547 # manually. These arguments should be bytestrings so that they are 

548 # not encoded. 

549 if isinstance(args[0], str): 

550 args = tuple(args[0].encode().split()) + args[1:] 

551 elif b" " in args[0]: 

552 args = tuple(args[0].split()) + args[1:] 

553 

554 buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) 

555 

556 buffer_cutoff = self._buffer_cutoff 

557 for arg in map(self.encode, args): 

558 # to avoid large string mallocs, chunk the command into the 

559 # output list if we're sending large values or memoryviews 

560 arg_length = len(arg) 

561 if ( 

562 len(buff) > buffer_cutoff 

563 or arg_length > buffer_cutoff 

564 or isinstance(arg, memoryview) 

565 ): 

566 buff = SYM_EMPTY.join( 

567 (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) 

568 ) 

569 output.append(buff) 

570 output.append(arg) 

571 buff = SYM_CRLF 

572 else: 

573 buff = SYM_EMPTY.join( 

574 ( 

575 buff, 

576 SYM_DOLLAR, 

577 str(arg_length).encode(), 

578 SYM_CRLF, 

579 arg, 

580 SYM_CRLF, 

581 ) 

582 ) 

583 output.append(buff) 

584 return output 

585 

586 

587class AbstractConnection: 

588 "Manages communication to and from a Redis server" 

589 

590 def __init__( 

591 self, 

592 db=0, 

593 password=None, 

594 retry_on_timeout=False, 

595 retry_on_error=SENTINEL, 

596 encoding="utf-8", 

597 encoding_errors="strict", 

598 decode_responses=False, 

599 parser_class=DefaultParser, 

600 socket_read_size=65536, 

601 health_check_interval=0, 

602 client_name=None, 

603 username=None, 

604 retry=None, 

605 redis_connect_func=None, 

606 credential_provider: Optional[CredentialProvider] = None, 

607 command_packer=None, 

608 ): 

609 """ 

610 Initialize a new Connection. 

611 To specify a retry policy for specific errors, first set 

612 `retry_on_error` to a list of the error/s to retry on, then set 

613 `retry` to a valid `Retry` object. 

614 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

615 """ 

616 if (username or password) and credential_provider is not None: 

617 raise DataError( 

618 "'username' and 'password' cannot be passed along with 'credential_" 

619 "provider'. Please provide only one of the following arguments: \n" 

620 "1. 'password' and (optional) 'username'\n" 

621 "2. 'credential_provider'" 

622 ) 

623 self.pid = os.getpid() 

624 self.db = db 

625 self.client_name = client_name 

626 self.credential_provider = credential_provider 

627 self.password = password 

628 self.username = username 

629 self.retry_on_timeout = retry_on_timeout 

630 if retry_on_error is SENTINEL: 

631 retry_on_error = [] 

632 if retry_on_timeout: 

633 # Add TimeoutError to the errors list to retry on 

634 retry_on_error.append(TimeoutError) 

635 self.retry_on_error = retry_on_error 

636 if retry or retry_on_error: 

637 if retry is None: 

638 self.retry = Retry(NoBackoff(), 1) 

639 else: 

640 # deep-copy the Retry object as it is mutable 

641 self.retry = copy.deepcopy(retry) 

642 # Update the retry's supported errors with the specified errors 

643 self.retry.update_supported_errors(retry_on_error) 

644 else: 

645 self.retry = Retry(NoBackoff(), 0) 

646 self.health_check_interval = health_check_interval 

647 self.next_health_check = 0 

648 self.redis_connect_func = redis_connect_func 

649 self.encoder = Encoder(encoding, encoding_errors, decode_responses) 

650 self._sock = None 

651 self._socket_read_size = socket_read_size 

652 self.set_parser(parser_class) 

653 self._connect_callbacks = [] 

654 self._buffer_cutoff = 6000 

655 self._command_packer = self._construct_command_packer(command_packer) 

656 

657 def __repr__(self): 

658 repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) 

659 return f"{self.__class__.__name__}<{repr_args}>" 

660 

661 @abstractmethod 

662 def repr_pieces(self): 

663 pass 

664 

665 def __del__(self): 

666 try: 

667 self.disconnect() 

668 except Exception: 

669 pass 

670 

671 def _construct_command_packer(self, packer): 

672 if packer is not None: 

673 return packer 

674 elif HIREDIS_PACK_AVAILABLE: 

675 return HiredisRespSerializer() 

676 else: 

677 return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) 

678 

679 def register_connect_callback(self, callback): 

680 self._connect_callbacks.append(weakref.WeakMethod(callback)) 

681 

682 def clear_connect_callbacks(self): 

683 self._connect_callbacks = [] 

684 

685 def set_parser(self, parser_class): 

686 """ 

687 Creates a new instance of parser_class with socket size: 

688 _socket_read_size and assigns it to the parser for the connection 

689 :param parser_class: The required parser class 

690 """ 

691 self._parser = parser_class(socket_read_size=self._socket_read_size) 

692 

693 def connect(self): 

694 "Connects to the Redis server if not already connected" 

695 if self._sock: 

696 return 

697 try: 

698 sock = self.retry.call_with_retry( 

699 lambda: self._connect(), lambda error: self.disconnect(error) 

700 ) 

701 except socket.timeout: 

702 raise TimeoutError("Timeout connecting to server") 

703 except OSError as e: 

704 raise ConnectionError(self._error_message(e)) 

705 

706 self._sock = sock 

707 try: 

708 if self.redis_connect_func is None: 

709 # Use the default on_connect function 

710 self.on_connect() 

711 else: 

712 # Use the passed function redis_connect_func 

713 self.redis_connect_func(self) 

714 except RedisError: 

715 # clean up after any error in on_connect 

716 self.disconnect() 

717 raise 

718 

719 # run any user callbacks. right now the only internal callback 

720 # is for pubsub channel/pattern resubscription 

721 for ref in self._connect_callbacks: 

722 callback = ref() 

723 if callback: 

724 callback(self) 

725 

726 @abstractmethod 

727 def _connect(self): 

728 pass 

729 

730 @abstractmethod 

731 def _host_error(self): 

732 pass 

733 

734 @abstractmethod 

735 def _error_message(self, exception): 

736 pass 

737 

738 def on_connect(self): 

739 "Initialize the connection, authenticate and select a database" 

740 self._parser.on_connect(self) 

741 

742 # if credential provider or username and/or password are set, authenticate 

743 if self.credential_provider or (self.username or self.password): 

744 cred_provider = ( 

745 self.credential_provider 

746 or UsernamePasswordCredentialProvider(self.username, self.password) 

747 ) 

748 auth_args = cred_provider.get_credentials() 

749 # avoid checking health here -- PING will fail if we try 

750 # to check the health prior to the AUTH 

751 self.send_command("AUTH", *auth_args, check_health=False) 

752 

753 try: 

754 auth_response = self.read_response() 

755 except AuthenticationWrongNumberOfArgsError: 

756 # a username and password were specified but the Redis 

757 # server seems to be < 6.0.0 which expects a single password 

758 # arg. retry auth with just the password. 

759 # https://github.com/andymccurdy/redis-py/issues/1274 

760 self.send_command("AUTH", auth_args[-1], check_health=False) 

761 auth_response = self.read_response() 

762 

763 if str_if_bytes(auth_response) != "OK": 

764 raise AuthenticationError("Invalid Username or Password") 

765 

766 # if a client_name is given, set it 

767 if self.client_name: 

768 self.send_command("CLIENT", "SETNAME", self.client_name) 

769 if str_if_bytes(self.read_response()) != "OK": 

770 raise ConnectionError("Error setting client name") 

771 

772 # if a database is specified, switch to it 

773 if self.db: 

774 self.send_command("SELECT", self.db) 

775 if str_if_bytes(self.read_response()) != "OK": 

776 raise ConnectionError("Invalid Database") 

777 

778 def disconnect(self, *args): 

779 "Disconnects from the Redis server" 

780 self._parser.on_disconnect() 

781 if self._sock is None: 

782 return 

783 

784 if os.getpid() == self.pid: 

785 try: 

786 self._sock.shutdown(socket.SHUT_RDWR) 

787 except OSError: 

788 pass 

789 

790 try: 

791 self._sock.close() 

792 except OSError: 

793 pass 

794 self._sock = None 

795 

796 def _send_ping(self): 

797 """Send PING, expect PONG in return""" 

798 self.send_command("PING", check_health=False) 

799 if str_if_bytes(self.read_response()) != "PONG": 

800 raise ConnectionError("Bad response from PING health check") 

801 

802 def _ping_failed(self, error): 

803 """Function to call when PING fails""" 

804 self.disconnect() 

805 

806 def check_health(self): 

807 """Check the health of the connection with a PING/PONG""" 

808 if self.health_check_interval and time() > self.next_health_check: 

809 self.retry.call_with_retry(self._send_ping, self._ping_failed) 

810 

811 def send_packed_command(self, command, check_health=True): 

812 """Send an already packed command to the Redis server""" 

813 if not self._sock: 

814 self.connect() 

815 # guard against health check recursion 

816 if check_health: 

817 self.check_health() 

818 try: 

819 if isinstance(command, str): 

820 command = [command] 

821 for item in command: 

822 self._sock.sendall(item) 

823 except socket.timeout: 

824 self.disconnect() 

825 raise TimeoutError("Timeout writing to socket") 

826 except OSError as e: 

827 self.disconnect() 

828 if len(e.args) == 1: 

829 errno, errmsg = "UNKNOWN", e.args[0] 

830 else: 

831 errno = e.args[0] 

832 errmsg = e.args[1] 

833 raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") 

834 except Exception: 

835 self.disconnect() 

836 raise 

837 

838 def send_command(self, *args, **kwargs): 

839 """Pack and send a command to the Redis server""" 

840 self.send_packed_command( 

841 self._command_packer.pack(*args), 

842 check_health=kwargs.get("check_health", True), 

843 ) 

844 

845 def can_read(self, timeout=0): 

846 """Poll the socket to see if there's data that can be read.""" 

847 sock = self._sock 

848 if not sock: 

849 self.connect() 

850 

851 host_error = self._host_error() 

852 

853 try: 

854 return self._parser.can_read(timeout) 

855 except OSError as e: 

856 self.disconnect() 

857 raise ConnectionError(f"Error while reading from {host_error}: {e.args}") 

858 

859 def read_response(self, disable_decoding=False): 

860 """Read the response from a previously sent command""" 

861 

862 host_error = self._host_error() 

863 

864 try: 

865 response = self._parser.read_response(disable_decoding=disable_decoding) 

866 except socket.timeout: 

867 self.disconnect() 

868 raise TimeoutError(f"Timeout reading from {host_error}") 

869 except OSError as e: 

870 self.disconnect() 

871 raise ConnectionError( 

872 f"Error while reading from {host_error}" f" : {e.args}" 

873 ) 

874 except Exception: 

875 self.disconnect() 

876 raise 

877 

878 if self.health_check_interval: 

879 self.next_health_check = time() + self.health_check_interval 

880 

881 if isinstance(response, ResponseError): 

882 raise response 

883 return response 

884 

885 def pack_command(self, *args): 

886 """Pack a series of arguments into the Redis protocol""" 

887 return self._command_packer.pack(*args) 

888 

889 def pack_commands(self, commands): 

890 """Pack multiple commands into the Redis protocol""" 

891 output = [] 

892 pieces = [] 

893 buffer_length = 0 

894 buffer_cutoff = self._buffer_cutoff 

895 

896 for cmd in commands: 

897 for chunk in self._command_packer.pack(*cmd): 

898 chunklen = len(chunk) 

899 if ( 

900 buffer_length > buffer_cutoff 

901 or chunklen > buffer_cutoff 

902 or isinstance(chunk, memoryview) 

903 ): 

904 if pieces: 

905 output.append(SYM_EMPTY.join(pieces)) 

906 buffer_length = 0 

907 pieces = [] 

908 

909 if chunklen > buffer_cutoff or isinstance(chunk, memoryview): 

910 output.append(chunk) 

911 else: 

912 pieces.append(chunk) 

913 buffer_length += chunklen 

914 

915 if pieces: 

916 output.append(SYM_EMPTY.join(pieces)) 

917 return output 

918 

919 

920class Connection(AbstractConnection): 

921 "Manages TCP communication to and from a Redis server" 

922 

923 def __init__( 

924 self, 

925 host="localhost", 

926 port=6379, 

927 socket_timeout=None, 

928 socket_connect_timeout=None, 

929 socket_keepalive=False, 

930 socket_keepalive_options=None, 

931 socket_type=0, 

932 **kwargs, 

933 ): 

934 self.host = host 

935 self.port = int(port) 

936 self.socket_timeout = socket_timeout 

937 self.socket_connect_timeout = socket_connect_timeout or socket_timeout 

938 self.socket_keepalive = socket_keepalive 

939 self.socket_keepalive_options = socket_keepalive_options or {} 

940 self.socket_type = socket_type 

941 super().__init__(**kwargs) 

942 

943 def repr_pieces(self): 

944 pieces = [("host", self.host), ("port", self.port), ("db", self.db)] 

945 if self.client_name: 

946 pieces.append(("client_name", self.client_name)) 

947 return pieces 

948 

949 def _connect(self): 

950 "Create a TCP socket connection" 

951 # we want to mimic what socket.create_connection does to support 

952 # ipv4/ipv6, but we want to set options prior to calling 

953 # socket.connect() 

954 err = None 

955 for res in socket.getaddrinfo( 

956 self.host, self.port, self.socket_type, socket.SOCK_STREAM 

957 ): 

958 family, socktype, proto, canonname, socket_address = res 

959 sock = None 

960 try: 

961 sock = socket.socket(family, socktype, proto) 

962 # TCP_NODELAY 

963 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

964 

965 # TCP_KEEPALIVE 

966 if self.socket_keepalive: 

967 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

968 for k, v in self.socket_keepalive_options.items(): 

969 sock.setsockopt(socket.IPPROTO_TCP, k, v) 

970 

971 # set the socket_connect_timeout before we connect 

972 sock.settimeout(self.socket_connect_timeout) 

973 

974 # connect 

975 sock.connect(socket_address) 

976 

977 # set the socket_timeout now that we're connected 

978 sock.settimeout(self.socket_timeout) 

979 return sock 

980 

981 except OSError as _: 

982 err = _ 

983 if sock is not None: 

984 sock.close() 

985 

986 if err is not None: 

987 raise err 

988 raise OSError("socket.getaddrinfo returned an empty list") 

989 

990 def _host_error(self): 

991 return f"{self.host}:{self.port}" 

992 

993 def _error_message(self, exception): 

994 # args for socket.error can either be (errno, "message") 

995 # or just "message" 

996 

997 host_error = self._host_error() 

998 

999 if len(exception.args) == 1: 

1000 try: 

1001 return f"Error connecting to {host_error}. \ 

1002 {exception.args[0]}." 

1003 except AttributeError: 

1004 return f"Connection Error: {exception.args[0]}" 

1005 else: 

1006 try: 

1007 return ( 

1008 f"Error {exception.args[0]} connecting to " 

1009 f"{host_error}. {exception.args[1]}." 

1010 ) 

1011 except AttributeError: 

1012 return f"Connection Error: {exception.args[0]}" 

1013 

1014 

1015class SSLConnection(Connection): 

1016 """Manages SSL connections to and from the Redis server(s). 

1017 This class extends the Connection class, adding SSL functionality, and making 

1018 use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) 

1019 """ # noqa 

1020 

1021 def __init__( 

1022 self, 

1023 ssl_keyfile=None, 

1024 ssl_certfile=None, 

1025 ssl_cert_reqs="required", 

1026 ssl_ca_certs=None, 

1027 ssl_ca_data=None, 

1028 ssl_check_hostname=False, 

1029 ssl_ca_path=None, 

1030 ssl_password=None, 

1031 ssl_validate_ocsp=False, 

1032 ssl_validate_ocsp_stapled=False, 

1033 ssl_ocsp_context=None, 

1034 ssl_ocsp_expected_cert=None, 

1035 **kwargs, 

1036 ): 

1037 """Constructor 

1038 

1039 Args: 

1040 ssl_keyfile: Path to an ssl private key. Defaults to None. 

1041 ssl_certfile: Path to an ssl certificate. Defaults to None. 

1042 ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required). Defaults to "required". 

1043 ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. 

1044 ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. 

1045 ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False. 

1046 ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. 

1047 ssl_password: Password for unlocking an encrypted private key. Defaults to None. 

1048 

1049 ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) 

1050 ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response 

1051 ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert 

1052 ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. 

1053 

1054 Raises: 

1055 RedisError 

1056 """ # noqa 

1057 if not ssl_available: 

1058 raise RedisError("Python wasn't built with SSL support") 

1059 

1060 self.keyfile = ssl_keyfile 

1061 self.certfile = ssl_certfile 

1062 if ssl_cert_reqs is None: 

1063 ssl_cert_reqs = ssl.CERT_NONE 

1064 elif isinstance(ssl_cert_reqs, str): 

1065 CERT_REQS = { 

1066 "none": ssl.CERT_NONE, 

1067 "optional": ssl.CERT_OPTIONAL, 

1068 "required": ssl.CERT_REQUIRED, 

1069 } 

1070 if ssl_cert_reqs not in CERT_REQS: 

1071 raise RedisError( 

1072 f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" 

1073 ) 

1074 ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] 

1075 self.cert_reqs = ssl_cert_reqs 

1076 self.ca_certs = ssl_ca_certs 

1077 self.ca_data = ssl_ca_data 

1078 self.ca_path = ssl_ca_path 

1079 self.check_hostname = ssl_check_hostname 

1080 self.certificate_password = ssl_password 

1081 self.ssl_validate_ocsp = ssl_validate_ocsp 

1082 self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled 

1083 self.ssl_ocsp_context = ssl_ocsp_context 

1084 self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert 

1085 super().__init__(**kwargs) 

1086 

1087 def _connect(self): 

1088 "Wrap the socket with SSL support" 

1089 sock = super()._connect() 

1090 context = ssl.create_default_context() 

1091 context.check_hostname = self.check_hostname 

1092 context.verify_mode = self.cert_reqs 

1093 if self.certfile or self.keyfile: 

1094 context.load_cert_chain( 

1095 certfile=self.certfile, 

1096 keyfile=self.keyfile, 

1097 password=self.certificate_password, 

1098 ) 

1099 if ( 

1100 self.ca_certs is not None 

1101 or self.ca_path is not None 

1102 or self.ca_data is not None 

1103 ): 

1104 context.load_verify_locations( 

1105 cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data 

1106 ) 

1107 sslsock = context.wrap_socket(sock, server_hostname=self.host) 

1108 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: 

1109 raise RedisError("cryptography is not installed.") 

1110 

1111 if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: 

1112 raise RedisError( 

1113 "Either an OCSP staple or pure OCSP connection must be validated " 

1114 "- not both." 

1115 ) 

1116 

1117 # validation for the stapled case 

1118 if self.ssl_validate_ocsp_stapled: 

1119 import OpenSSL 

1120 

1121 from .ocsp import ocsp_staple_verifier 

1122 

1123 # if a context is provided use it - otherwise, a basic context 

1124 if self.ssl_ocsp_context is None: 

1125 staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) 

1126 staple_ctx.use_certificate_file(self.certfile) 

1127 staple_ctx.use_privatekey_file(self.keyfile) 

1128 else: 

1129 staple_ctx = self.ssl_ocsp_context 

1130 

1131 staple_ctx.set_ocsp_client_callback( 

1132 ocsp_staple_verifier, self.ssl_ocsp_expected_cert 

1133 ) 

1134 

1135 # need another socket 

1136 con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) 

1137 con.request_ocsp() 

1138 con.connect((self.host, self.port)) 

1139 con.do_handshake() 

1140 con.shutdown() 

1141 return sslsock 

1142 

1143 # pure ocsp validation 

1144 if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: 

1145 from .ocsp import OCSPVerifier 

1146 

1147 o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) 

1148 if o.is_valid(): 

1149 return sslsock 

1150 else: 

1151 raise ConnectionError("ocsp validation error") 

1152 return sslsock 

1153 

1154 

1155class UnixDomainSocketConnection(AbstractConnection): 

1156 "Manages UDS communication to and from a Redis server" 

1157 

1158 def __init__(self, path="", **kwargs): 

1159 self.path = path 

1160 super().__init__(**kwargs) 

1161 

1162 def repr_pieces(self): 

1163 pieces = [("path", self.path), ("db", self.db)] 

1164 if self.client_name: 

1165 pieces.append(("client_name", self.client_name)) 

1166 return pieces 

1167 

1168 def _connect(self): 

1169 "Create a Unix domain socket connection" 

1170 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

1171 sock.settimeout(self.socket_timeout) 

1172 sock.connect(self.path) 

1173 return sock 

1174 

1175 def _host_error(self): 

1176 return self.path 

1177 

1178 def _error_message(self, exception): 

1179 # args for socket.error can either be (errno, "message") 

1180 # or just "message" 

1181 host_error = self._host_error() 

1182 if len(exception.args) == 1: 

1183 return ( 

1184 f"Error connecting to unix socket: {host_error}. {exception.args[0]}." 

1185 ) 

1186 else: 

1187 return ( 

1188 f"Error {exception.args[0]} connecting to unix socket: " 

1189 f"{host_error}. {exception.args[1]}." 

1190 ) 

1191 

1192 

1193FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") 

1194 

1195 

1196def to_bool(value): 

1197 if value is None or value == "": 

1198 return None 

1199 if isinstance(value, str) and value.upper() in FALSE_STRINGS: 

1200 return False 

1201 return bool(value) 

1202 

1203 

1204URL_QUERY_ARGUMENT_PARSERS = { 

1205 "db": int, 

1206 "socket_timeout": float, 

1207 "socket_connect_timeout": float, 

1208 "socket_keepalive": to_bool, 

1209 "retry_on_timeout": to_bool, 

1210 "retry_on_error": list, 

1211 "max_connections": int, 

1212 "health_check_interval": int, 

1213 "ssl_check_hostname": to_bool, 

1214} 

1215 

1216 

1217def parse_url(url): 

1218 if not ( 

1219 url.startswith("redis://") 

1220 or url.startswith("rediss://") 

1221 or url.startswith("unix://") 

1222 ): 

1223 raise ValueError( 

1224 "Redis URL must specify one of the following " 

1225 "schemes (redis://, rediss://, unix://)" 

1226 ) 

1227 

1228 url = urlparse(url) 

1229 kwargs = {} 

1230 

1231 for name, value in parse_qs(url.query).items(): 

1232 if value and len(value) > 0: 

1233 value = unquote(value[0]) 

1234 parser = URL_QUERY_ARGUMENT_PARSERS.get(name) 

1235 if parser: 

1236 try: 

1237 kwargs[name] = parser(value) 

1238 except (TypeError, ValueError): 

1239 raise ValueError(f"Invalid value for `{name}` in connection URL.") 

1240 else: 

1241 kwargs[name] = value 

1242 

1243 if url.username: 

1244 kwargs["username"] = unquote(url.username) 

1245 if url.password: 

1246 kwargs["password"] = unquote(url.password) 

1247 

1248 # We only support redis://, rediss:// and unix:// schemes. 

1249 if url.scheme == "unix": 

1250 if url.path: 

1251 kwargs["path"] = unquote(url.path) 

1252 kwargs["connection_class"] = UnixDomainSocketConnection 

1253 

1254 else: # implied: url.scheme in ("redis", "rediss"): 

1255 if url.hostname: 

1256 kwargs["host"] = unquote(url.hostname) 

1257 if url.port: 

1258 kwargs["port"] = int(url.port) 

1259 

1260 # If there's a path argument, use it as the db argument if a 

1261 # querystring value wasn't specified 

1262 if url.path and "db" not in kwargs: 

1263 try: 

1264 kwargs["db"] = int(unquote(url.path).replace("/", "")) 

1265 except (AttributeError, ValueError): 

1266 pass 

1267 

1268 if url.scheme == "rediss": 

1269 kwargs["connection_class"] = SSLConnection 

1270 

1271 return kwargs 

1272 

1273 

1274class ConnectionPool: 

1275 """ 

1276 Create a connection pool. ``If max_connections`` is set, then this 

1277 object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's 

1278 limit is reached. 

1279 

1280 By default, TCP connections are created unless ``connection_class`` 

1281 is specified. Use class:`.UnixDomainSocketConnection` for 

1282 unix sockets. 

1283 

1284 Any additional keyword arguments are passed to the constructor of 

1285 ``connection_class``. 

1286 """ 

1287 

1288 @classmethod 

1289 def from_url(cls, url, **kwargs): 

1290 """ 

1291 Return a connection pool configured from the given URL. 

1292 

1293 For example:: 

1294 

1295 redis://[[username]:[password]]@localhost:6379/0 

1296 rediss://[[username]:[password]]@localhost:6379/0 

1297 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

1298 

1299 Three URL schemes are supported: 

1300 

1301 - `redis://` creates a TCP socket connection. See more at: 

1302 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

1303 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

1304 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

1305 - ``unix://``: creates a Unix Domain Socket connection. 

1306 

1307 The username, password, hostname, path and all querystring values 

1308 are passed through urllib.parse.unquote in order to replace any 

1309 percent-encoded values with their corresponding characters. 

1310 

1311 There are several ways to specify a database number. The first value 

1312 found will be used: 

1313 

1314 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

1315 2. If using the redis:// or rediss:// schemes, the path argument 

1316 of the url, e.g. redis://localhost/0 

1317 3. A ``db`` keyword argument to this function. 

1318 

1319 If none of these options are specified, the default db=0 is used. 

1320 

1321 All querystring options are cast to their appropriate Python types. 

1322 Boolean arguments can be specified with string values "True"/"False" 

1323 or "Yes"/"No". Values that cannot be properly cast cause a 

1324 ``ValueError`` to be raised. Once parsed, the querystring arguments 

1325 and keyword arguments are passed to the ``ConnectionPool``'s 

1326 class initializer. In the case of conflicting arguments, querystring 

1327 arguments always win. 

1328 """ 

1329 url_options = parse_url(url) 

1330 

1331 if "connection_class" in kwargs: 

1332 url_options["connection_class"] = kwargs["connection_class"] 

1333 

1334 kwargs.update(url_options) 

1335 return cls(**kwargs) 

1336 

1337 def __init__( 

1338 self, connection_class=Connection, max_connections=None, **connection_kwargs 

1339 ): 

1340 max_connections = max_connections or 2**31 

1341 if not isinstance(max_connections, int) or max_connections < 0: 

1342 raise ValueError('"max_connections" must be a positive integer') 

1343 

1344 self.connection_class = connection_class 

1345 self.connection_kwargs = connection_kwargs 

1346 self.max_connections = max_connections 

1347 

1348 # a lock to protect the critical section in _checkpid(). 

1349 # this lock is acquired when the process id changes, such as 

1350 # after a fork. during this time, multiple threads in the child 

1351 # process could attempt to acquire this lock. the first thread 

1352 # to acquire the lock will reset the data structures and lock 

1353 # object of this pool. subsequent threads acquiring this lock 

1354 # will notice the first thread already did the work and simply 

1355 # release the lock. 

1356 self._fork_lock = threading.Lock() 

1357 self.reset() 

1358 

1359 def __repr__(self): 

1360 return ( 

1361 f"{type(self).__name__}" 

1362 f"<{repr(self.connection_class(**self.connection_kwargs))}>" 

1363 ) 

1364 

1365 def reset(self): 

1366 self._lock = threading.Lock() 

1367 self._created_connections = 0 

1368 self._available_connections = [] 

1369 self._in_use_connections = set() 

1370 

1371 # this must be the last operation in this method. while reset() is 

1372 # called when holding _fork_lock, other threads in this process 

1373 # can call _checkpid() which compares self.pid and os.getpid() without 

1374 # holding any lock (for performance reasons). keeping this assignment 

1375 # as the last operation ensures that those other threads will also 

1376 # notice a pid difference and block waiting for the first thread to 

1377 # release _fork_lock. when each of these threads eventually acquire 

1378 # _fork_lock, they will notice that another thread already called 

1379 # reset() and they will immediately release _fork_lock and continue on. 

1380 self.pid = os.getpid() 

1381 

1382 def _checkpid(self): 

1383 # _checkpid() attempts to keep ConnectionPool fork-safe on modern 

1384 # systems. this is called by all ConnectionPool methods that 

1385 # manipulate the pool's state such as get_connection() and release(). 

1386 # 

1387 # _checkpid() determines whether the process has forked by comparing 

1388 # the current process id to the process id saved on the ConnectionPool 

1389 # instance. if these values are the same, _checkpid() simply returns. 

1390 # 

1391 # when the process ids differ, _checkpid() assumes that the process 

1392 # has forked and that we're now running in the child process. the child 

1393 # process cannot use the parent's file descriptors (e.g., sockets). 

1394 # therefore, when _checkpid() sees the process id change, it calls 

1395 # reset() in order to reinitialize the child's ConnectionPool. this 

1396 # will cause the child to make all new connection objects. 

1397 # 

1398 # _checkpid() is protected by self._fork_lock to ensure that multiple 

1399 # threads in the child process do not call reset() multiple times. 

1400 # 

1401 # there is an extremely small chance this could fail in the following 

1402 # scenario: 

1403 # 1. process A calls _checkpid() for the first time and acquires 

1404 # self._fork_lock. 

1405 # 2. while holding self._fork_lock, process A forks (the fork() 

1406 # could happen in a different thread owned by process A) 

1407 # 3. process B (the forked child process) inherits the 

1408 # ConnectionPool's state from the parent. that state includes 

1409 # a locked _fork_lock. process B will not be notified when 

1410 # process A releases the _fork_lock and will thus never be 

1411 # able to acquire the _fork_lock. 

1412 # 

1413 # to mitigate this possible deadlock, _checkpid() will only wait 5 

1414 # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in 

1415 # that time it is assumed that the child is deadlocked and a 

1416 # redis.ChildDeadlockedError error is raised. 

1417 if self.pid != os.getpid(): 

1418 acquired = self._fork_lock.acquire(timeout=5) 

1419 if not acquired: 

1420 raise ChildDeadlockedError 

1421 # reset() the instance for the new process if another thread 

1422 # hasn't already done so 

1423 try: 

1424 if self.pid != os.getpid(): 

1425 self.reset() 

1426 finally: 

1427 self._fork_lock.release() 

1428 

1429 def get_connection(self, command_name, *keys, **options): 

1430 "Get a connection from the pool" 

1431 self._checkpid() 

1432 with self._lock: 

1433 try: 

1434 connection = self._available_connections.pop() 

1435 except IndexError: 

1436 connection = self.make_connection() 

1437 self._in_use_connections.add(connection) 

1438 

1439 try: 

1440 # ensure this connection is connected to Redis 

1441 connection.connect() 

1442 # connections that the pool provides should be ready to send 

1443 # a command. if not, the connection was either returned to the 

1444 # pool before all data has been read or the socket has been 

1445 # closed. either way, reconnect and verify everything is good. 

1446 try: 

1447 if connection.can_read(): 

1448 raise ConnectionError("Connection has data") 

1449 except (ConnectionError, OSError): 

1450 connection.disconnect() 

1451 connection.connect() 

1452 if connection.can_read(): 

1453 raise ConnectionError("Connection not ready") 

1454 except BaseException: 

1455 # release the connection back to the pool so that we don't 

1456 # leak it 

1457 self.release(connection) 

1458 raise 

1459 

1460 return connection 

1461 

1462 def get_encoder(self): 

1463 "Return an encoder based on encoding settings" 

1464 kwargs = self.connection_kwargs 

1465 return Encoder( 

1466 encoding=kwargs.get("encoding", "utf-8"), 

1467 encoding_errors=kwargs.get("encoding_errors", "strict"), 

1468 decode_responses=kwargs.get("decode_responses", False), 

1469 ) 

1470 

1471 def make_connection(self): 

1472 "Create a new connection" 

1473 if self._created_connections >= self.max_connections: 

1474 raise ConnectionError("Too many connections") 

1475 self._created_connections += 1 

1476 return self.connection_class(**self.connection_kwargs) 

1477 

1478 def release(self, connection): 

1479 "Releases the connection back to the pool" 

1480 self._checkpid() 

1481 with self._lock: 

1482 try: 

1483 self._in_use_connections.remove(connection) 

1484 except KeyError: 

1485 # Gracefully fail when a connection is returned to this pool 

1486 # that the pool doesn't actually own 

1487 pass 

1488 

1489 if self.owns_connection(connection): 

1490 self._available_connections.append(connection) 

1491 else: 

1492 # pool doesn't own this connection. do not add it back 

1493 # to the pool and decrement the count so that another 

1494 # connection can take its place if needed 

1495 self._created_connections -= 1 

1496 connection.disconnect() 

1497 return 

1498 

1499 def owns_connection(self, connection): 

1500 return connection.pid == self.pid 

1501 

1502 def disconnect(self, inuse_connections=True): 

1503 """ 

1504 Disconnects connections in the pool 

1505 

1506 If ``inuse_connections`` is True, disconnect connections that are 

1507 current in use, potentially by other threads. Otherwise only disconnect 

1508 connections that are idle in the pool. 

1509 """ 

1510 self._checkpid() 

1511 with self._lock: 

1512 if inuse_connections: 

1513 connections = chain( 

1514 self._available_connections, self._in_use_connections 

1515 ) 

1516 else: 

1517 connections = self._available_connections 

1518 

1519 for connection in connections: 

1520 connection.disconnect() 

1521 

1522 def set_retry(self, retry: "Retry") -> None: 

1523 self.connection_kwargs.update({"retry": retry}) 

1524 for conn in self._available_connections: 

1525 conn.retry = retry 

1526 for conn in self._in_use_connections: 

1527 conn.retry = retry 

1528 

1529 

1530class BlockingConnectionPool(ConnectionPool): 

1531 """ 

1532 Thread-safe blocking connection pool:: 

1533 

1534 >>> from redis.client import Redis 

1535 >>> client = Redis(connection_pool=BlockingConnectionPool()) 

1536 

1537 It performs the same function as the default 

1538 :py:class:`~redis.ConnectionPool` implementation, in that, 

1539 it maintains a pool of reusable connections that can be shared by 

1540 multiple redis clients (safely across threads if required). 

1541 

1542 The difference is that, in the event that a client tries to get a 

1543 connection from the pool when all of connections are in use, rather than 

1544 raising a :py:class:`~redis.ConnectionError` (as the default 

1545 :py:class:`~redis.ConnectionPool` implementation does), it 

1546 makes the client wait ("blocks") for a specified number of seconds until 

1547 a connection becomes available. 

1548 

1549 Use ``max_connections`` to increase / decrease the pool size:: 

1550 

1551 >>> pool = BlockingConnectionPool(max_connections=10) 

1552 

1553 Use ``timeout`` to tell it either how many seconds to wait for a connection 

1554 to become available, or to block forever: 

1555 

1556 >>> # Block forever. 

1557 >>> pool = BlockingConnectionPool(timeout=None) 

1558 

1559 >>> # Raise a ``ConnectionError`` after five seconds if a connection is 

1560 >>> # not available. 

1561 >>> pool = BlockingConnectionPool(timeout=5) 

1562 """ 

1563 

1564 def __init__( 

1565 self, 

1566 max_connections=50, 

1567 timeout=20, 

1568 connection_class=Connection, 

1569 queue_class=LifoQueue, 

1570 **connection_kwargs, 

1571 ): 

1572 

1573 self.queue_class = queue_class 

1574 self.timeout = timeout 

1575 super().__init__( 

1576 connection_class=connection_class, 

1577 max_connections=max_connections, 

1578 **connection_kwargs, 

1579 ) 

1580 

1581 def reset(self): 

1582 # Create and fill up a thread safe queue with ``None`` values. 

1583 self.pool = self.queue_class(self.max_connections) 

1584 while True: 

1585 try: 

1586 self.pool.put_nowait(None) 

1587 except Full: 

1588 break 

1589 

1590 # Keep a list of actual connection instances so that we can 

1591 # disconnect them later. 

1592 self._connections = [] 

1593 

1594 # this must be the last operation in this method. while reset() is 

1595 # called when holding _fork_lock, other threads in this process 

1596 # can call _checkpid() which compares self.pid and os.getpid() without 

1597 # holding any lock (for performance reasons). keeping this assignment 

1598 # as the last operation ensures that those other threads will also 

1599 # notice a pid difference and block waiting for the first thread to 

1600 # release _fork_lock. when each of these threads eventually acquire 

1601 # _fork_lock, they will notice that another thread already called 

1602 # reset() and they will immediately release _fork_lock and continue on. 

1603 self.pid = os.getpid() 

1604 

1605 def make_connection(self): 

1606 "Make a fresh connection." 

1607 connection = self.connection_class(**self.connection_kwargs) 

1608 self._connections.append(connection) 

1609 return connection 

1610 

1611 def get_connection(self, command_name, *keys, **options): 

1612 """ 

1613 Get a connection, blocking for ``self.timeout`` until a connection 

1614 is available from the pool. 

1615 

1616 If the connection returned is ``None`` then creates a new connection. 

1617 Because we use a last-in first-out queue, the existing connections 

1618 (having been returned to the pool after the initial ``None`` values 

1619 were added) will be returned before ``None`` values. This means we only 

1620 create new connections when we need to, i.e.: the actual number of 

1621 connections will only increase in response to demand. 

1622 """ 

1623 # Make sure we haven't changed process. 

1624 self._checkpid() 

1625 

1626 # Try and get a connection from the pool. If one isn't available within 

1627 # self.timeout then raise a ``ConnectionError``. 

1628 connection = None 

1629 try: 

1630 connection = self.pool.get(block=True, timeout=self.timeout) 

1631 except Empty: 

1632 # Note that this is not caught by the redis client and will be 

1633 # raised unless handled by application code. If you want never to 

1634 raise ConnectionError("No connection available.") 

1635 

1636 # If the ``connection`` is actually ``None`` then that's a cue to make 

1637 # a new connection to add to the pool. 

1638 if connection is None: 

1639 connection = self.make_connection() 

1640 

1641 try: 

1642 # ensure this connection is connected to Redis 

1643 connection.connect() 

1644 # connections that the pool provides should be ready to send 

1645 # a command. if not, the connection was either returned to the 

1646 # pool before all data has been read or the socket has been 

1647 # closed. either way, reconnect and verify everything is good. 

1648 try: 

1649 if connection.can_read(): 

1650 raise ConnectionError("Connection has data") 

1651 except (ConnectionError, OSError): 

1652 connection.disconnect() 

1653 connection.connect() 

1654 if connection.can_read(): 

1655 raise ConnectionError("Connection not ready") 

1656 except BaseException: 

1657 # release the connection back to the pool so that we don't leak it 

1658 self.release(connection) 

1659 raise 

1660 

1661 return connection 

1662 

1663 def release(self, connection): 

1664 "Releases the connection back to the pool." 

1665 # Make sure we haven't changed process. 

1666 self._checkpid() 

1667 if not self.owns_connection(connection): 

1668 # pool doesn't own this connection. do not add it back 

1669 # to the pool. instead add a None value which is a placeholder 

1670 # that will cause the pool to recreate the connection if 

1671 # its needed. 

1672 connection.disconnect() 

1673 self.pool.put_nowait(None) 

1674 return 

1675 

1676 # Put the connection back into the pool. 

1677 try: 

1678 self.pool.put_nowait(connection) 

1679 except Full: 

1680 # perhaps the pool has been reset() after a fork? regardless, 

1681 # we don't want this connection 

1682 pass 

1683 

1684 def disconnect(self): 

1685 "Disconnects all connections in the pool." 

1686 self._checkpid() 

1687 for connection in self._connections: 

1688 connection.disconnect()