Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/fsspec/caching.py: 19%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

473 statements  

1from __future__ import annotations 

2 

3import collections 

4import functools 

5import logging 

6import math 

7import os 

8import threading 

9import warnings 

10from concurrent.futures import Future, ThreadPoolExecutor 

11from itertools import groupby 

12from operator import itemgetter 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 ClassVar, 

18 Generic, 

19 NamedTuple, 

20 Optional, 

21 OrderedDict, 

22 TypeVar, 

23) 

24 

25if TYPE_CHECKING: 

26 import mmap 

27 

28 from typing_extensions import ParamSpec 

29 

30 P = ParamSpec("P") 

31else: 

32 P = TypeVar("P") 

33 

34T = TypeVar("T") 

35 

36 

37logger = logging.getLogger("fsspec") 

38 

39Fetcher = Callable[[int, int], bytes] # Maps (start, end) to bytes 

40MultiFetcher = Callable[[list[int, int]], bytes] # Maps [(start, end)] to bytes 

41 

42 

43class BaseCache: 

44 """Pass-though cache: doesn't keep anything, calls every time 

45 

46 Acts as base class for other cachers 

47 

48 Parameters 

49 ---------- 

50 blocksize: int 

51 How far to read ahead in numbers of bytes 

52 fetcher: func 

53 Function of the form f(start, end) which gets bytes from remote as 

54 specified 

55 size: int 

56 How big this file is 

57 """ 

58 

59 name: ClassVar[str] = "none" 

60 

61 def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None: 

62 self.blocksize = blocksize 

63 self.nblocks = 0 

64 self.fetcher = fetcher 

65 self.size = size 

66 self.hit_count = 0 

67 self.miss_count = 0 

68 # the bytes that we actually requested 

69 self.total_requested_bytes = 0 

70 

71 def _fetch(self, start: int | None, stop: int | None) -> bytes: 

72 if start is None: 

73 start = 0 

74 if stop is None: 

75 stop = self.size 

76 if start >= self.size or start >= stop: 

77 return b"" 

78 return self.fetcher(start, stop) 

79 

80 def _reset_stats(self) -> None: 

81 """Reset hit and miss counts for a more ganular report e.g. by file.""" 

82 self.hit_count = 0 

83 self.miss_count = 0 

84 self.total_requested_bytes = 0 

85 

86 def _log_stats(self) -> str: 

87 """Return a formatted string of the cache statistics.""" 

88 if self.hit_count == 0 and self.miss_count == 0: 

89 # a cache that does nothing, this is for logs only 

90 return "" 

91 return f" , {self.name}: {self.hit_count} hits, {self.miss_count} misses, {self.total_requested_bytes} total requested bytes" 

92 

93 def __repr__(self) -> str: 

94 # TODO: use rich for better formatting 

95 return f""" 

96 <{self.__class__.__name__}: 

97 block size : {self.blocksize} 

98 block count : {self.nblocks} 

99 file size : {self.size} 

100 cache hits : {self.hit_count} 

101 cache misses: {self.miss_count} 

102 total requested bytes: {self.total_requested_bytes}> 

103 """ 

104 

105 

106class MMapCache(BaseCache): 

107 """memory-mapped sparse file cache 

108 

109 Opens temporary file, which is filled blocks-wise when data is requested. 

110 Ensure there is enough disc space in the temporary location. 

111 

112 This cache method might only work on posix 

113 

114 Parameters 

115 ---------- 

116 blocksize: int 

117 How far to read ahead in numbers of bytes 

118 fetcher: Fetcher 

119 Function of the form f(start, end) which gets bytes from remote as 

120 specified 

121 size: int 

122 How big this file is 

123 location: str 

124 Where to create the temporary file. If None, a temporary file is 

125 created using tempfile.TemporaryFile(). 

126 blocks: set[int] 

127 Set of block numbers that have already been fetched. If None, an empty 

128 set is created. 

129 multi_fetcher: MultiFetcher 

130 Function of the form f([(start, end)]) which gets bytes from remote 

131 as specified. This function is used to fetch multiple blocks at once. 

132 If not specified, the fetcher function is used instead. 

133 """ 

134 

135 name = "mmap" 

136 

