1from array import array
2from bisect import bisect_right
3import io
4from os import PathLike
5from os.path import isfile
6from struct import Struct
7import sys
8from typing import BinaryIO, ClassVar, Literal, cast
9import warnings
10
11from pyzstd import (
12 _DEPRECATED_PLACEHOLDER,
13 ZstdCompressor,
14 ZstdDecompressor,
15 _DeprecatedPlaceholder,
16 _LevelOrOption,
17 _Option,
18 _StrOrBytesPath,
19 _ZstdDict,
20)
21
22if sys.version_info < (3, 12):
23 from typing_extensions import Buffer
24else:
25 from collections.abc import Buffer
26
27if sys.version_info < (3, 11):
28 from typing_extensions import Self
29else:
30 from typing import Self
31
32__all__ = ("SeekableFormatError", "SeekableZstdFile")
33
34_MODE_CLOSED = 0
35_MODE_READ = 1
36_MODE_WRITE = 2
37
38
39class SeekableFormatError(Exception):
40 "An error related to Zstandard Seekable Format."
41
42 def __init__(self, msg: str) -> None:
43 super().__init__("Zstandard Seekable Format error: " + msg)
44
45
46__doc__ = """\
47Zstandard Seekable Format (Ver 0.1.0, Apr 2017)
48Square brackets are used to indicate optional fields.
49All numeric fields are little-endian unless specified otherwise.
50A. Seek table is a skippable frame at the end of file:
51 Magic_Number Frame_Size [Seek_Table_Entries] Seek_Table_Footer
52 4 bytes 4 bytes 8-12 bytes each 9 bytes
53 Magic_Number must be 0x184D2A5E.
54B. Seek_Table_Entries:
55 Compressed_Size Decompressed_Size [Checksum]
56 4 bytes 4 bytes 4 bytes
57 Checksum is optional.
58C. Seek_Table_Footer:
59 Number_Of_Frames Seek_Table_Descriptor Seekable_Magic_Number
60 4 bytes 1 byte 4 bytes
61 Seekable_Magic_Number must be 0x8F92EAB1.
62D. Seek_Table_Descriptor:
63 Bit_number Field_name
64 7 Checksum_Flag
65 6-2 Reserved_Bits (should ensure they are set to 0)
66 1-0 Unused_Bits (should not interpret these bits)"""
67__format_version__ = "0.1.0"
68
69
70class _SeekTable:
71 _s_2uint32 = Struct("<II")
72 _s_3uint32 = Struct("<III")
73 _s_footer = Struct("<IBI")
74
75 # read_mode is True for read mode, False for write/append modes.
76 def __init__(self, *, read_mode: bool) -> None:
77 self._read_mode = read_mode
78 self._clear_seek_table()
79
80 def _clear_seek_table(self) -> None:
81 self._has_checksum = False
82 # The seek table frame size, used for append mode.
83 self._seek_frame_size = 0
84 # The file size, used for seeking to EOF.
85 self._file_size = 0
86
87 self._frames_count = 0
88 self._full_c_size = 0
89 self._full_d_size = 0
90
91 if self._read_mode:
92 # Item: cumulated_size
93 # Length: frames_count + 1
94 # q is int64_t. On Linux/macOS/Windows, Py_off_t is signed, so
95 # ZstdFile/SeekableZstdFile use int64_t as file position/size.
96 self._cumulated_c_size = array("q", [0])
97 self._cumulated_d_size = array("q", [0])
98 else:
99 # Item: (c_size1, d_size1,
100 # c_size2, d_size2,
101 # c_size3, d_size3,
102 # ...)
103 # Length: frames_count * 2
104 # I is uint32_t.
105 self._frames = array("I")
106
107 def append_entry(self, compressed_size: int, decompressed_size: int) -> None:
108 if compressed_size == 0:
109 if decompressed_size == 0:
110 # (0, 0) frame is no sense
111 return
112 # Impossible frame
113 raise ValueError
114
115 self._frames_count += 1
116 self._full_c_size += compressed_size
117 self._full_d_size += decompressed_size
118
119 if self._read_mode:
120 self._cumulated_c_size.append(self._full_c_size)
121 self._cumulated_d_size.append(self._full_d_size)
122 else:
123 self._frames.append(compressed_size)
124 self._frames.append(decompressed_size)
125
126 # seek_to_0 is True or False.
127 # In read mode, seeking to 0 is necessary.
128 def load_seek_table(self, fp: BinaryIO, seek_to_0: bool) -> None: # noqa: FBT001
129 # Get file size
130 fsize = fp.seek(0, 2) # 2 is SEEK_END
131 if fsize == 0:
132 return
133 if fsize < 17: # 17=4+4+9
134 msg = (
135 "File size is less than the minimal size "
136 "(17 bytes) of Zstandard Seekable Format."
137 )
138 raise SeekableFormatError(msg)
139
140 # Read footer
141 fp.seek(-9, 2) # 2 is SEEK_END
142 footer = fp.read(9)
143 frames_number, descriptor, magic_number = self._s_footer.unpack(footer)
144 # Check format
145 if magic_number != 0x8F92EAB1:
146 msg = (
147 "The last 4 bytes of the file is not Zstandard Seekable "
148 'Format Magic Number (b"\\xb1\\xea\\x92\\x8f)". '
149 "SeekableZstdFile class only supports Zstandard Seekable "
150 "Format file or 0-size file. To read a zstd file that is "
151 "not in Zstandard Seekable Format, use ZstdFile class."
152 )
153 raise SeekableFormatError(msg)
154
155 # Seek_Table_Descriptor
156 self._has_checksum = descriptor & 0b10000000
157 if descriptor & 0b01111100:
158 msg = (
159 f"In Zstandard Seekable Format version {__format_version__}, the "
160 "Reserved_Bits in Seek_Table_Descriptor must be 0."
161 )
162 raise SeekableFormatError(msg)
163
164 # Frame size
165 entry_size = 12 if self._has_checksum else 8
166 skippable_frame_size = 17 + frames_number * entry_size
167 if fsize < skippable_frame_size:
168 raise SeekableFormatError(
169 "File size is less than expected size of the seek table frame."
170 )
171
172 # Read seek table
173 fp.seek(-skippable_frame_size, 2) # 2 is SEEK_END
174 skippable_frame = fp.read(skippable_frame_size)
175 skippable_magic_number, content_size = self._s_2uint32.unpack_from(
176 skippable_frame, 0
177 )
178
179 # Check format
180 if skippable_magic_number != 0x184D2A5E:
181 msg = "Seek table frame's Magic_Number is wrong."
182 raise SeekableFormatError(msg)
183 if content_size != skippable_frame_size - 8:
184 msg = "Seek table frame's Frame_Size is wrong."
185 raise SeekableFormatError(msg)
186
187 # No more fp operations
188 if seek_to_0:
189 fp.seek(0)
190
191 # Parse seek table
192 offset = 8
193 for idx in range(frames_number):
194 if self._has_checksum:
195 compressed_size, decompressed_size, _ = self._s_3uint32.unpack_from(
196 skippable_frame, offset
197 )
198 offset += 12
199 else:
200 compressed_size, decompressed_size = self._s_2uint32.unpack_from(
201 skippable_frame, offset
202 )
203 offset += 8
204
205 # Check format
206 if compressed_size == 0 and decompressed_size != 0:
207 msg = (
208 f"Wrong seek table. The index {idx} frame (0-based) "
209 "is 0 size, but decompressed size is non-zero, "
210 "this is impossible."
211 )
212 raise SeekableFormatError(msg)
213
214 # Append to seek table
215 self.append_entry(compressed_size, decompressed_size)
216
217 # Check format
218 if self._full_c_size > fsize - skippable_frame_size:
219 msg = (
220 f"Wrong seek table. Since index {idx} frame (0-based), "
221 "the cumulated compressed size is greater than "
222 "file size."
223 )
224 raise SeekableFormatError(msg)
225
226 # Check format
227 if self._full_c_size != fsize - skippable_frame_size:
228 raise SeekableFormatError("The cumulated compressed size is wrong")
229
230 # Parsed successfully, save for future use.
231 self._seek_frame_size = skippable_frame_size
232 self._file_size = fsize
233
234 # Find frame index by decompressed position
235 def index_by_dpos(self, pos: int) -> int | None:
236 # Array's first item is 0, so need this.
237 pos = max(pos, 0)
238
239 i = bisect_right(self._cumulated_d_size, pos)
240 if i != self._frames_count + 1:
241 return i
242 # None means >= EOF
243 return None
244
245 def get_frame_sizes(self, i: int) -> tuple[int, int]:
246 return (self._cumulated_c_size[i - 1], self._cumulated_d_size[i - 1])
247
248 def get_full_c_size(self) -> int:
249 return self._full_c_size
250
251 def get_full_d_size(self) -> int:
252 return self._full_d_size
253
254 # Merge the seek table to max_frames frames.
255 # The format allows up to 0xFFFF_FFFF frames. When frames
256 # number exceeds, use this method to merge.
257 def _merge_frames(self, max_frames: int) -> None:
258 if self._frames_count <= max_frames:
259 return
260
261 # Clear the table
262 arr = self._frames
263 a, b = divmod(self._frames_count, max_frames)
264 self._clear_seek_table()
265
266 # Merge frames
267 pos = 0
268 for i in range(max_frames):
269 # Slice length
270 length = (a + (1 if i < b else 0)) * 2
271
272 # Merge
273 c_size = 0
274 d_size = 0
275 for j in range(pos, pos + length, 2):
276 c_size += arr[j]
277 d_size += arr[j + 1]
278 self.append_entry(c_size, d_size)
279
280 pos += length
281
282 def write_seek_table(self, fp: BinaryIO) -> None:
283 # Exceeded format limit
284 if self._frames_count > 0xFFFFFFFF:
285 # Emit a warning
286 warnings.warn(
287 f"SeekableZstdFile's seek table has {self._frames_count} entries, "
288 "which exceeds the maximal value allowed by "
289 "Zstandard Seekable Format (0xFFFFFFFF). The "
290 "entries will be merged into 0xFFFFFFFF entries, "
291 "this may reduce seeking performance.",
292 RuntimeWarning,
293 3,
294 )
295
296 # Merge frames
297 self._merge_frames(0xFFFFFFFF)
298
299 # The skippable frame
300 offset = 0
301 size = 17 + 8 * self._frames_count
302 ba = bytearray(size)
303
304 # Header
305 self._s_2uint32.pack_into(ba, offset, 0x184D2A5E, size - 8)
306 offset += 8
307 # Entries
308 iter_frames = iter(self._frames)
309 for frame_c, frame_d in zip(iter_frames, iter_frames, strict=True):
310 self._s_2uint32.pack_into(ba, offset, frame_c, frame_d)
311 offset += 8
312 # Footer
313 self._s_footer.pack_into(ba, offset, self._frames_count, 0, 0x8F92EAB1)
314
315 # Write
316 fp.write(ba)
317
318 @property
319 def seek_frame_size(self) -> int:
320 return self._seek_frame_size
321
322 @property
323 def file_size(self) -> int:
324 return self._file_size
325
326 def __len__(self) -> int:
327 return self._frames_count
328
329 def get_info(self) -> tuple[int, int, int]:
330 return (self._frames_count, self._full_c_size, self._full_d_size)
331
332
333class _EOFSuccess(EOFError): # noqa: N818
334 pass
335
336
337class _SeekableDecompressReader(io.RawIOBase):
338 def __init__(
339 self, fp: BinaryIO, zstd_dict: _ZstdDict, option: _Option, read_size: int
340 ) -> None:
341 # Check fp readable/seekable
342 if not hasattr(fp, "readable") or not hasattr(fp, "seekable"):
343 raise TypeError(
344 "In SeekableZstdFile's reading mode, the file object should "
345 "have .readable()/.seekable() methods."
346 )
347 if not fp.readable():
348 raise TypeError(
349 "In SeekableZstdFile's reading mode, the file object should "
350 "be readable."
351 )
352 if not fp.seekable():
353 raise TypeError(
354 "In SeekableZstdFile's reading mode, the file object should "
355 "be seekable. If the file object is not seekable, it can be "
356 "read sequentially using ZstdFile class."
357 )
358
359 self._fp = fp
360 self._zstd_dict = zstd_dict
361 self._option = option
362 self._read_size = read_size
363
364 # Load seek table
365 self._seek_table = _SeekTable(read_mode=True)
366 self._seek_table.load_seek_table(fp, seek_to_0=True)
367 self._size = self._seek_table.get_full_d_size()
368
369 self._pos = 0
370 self._decompressor: ZstdDecompressor | None = ZstdDecompressor(
371 self._zstd_dict, self._option
372 )
373
374 def close(self) -> None:
375 self._decompressor = None
376 return super().close()
377
378 def readable(self) -> bool:
379 return True
380
381 def seekable(self) -> bool:
382 return True
383
384 def tell(self) -> int:
385 return self._pos
386
387 def _decompress(self, size: int) -> bytes:
388 """
389 Decompress up to size bytes.
390 May return b"", in which case try again.
391 Raises _EOFSuccess if EOF is reached at frame edge.
392 Raises EOFError if EOF is reached elsewhere.
393 """
394 if self._decompressor is None: # frame edge
395 data = self._fp.read(self._read_size)
396 if not data: # EOF
397 raise _EOFSuccess
398 elif self._decompressor.needs_input:
399 data = self._fp.read(self._read_size)
400 if not data: # EOF
401 raise EOFError(
402 "Compressed file ended before the end-of-stream marker was reached"
403 )
404 else:
405 data = self._decompressor.unused_data
406 if self._decompressor.eof: # frame edge
407 self._decompressor = None
408 if not data: # may not be at EOF
409 return b""
410 if self._decompressor is None:
411 self._decompressor = ZstdDecompressor(self._zstd_dict, self._option)
412 out = self._decompressor.decompress(data, size)
413 self._pos += len(out)
414 return out
415
416 def readinto(self, b: Buffer) -> int:
417 with memoryview(b) as view, view.cast("B") as byte_view:
418 try:
419 while True:
420 if out := self._decompress(byte_view.nbytes):
421 byte_view[: len(out)] = out
422 return len(out)
423 except _EOFSuccess:
424 return 0
425
426 # If the new position is within BufferedReader's buffer,
427 # this method may not be called.
428 def seek(self, offset: int, whence: int = 0) -> int:
429 # offset is absolute file position
430 if whence == 0: # SEEK_SET
431 pass
432 elif whence == 1: # SEEK_CUR
433 offset = self._pos + offset
434 elif whence == 2: # SEEK_END
435 offset = self._size + offset
436 else:
437 raise ValueError(f"Invalid value for whence: {whence}")
438
439 # Get new frame index
440 new_frame = self._seek_table.index_by_dpos(offset)
441 # offset >= EOF
442 if new_frame is None:
443 self._pos = self._size
444 self._decompressor = None
445 self._fp.seek(self._seek_table.file_size)
446 return self._pos
447
448 # Prepare to jump
449 old_frame = self._seek_table.index_by_dpos(self._pos)
450 c_pos, d_pos = self._seek_table.get_frame_sizes(new_frame)
451
452 # If at P1, seeking to P2 will unnecessarily read the skippable
453 # frame. So check self._fp position to skip the skippable frame.
454 # |--data1--|--skippable--|--data2--|
455 # cpos: ^P1
456 # dpos: ^P1 ^P2
457 if new_frame == old_frame and offset >= self._pos and self._fp.tell() >= c_pos:
458 pass
459 else:
460 # Jump
461 self._pos = d_pos
462 self._decompressor = None
463 self._fp.seek(c_pos)
464
465 # offset is bytes number to skip forward
466 offset -= self._pos
467 while offset > 0:
468 offset -= len(self._decompress(offset))
469
470 return self._pos
471
472 def get_seek_table_info(self) -> tuple[int, int, int]:
473 return self._seek_table.get_info()
474
475
476# Compared to ZstdFile class, it's important to handle the seekable
477# of underlying file object carefully. Need to check seekable in
478# each situation. For example, there may be a CD-R file system that
479# is seekable when reading, but not seekable when appending.
480class SeekableZstdFile(io.BufferedIOBase):
481 """This class can only create/write/read Zstandard Seekable Format file,
482 or read 0-size file.
483 It provides relatively fast seeking ability in read mode.
484 """
485
486 # The format uses uint32_t for compressed/decompressed sizes. If flush
487 # block a lot, compressed_size may exceed the limit, so set a max size.
488 FRAME_MAX_C_SIZE: ClassVar[int] = 2 * 1024 * 1024 * 1024
489 # Zstd seekable format's example code also use 1GiB as max content size.
490 FRAME_MAX_D_SIZE: ClassVar[int] = 1 * 1024 * 1024 * 1024
491
492 FLUSH_BLOCK: ClassVar[Literal[1]] = ZstdCompressor.FLUSH_BLOCK
493 FLUSH_FRAME: ClassVar[Literal[2]] = ZstdCompressor.FLUSH_FRAME
494
495 def __init__(
496 self,
497 filename: _StrOrBytesPath | BinaryIO,
498 mode: Literal["r", "rb", "w", "wb", "a", "ab", "x", "xb"] = "r",
499 *,
500 level_or_option: _LevelOrOption | _Option = None,
501 zstd_dict: _ZstdDict = None,
502 read_size: int | _DeprecatedPlaceholder = _DEPRECATED_PLACEHOLDER, # type: ignore[has-type]
503 write_size: int | _DeprecatedPlaceholder = _DEPRECATED_PLACEHOLDER, # type: ignore[has-type]
504 max_frame_content_size: int = 1024 * 1024 * 1024,
505 ) -> None:
506 """Open a Zstandard Seekable Format file in binary mode. In read mode,
507 the file can be 0-size file.
508
509 filename can be either an actual file name (given as a str, bytes, or
510 PathLike object), in which case the named file is opened, or it can be
511 an existing file object to read from or write to.
512
513 mode can be "r" for reading (default), "w" for (over)writing, "x" for
514 creating exclusively, or "a" for appending. These can equivalently be
515 given as "rb", "wb", "xb" and "ab" respectively.
516
517 In append mode ("a" or "ab"), filename argument can't be a file object,
518 please use file path.
519
520 Parameters
521 level_or_option: When it's an int object, it represents compression
522 level. When it's a dict object, it contains advanced compression
523 parameters. Note, in read mode (decompression), it can only be a
524 dict object, that represents decompression option. It doesn't
525 support int type compression level in this case.
526 zstd_dict: A ZstdDict object, pre-trained dictionary for compression /
527 decompression.
528 max_frame_content_size: In write/append modes (compression), when
529 the uncompressed data size reaches max_frame_content_size, a frame
530 is generated automatically. If the size is small, it will increase
531 seeking speed, but reduce compression ratio. If the size is large,
532 it will reduce seeking speed, but increase compression ratio. You
533 can also manually generate a frame using f.flush(f.FLUSH_FRAME).
534 """
535 if read_size == _DEPRECATED_PLACEHOLDER:
536 read_size = 131075
537 else:
538 warnings.warn(
539 "pyzstd.SeekableZstdFile()'s read_size parameter is deprecated",
540 DeprecationWarning,
541 stacklevel=2,
542 )
543 read_size = cast("int", read_size)
544 if write_size == _DEPRECATED_PLACEHOLDER:
545 write_size = 131591
546 else:
547 warnings.warn(
548 "pyzstd.SeekableZstdFile()'s write_size parameter is deprecated",
549 DeprecationWarning,
550 stacklevel=2,
551 )
552 write_size = cast("int", write_size)
553
554 self._fp: BinaryIO | None = None
555 self._close_fp = False
556 self._mode = _MODE_CLOSED
557 self._buffer = None
558
559 if not isinstance(mode, str):
560 raise TypeError("mode must be a str")
561 mode = mode.removesuffix("b") # type: ignore[assignment] # handle rb, wb, xb, ab
562
563 # Read or write mode
564 if mode == "r":
565 if not isinstance(level_or_option, (type(None), dict)):
566 raise TypeError(
567 "In read mode (decompression), level_or_option argument "
568 "should be a dict object, that represents decompression "
569 "option. It doesn't support int type compression level "
570 "in this case."
571 )
572 if read_size <= 0:
573 raise ValueError("read_size argument should > 0")
574 if write_size != 131591:
575 raise ValueError("write_size argument is only valid in write modes.")
576 # Specified max_frame_content_size argument
577 if max_frame_content_size != 1024 * 1024 * 1024:
578 raise ValueError(
579 "max_frame_content_size argument is only "
580 "valid in write modes (compression)."
581 )
582 mode_code = _MODE_READ
583
584 elif mode in {"w", "a", "x"}:
585 if not isinstance(level_or_option, (type(None), int, dict)):
586 raise TypeError(
587 "level_or_option argument should be int or dict object."
588 )
589 if read_size != 131075:
590 raise ValueError("read_size argument is only valid in read mode.")
591 if write_size <= 0:
592 raise ValueError("write_size argument should > 0")
593 if not (0 < max_frame_content_size <= self.FRAME_MAX_D_SIZE):
594 raise ValueError(
595 "max_frame_content_size argument should be "
596 f"0 < value <= {self.FRAME_MAX_D_SIZE}, "
597 f"provided value is {max_frame_content_size}."
598 )
599
600 # For seekable format
601 self._max_frame_content_size = max_frame_content_size
602 self._reset_frame_sizes()
603 self._seek_table: _SeekTable | None = _SeekTable(read_mode=False)
604
605 mode_code = _MODE_WRITE
606 self._compressor: ZstdCompressor | None = ZstdCompressor(
607 level_or_option=level_or_option, zstd_dict=zstd_dict
608 )
609 self._pos = 0
610
611 # Load seek table in append mode
612 if mode == "a":
613 if not isinstance(filename, (str, bytes, PathLike)):
614 raise TypeError(
615 "In append mode ('a', 'ab'), "
616 "SeekableZstdFile.__init__() method can't "
617 "accept file object as filename argument. "
618 "Please use file path (str/bytes/PathLike)."
619 )
620
621 # Load seek table if file exists
622 if isfile(filename):
623 with open(filename, "rb") as f:
624 if not hasattr(f, "seekable") or not f.seekable():
625 raise TypeError(
626 "In SeekableZstdFile's append mode "
627 "('a', 'ab'), the opened 'rb' file "
628 "object should be seekable."
629 )
630 self._seek_table.load_seek_table(f, seek_to_0=False)
631
632 else:
633 raise ValueError(f"Invalid mode: {mode!r}")
634
635 # File object
636 if isinstance(filename, (str, bytes, PathLike)):
637 self._fp = cast("BinaryIO", open(filename, mode + "b")) # noqa: SIM115
638 self._close_fp = True
639 elif hasattr(filename, "read") or hasattr(filename, "write"):
640 self._fp = filename
641 else:
642 raise TypeError("filename must be a str, bytes, file or PathLike object")
643
644 self._mode = mode_code
645
646 if self._mode == _MODE_READ:
647 raw = _SeekableDecompressReader(
648 self._fp,
649 zstd_dict=zstd_dict,
650 option=cast("_Option", level_or_option), # checked earlier on
651 read_size=read_size,
652 )
653 self._buffer = io.BufferedReader(raw)
654
655 elif mode == "a":
656 if self._fp.seekable():
657 self._fp.seek(self._seek_table.get_full_c_size()) # type: ignore[union-attr]
658 # Necessary if the current table has many (0, 0) entries
659 self._fp.truncate()
660 else:
661 # Add the seek table frame
662 self._seek_table.append_entry(self._seek_table.seek_frame_size, 0) # type: ignore[union-attr]
663 # Emit a warning
664 warnings.warn(
665 (
666 "SeekableZstdFile is opened in append mode "
667 "('a', 'ab'), but the underlying file object "
668 "is not seekable. Therefore the seek table (a "
669 "zstd skippable frame) at the end of the file "
670 "can't be overwritten. Each time open such file "
671 "in append mode, it will waste some storage "
672 f"space. {self._seek_table.seek_frame_size} bytes " # type: ignore[union-attr]
673 "were wasted this time."
674 ),
675 RuntimeWarning,
676 2,
677 )
678
679 def _reset_frame_sizes(self) -> None:
680 self._current_c_size = 0
681 self._current_d_size = 0
682 self._left_d_size = self._max_frame_content_size
683
684 def _check_not_closed(self) -> None:
685 if self.closed:
686 raise ValueError("I/O operation on closed file")
687
688 def _check_can_read(self) -> None:
689 if not self.readable():
690 raise io.UnsupportedOperation("File not open for reading")
691
692 def _check_can_write(self) -> None:
693 if not self.writable():
694 raise io.UnsupportedOperation("File not open for writing")
695
696 def close(self) -> None:
697 """Flush and close the file.
698
699 May be called more than once without error. Once the file is
700 closed, any other operation on it will raise a ValueError.
701 """
702 if self._mode == _MODE_CLOSED:
703 return
704
705 if self._fp is None:
706 return
707 try:
708 if self._mode == _MODE_READ:
709 if getattr(self, "_buffer", None):
710 self._buffer.close() # type: ignore[union-attr]
711 self._buffer = None
712 elif self._mode == _MODE_WRITE:
713 self.flush(self.FLUSH_FRAME)
714 self._seek_table.write_seek_table(self._fp) # type: ignore[union-attr]
715 self._compressor = None
716 finally:
717 self._mode = _MODE_CLOSED
718 self._seek_table = None
719 try:
720 if self._close_fp:
721 self._fp.close()
722 finally:
723 self._fp = None
724 self._close_fp = False
725
726 def write(self, data: Buffer) -> int:
727 """Write a bytes-like object to the file.
728
729 Returns the number of uncompressed bytes written, which is
730 always the length of data in bytes. Note that due to buffering,
731 the file on disk may not reflect the data written until .flush()
732 or .close() is called.
733 """
734 self._check_can_write()
735 # Accept any data that supports the buffer protocol.
736 # And memoryview's subview is faster than slice.
737 with memoryview(data) as view, view.cast("B") as byte_view:
738 nbytes = byte_view.nbytes
739 pos = 0
740
741 while nbytes > 0:
742 # Write size
743 write_size = min(nbytes, self._left_d_size)
744
745 # Compress & write
746 compressed = self._compressor.compress( # type: ignore[union-attr]
747 byte_view[pos : pos + write_size]
748 )
749 output_size = self._fp.write(compressed) # type: ignore[union-attr]
750 self._pos += write_size
751
752 pos += write_size
753 nbytes -= write_size
754
755 # Cumulate
756 self._current_c_size += output_size
757 self._current_d_size += write_size
758 self._left_d_size -= write_size
759
760 # Should flush a frame
761 if (
762 self._left_d_size == 0
763 or self._current_c_size >= self.FRAME_MAX_C_SIZE
764 ):
765 self.flush(self.FLUSH_FRAME)
766
767 return pos
768
769 def flush(self, mode: Literal[1, 2] = ZstdCompressor.FLUSH_BLOCK) -> None:
770 """Flush remaining data to the underlying stream.
771
772 The mode argument can be ZstdFile.FLUSH_BLOCK, ZstdFile.FLUSH_FRAME.
773 Abuse of this method will reduce compression ratio, use it only when
774 necessary.
775
776 If the program is interrupted afterwards, all data can be recovered.
777 To ensure saving to disk, also need to use os.fsync(fd).
778
779 This method does nothing in reading mode.
780 """
781 if self._mode == _MODE_READ:
782 return
783
784 self._check_not_closed()
785 if mode not in {self.FLUSH_BLOCK, self.FLUSH_FRAME}:
786 raise ValueError(
787 "Invalid mode argument, expected either "
788 "ZstdFile.FLUSH_FRAME or "
789 "ZstdFile.FLUSH_BLOCK"
790 )
791
792 if self._compressor.last_mode != mode: # type: ignore[union-attr]
793 # Flush zstd block/frame, and write.
794 compressed = self._compressor.flush(mode) # type: ignore[union-attr]
795 output_size = self._fp.write(compressed) # type: ignore[union-attr]
796 if hasattr(self._fp, "flush"):
797 self._fp.flush() # type: ignore[union-attr]
798
799 # Cumulate
800 self._current_c_size += output_size
801 # self._current_d_size += 0
802 # self._left_d_size -= 0
803
804 if mode == self.FLUSH_FRAME and self._current_c_size != 0:
805 # Add an entry to seek table
806 self._seek_table.append_entry(self._current_c_size, self._current_d_size) # type: ignore[union-attr]
807 self._reset_frame_sizes()
808
809 def read(self, size: int | None = -1) -> bytes:
810 """Read up to size uncompressed bytes from the file.
811
812 If size is negative or omitted, read until EOF is reached.
813 Returns b"" if the file is already at EOF.
814 """
815 if size is None:
816 size = -1
817 self._check_can_read()
818 return self._buffer.read(size) # type: ignore[union-attr]
819
820 def read1(self, size: int = -1) -> bytes:
821 """Read up to size uncompressed bytes, while trying to avoid
822 making multiple reads from the underlying stream. Reads up to a
823 buffer's worth of data if size is negative.
824
825 Returns b"" if the file is at EOF.
826 """
827 self._check_can_read()
828 if size < 0:
829 size = io.DEFAULT_BUFFER_SIZE
830 return self._buffer.read1(size) # type: ignore[union-attr]
831
832 def readinto(self, b: Buffer) -> int:
833 """Read bytes into b.
834
835 Returns the number of bytes read (0 for EOF).
836 """
837 self._check_can_read()
838 return self._buffer.readinto(b) # type: ignore[union-attr]
839
840 def readinto1(self, b: Buffer) -> int:
841 """Read bytes into b, while trying to avoid making multiple reads
842 from the underlying stream.
843
844 Returns the number of bytes read (0 for EOF).
845 """
846 self._check_can_read()
847 return self._buffer.readinto1(b) # type: ignore[union-attr]
848
849 def readline(self, size: int | None = -1) -> bytes:
850 """Read a line of uncompressed bytes from the file.
851
852 The terminating newline (if present) is retained. If size is
853 non-negative, no more than size bytes will be read (in which
854 case the line may be incomplete). Returns b'' if already at EOF.
855 """
856 self._check_can_read()
857 return self._buffer.readline(size) # type: ignore[union-attr]
858
859 def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
860 """Change the file position.
861
862 The new position is specified by offset, relative to the
863 position indicated by whence. Possible values for whence are:
864
865 0: start of stream (default): offset must not be negative
866 1: current stream position
867 2: end of stream; offset must not be positive
868
869 Returns the new file position.
870
871 Note that seeking is emulated, so depending on the arguments,
872 this operation may be extremely slow.
873 """
874 self._check_can_read()
875 return self._buffer.seek(offset, whence) # type: ignore[union-attr]
876
877 def peek(self, size: int = -1) -> bytes:
878 """Return buffered data without advancing the file position.
879
880 Always returns at least one byte of data, unless at EOF.
881 The exact number of bytes returned is unspecified.
882 """
883 self._check_can_read()
884 return self._buffer.peek(size) # type: ignore[union-attr]
885
886 def __iter__(self) -> Self:
887 self._check_can_read()
888 return self
889
890 def __next__(self) -> bytes:
891 self._check_can_read()
892 if ret := self._buffer.readline(): # type: ignore[union-attr]
893 return ret
894 raise StopIteration
895
896 def tell(self) -> int:
897 """Return the current file position."""
898 self._check_not_closed()
899 if self._mode == _MODE_READ:
900 return self._buffer.tell() # type: ignore[union-attr]
901 if self._mode == _MODE_WRITE:
902 return self._pos
903 raise RuntimeError # impossible code path
904
905 def fileno(self) -> int:
906 """Return the file descriptor for the underlying file."""
907 self._check_not_closed()
908 return self._fp.fileno() # type: ignore[union-attr]
909
910 @property
911 def name(self) -> str:
912 """Return the file name for the underlying file."""
913 self._check_not_closed()
914 return self._fp.name # type: ignore[union-attr]
915
916 @property
917 def closed(self) -> bool:
918 """True if this file is closed."""
919 return self._mode == _MODE_CLOSED
920
921 def writable(self) -> bool:
922 """Return whether the file was opened for writing."""
923 self._check_not_closed()
924 return self._mode == _MODE_WRITE
925
926 def readable(self) -> bool:
927 """Return whether the file was opened for reading."""
928 self._check_not_closed()
929 return self._mode == _MODE_READ
930
931 def seekable(self) -> bool:
932 """Return whether the file supports seeking."""
933 return self.readable() and self._buffer.seekable() # type: ignore[union-attr]
934
935 @property
936 def seek_table_info(self) -> tuple[int, int, int] | None:
937 """A tuple: (frames_number, compressed_size, decompressed_size)
938 1, Frames_number and compressed_size don't count the seek table
939 frame (a zstd skippable frame at the end of the file).
940 2, In write modes, the part of data that has not been flushed to
941 frames is not counted.
942 3, If the SeekableZstdFile object is closed, it's None.
943 """
944 if self._mode == _MODE_WRITE:
945 return self._seek_table.get_info() # type: ignore[union-attr]
946 if self._mode == _MODE_READ:
947 return self._buffer.raw.get_seek_table_info() # type: ignore[union-attr]
948 return None
949
950 @staticmethod
951 def is_seekable_format_file(filename: _StrOrBytesPath | BinaryIO) -> bool:
952 """Check if a file is Zstandard Seekable Format file or 0-size file.
953
954 It parses the seek table at the end of the file, returns True if no
955 format error.
956
957 filename can be either a file path (str/bytes/PathLike), or can be an
958 existing file object in reading mode.
959 """
960 # Check argument
961 if isinstance(filename, (str, bytes, PathLike)):
962 fp: BinaryIO = open(filename, "rb") # noqa: SIM115
963 is_file_path = True
964 elif (
965 hasattr(filename, "readable")
966 and filename.readable()
967 and hasattr(filename, "seekable")
968 and filename.seekable()
969 ):
970 fp = filename
971 is_file_path = False
972 orig_pos = fp.tell()
973 else:
974 raise TypeError(
975 "filename argument should be a str/bytes/PathLike object, "
976 "or a file object that is readable and seekable."
977 )
978
979 # Write mode uses less RAM
980 table = _SeekTable(read_mode=False)
981 try:
982 # Read/Parse the seek table
983 table.load_seek_table(fp, seek_to_0=False)
984 except SeekableFormatError:
985 ret = False
986 else:
987 ret = True
988 finally:
989 if is_file_path:
990 fp.close()
991 else:
992 fp.seek(orig_pos)
993
994 return ret