1import io
2from os import PathLike
3from backports.zstd._zstd import ZstdCompressor, ZstdDecompressor, ZSTD_DStreamOutSize
4from backports.zstd import _streams
5
6__all__ = ('ZstdFile', 'open')
7
8_MODE_CLOSED = 0
9_MODE_READ = 1
10_MODE_WRITE = 2
11
12
13def _nbytes(dat, /):
14 if isinstance(dat, (bytes, bytearray)):
15 return len(dat)
16 with memoryview(dat) as mv:
17 return mv.nbytes
18
19
20class ZstdFile(_streams.BaseStream):
21 """A file-like object providing transparent Zstandard (de)compression.
22
23 A ZstdFile can act as a wrapper for an existing file object, or refer
24 directly to a named file on disk.
25
26 ZstdFile provides a *binary* file interface. Data is read and returned as
27 bytes, and may only be written to objects that support the Buffer Protocol.
28 """
29
30 FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK
31 FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME
32
33 def __init__(self, file, /, mode='r', *,
34 level=None, options=None, zstd_dict=None):
35 """Open a Zstandard compressed file in binary mode.
36
37 *file* can be either an file-like object, or a file name to open.
38
39 *mode* can be 'r' for reading (default), 'w' for (over)writing, 'x' for
40 creating exclusively, or 'a' for appending. These can equivalently be
41 given as 'rb', 'wb', 'xb' and 'ab' respectively.
42
43 *level* is an optional int specifying the compression level to use,
44 or COMPRESSION_LEVEL_DEFAULT if not given.
45
46 *options* is an optional dict for advanced compression parameters.
47 See CompressionParameter and DecompressionParameter for the possible
48 options.
49
50 *zstd_dict* is an optional ZstdDict object, a pre-trained Zstandard
51 dictionary. See train_dict() to train ZstdDict on sample data.
52 """
53 self._fp = None
54 self._close_fp = False
55 self._mode = _MODE_CLOSED
56 self._buffer = None
57
58 if not isinstance(mode, str):
59 raise ValueError('mode must be a str')
60 if options is not None and not isinstance(options, dict):
61 raise TypeError('options must be a dict or None')
62 mode = mode.removesuffix('b') # handle rb, wb, xb, ab
63 if mode == 'r':
64 if level is not None:
65 raise TypeError('level is illegal in read mode')
66 self._mode = _MODE_READ
67 elif mode in {'w', 'a', 'x'}:
68 if level is not None and not isinstance(level, int):
69 raise TypeError('level must be int or None')
70 self._mode = _MODE_WRITE
71 self._compressor = ZstdCompressor(level=level, options=options,
72 zstd_dict=zstd_dict)
73 self._pos = 0
74 else:
75 raise ValueError(f'Invalid mode: {mode!r}')
76
77 if isinstance(file, (str, bytes, PathLike)):
78 self._fp = io.open(file, f'{mode}b')
79 self._close_fp = True
80 elif ((mode == 'r' and hasattr(file, 'read'))
81 or (mode != 'r' and hasattr(file, 'write'))):
82 self._fp = file
83 else:
84 raise TypeError('file must be a file-like object '
85 'or a str, bytes, or PathLike object')
86
87 if self._mode == _MODE_READ:
88 raw = _streams.DecompressReader(
89 self._fp,
90 ZstdDecompressor,
91 zstd_dict=zstd_dict,
92 options=options,
93 )
94 self._buffer = io.BufferedReader(raw)
95
96 def close(self):
97 """Flush and close the file.
98
99 May be called multiple times. Once the file has been closed,
100 any other operation on it will raise ValueError.
101 """
102 if self._fp is None:
103 return
104 try:
105 if self._mode == _MODE_READ:
106 if getattr(self, '_buffer', None):
107 self._buffer.close()
108 self._buffer = None
109 elif self._mode == _MODE_WRITE:
110 self.flush(self.FLUSH_FRAME)
111 self._compressor = None
112 finally:
113 self._mode = _MODE_CLOSED
114 try:
115 if self._close_fp:
116 self._fp.close()
117 finally:
118 self._fp = None
119 self._close_fp = False
120
121 def write(self, data, /):
122 """Write a bytes-like object *data* to the file.
123
124 Returns the number of uncompressed bytes written, which is
125 always the length of data in bytes. Note that due to buffering,
126 the file on disk may not reflect the data written until .flush()
127 or .close() is called.
128 """
129 self._check_can_write()
130
131 length = _nbytes(data)
132
133 compressed = self._compressor.compress(data)
134 self._fp.write(compressed)
135 self._pos += length
136 return length
137
138 def flush(self, mode=FLUSH_BLOCK):
139 """Flush remaining data to the underlying stream.
140
141 The mode argument can be FLUSH_BLOCK or FLUSH_FRAME. Abuse of this
142 method will reduce compression ratio, use it only when necessary.
143
144 If the program is interrupted afterwards, all data can be recovered.
145 To ensure saving to disk, also need to use os.fsync(fd).
146
147 This method does nothing in reading mode.
148 """
149 if self._mode == _MODE_READ:
150 return
151 self._check_not_closed()
152 if mode not in {self.FLUSH_BLOCK, self.FLUSH_FRAME}:
153 raise ValueError('Invalid mode argument, expected either '
154 'ZstdFile.FLUSH_FRAME or '
155 'ZstdFile.FLUSH_BLOCK')
156 if self._compressor.last_mode == mode:
157 return
158 # Flush zstd block/frame, and write.
159 data = self._compressor.flush(mode)
160 self._fp.write(data)
161 if hasattr(self._fp, 'flush'):
162 self._fp.flush()
163
164 def read(self, size=-1):
165 """Read up to size uncompressed bytes from the file.
166
167 If size is negative or omitted, read until EOF is reached.
168 Returns b'' if the file is already at EOF.
169 """
170 if size is None:
171 size = -1
172 self._check_can_read()
173 return self._buffer.read(size)
174
175 def read1(self, size=-1):
176 """Read up to size uncompressed bytes, while trying to avoid
177 making multiple reads from the underlying stream. Reads up to a
178 buffer's worth of data if size is negative.
179
180 Returns b'' if the file is at EOF.
181 """
182 self._check_can_read()
183 if size < 0:
184 # Note this should *not* be io.DEFAULT_BUFFER_SIZE.
185 # ZSTD_DStreamOutSize is the minimum amount to read guaranteeing
186 # a full block is read.
187 size = ZSTD_DStreamOutSize
188 return self._buffer.read1(size)
189
190 def readinto(self, b):
191 """Read bytes into b.
192
193 Returns the number of bytes read (0 for EOF).
194 """
195 self._check_can_read()
196 return self._buffer.readinto(b)
197
198 def readinto1(self, b):
199 """Read bytes into b, while trying to avoid making multiple reads
200 from the underlying stream.
201
202 Returns the number of bytes read (0 for EOF).
203 """
204 self._check_can_read()
205 return self._buffer.readinto1(b)
206
207 def readline(self, size=-1):
208 """Read a line of uncompressed bytes from the file.
209
210 The terminating newline (if present) is retained. If size is
211 non-negative, no more than size bytes will be read (in which
212 case the line may be incomplete). Returns b'' if already at EOF.
213 """
214 self._check_can_read()
215 return self._buffer.readline(size)
216
217 def seek(self, offset, whence=io.SEEK_SET):
218 """Change the file position.
219
220 The new position is specified by offset, relative to the
221 position indicated by whence. Possible values for whence are:
222
223 0: start of stream (default): offset must not be negative
224 1: current stream position
225 2: end of stream; offset must not be positive
226
227 Returns the new file position.
228
229 Note that seeking is emulated, so depending on the arguments,
230 this operation may be extremely slow.
231 """
232 self._check_can_read()
233
234 # BufferedReader.seek() checks seekable
235 return self._buffer.seek(offset, whence)
236
237 def peek(self, size=-1):
238 """Return buffered data without advancing the file position.
239
240 Always returns at least one byte of data, unless at EOF.
241 The exact number of bytes returned is unspecified.
242 """
243 # Relies on the undocumented fact that BufferedReader.peek() always
244 # returns at least one byte (except at EOF)
245 self._check_can_read()
246 return self._buffer.peek(size)
247
248 def __next__(self):
249 if ret := self._buffer.readline():
250 return ret
251 raise StopIteration
252
253 def tell(self):
254 """Return the current file position."""
255 self._check_not_closed()
256 if self._mode == _MODE_READ:
257 return self._buffer.tell()
258 elif self._mode == _MODE_WRITE:
259 return self._pos
260
261 def fileno(self):
262 """Return the file descriptor for the underlying file."""
263 self._check_not_closed()
264 return self._fp.fileno()
265
266 @property
267 def name(self):
268 self._check_not_closed()
269 return self._fp.name
270
271 @property
272 def mode(self):
273 return 'wb' if self._mode == _MODE_WRITE else 'rb'
274
275 @property
276 def closed(self):
277 """True if this file is closed."""
278 return self._mode == _MODE_CLOSED
279
280 def seekable(self):
281 """Return whether the file supports seeking."""
282 return self.readable() and self._buffer.seekable()
283
284 def readable(self):
285 """Return whether the file was opened for reading."""
286 self._check_not_closed()
287 return self._mode == _MODE_READ
288
289 def writable(self):
290 """Return whether the file was opened for writing."""
291 self._check_not_closed()
292 return self._mode == _MODE_WRITE
293
294
295def open(file, /, mode='rb', *, level=None, options=None, zstd_dict=None,
296 encoding=None, errors=None, newline=None):
297 """Open a Zstandard compressed file in binary or text mode.
298
299 file can be either a file name (given as a str, bytes, or PathLike object),
300 in which case the named file is opened, or it can be an existing file object
301 to read from or write to.
302
303 The mode parameter can be 'r', 'rb' (default), 'w', 'wb', 'x', 'xb', 'a',
304 'ab' for binary mode, or 'rt', 'wt', 'xt', 'at' for text mode.
305
306 The level, options, and zstd_dict parameters specify the settings the same
307 as ZstdFile.
308
309 When using read mode (decompression), the options parameter is a dict
310 representing advanced decompression options. The level parameter is not
311 supported in this case. When using write mode (compression), only one of
312 level, an int representing the compression level, or options, a dict
313 representing advanced compression options, may be passed. In both modes,
314 zstd_dict is a ZstdDict instance containing a trained Zstandard dictionary.
315
316 For binary mode, this function is equivalent to the ZstdFile constructor:
317 ZstdFile(filename, mode, ...). In this case, the encoding, errors and
318 newline parameters must not be provided.
319
320 For text mode, an ZstdFile object is created, and wrapped in an
321 io.TextIOWrapper instance with the specified encoding, error handling
322 behavior, and line ending(s).
323 """
324
325 text_mode = 't' in mode
326 mode = mode.replace('t', '')
327
328 if text_mode:
329 if 'b' in mode:
330 raise ValueError(f'Invalid mode: {mode!r}')
331 else:
332 if encoding is not None:
333 raise ValueError('Argument "encoding" not supported in binary mode')
334 if errors is not None:
335 raise ValueError('Argument "errors" not supported in binary mode')
336 if newline is not None:
337 raise ValueError('Argument "newline" not supported in binary mode')
338
339 binary_file = ZstdFile(file, mode, level=level, options=options,
340 zstd_dict=zstd_dict)
341
342 if text_mode:
343 return io.TextIOWrapper(binary_file, encoding, errors, newline)
344 else:
345 return binary_file