137 def __init__( 

138 self, 

139 blocksize: int, 

140 fetcher: Fetcher, 

141 size: int, 

142 location: str | None = None, 

143 blocks: set[int] | None = None, 

144 multi_fetcher: MultiFetcher | None = None, 

145 ) -> None: 

146 super().__init__(blocksize, fetcher, size) 

147 self.blocks = set() if blocks is None else blocks 

148 self.location = location 

149 self.multi_fetcher = multi_fetcher 

150 self.cache = self._makefile() 

151 

152 def _makefile(self) -> mmap.mmap | bytearray: 

153 import mmap 

154 import tempfile 

155 

156 if self.size == 0: 

157 return bytearray() 

158 

159 # posix version 

160 if self.location is None or not os.path.exists(self.location): 

161 if self.location is None: 

162 fd = tempfile.TemporaryFile() 

163 self.blocks = set() 

164 else: 

165 fd = open(self.location, "wb+") 

166 fd.seek(self.size - 1) 

167 fd.write(b"1") 

168 fd.flush() 

169 else: 

170 fd = open(self.location, "r+b") 

171 

172 return mmap.mmap(fd.fileno(), self.size) 

173 

174 def _fetch(self, start: int | None, end: int | None) -> bytes: 

175 logger.debug(f"MMap cache fetching {start}-{end}") 

176 if start is None: 

177 start = 0 

178 if end is None: 

179 end = self.size 

180 if start >= self.size or start >= end: 

181 return b"" 

182 start_block = start // self.blocksize 

183 end_block = end // self.blocksize 

184 block_range = range(start_block, end_block + 1) 

185 # Determine which blocks need to be fetched. This sequence is sorted by construction. 

186 need = (i for i in block_range if i not in self.blocks) 

187 # Count the number of blocks already cached 

188 self.hit_count += sum(1 for i in block_range if i in self.blocks) 

189 

190 ranges = [] 

191 

192 # Consolidate needed blocks. 

193 # Algorithm adapted from Python 2.x itertools documentation. 

194 # We are grouping an enumerated sequence of blocks. By comparing when the difference 

195 # between an ascending range (provided by enumerate) and the needed block numbers 

196 # we can detect when the block number skips values. The key computes this difference. 

197 # Whenever the difference changes, we know that we have previously cached block(s), 

198 # and a new group is started. In other words, this algorithm neatly groups 

199 # runs of consecutive block numbers so they can be fetched together. 

200 for _, _blocks in groupby(enumerate(need), key=lambda x: x[0] - x[1]): 

201 # Extract the blocks from the enumerated sequence 

202 _blocks = tuple(map(itemgetter(1), _blocks)) 

203 # Compute start of first block 

204 sstart = _blocks[0] * self.blocksize 

205 # Compute the end of the last block. Last block may not be full size. 

206 send = min(_blocks[-1] * self.blocksize + self.blocksize, self.size) 

207 

208 # Fetch bytes (could be multiple consecutive blocks) 

209 self.total_requested_bytes += send - sstart 

210 logger.debug( 

211 f"MMap get blocks {_blocks[0]}-{_blocks[-1]} ({sstart}-{send})" 

212 ) 

213 ranges.append((sstart, send)) 

214 

215 # Update set of cached blocks 

216 self.blocks.update(_blocks) 

217 # Update cache statistics with number of blocks we had to cache 

218 self.miss_count += len(_blocks) 

219 

220 if not ranges: 

221 return self.cache[start:end] 

222 

223 if self.multi_fetcher: 

224 logger.debug(f"MMap get blocks {ranges}") 

225 for idx, r in enumerate(self.multi_fetcher(ranges)): 

226 (sstart, send) = ranges[idx] 

227 logger.debug(f"MMap copy block ({sstart}-{send}") 

228 self.cache[sstart:send] = r 

229 else: 

230 for sstart, send in ranges: 

231 logger.debug(f"MMap get block ({sstart}-{send}") 

232 self.cache[sstart:send] = self.fetcher(sstart, send) 

233 

234 return self.cache[start:end] 

235 

236 def __getstate__(self) -> dict[str, Any]: 

237 state = self.__dict__.copy() 

238 # Remove the unpicklable entries. 

239 del state["cache"] 

240 return state 

241 

242 def __setstate__(self, state: dict[str, Any]) -> None: 

243 # Restore instance attributes 

244 self.__dict__.update(state) 

