1from array import array
2from bisect import bisect_right
3from os.path import isfile
4from struct import Struct
5from warnings import warn
6
7from pyzstd._zstdfile import _ZstdDecompressReader, ZstdFile, \
8 _MODE_CLOSED, _MODE_READ, _MODE_WRITE, \
9 PathLike, io
10
11__all__ = ('SeekableFormatError', 'SeekableZstdFile')
12
13class SeekableFormatError(Exception):
14 'An error related to Zstandard Seekable Format.'
15 def __init__(self, msg):
16 super().__init__('Zstandard Seekable Format error: ' + msg)
17
18__doc__ = '''\
19Zstandard Seekable Format (Ver 0.1.0, Apr 2017)
20Square brackets are used to indicate optional fields.
21All numeric fields are little-endian unless specified otherwise.
22A. Seek table is a skippable frame at the end of file:
23 Magic_Number Frame_Size [Seek_Table_Entries] Seek_Table_Footer
24 4 bytes 4 bytes 8-12 bytes each 9 bytes
25 Magic_Number must be 0x184D2A5E.
26B. Seek_Table_Entries:
27 Compressed_Size Decompressed_Size [Checksum]
28 4 bytes 4 bytes 4 bytes
29 Checksum is optional.
30C. Seek_Table_Footer:
31 Number_Of_Frames Seek_Table_Descriptor Seekable_Magic_Number
32 4 bytes 1 byte 4 bytes
33 Seekable_Magic_Number must be 0x8F92EAB1.
34D. Seek_Table_Descriptor:
35 Bit_number Field_name
36 7 Checksum_Flag
37 6-2 Reserved_Bits (should ensure they are set to 0)
38 1-0 Unused_Bits (should not interpret these bits)'''
39__format_version__ = '0.1.0'
40
41class _SeekTable:
42 _s_2uint32 = Struct('<II')
43 _s_3uint32 = Struct('<III')
44 _s_footer = Struct('<IBI')
45
46 # read_mode is True for read mode, False for write/append modes.
47 def __init__(self, read_mode):
48 self._read_mode = read_mode
49 self._clear_seek_table()
50
51 def _clear_seek_table(self):
52 self._has_checksum = False
53 # The seek table frame size, used for append mode.
54 self._seek_frame_size = 0
55 # The file size, used for seeking to EOF.
56 self._file_size = 0
57
58 self._frames_count = 0
59 self._full_c_size = 0
60 self._full_d_size = 0
61
62 if self._read_mode:
63 # Item: cumulated_size
64 # Length: frames_count + 1
65 # q is int64_t. On Linux/macOS/Windows, Py_off_t is signed, so
66 # ZstdFile/SeekableZstdFile use int64_t as file position/size.
67 self._cumulated_c_size = array('q', [0])
68 self._cumulated_d_size = array('q', [0])
69 else:
70 # Item: (c_size1, d_size1,
71 # c_size2, d_size2,
72 # c_size3, d_size3,
73 # ...)
74 # Length: frames_count * 2
75 # I is uint32_t.
76 self._frames = array('I')
77
78 def append_entry(self, compressed_size, decompressed_size):
79 if compressed_size == 0:
80 if decompressed_size == 0:
81 # (0, 0) frame is no sense
82 return
83 else:
84 # Impossible frame
85 raise ValueError
86
87 self._frames_count += 1
88 self._full_c_size += compressed_size
89 self._full_d_size += decompressed_size
90
91 if self._read_mode:
92 self._cumulated_c_size.append(self._full_c_size)
93 self._cumulated_d_size.append(self._full_d_size)
94 else:
95 self._frames.append(compressed_size)
96 self._frames.append(decompressed_size)
97
98 # seek_to_0 is True or False.
99 # In read mode, seeking to 0 is necessary.
100 def load_seek_table(self, fp, seek_to_0):
101 # Get file size
102 fsize = fp.seek(0, 2) # 2 is SEEK_END
103 if fsize == 0:
104 return
105 elif fsize < 17: # 17=4+4+9
106 msg = ('File size is less than the minimal size '
107 '(17 bytes) of Zstandard Seekable Format.')
108 raise SeekableFormatError(msg)
109
110 # Read footer
111 fp.seek(-9, 2) # 2 is SEEK_END
112 footer = fp.read(9)
113 frames_number, descriptor, magic_number = self._s_footer.unpack(footer)
114 # Check format
115 if magic_number != 0x8F92EAB1:
116 msg = ('The last 4 bytes of the file is not Zstandard Seekable '
117 'Format Magic Number (b"\\xb1\\xea\\x92\\x8f)". '
118 'SeekableZstdFile class only supports Zstandard Seekable '
119 'Format file or 0-size file. To read a zstd file that is '
120 'not in Zstandard Seekable Format, use ZstdFile class.')
121 raise SeekableFormatError(msg)
122
123 # Seek_Table_Descriptor
124 self._has_checksum = \
125 descriptor & 0b10000000
126 if descriptor & 0b01111100:
127 msg = ('In Zstandard Seekable Format version %s, the '
128 'Reserved_Bits in Seek_Table_Descriptor must be 0.') \
129 % __format_version__
130 raise SeekableFormatError(msg)
131
132 # Frame size
133 entry_size = 12 if self._has_checksum else 8
134 skippable_frame_size = 17 + frames_number * entry_size
135 if fsize < skippable_frame_size:
136 raise SeekableFormatError(('File size is less than expected '
137 'size of the seek table frame.'))
138
139 # Read seek table
140 fp.seek(-skippable_frame_size, 2) # 2 is SEEK_END
141 skippable_frame = fp.read(skippable_frame_size)
142 skippable_magic_number, content_size = \
143 self._s_2uint32.unpack_from(skippable_frame, 0)
144
145 # Check format
146 if skippable_magic_number != 0x184D2A5E:
147 msg = "Seek table frame's Magic_Number is wrong."
148 raise SeekableFormatError(msg)
149 if content_size != skippable_frame_size - 8:
150 msg = "Seek table frame's Frame_Size is wrong."
151 raise SeekableFormatError(msg)
152
153 # No more fp operations
154 if seek_to_0:
155 fp.seek(0)
156
157 # Parse seek table
158 offset = 8
159 for idx in range(frames_number):
160 if self._has_checksum:
161 compressed_size, decompressed_size, checksum = \
162 self._s_3uint32.unpack_from(skippable_frame, offset)
163 offset += 12
164 else:
165 compressed_size, decompressed_size = \
166 self._s_2uint32.unpack_from(skippable_frame, offset)
167 offset += 8
168
169 # Check format
170 if compressed_size == 0 and decompressed_size != 0:
171 msg = ('Wrong seek table. The index %d frame (0-based) '
172 'is 0 size, but decompressed size is non-zero, '
173 'this is impossible.') % idx
174 raise SeekableFormatError(msg)
175
176 # Append to seek table
177 self.append_entry(compressed_size, decompressed_size)
178
179 # Check format
180 if self._full_c_size > fsize - skippable_frame_size:
181 msg = ('Wrong seek table. Since index %d frame (0-based), '
182 'the cumulated compressed size is greater than '
183 'file size.') % idx
184 raise SeekableFormatError(msg)
185
186 # Check format
187 if self._full_c_size != fsize - skippable_frame_size:
188 raise SeekableFormatError('The cumulated compressed size is wrong')
189
190 # Parsed successfully, save for future use.
191 self._seek_frame_size = skippable_frame_size
192 self._file_size = fsize
193
194 # Find frame index by decompressed position
195 def index_by_dpos(self, pos):
196 # Array's first item is 0, so need this.
197 if pos < 0:
198 pos = 0
199
200 i = bisect_right(self._cumulated_d_size, pos)
201 if i != self._frames_count + 1:
202 return i
203 else:
204 # None means >= EOF
205 return None
206
207 def get_frame_sizes(self, i):
208 return (self._cumulated_c_size[i-1],
209 self._cumulated_d_size[i-1])
210
211 def get_full_c_size(self):
212 return self._full_c_size
213
214 def get_full_d_size(self):
215 return self._full_d_size
216
217 # Merge the seek table to max_frames frames.
218 # The format allows up to 0xFFFF_FFFF frames. When frames
219 # number exceeds, use this method to merge.
220 def _merge_frames(self, max_frames):
221 if self._frames_count <= max_frames:
222 return
223
224 # Clear the table
225 arr = self._frames
226 a, b = divmod(self._frames_count, max_frames)
227 self._clear_seek_table()
228
229 # Merge frames
230 pos = 0
231 for i in range(max_frames):
232 # Slice length
233 length = (a + (1 if i < b else 0)) * 2
234
235 # Merge
236 c_size = 0
237 d_size = 0
238 for j in range(pos, pos+length, 2):
239 c_size += arr[j]
240 d_size += arr[j+1]
241 self.append_entry(c_size, d_size)
242
243 pos += length
244
245 def write_seek_table(self, fp):
246 # Exceeded format limit
247 if self._frames_count > 0xFFFFFFFF:
248 # Emit a warning
249 warn(('SeekableZstdFile\'s seek table has %d entries, '
250 'which exceeds the maximal value allowed by '
251 'Zstandard Seekable Format (0xFFFFFFFF). The '
252 'entries will be merged into 0xFFFFFFFF entries, '
253 'this may reduce seeking performance.') % self._frames_count,
254 RuntimeWarning, 3)
255
256 # Merge frames
257 self._merge_frames(0xFFFFFFFF)
258
259 # The skippable frame
260 offset = 0
261 size = 17 + 8 * self._frames_count
262 ba = bytearray(size)
263
264 # Header
265 self._s_2uint32.pack_into(ba, offset, 0x184D2A5E, size-8)
266 offset += 8
267 # Entries
268 for i in range(0, len(self._frames), 2):
269 self._s_2uint32.pack_into(ba, offset,
270 self._frames[i],
271 self._frames[i+1])
272 offset += 8
273 # Footer
274 self._s_footer.pack_into(ba, offset,
275 self._frames_count, 0, 0x8F92EAB1)
276
277 # Write
278 fp.write(ba)
279
280 @property
281 def seek_frame_size(self):
282 return self._seek_frame_size
283
284 @property
285 def file_size(self):
286 return self._file_size
287
288 def __len__(self):
289 return self._frames_count
290
291 def get_info(self):
292 return (self._frames_count,
293 self._full_c_size,
294 self._full_d_size)
295
296class _SeekableDecompressReader(_ZstdDecompressReader):
297 def __init__(self, fp, zstd_dict, option, read_size):
298 # Check fp readable/seekable
299 if not hasattr(fp, 'readable') or not hasattr(fp, "seekable"):
300 raise TypeError(
301 ("In SeekableZstdFile's reading mode, the file object should "
302 "have .readable()/.seekable() methods."))
303 if not fp.readable():
304 raise TypeError(
305 ("In SeekableZstdFile's reading mode, the file object should "
306 "be readable."))
307 if not fp.seekable():
308 raise TypeError(
309 ("In SeekableZstdFile's reading mode, the file object should "
310 "be seekable. If the file object is not seekable, it can be "
311 "read sequentially using ZstdFile class."))
312
313 # Load seek table
314 self._seek_table = _SeekTable(read_mode=True)
315 self._seek_table.load_seek_table(fp, seek_to_0=True)
316
317 # Initialize super()
318 super().__init__(fp, zstd_dict, option, read_size)
319 self._decomp.size = self._seek_table.get_full_d_size()
320
321 # super().seekable() returns self._fp.seekable().
322 # Seekable has been checked in .__init__() method.
323 # BufferedReader.seek() checks this in each invoke, if self._fp.seekable()
324 # becomes False at runtime, .seek() method just raise OSError instead of
325 # io.UnsupportedOperation.
326 def seekable(self):
327 return True
328
329 # If the new position is within BufferedReader's buffer,
330 # this method may not be called.
331 def seek(self, offset, whence=0):
332 # offset is absolute file position
333 if whence == 0: # SEEK_SET
334 pass
335 elif whence == 1: # SEEK_CUR
336 offset = self._decomp.pos + offset
337 elif whence == 2: # SEEK_END
338 offset = self._decomp.size + offset
339 else:
340 raise ValueError("Invalid value for whence: {}".format(whence))
341
342 # Get new frame index
343 new_frame = self._seek_table.index_by_dpos(offset)
344 # offset >= EOF
345 if new_frame is None:
346 self._decomp.eof = True
347 self._decomp.pos = self._decomp.size
348 self._fp.seek(self._seek_table.file_size)
349 return self._decomp.pos
350
351 # Prepare to jump
352 old_frame = self._seek_table.index_by_dpos(self._decomp.pos)
353 c_pos, d_pos = self._seek_table.get_frame_sizes(new_frame)
354
355 # If at P1, seeking to P2 will unnecessarily read the skippable
356 # frame. So check self._fp position to skip the skippable frame.
357 # |--data1--|--skippable--|--data2--|
358 # cpos: ^P1
359 # dpos: ^P1 ^P2
360 if new_frame == old_frame and \
361 offset >= self._decomp.pos and \
362 self._fp.tell() >= c_pos:
363 pass
364 else:
365 # Jump
366 self._decomp.eof = False
367 self._decomp.pos = d_pos
368 self._decomp.reset_session()
369 self._fp.seek(c_pos)
370
371 # offset is bytes number to skip forward
372 offset -= self._decomp.pos
373 # If offset <= 0, .forward() method does nothing.
374 self._decomp.forward(offset)
375
376 return self._decomp.pos
377
378 def get_seek_table_info(self):
379 return self._seek_table.get_info()
380
381# Compared to ZstdFile class, it's important to handle the seekable
382# of underlying file object carefully. Need to check seekable in
383# each situation. For example, there may be a CD-R file system that
384# is seekable when reading, but not seekable when appending.
385class SeekableZstdFile(ZstdFile):
386 """This class can only create/write/read Zstandard Seekable Format file,
387 or read 0-size file.
388 It provides relatively fast seeking ability in read mode.
389 """
390 # The format uses uint32_t for compressed/decompressed sizes. If flush
391 # block a lot, compressed_size may exceed the limit, so set a max size.
392 FRAME_MAX_C_SIZE = 2*1024*1024*1024
393 # Zstd seekable format's example code also use 1GiB as max content size.
394 FRAME_MAX_D_SIZE = 1*1024*1024*1024
395
396 _READER_CLASS = _SeekableDecompressReader
397
398 def __init__(self, filename, mode="r", *,
399 level_or_option=None, zstd_dict=None,
400 read_size=131075, write_size=131591,
401 max_frame_content_size=1024*1024*1024):
402 """Open a Zstandard Seekable Format file in binary mode. In read mode,
403 the file can be 0-size file.
404
405 filename can be either an actual file name (given as a str, bytes, or
406 PathLike object), in which case the named file is opened, or it can be
407 an existing file object to read from or write to.
408
409 mode can be "r" for reading (default), "w" for (over)writing, "x" for
410 creating exclusively, or "a" for appending. These can equivalently be
411 given as "rb", "wb", "xb" and "ab" respectively.
412
413 In append mode ("a" or "ab"), filename argument can't be a file object,
414 please use file path.
415
416 Parameters
417 level_or_option: When it's an int object, it represents compression
418 level. When it's a dict object, it contains advanced compression
419 parameters. Note, in read mode (decompression), it can only be a
420 dict object, that represents decompression option. It doesn't
421 support int type compression level in this case.
422 zstd_dict: A ZstdDict object, pre-trained dictionary for compression /
423 decompression.
424 read_size: In reading mode, this is bytes number that read from the
425 underlying file object each time, default value is zstd's
426 recommended value. If use with Network File System, increasing
427 it may get better performance.
428 write_size: In writing modes, this is output buffer's size, default
429 value is zstd's recommended value. If use with Network File
430 System, increasing it may get better performance.
431 max_frame_content_size: In write/append modes (compression), when
432 the uncompressed data size reaches max_frame_content_size, a frame
433 is generated automatically. If the size is small, it will increase
434 seeking speed, but reduce compression ratio. If the size is large,
435 it will reduce seeking speed, but increase compression ratio. You
436 can also manually generate a frame using f.flush(f.FLUSH_FRAME).
437 """
438 # For self.close()
439 self._write_in_close = False
440 # For super().close()
441 self._fp = None
442 self._closefp = False
443 self._mode = _MODE_CLOSED
444
445 if mode in ("r", "rb"):
446 # Specified max_frame_content_size argument
447 if max_frame_content_size != 1024*1024*1024:
448 raise ValueError(('max_frame_content_size argument is only '
449 'valid in write modes (compression).'))
450 elif mode in ("w", "wb", "a", "ab", "x", "xb"):
451 if not (0 < max_frame_content_size <= self.FRAME_MAX_D_SIZE):
452 raise ValueError(
453 ('max_frame_content_size argument should be '
454 '0 < value <= %d, provided value is %d.') % \
455 (self.FRAME_MAX_D_SIZE, max_frame_content_size))
456
457 # For seekable format
458 self._max_frame_content_size = max_frame_content_size
459 self._reset_frame_sizes()
460 self._seek_table = _SeekTable(read_mode=False)
461
462 # Load seek table in append mode
463 if mode in ("a", "ab"):
464 if not isinstance(filename, (str, bytes, PathLike)):
465 raise TypeError(
466 ("In append mode ('a', 'ab'), "
467 "SeekableZstdFile.__init__() method can't "
468 "accept file object as filename argument. "
469 "Please use file path (str/bytes/PathLike)."))
470
471 # Load seek table if file exists
472 if isfile(filename):
473 with io.open(filename, "rb") as f:
474 if not hasattr(f, "seekable") or not f.seekable():
475 raise TypeError(
476 ("In SeekableZstdFile's append mode "
477 "('a', 'ab'), the opened 'rb' file "
478 "object should be seekable."))
479 self._seek_table.load_seek_table(f, seek_to_0=False)
480
481 super().__init__(filename, mode,
482 level_or_option=level_or_option,
483 zstd_dict=zstd_dict,
484 read_size=read_size,
485 write_size=write_size)
486
487 # Overwrite seek table in append mode
488 if mode in ("a", "ab"):
489 if self._fp.seekable():
490 self._fp.seek(self._seek_table.get_full_c_size())
491 # Necessary if the current table has many (0, 0) entries
492 self._fp.truncate()
493 else:
494 # Add the seek table frame
495 self._seek_table.append_entry(
496 self._seek_table.seek_frame_size, 0)
497 # Emit a warning
498 warn(("SeekableZstdFile is opened in append mode "
499 "('a', 'ab'), but the underlying file object "
500 "is not seekable. Therefore the seek table (a "
501 "zstd skippable frame) at the end of the file "
502 "can't be overwritten. Each time open such file "
503 "in append mode, it will waste some storage "
504 "space. %d bytes were wasted this time.") % \
505 self._seek_table.seek_frame_size,
506 RuntimeWarning, 2)
507
508 # Initialized successfully
509 self._write_in_close = (self._mode == _MODE_WRITE)
510
511 def _reset_frame_sizes(self):
512 self._current_c_size = 0
513 self._current_d_size = 0
514 self._left_d_size = self._max_frame_content_size
515
516 def close(self):
517 """Flush and close the file.
518
519 May be called more than once without error. Once the file is
520 closed, any other operation on it will raise a ValueError.
521 """
522 try:
523 if self._write_in_close:
524 try:
525 self.flush(self.FLUSH_FRAME)
526 self._seek_table.write_seek_table(self._fp)
527 finally:
528 # For multiple calls to .close()
529 self._write_in_close = False
530 finally:
531 # Clear write mode's seek table.
532 # Put here for failures in/after super().__init__().
533 self._seek_table = None
534 super().close()
535
536 def write(self, data):
537 """Write a bytes-like object to the file.
538
539 Returns the number of uncompressed bytes written, which is
540 always the length of data in bytes. Note that due to buffering,
541 the file on disk may not reflect the data written until .flush()
542 or .close() is called.
543 """
544 if self._mode != _MODE_WRITE:
545 self._check_mode(_MODE_WRITE)
546
547 # Accept any data that supports the buffer protocol.
548 # And memoryview's subview is faster than slice.
549 with memoryview(data) as view, view.cast('B') as byte_view:
550 nbytes = byte_view.nbytes
551 pos = 0
552
553 while nbytes > 0:
554 # Write size
555 write_size = min(nbytes, self._left_d_size)
556
557 # Use inserted super().write() method, to prevent
558 # self._fp.tell() from reporting incorrect position.
559 # -------------------------
560 # super().write() begin
561 # -------------------------
562 # Compress & write
563 _, output_size = self._writer.write(
564 byte_view[pos:pos+write_size])
565 self._pos += write_size
566 # -----------------------
567 # super().write() end
568 # -----------------------
569
570 pos += write_size
571 nbytes -= write_size
572
573 # Cumulate
574 self._current_c_size += output_size
575 self._current_d_size += write_size
576 self._left_d_size -= write_size
577
578 # Should flush a frame
579 if self._left_d_size == 0 or \
580 self._current_c_size >= self.FRAME_MAX_C_SIZE:
581 self.flush(self.FLUSH_FRAME)
582
583 return pos
584
585 def flush(self, mode=ZstdFile.FLUSH_BLOCK):
586 """Flush remaining data to the underlying stream.
587
588 The mode argument can be ZstdFile.FLUSH_BLOCK, ZstdFile.FLUSH_FRAME.
589 Abuse of this method will reduce compression ratio, use it only when
590 necessary.
591
592 If the program is interrupted afterwards, all data can be recovered.
593 To ensure saving to disk, also need to use os.fsync(fd).
594
595 This method does nothing in reading mode.
596 """
597 if self._mode != _MODE_WRITE:
598 # Like IOBase.flush(), do nothing in reading mode.
599 # TextIOWrapper.close() relies on this behavior.
600 if self._mode == _MODE_READ:
601 return
602 # Closed, raise ValueError.
603 self._check_mode()
604
605 # Use inserted super().flush() method, to prevent
606 # self._fp.tell() from reporting incorrect position.
607 # -------------------------
608 # super().flush() begin
609 # -------------------------
610 # Flush zstd block/frame, and write.
611 _, output_size = self._writer.flush(mode)
612 # -----------------------
613 # super().flush() end
614 # -----------------------
615
616 # Cumulate
617 self._current_c_size += output_size
618 # self._current_d_size += 0
619 # self._left_d_size -= 0
620
621 if mode == self.FLUSH_FRAME and \
622 self._current_c_size != 0:
623 # Add an entry to seek table
624 self._seek_table.append_entry(self._current_c_size,
625 self._current_d_size)
626 self._reset_frame_sizes()
627
628 @property
629 def seek_table_info(self):
630 """A tuple: (frames_number, compressed_size, decompressed_size)
631 1, Frames_number and compressed_size don't count the seek table
632 frame (a zstd skippable frame at the end of the file).
633 2, In write modes, the part of data that has not been flushed to
634 frames is not counted.
635 3, If the SeekableZstdFile object is closed, it's None.
636 """
637 if self._mode == _MODE_WRITE:
638 return self._seek_table.get_info()
639 elif self._mode == _MODE_READ:
640 return self._buffer.raw.get_seek_table_info()
641 else:
642 # Closed
643 return None
644
645 @staticmethod
646 def is_seekable_format_file(filename):
647 """Check if a file is Zstandard Seekable Format file or 0-size file.
648
649 It parses the seek table at the end of the file, returns True if no
650 format error.
651
652 filename can be either a file path (str/bytes/PathLike), or can be an
653 existing file object in reading mode.
654 """
655 # Check argument
656 if isinstance(filename, (str, bytes, PathLike)):
657 fp = io.open(filename, 'rb')
658 is_file_path = True
659 elif hasattr(filename, 'readable') and filename.readable() and \
660 hasattr(filename, "seekable") and filename.seekable():
661 fp = filename
662 is_file_path = False
663 orig_pos = fp.tell()
664 else:
665 raise TypeError(
666 ('filename argument should be a str/bytes/PathLike object, '
667 'or a file object that is readable and seekable.'))
668
669 # Write mode uses less RAM
670 table = _SeekTable(read_mode=False)
671 try:
672 # Read/Parse the seek table
673 table.load_seek_table(fp, seek_to_0=False)
674 except SeekableFormatError:
675 ret = False
676 else:
677 ret = True
678 finally:
679 if is_file_path:
680 fp.close()
681 else:
682 fp.seek(orig_pos)
683
684 return ret