1from array import array
2from bisect import bisect_right
3from os.path import isfile
4from struct import Struct
5from warnings import warn
6
7from pyzstd._zstdfile import _ZstdDecompressReader, ZstdFile, \
8 _MODE_CLOSED, _MODE_READ, _MODE_WRITE, \
9 PathLike, io, _DEPRECATED_PLACEHOLDER
10
11__all__ = ('SeekableFormatError', 'SeekableZstdFile')
12
13class SeekableFormatError(Exception):
14 'An error related to Zstandard Seekable Format.'
15 def __init__(self, msg):
16 super().__init__('Zstandard Seekable Format error: ' + msg)
17
18__doc__ = '''\
19Zstandard Seekable Format (Ver 0.1.0, Apr 2017)
20Square brackets are used to indicate optional fields.
21All numeric fields are little-endian unless specified otherwise.
22A. Seek table is a skippable frame at the end of file:
23 Magic_Number Frame_Size [Seek_Table_Entries] Seek_Table_Footer
24 4 bytes 4 bytes 8-12 bytes each 9 bytes
25 Magic_Number must be 0x184D2A5E.
26B. Seek_Table_Entries:
27 Compressed_Size Decompressed_Size [Checksum]
28 4 bytes 4 bytes 4 bytes
29 Checksum is optional.
30C. Seek_Table_Footer:
31 Number_Of_Frames Seek_Table_Descriptor Seekable_Magic_Number
32 4 bytes 1 byte 4 bytes
33 Seekable_Magic_Number must be 0x8F92EAB1.
34D. Seek_Table_Descriptor:
35 Bit_number Field_name
36 7 Checksum_Flag
37 6-2 Reserved_Bits (should ensure they are set to 0)
38 1-0 Unused_Bits (should not interpret these bits)'''
39__format_version__ = '0.1.0'
40
41class _SeekTable:
42 _s_2uint32 = Struct('<II')
43 _s_3uint32 = Struct('<III')
44 _s_footer = Struct('<IBI')
45
46 # read_mode is True for read mode, False for write/append modes.
47 def __init__(self, read_mode):
48 self._read_mode = read_mode
49 self._clear_seek_table()
50
51 def _clear_seek_table(self):
52 self._has_checksum = False
53 # The seek table frame size, used for append mode.
54 self._seek_frame_size = 0
55 # The file size, used for seeking to EOF.
56 self._file_size = 0
57
58 self._frames_count = 0
59 self._full_c_size = 0
60 self._full_d_size = 0
61
62 if self._read_mode:
63 # Item: cumulated_size
64 # Length: frames_count + 1
65 # q is int64_t. On Linux/macOS/Windows, Py_off_t is signed, so
66 # ZstdFile/SeekableZstdFile use int64_t as file position/size.
67 self._cumulated_c_size = array('q', [0])
68 self._cumulated_d_size = array('q', [0])
69 else:
70 # Item: (c_size1, d_size1,
71 # c_size2, d_size2,
72 # c_size3, d_size3,
73 # ...)
74 # Length: frames_count * 2
75 # I is uint32_t.
76 self._frames = array('I')
77
78 def append_entry(self, compressed_size, decompressed_size):
79 if compressed_size == 0:
80 if decompressed_size == 0:
81 # (0, 0) frame is no sense
82 return
83 else:
84 # Impossible frame
85 raise ValueError
86
87 self._frames_count += 1
88 self._full_c_size += compressed_size
89 self._full_d_size += decompressed_size
90
91 if self._read_mode:
92 self._cumulated_c_size.append(self._full_c_size)
93 self._cumulated_d_size.append(self._full_d_size)
94 else:
95 self._frames.append(compressed_size)
96 self._frames.append(decompressed_size)
97
98 # seek_to_0 is True or False.
99 # In read mode, seeking to 0 is necessary.
100 def load_seek_table(self, fp, seek_to_0):
101 # Get file size
102 fsize = fp.seek(0, 2) # 2 is SEEK_END
103 if fsize == 0:
104 return
105 elif fsize < 17: # 17=4+4+9
106 msg = ('File size is less than the minimal size '
107 '(17 bytes) of Zstandard Seekable Format.')
108 raise SeekableFormatError(msg)
109
110 # Read footer
111 fp.seek(-9, 2) # 2 is SEEK_END
112 footer = fp.read(9)
113 frames_number, descriptor, magic_number = self._s_footer.unpack(footer)
114 # Check format
115 if magic_number != 0x8F92EAB1:
116 msg = ('The last 4 bytes of the file is not Zstandard Seekable '
117 'Format Magic Number (b"\\xb1\\xea\\x92\\x8f)". '
118 'SeekableZstdFile class only supports Zstandard Seekable '
119 'Format file or 0-size file. To read a zstd file that is '
120 'not in Zstandard Seekable Format, use ZstdFile class.')
121 raise SeekableFormatError(msg)
122
123 # Seek_Table_Descriptor
124 self._has_checksum = \
125 descriptor & 0b10000000
126 if descriptor & 0b01111100:
127 msg = ('In Zstandard Seekable Format version %s, the '
128 'Reserved_Bits in Seek_Table_Descriptor must be 0.') \
129 % __format_version__
130 raise SeekableFormatError(msg)
131
132 # Frame size
133 entry_size = 12 if self._has_checksum else 8
134 skippable_frame_size = 17 + frames_number * entry_size
135 if fsize < skippable_frame_size:
136 raise SeekableFormatError(('File size is less than expected '
137 'size of the seek table frame.'))
138
139 # Read seek table
140 fp.seek(-skippable_frame_size, 2) # 2 is SEEK_END
141 skippable_frame = fp.read(skippable_frame_size)
142 skippable_magic_number, content_size = \
143 self._s_2uint32.unpack_from(skippable_frame, 0)
144
145 # Check format
146 if skippable_magic_number != 0x184D2A5E:
147 msg = "Seek table frame's Magic_Number is wrong."
148 raise SeekableFormatError(msg)
149 if content_size != skippable_frame_size - 8:
150 msg = "Seek table frame's Frame_Size is wrong."
151 raise SeekableFormatError(msg)
152
153 # No more fp operations
154 if seek_to_0:
155 fp.seek(0)
156
157 # Parse seek table
158 offset = 8
159 for idx in range(frames_number):
160 if self._has_checksum:
161 compressed_size, decompressed_size, checksum = \
162 self._s_3uint32.unpack_from(skippable_frame, offset)
163 offset += 12
164 else:
165 compressed_size, decompressed_size = \
166 self._s_2uint32.unpack_from(skippable_frame, offset)
167 offset += 8
168
169 # Check format
170 if compressed_size == 0 and decompressed_size != 0:
171 msg = ('Wrong seek table. The index %d frame (0-based) '
172 'is 0 size, but decompressed size is non-zero, '
173 'this is impossible.') % idx
174 raise SeekableFormatError(msg)
175
176 # Append to seek table
177 self.append_entry(compressed_size, decompressed_size)
178
179 # Check format
180 if self._full_c_size > fsize - skippable_frame_size:
181 msg = ('Wrong seek table. Since index %d frame (0-based), '
182 'the cumulated compressed size is greater than '
183 'file size.') % idx
184 raise SeekableFormatError(msg)
185
186 # Check format
187 if self._full_c_size != fsize - skippable_frame_size:
188 raise SeekableFormatError('The cumulated compressed size is wrong')
189
190 # Parsed successfully, save for future use.
191 self._seek_frame_size = skippable_frame_size
192 self._file_size = fsize
193
194 # Find frame index by decompressed position
195 def index_by_dpos(self, pos):
196 # Array's first item is 0, so need this.
197 if pos < 0:
198 pos = 0
199
200 i = bisect_right(self._cumulated_d_size, pos)
201 if i != self._frames_count + 1:
202 return i
203 else:
204 # None means >= EOF
205 return None
206
207 def get_frame_sizes(self, i):
208 return (self._cumulated_c_size[i-1],
209 self._cumulated_d_size[i-1])
210
211 def get_full_c_size(self):
212 return self._full_c_size
213
214 def get_full_d_size(self):
215 return self._full_d_size
216
217 # Merge the seek table to max_frames frames.
218 # The format allows up to 0xFFFF_FFFF frames. When frames
219 # number exceeds, use this method to merge.
220 def _merge_frames(self, max_frames):
221 if self._frames_count <= max_frames:
222 return
223
224 # Clear the table
225 arr = self._frames
226 a, b = divmod(self._frames_count, max_frames)
227 self._clear_seek_table()
228
229 # Merge frames
230 pos = 0
231 for i in range(max_frames):
232 # Slice length
233 length = (a + (1 if i < b else 0)) * 2
234
235 # Merge
236 c_size = 0
237 d_size = 0
238 for j in range(pos, pos+length, 2):
239 c_size += arr[j]
240 d_size += arr[j+1]
241 self.append_entry(c_size, d_size)
242
243 pos += length
244
245 def write_seek_table(self, fp):
246 # Exceeded format limit
247 if self._frames_count > 0xFFFFFFFF:
248 # Emit a warning
249 warn(('SeekableZstdFile\'s seek table has %d entries, '
250 'which exceeds the maximal value allowed by '
251 'Zstandard Seekable Format (0xFFFFFFFF). The '
252 'entries will be merged into 0xFFFFFFFF entries, '
253 'this may reduce seeking performance.') % self._frames_count,
254 RuntimeWarning, 3)
255
256 # Merge frames
257 self._merge_frames(0xFFFFFFFF)
258
259 # The skippable frame
260 offset = 0
261 size = 17 + 8 * self._frames_count
262 ba = bytearray(size)
263
264 # Header
265 self._s_2uint32.pack_into(ba, offset, 0x184D2A5E, size-8)
266 offset += 8
267 # Entries
268 for i in range(0, len(self._frames), 2):
269 self._s_2uint32.pack_into(ba, offset,
270 self._frames[i],
271 self._frames[i+1])
272 offset += 8
273 # Footer
274 self._s_footer.pack_into(ba, offset,
275 self._frames_count, 0, 0x8F92EAB1)
276
277 # Write
278 fp.write(ba)
279
280 @property
281 def seek_frame_size(self):
282 return self._seek_frame_size
283
284 @property
285 def file_size(self):
286 return self._file_size
287
288 def __len__(self):
289 return self._frames_count
290
291 def get_info(self):
292 return (self._frames_count,
293 self._full_c_size,
294 self._full_d_size)
295
296class _SeekableDecompressReader(_ZstdDecompressReader):
297 def __init__(self, fp, zstd_dict, option, read_size):
298 # Check fp readable/seekable
299 if not hasattr(fp, 'readable') or not hasattr(fp, "seekable"):
300 raise TypeError(
301 ("In SeekableZstdFile's reading mode, the file object should "
302 "have .readable()/.seekable() methods."))
303 if not fp.readable():
304 raise TypeError(
305 ("In SeekableZstdFile's reading mode, the file object should "
306 "be readable."))
307 if not fp.seekable():
308 raise TypeError(
309 ("In SeekableZstdFile's reading mode, the file object should "
310 "be seekable. If the file object is not seekable, it can be "
311 "read sequentially using ZstdFile class."))
312
313 # Load seek table
314 self._seek_table = _SeekTable(read_mode=True)
315 self._seek_table.load_seek_table(fp, seek_to_0=True)
316
317 # Initialize super()
318 super().__init__(fp, zstd_dict, option, read_size)
319 self._decomp.size = self._seek_table.get_full_d_size()
320
321 # super().seekable() returns self._fp.seekable().
322 # Seekable has been checked in .__init__() method.
323 # BufferedReader.seek() checks this in each invoke, if self._fp.seekable()
324 # becomes False at runtime, .seek() method just raise OSError instead of
325 # io.UnsupportedOperation.
326 def seekable(self):
327 return True
328
329 # If the new position is within BufferedReader's buffer,
330 # this method may not be called.
331 def seek(self, offset, whence=0):
332 # offset is absolute file position
333 if whence == 0: # SEEK_SET
334 pass
335 elif whence == 1: # SEEK_CUR
336 offset = self._decomp.pos + offset
337 elif whence == 2: # SEEK_END
338 offset = self._decomp.size + offset
339 else:
340 raise ValueError("Invalid value for whence: {}".format(whence))
341
342 # Get new frame index
343 new_frame = self._seek_table.index_by_dpos(offset)
344 # offset >= EOF
345 if new_frame is None:
346 self._decomp.eof = True
347 self._decomp.pos = self._decomp.size
348 self._fp.seek(self._seek_table.file_size)
349 return self._decomp.pos
350
351 # Prepare to jump
352 old_frame = self._seek_table.index_by_dpos(self._decomp.pos)
353 c_pos, d_pos = self._seek_table.get_frame_sizes(new_frame)
354
355 # If at P1, seeking to P2 will unnecessarily read the skippable
356 # frame. So check self._fp position to skip the skippable frame.
357 # |--data1--|--skippable--|--data2--|
358 # cpos: ^P1
359 # dpos: ^P1 ^P2
360 if new_frame == old_frame and \
361 offset >= self._decomp.pos and \
362 self._fp.tell() >= c_pos:
363 pass
364 else:
365 # Jump
366 self._decomp.eof = False
367 self._decomp.pos = d_pos
368 self._decomp.reset_session()
369 self._fp.seek(c_pos)
370
371 # offset is bytes number to skip forward
372 offset -= self._decomp.pos
373 # If offset <= 0, .forward() method does nothing.
374 self._decomp.forward(offset)
375
376 return self._decomp.pos
377
378 def get_seek_table_info(self):
379 return self._seek_table.get_info()
380
381# Compared to ZstdFile class, it's important to handle the seekable
382# of underlying file object carefully. Need to check seekable in
383# each situation. For example, there may be a CD-R file system that
384# is seekable when reading, but not seekable when appending.
385class SeekableZstdFile(ZstdFile):
386 """This class can only create/write/read Zstandard Seekable Format file,
387 or read 0-size file.
388 It provides relatively fast seeking ability in read mode.
389 """
390 # The format uses uint32_t for compressed/decompressed sizes. If flush
391 # block a lot, compressed_size may exceed the limit, so set a max size.
392 FRAME_MAX_C_SIZE = 2*1024*1024*1024
393 # Zstd seekable format's example code also use 1GiB as max content size.
394 FRAME_MAX_D_SIZE = 1*1024*1024*1024
395
396 _READER_CLASS = _SeekableDecompressReader
397
398 def __init__(self, filename, mode="r", *,
399 level_or_option=None, zstd_dict=None,
400 read_size=_DEPRECATED_PLACEHOLDER, write_size=_DEPRECATED_PLACEHOLDER,
401 max_frame_content_size=1024*1024*1024):
402 """Open a Zstandard Seekable Format file in binary mode. In read mode,
403 the file can be 0-size file.
404
405 filename can be either an actual file name (given as a str, bytes, or
406 PathLike object), in which case the named file is opened, or it can be
407 an existing file object to read from or write to.
408
409 mode can be "r" for reading (default), "w" for (over)writing, "x" for
410 creating exclusively, or "a" for appending. These can equivalently be
411 given as "rb", "wb", "xb" and "ab" respectively.
412
413 In append mode ("a" or "ab"), filename argument can't be a file object,
414 please use file path.
415
416 Parameters
417 level_or_option: When it's an int object, it represents compression
418 level. When it's a dict object, it contains advanced compression
419 parameters. Note, in read mode (decompression), it can only be a
420 dict object, that represents decompression option. It doesn't
421 support int type compression level in this case.
422 zstd_dict: A ZstdDict object, pre-trained dictionary for compression /
423 decompression.
424 max_frame_content_size: In write/append modes (compression), when
425 the uncompressed data size reaches max_frame_content_size, a frame
426 is generated automatically. If the size is small, it will increase
427 seeking speed, but reduce compression ratio. If the size is large,
428 it will reduce seeking speed, but increase compression ratio. You
429 can also manually generate a frame using f.flush(f.FLUSH_FRAME).
430 """
431 # For self.close()
432 self._write_in_close = False
433 # For super().close()
434 self._fp = None
435 self._closefp = False
436 self._mode = _MODE_CLOSED
437
438 if mode in ("r", "rb"):
439 # Specified max_frame_content_size argument
440 if max_frame_content_size != 1024*1024*1024:
441 raise ValueError(('max_frame_content_size argument is only '
442 'valid in write modes (compression).'))
443 elif mode in ("w", "wb", "a", "ab", "x", "xb"):
444 if not (0 < max_frame_content_size <= self.FRAME_MAX_D_SIZE):
445 raise ValueError(
446 ('max_frame_content_size argument should be '
447 '0 < value <= %d, provided value is %d.') % \
448 (self.FRAME_MAX_D_SIZE, max_frame_content_size))
449
450 # For seekable format
451 self._max_frame_content_size = max_frame_content_size
452 self._reset_frame_sizes()
453 self._seek_table = _SeekTable(read_mode=False)
454
455 # Load seek table in append mode
456 if mode in ("a", "ab"):
457 if not isinstance(filename, (str, bytes, PathLike)):
458 raise TypeError(
459 ("In append mode ('a', 'ab'), "
460 "SeekableZstdFile.__init__() method can't "
461 "accept file object as filename argument. "
462 "Please use file path (str/bytes/PathLike)."))
463
464 # Load seek table if file exists
465 if isfile(filename):
466 with io.open(filename, "rb") as f:
467 if not hasattr(f, "seekable") or not f.seekable():
468 raise TypeError(
469 ("In SeekableZstdFile's append mode "
470 "('a', 'ab'), the opened 'rb' file "
471 "object should be seekable."))
472 self._seek_table.load_seek_table(f, seek_to_0=False)
473
474 super().__init__(filename, mode,
475 level_or_option=level_or_option,
476 zstd_dict=zstd_dict,
477 read_size=read_size,
478 write_size=write_size)
479
480 # Overwrite seek table in append mode
481 if mode in ("a", "ab"):
482 if self._fp.seekable():
483 self._fp.seek(self._seek_table.get_full_c_size())
484 # Necessary if the current table has many (0, 0) entries
485 self._fp.truncate()
486 else:
487 # Add the seek table frame
488 self._seek_table.append_entry(
489 self._seek_table.seek_frame_size, 0)
490 # Emit a warning
491 warn(("SeekableZstdFile is opened in append mode "
492 "('a', 'ab'), but the underlying file object "
493 "is not seekable. Therefore the seek table (a "
494 "zstd skippable frame) at the end of the file "
495 "can't be overwritten. Each time open such file "
496 "in append mode, it will waste some storage "
497 "space. %d bytes were wasted this time.") % \
498 self._seek_table.seek_frame_size,
499 RuntimeWarning, 2)
500
501 # Initialized successfully
502 self._write_in_close = (self._mode == _MODE_WRITE)
503
504 def _reset_frame_sizes(self):
505 self._current_c_size = 0
506 self._current_d_size = 0
507 self._left_d_size = self._max_frame_content_size
508
509 def close(self):
510 """Flush and close the file.
511
512 May be called more than once without error. Once the file is
513 closed, any other operation on it will raise a ValueError.
514 """
515 try:
516 if self._write_in_close:
517 try:
518 self.flush(self.FLUSH_FRAME)
519 self._seek_table.write_seek_table(self._fp)
520 finally:
521 # For multiple calls to .close()
522 self._write_in_close = False
523 finally:
524 # Clear write mode's seek table.
525 # Put here for failures in/after super().__init__().
526 self._seek_table = None
527 super().close()
528
529 def write(self, data):
530 """Write a bytes-like object to the file.
531
532 Returns the number of uncompressed bytes written, which is
533 always the length of data in bytes. Note that due to buffering,
534 the file on disk may not reflect the data written until .flush()
535 or .close() is called.
536 """
537 if self._mode != _MODE_WRITE:
538 self._check_mode(_MODE_WRITE)
539
540 # Accept any data that supports the buffer protocol.
541 # And memoryview's subview is faster than slice.
542 with memoryview(data) as view, view.cast('B') as byte_view:
543 nbytes = byte_view.nbytes
544 pos = 0
545
546 while nbytes > 0:
547 # Write size
548 write_size = min(nbytes, self._left_d_size)
549
550 # Use inserted super().write() method, to prevent
551 # self._fp.tell() from reporting incorrect position.
552 # -------------------------
553 # super().write() begin
554 # -------------------------
555 # Compress & write
556 _, output_size = self._writer.write(
557 byte_view[pos:pos+write_size])
558 self._pos += write_size
559 # -----------------------
560 # super().write() end
561 # -----------------------
562
563 pos += write_size
564 nbytes -= write_size
565
566 # Cumulate
567 self._current_c_size += output_size
568 self._current_d_size += write_size
569 self._left_d_size -= write_size
570
571 # Should flush a frame
572 if self._left_d_size == 0 or \
573 self._current_c_size >= self.FRAME_MAX_C_SIZE:
574 self.flush(self.FLUSH_FRAME)
575
576 return pos
577
578 def flush(self, mode=ZstdFile.FLUSH_BLOCK):
579 """Flush remaining data to the underlying stream.
580
581 The mode argument can be ZstdFile.FLUSH_BLOCK, ZstdFile.FLUSH_FRAME.
582 Abuse of this method will reduce compression ratio, use it only when
583 necessary.
584
585 If the program is interrupted afterwards, all data can be recovered.
586 To ensure saving to disk, also need to use os.fsync(fd).
587
588 This method does nothing in reading mode.
589 """
590 if self._mode != _MODE_WRITE:
591 # Like IOBase.flush(), do nothing in reading mode.
592 # TextIOWrapper.close() relies on this behavior.
593 if self._mode == _MODE_READ:
594 return
595 # Closed, raise ValueError.
596 self._check_mode()
597
598 # Use inserted super().flush() method, to prevent
599 # self._fp.tell() from reporting incorrect position.
600 # -------------------------
601 # super().flush() begin
602 # -------------------------
603 # Flush zstd block/frame, and write.
604 _, output_size = self._writer.flush(mode)
605 # -----------------------
606 # super().flush() end
607 # -----------------------
608
609 # Cumulate
610 self._current_c_size += output_size
611 # self._current_d_size += 0
612 # self._left_d_size -= 0
613
614 if mode == self.FLUSH_FRAME and \
615 self._current_c_size != 0:
616 # Add an entry to seek table
617 self._seek_table.append_entry(self._current_c_size,
618 self._current_d_size)
619 self._reset_frame_sizes()
620
621 @property
622 def seek_table_info(self):
623 """A tuple: (frames_number, compressed_size, decompressed_size)
624 1, Frames_number and compressed_size don't count the seek table
625 frame (a zstd skippable frame at the end of the file).
626 2, In write modes, the part of data that has not been flushed to
627 frames is not counted.
628 3, If the SeekableZstdFile object is closed, it's None.
629 """
630 if self._mode == _MODE_WRITE:
631 return self._seek_table.get_info()
632 elif self._mode == _MODE_READ:
633 return self._buffer.raw.get_seek_table_info()
634 else:
635 # Closed
636 return None
637
638 @staticmethod
639 def is_seekable_format_file(filename):
640 """Check if a file is Zstandard Seekable Format file or 0-size file.
641
642 It parses the seek table at the end of the file, returns True if no
643 format error.
644
645 filename can be either a file path (str/bytes/PathLike), or can be an
646 existing file object in reading mode.
647 """
648 # Check argument
649 if isinstance(filename, (str, bytes, PathLike)):
650 fp = io.open(filename, 'rb')
651 is_file_path = True
652 elif hasattr(filename, 'readable') and filename.readable() and \
653 hasattr(filename, "seekable") and filename.seekable():
654 fp = filename
655 is_file_path = False
656 orig_pos = fp.tell()
657 else:
658 raise TypeError(
659 ('filename argument should be a str/bytes/PathLike object, '
660 'or a file object that is readable and seekable.'))
661
662 # Write mode uses less RAM
663 table = _SeekTable(read_mode=False)
664 try:
665 # Read/Parse the seek table
666 table.load_seek_table(fp, seek_to_0=False)
667 except SeekableFormatError:
668 ret = False
669 else:
670 ret = True
671 finally:
672 if is_file_path:
673 fp.close()
674 else:
675 fp.seek(orig_pos)
676
677 return ret