245 self.cache = self._makefile() 

246 

247 

248class ReadAheadCache(BaseCache): 

249 """Cache which reads only when we get beyond a block of data 

250 

251 This is a much simpler version of BytesCache, and does not attempt to 

252 fill holes in the cache or keep fragments alive. It is best suited to 

253 many small reads in a sequential order (e.g., reading lines from a file). 

254 """ 

255 

256 name = "readahead" 

257 

258 def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None: 

259 super().__init__(blocksize, fetcher, size) 

260 self.cache = b"" 

261 self.start = 0 

262 self.end = 0 

263 

264 def _fetch(self, start: int | None, end: int | None) -> bytes: 

265 if start is None: 

266 start = 0 

267 if end is None or end > self.size: 

268 end = self.size 

269 if start >= self.size or start >= end: 

270 return b"" 

271 l = end - start 

272 if start >= self.start and end <= self.end: 

273 # cache hit 

274 self.hit_count += 1 

275 return self.cache[start - self.start : end - self.start] 

276 elif self.start <= start < self.end: 

277 # partial hit 

278 self.miss_count += 1 

279 part = self.cache[start - self.start :] 

280 l -= len(part) 

281 start = self.end 

282 else: 

283 # miss 

284 self.miss_count += 1 

285 part = b"" 

286 end = min(self.size, end + self.blocksize) 

287 self.total_requested_bytes += end - start 

288 self.cache = self.fetcher(start, end) # new block replaces old 

289 self.start = start 

290 self.end = self.start + len(self.cache) 

291 return part + self.cache[:l] 

292 

293 

294class FirstChunkCache(BaseCache): 

295 """Caches the first block of a file only 

296 

297 This may be useful for file types where the metadata is stored in the header, 

298 but is randomly accessed. 

299 """ 

300 

301 name = "first" 

302 

303 def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None: 

304 if blocksize > size: 

305 # this will buffer the whole thing 

306 blocksize = size 

307 super().__init__(blocksize, fetcher, size) 

308 self.cache: bytes | None = None 

309 

310 def _fetch(self, start: int | None, end: int | None) -> bytes: 

311 start = start or 0 

312 if start > self.size: 

313 logger.debug("FirstChunkCache: requested start > file size") 

314 return b"" 

315 

316 end = min(end, self.size) 

317 

318 if start < self.blocksize: 

319 if self.cache is None: 

320 self.miss_count += 1 

321 if end > self.blocksize: 

322 self.total_requested_bytes += end 

323 data = self.fetcher(0, end) 

324 self.cache = data[: self.blocksize] 

325 return data[start:] 

326 self.cache = self.fetcher(0, self.blocksize) 

327 self.total_requested_bytes += self.blocksize 

328 part = self.cache[start:end] 

329 if end > self.blocksize: 

330 self.total_requested_bytes += end - self.blocksize 

331 part += self.fetcher(self.blocksize, end) 

332 self.hit_count += 1 

333 return part 

334 else: 

335 self.miss_count += 1 

336 self.total_requested_bytes += end - start 

337 return self.fetcher(start, end) 

338 

339 

340class BlockCache(BaseCache): 

341 """ 

342 Cache holding memory as a set of blocks. 

343 

344 Requests are only ever made ``blocksize`` at a time, and are 

345 stored in an LRU cache. The least recently accessed block is 

346 discarded when more than ``maxblocks`` are stored. 

347 

348 Parameters 

349 ---------- 

350 blocksize : int 

351 The number of bytes to store in each block. 

352 Requests are only ever made for ``blocksize``, so this 

353 should balance the overhead of making a request against 

354 the granularity of the blocks. 

355 fetcher : Callable 

356 size : int 

357 The total size of the file being cached. 

358 maxblocks : int 

359 The maximum number of blocks to cache for. The maximum memory 

360 use for this cache is then ``blocksize * maxblocks``. 

361 """ 

362 

363 name = "blockcache" 

364 

365 def __init__( 

366 self, blocksize: int, fetcher: Fetcher, size: int, maxblocks: int = 32 

367 ) -> None: 

368 super().__init__(blocksize, fetcher, size) 

369 self.nblocks = math.ceil(size / blocksize) 

370 self.maxblocks = maxblocks 

371 self._fetch_block_cached = functools.lru_cache(maxblocks)(self._fetch_block) 

372 

373 def cache_info(self): 

374 """ 

375 The statistics on the block cache. 

376 

377 Returns 

378 ------- 

379 NamedTuple 

380 Returned directly from the LRU Cache used internally. 

381 """ 

382 return self._fetch_block_cached.cache_info() 

383 

384 def __getstate__(self) -> dict[str, Any]: 

385 state = self.__dict__ 

386 del state["_fetch_block_cached"] 

387 return state 

388 

389 def __setstate__(self, state: dict[str, Any]) -> None: 

390 self.__dict__.update(state) 

391 self._fetch_block_cached = functools.lru_cache(state["maxblocks"])( 

392 self._fetch_block 

393 ) 

394 

395 def _fetch(self, start: int | None, end: int | None) -> bytes: 

396 if start is None: 

397 start = 0 

398 if end is None: 

399 end = self.size 

400 if start >= self.size or start >= end: 

401 return b"" 

402 

403 # byte position -> block numbers 

404 start_block_number = start // self.blocksize 

405 end_block_number = end // self.blocksize 

406 

407 # these are cached, so safe to do multiple calls for the same start and end. 

408 for block_number in range(start_block_number, end_block_number + 1): 

409 self._fetch_block_cached(block_number) 

410 

411 return self._read_cache( 

412 start, 

413 end, 

414 start_block_number=start_block_number, 

415 end_block_number=end_block_number, 

416 ) 

417 

418 def _fetch_block(self, block_number: int) -> bytes: 

419 """ 

420 Fetch the block of data for `block_number`. 

421 """ 

422 if block_number > self.nblocks: 

423 raise ValueError( 

424 f"'block_number={block_number}' is greater than " 

425 f"the number of blocks ({self.nblocks})" 

426 ) 

427 

428 start = block_number * self.blocksize 

429 end = start + self.blocksize 

430 self.total_requested_bytes += end - start 

431 self.miss_count += 1 

432 logger.info("BlockCache fetching block %d", block_number) 

433 block_contents = super()._fetch(start, end) 

434 return block_contents 

435 

436 def _read_cache( 

437 self, start: int, end: int, start_block_number: int, end_block_number: int 

438 ) -> bytes: 

439 """ 

440 Read from our block cache. 

441 

442 Parameters 

443 ---------- 

444 start, end : int 

445 The start and end byte positions. 

446 start_block_number, end_block_number : int 

447 The start and end block numbers. 

448 """ 

449 start_pos = start % self.blocksize 

450 end_pos = end % self.blocksize 

451 

452 self.hit_count += 1 

453 if start_block_number == end_block_number: 

454 block: bytes = self._fetch_block_cached(start_block_number) 

455 return block[start_pos:end_pos] 

456 

457 else: 

458 # read from the initial 

459 out = [self._fetch_block_cached(start_block_number)[start_pos:]] 

460 

461 # intermediate blocks 

462 # Note: it'd be nice to combine these into one big request. However 

463 # that doesn't play nicely with our LRU cache. 

464 out.extend( 

465 map( 

466 self._fetch_block_cached, 

467 range(start_block_number + 1, end_block_number), 

468 ) 

469 ) 

470 

471 # final block 

472 out.append(self._fetch_block_cached(end_block_number)[:end_pos]) 

473 

474 return b"".join(out) 

475 

476 

477class BytesCache(BaseCache): 

478 """Cache which holds data in a in-memory bytes object 

479 

480 Implements read-ahead by the block size, for semi-random reads progressing 

481 through the file. 

482 

483 Parameters 

484 ---------- 

485 trim: bool 

486 As we read more data, whether to discard the start of the buffer when 

487 we are more than a blocksize ahead of it. 

488 """ 

489 

490 name: ClassVar[str] = "bytes" 

491 

492 def __init__( 

493 self, blocksize: int, fetcher: Fetcher, size: int, trim: bool = True 

494 ) -> None: 

495 super().__init__(blocksize, fetcher, size) 

496 self.cache = b"" 

497 self.start: int | None = None 

498 self.end: int | None = None 

499 self.trim = trim 

500 

501 def _fetch(self, start: int | None, end: int | None) -> bytes: 

502 # TODO: only set start/end after fetch, in case it fails? 

503 # is this where retry logic might go? 

504 if start is None: 

505 start = 0 

506 if end is None: 

507 end = self.size 

508 if start >= self.size or start >= end: 

509 return b"" 

510 if ( 

511 self.start is not None 

512 and start >= self.start 

513 and self.end is not None 

514 and end < self.end 

515 ): 

516 # cache hit: we have all the required data 

517 offset = start - self.start 

518 self.hit_count += 1 

519 return self.cache[offset : offset + end - start] 

520 

521 if self.blocksize: 

522 bend = min(self.size, end + self.blocksize) 

523 else: 

524 bend = end 

525 

526 if bend == start or start > self.size: 

527 return b"" 

528 

529 if (self.start is None or start < self.start) and ( 

530 self.end is None or end > self.end 

531 ): 

532 # First read, or extending both before and after 

533 self.total_requested_bytes += bend - start 

534 self.miss_count += 1 

535 self.cache = self.fetcher(start, bend) 

536 self.start = start 

537 else: 

538 assert self.start is not None 

539 assert self.end is not None 

540 self.miss_count += 1 

541 

542 if start < self.start: 

543 if self.end is None or self.end - end > self.blocksize: 

544 self.total_requested_bytes += bend - start 

545 self.cache = self.fetcher(start, bend) 

546 self.start = start 

547 else: 

548 self.total_requested_bytes += self.start - start 

549 new = self.fetcher(start, self.start) 

550 self.start = start 

551 self.cache = new + self.cache 

552 elif self.end is not None and bend > self.end: 

553 if self.end > self.size: 

554 pass 

555 elif end - self.end > self.blocksize: 

556 self.total_requested_bytes += bend - start 

557 self.cache = self.fetcher(start, bend) 

558 self.start = start 

559 else: 

560 self.total_requested_bytes += bend - self.end 

561 new = self.fetcher(self.end, bend) 

562 self.cache = self.cache + new 

563 

564 self.end = self.start + len(self.cache) 

565 offset = start - self.start 

566 out = self.cache[offset : offset + end - start] 

567 if self.trim: 

568 num = (self.end - self.start) // (self.blocksize + 1) 

569 if num > 1: 

570 self.start += self.blocksize * num 

571 self.cache = self.cache[self.blocksize * num :] 

572 return out 

573 

574 def __len__(self) -> int: 

575 return len(self.cache) 

576 

577 

578class AllBytes(BaseCache): 

579 """Cache entire contents of the file""" 

580 

581 name: ClassVar[str] = "all" 

582 

583 def __init__( 

584 self, 

585 blocksize: int | None = None, 

586 fetcher: Fetcher | None = None, 

587 size: int | None = None, 

588 data: bytes | None = None, 

589 ) -> None: 

590 super().__init__(blocksize, fetcher, size) # type: ignore[arg-type] 

591 if data is None: 

592 self.miss_count += 1 

593 self.total_requested_bytes += self.size 

594 data = self.fetcher(0, self.size) 

595 self.data = data 

596 

597 def _fetch(self, start: int | None, stop: int | None) -> bytes: 

598 self.hit_count += 1 

599 return self.data[start:stop] 

600 

601 

602class KnownPartsOfAFile(BaseCache): 

603 """ 

604 Cache holding known file parts. 

605 

606 Parameters 

607 ---------- 

608 blocksize: int 

609 How far to read ahead in numbers of bytes 

610 fetcher: func 

611 Function of the form f(start, end) which gets bytes from remote as 

612 specified 

613 size: int 

614 How big this file is 

615 data: dict 

616 A dictionary mapping explicit `(start, stop)` file-offset tuples 

617 with known bytes. 

618 strict: bool, default True 

619 Whether to fetch reads that go beyond a known byte-range boundary. 

620 If `False`, any read that ends outside a known part will be zero 

621 padded. Note that zero padding will not be used for reads that 

622 begin outside a known byte-range. 

623 """ 

624 

625 name: ClassVar[str] = "parts" 

626 

627 def __init__( 

628 self, 

629 blocksize: int, 

630 fetcher: Fetcher, 

631 size: int, 

632 data: Optional[dict[tuple[int, int], bytes]] = None, 

633 strict: bool = True, 

634 **_: Any, 

635 ): 

636 super().__init__(blocksize, fetcher, size) 

637 self.strict = strict 

638 

639 # simple consolidation of contiguous blocks 

640 if data: 

641 old_offsets = sorted(data.keys()) 

642 offsets = [old_offsets[0]] 

643 blocks = [data.pop(old_offsets[0])] 

644 for start, stop in old_offsets[1:]: 

645 start0, stop0 = offsets[-1] 

646 if start == stop0: 

647 offsets[-1] = (start0, stop) 

648 blocks[-1] += data.pop((start, stop)) 

649 else: 

650 offsets.append((start, stop)) 

651 blocks.append(data.pop((start, stop))) 

652 

653 self.data = dict(zip(offsets, blocks)) 

654 else: 

655 self.data = {} 

656 

657 def _fetch(self, start: int | None, stop: int | None) -> bytes: 

658 if start is None: 

659 start = 0 

660 if stop is None: 

661 stop = self.size 

662 

663 out = b"" 

664 for (loc0, loc1), data in self.data.items(): 

665 # If self.strict=False, use zero-padded data 

666 # for reads beyond the end of a "known" buffer 

667 if loc0 <= start < loc1: 

668 off = start - loc0 

669 out = data[off : off + stop - start] 

670 if not self.strict or loc0 <= stop <= loc1: 

671 # The request is within a known range, or 

672 # it begins within a known range, and we 

673 # are allowed to pad reads beyond the 

674 # buffer with zero 

675 out += b"\x00" * (stop - start - len(out)) 

676 self.hit_count += 1 

677 return out 

678 else: 

679 # The request ends outside a known range, 

680 # and we are being "strict" about reads 

681 # beyond the buffer 

682 start = loc1 

683 break 

684 

685 # We only get here if there is a request outside the 

686 # known parts of the file. In an ideal world, this 

687 # should never happen 

688 if self.fetcher is None: 

689 # We cannot fetch the data, so raise an error 

690 raise ValueError(f"Read is outside the known file parts: {(start, stop)}. ") 

691 # We can fetch the data, but should warn the user 

692 # that this may be slow 

693 warnings.warn( 

694 f"Read is outside the known file parts: {(start, stop)}. " 

695 f"IO/caching performance may be poor!" 

696 ) 

697 logger.debug(f"KnownPartsOfAFile cache fetching {start}-{stop}") 

698 self.total_requested_bytes += stop - start 

699 self.miss_count += 1 

700 return out + super()._fetch(start, stop) 

701 

702 

703class UpdatableLRU(Generic[P, T]): 

704 """ 

705 Custom implementation of LRU cache that allows updating keys 

706 

707 Used by BackgroudBlockCache 

708 """ 

709 

710 class CacheInfo(NamedTuple): 

711 hits: int 

712 misses: int 

713 maxsize: int 

714 currsize: int 

715 

716 def __init__(self, func: Callable[P, T], max_size: int = 128) -> None: 

717 self._cache: OrderedDict[Any, T] = collections.OrderedDict() 

718 self._func = func 

719 self._max_size = max_size 

720 self._hits = 0 

721 self._misses = 0 

722 self._lock = threading.Lock() 

723 

724 def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: 

725 if kwargs: 

726 raise TypeError(f"Got unexpected keyword argument {kwargs.keys()}") 

727 with self._lock: 

728 if args in self._cache: 

729 self._cache.move_to_end(args) 

730 self._hits += 1 

731 return self._cache[args] 

732 

733 result = self._func(*args, **kwargs) 

734 

735 with self._lock: 

736 self._cache[args] = result 

737 self._misses += 1 

738 if len(self._cache) > self._max_size: 

739 self._cache.popitem(last=False) 

740 

741 return result 

742 

743 def is_key_cached(self, *args: Any) -> bool: 

744 with self._lock: 

745 return args in self._cache 

746 

747 def add_key(self, result: T, *args: Any) -> None: 

748 with self._lock: 

749 self._cache[args] = result 

750 if len(self._cache) > self._max_size: 

751 self._cache.popitem(last=False) 

752 

753 def cache_info(self) -> UpdatableLRU.CacheInfo: 

754 with self._lock: 

755 return self.CacheInfo( 

756 maxsize=self._max_size, 

757 currsize=len(self._cache), 

758 hits=self._hits, 

759 misses=self._misses, 

760 ) 

761 

762 

763class BackgroundBlockCache(BaseCache): 

764 """ 

765 Cache holding memory as a set of blocks with pre-loading of 

766 the next block in the background. 

767 

768 Requests are only ever made ``blocksize`` at a time, and are 

769 stored in an LRU cache. The least recently accessed block is 

770 discarded when more than ``maxblocks`` are stored. If the 

771 next block is not in cache, it is loaded in a separate thread 

772 in non-blocking way. 

773 

774 Parameters 

775 ---------- 

776 blocksize : int 

777 The number of bytes to store in each block. 

778 Requests are only ever made for ``blocksize``, so this 

779 should balance the overhead of making a request against 

780 the granularity of the blocks. 

781 fetcher : Callable 

782 size : int 

783 The total size of the file being cached. 

784 maxblocks : int 

785 The maximum number of blocks to cache for. The maximum memory 

786 use for this cache is then ``blocksize * maxblocks``. 

787 """ 

788 

789 name: ClassVar[str] = "background" 

790 

791 def __init__( 

792 self, blocksize: int, fetcher: Fetcher, size: int, maxblocks: int = 32 

793 ) -> None: 

794 super().__init__(blocksize, fetcher, size) 

795 self.nblocks = math.ceil(size / blocksize) 

796 self.maxblocks = maxblocks 

797 self._fetch_block_cached = UpdatableLRU(self._fetch_block, maxblocks) 

798 

799 self._thread_executor = ThreadPoolExecutor(max_workers=1) 

800 self._fetch_future_block_number: int | None = None 

801 self._fetch_future: Future[bytes] | None = None 

802 self._fetch_future_lock = threading.Lock() 

803 

804 def cache_info(self) -> UpdatableLRU.CacheInfo: 

805 """ 

806 The statistics on the block cache. 

807 

808 Returns 

809 ------- 

810 NamedTuple 

811 Returned directly from the LRU Cache used internally. 

812 """ 

813 return self._fetch_block_cached.cache_info() 

814 

815 def __getstate__(self) -> dict[str, Any]: 

816 state = self.__dict__ 

817 del state["_fetch_block_cached"] 

818 del state["_thread_executor"] 

819 del state["_fetch_future_block_number"] 

820 del state["_fetch_future"] 

821 del state["_fetch_future_lock"] 

822 return state 

823 

824 def __setstate__(self, state) -> None: 

825 self.__dict__.update(state) 

826 self._fetch_block_cached = UpdatableLRU(self._fetch_block, state["maxblocks"]) 

827 self._thread_executor = ThreadPoolExecutor(max_workers=1) 

828 self._fetch_future_block_number = None 

829 self._fetch_future = None 

830 self._fetch_future_lock = threading.Lock() 

831 

832 def _fetch(self, start: int | None, end: int | None) -> bytes: 

833 if start is None: 

834 start = 0 

835 if end is None: 

836 end = self.size 

837 if start >= self.size or start >= end: 

838 return b"" 

839 

840 # byte position -> block numbers 

841 start_block_number = start // self.blocksize 

842 end_block_number = end // self.blocksize 

843 

844 fetch_future_block_number = None 

845 fetch_future = None 

846 with self._fetch_future_lock: 

847 # Background thread is running. Check we we can or must join it. 

848 if self._fetch_future is not None: 

849 assert self._fetch_future_block_number is not None 

850 if self._fetch_future.done(): 

851 logger.info("BlockCache joined background fetch without waiting.") 

852 self._fetch_block_cached.add_key( 

853 self._fetch_future.result(), self._fetch_future_block_number 

854 ) 

855 # Cleanup the fetch variables. Done with fetching the block. 

856 self._fetch_future_block_number = None 

857 self._fetch_future = None 

858 else: 

859 # Must join if we need the block for the current fetch 

860 must_join = bool( 

861 start_block_number 

862 <= self._fetch_future_block_number 

863 <= end_block_number 

864 ) 

865 if must_join: 

866 # Copy to the local variables to release lock 

867 # before waiting for result 

868 fetch_future_block_number = self._fetch_future_block_number 

869 fetch_future = self._fetch_future 

870 

871 # Cleanup the fetch variables. Have a local copy. 

872 self._fetch_future_block_number = None 

873 self._fetch_future = None 

874 

875 # Need to wait for the future for the current read 

876 if fetch_future is not None: 

877 logger.info("BlockCache waiting for background fetch.") 

878 # Wait until result and put it in cache 

879 self._fetch_block_cached.add_key( 

880 fetch_future.result(), fetch_future_block_number 

881 ) 

882 

883 # these are cached, so safe to do multiple calls for the same start and end. 

884 for block_number in range(start_block_number, end_block_number + 1): 

885 self._fetch_block_cached(block_number) 

886 

887 # fetch next block in the background if nothing is running in the background, 

888 # the block is within file and it is not already cached 

889 end_block_plus_1 = end_block_number + 1 

890 with self._fetch_future_lock: 

891 if ( 

892 self._fetch_future is None 

893 and end_block_plus_1 <= self.nblocks 

894 and not self._fetch_block_cached.is_key_cached(end_block_plus_1) 

895 ): 

896 self._fetch_future_block_number = end_block_plus_1 

897 self._fetch_future = self._thread_executor.submit( 

898 self._fetch_block, end_block_plus_1, "async" 

899 ) 

900 

901 return self._read_cache( 

902 start, 

903 end, 

904 start_block_number=start_block_number, 

905 end_block_number=end_block_number, 

906 ) 

907 

908 def _fetch_block(self, block_number: int, log_info: str = "sync") -> bytes: 

909 """ 

910 Fetch the block of data for `block_number`. 

911 """ 

912 if block_number > self.nblocks: 

913 raise ValueError( 

914 f"'block_number={block_number}' is greater than " 

915 f"the number of blocks ({self.nblocks})" 

916 ) 

917 

918 start = block_number * self.blocksize 

919 end = start + self.blocksize 

920 logger.info("BlockCache fetching block (%s) %d", log_info, block_number) 

921 self.total_requested_bytes += end - start 

922 self.miss_count += 1 

923 block_contents = super()._fetch(start, end) 

924 return block_contents 

925 

926 def _read_cache( 

927 self, start: int, end: int, start_block_number: int, end_block_number: int 

928 ) -> bytes: 

929 """ 

930 Read from our block cache. 

931 

932 Parameters 

933 ---------- 

934 start, end : int 

935 The start and end byte positions. 

936 start_block_number, end_block_number : int 

937 The start and end block numbers. 

938 """ 

939 start_pos = start % self.blocksize 

940 end_pos = end % self.blocksize 

941 

942 # kind of pointless to count this as a hit, but it is 

943 self.hit_count += 1 

944 

945 if start_block_number == end_block_number: 

946 block = self._fetch_block_cached(start_block_number) 

947 return block[start_pos:end_pos] 

948 

949 else: 

950 # read from the initial 

951 out = [self._fetch_block_cached(start_block_number)[start_pos:]] 

952 

953 # intermediate blocks 

954 # Note: it'd be nice to combine these into one big request. However 

955 # that doesn't play nicely with our LRU cache. 

956 out.extend( 

957 map( 

958 self._fetch_block_cached, 

959 range(start_block_number + 1, end_block_number), 

960 ) 

961 ) 

962 

963 # final block 

964 out.append(self._fetch_block_cached(end_block_number)[:end_pos]) 

965 

966 return b"".join(out) 

967 

968 

969caches: dict[str | None, type[BaseCache]] = { 

970 # one custom case 

971 None: BaseCache, 

972} 

973 

974 

975def register_cache(cls: type[BaseCache], clobber: bool = False) -> None: 

976 """'Register' cache implementation. 

977 

978 Parameters 

979 ---------- 

980 clobber: bool, optional 

981 If set to True (default is False) - allow to overwrite existing 

982 entry. 

983 

984 Raises 

985 ------ 

986 ValueError 

987 """ 

988 name = cls.name 

989 if not clobber and name in caches: 

990 raise ValueError(f"Cache with name {name!r} is already known: {caches[name]}") 

991 caches[name] = cls 

992 

993 

994for c in ( 

995 BaseCache, 

996 MMapCache, 

997 BytesCache, 

998 ReadAheadCache, 

999 BlockCache, 

1000 FirstChunkCache, 

1001 AllBytes, 

1002 KnownPartsOfAFile, 

1003 BackgroundBlockCache, 

1004): 

1005 register_cache(c)