1"""Classes and functions for managing compressors."""
2
3import io
4import zlib
5from joblib.backports import LooseVersion
6
7try:
8 from threading import RLock
9except ImportError:
10 from dummy_threading import RLock
11
12try:
13 import bz2
14except ImportError:
15 bz2 = None
16
17try:
18 import lz4
19 from lz4.frame import LZ4FrameFile
20except ImportError:
21 lz4 = None
22
23try:
24 import lzma
25except ImportError:
26 lzma = None
27
28
29LZ4_NOT_INSTALLED_ERROR = ('LZ4 is not installed. Install it with pip: '
30 'https://python-lz4.readthedocs.io/')
31
32# Registered compressors
33_COMPRESSORS = {}
34
35# Magic numbers of supported compression file formats.
36_ZFILE_PREFIX = b'ZF' # used with pickle files created before 0.9.3.
37_ZLIB_PREFIX = b'\x78'
38_GZIP_PREFIX = b'\x1f\x8b'
39_BZ2_PREFIX = b'BZ'
40_XZ_PREFIX = b'\xfd\x37\x7a\x58\x5a'
41_LZMA_PREFIX = b'\x5d\x00'
42_LZ4_PREFIX = b'\x04\x22\x4D\x18'
43
44
45def register_compressor(compressor_name, compressor,
46 force=False):
47 """Register a new compressor.
48
49 Parameters
50 ----------
51 compressor_name: str.
52 The name of the compressor.
53 compressor: CompressorWrapper
54 An instance of a 'CompressorWrapper'.
55 """
56 global _COMPRESSORS
57 if not isinstance(compressor_name, str):
58 raise ValueError("Compressor name should be a string, "
59 "'{}' given.".format(compressor_name))
60
61 if not isinstance(compressor, CompressorWrapper):
62 raise ValueError("Compressor should implement the CompressorWrapper "
63 "interface, '{}' given.".format(compressor))
64
65 if (compressor.fileobj_factory is not None and
66 (not hasattr(compressor.fileobj_factory, 'read') or
67 not hasattr(compressor.fileobj_factory, 'write') or
68 not hasattr(compressor.fileobj_factory, 'seek') or
69 not hasattr(compressor.fileobj_factory, 'tell'))):
70 raise ValueError("Compressor 'fileobj_factory' attribute should "
71 "implement the file object interface, '{}' given."
72 .format(compressor.fileobj_factory))
73
74 if compressor_name in _COMPRESSORS and not force:
75 raise ValueError("Compressor '{}' already registered."
76 .format(compressor_name))
77
78 _COMPRESSORS[compressor_name] = compressor
79
80
81class CompressorWrapper():
82 """A wrapper around a compressor file object.
83
84 Attributes
85 ----------
86 obj: a file-like object
87 The object must implement the buffer interface and will be used
88 internally to compress/decompress the data.
89 prefix: bytestring
90 A bytestring corresponding to the magic number that identifies the
91 file format associated to the compressor.
92 extension: str
93 The file extension used to automatically select this compressor during
94 a dump to a file.
95 """
96
97 def __init__(self, obj, prefix=b'', extension=''):
98 self.fileobj_factory = obj
99 self.prefix = prefix
100 self.extension = extension
101
102 def compressor_file(self, fileobj, compresslevel=None):
103 """Returns an instance of a compressor file object."""
104 if compresslevel is None:
105 return self.fileobj_factory(fileobj, 'wb')
106 else:
107 return self.fileobj_factory(fileobj, 'wb',
108 compresslevel=compresslevel)
109
110 def decompressor_file(self, fileobj):
111 """Returns an instance of a decompressor file object."""
112 return self.fileobj_factory(fileobj, 'rb')
113
114
115class BZ2CompressorWrapper(CompressorWrapper):
116
117 prefix = _BZ2_PREFIX
118 extension = '.bz2'
119
120 def __init__(self):
121 if bz2 is not None:
122 self.fileobj_factory = bz2.BZ2File
123 else:
124 self.fileobj_factory = None
125
126 def _check_versions(self):
127 if bz2 is None:
128 raise ValueError('bz2 module is not compiled on your python '
129 'standard library.')
130
131 def compressor_file(self, fileobj, compresslevel=None):
132 """Returns an instance of a compressor file object."""
133 self._check_versions()
134 if compresslevel is None:
135 return self.fileobj_factory(fileobj, 'wb')
136 else:
137 return self.fileobj_factory(fileobj, 'wb',
138 compresslevel=compresslevel)
139
140 def decompressor_file(self, fileobj):
141 """Returns an instance of a decompressor file object."""
142 self._check_versions()
143 fileobj = self.fileobj_factory(fileobj, 'rb')
144 return fileobj
145
146
147class LZMACompressorWrapper(CompressorWrapper):
148
149 prefix = _LZMA_PREFIX
150 extension = '.lzma'
151 _lzma_format_name = 'FORMAT_ALONE'
152
153 def __init__(self):
154 if lzma is not None:
155 self.fileobj_factory = lzma.LZMAFile
156 self._lzma_format = getattr(lzma, self._lzma_format_name)
157 else:
158 self.fileobj_factory = None
159
160 def _check_versions(self):
161 if lzma is None:
162 raise ValueError('lzma module is not compiled on your python '
163 'standard library.')
164
165 def compressor_file(self, fileobj, compresslevel=None):
166 """Returns an instance of a compressor file object."""
167 if compresslevel is None:
168 return self.fileobj_factory(fileobj, 'wb',
169 format=self._lzma_format)
170 else:
171 return self.fileobj_factory(fileobj, 'wb',
172 format=self._lzma_format,
173 preset=compresslevel)
174
175 def decompressor_file(self, fileobj):
176 """Returns an instance of a decompressor file object."""
177 return lzma.LZMAFile(fileobj, 'rb')
178
179
180class XZCompressorWrapper(LZMACompressorWrapper):
181
182 prefix = _XZ_PREFIX
183 extension = '.xz'
184 _lzma_format_name = 'FORMAT_XZ'
185
186
187class LZ4CompressorWrapper(CompressorWrapper):
188
189 prefix = _LZ4_PREFIX
190 extension = '.lz4'
191
192 def __init__(self):
193 if lz4 is not None:
194 self.fileobj_factory = LZ4FrameFile
195 else:
196 self.fileobj_factory = None
197
198 def _check_versions(self):
199 if lz4 is None:
200 raise ValueError(LZ4_NOT_INSTALLED_ERROR)
201 lz4_version = lz4.__version__
202 if lz4_version.startswith("v"):
203 lz4_version = lz4_version[1:]
204 if LooseVersion(lz4_version) < LooseVersion('0.19'):
205 raise ValueError(LZ4_NOT_INSTALLED_ERROR)
206
207 def compressor_file(self, fileobj, compresslevel=None):
208 """Returns an instance of a compressor file object."""
209 self._check_versions()
210 if compresslevel is None:
211 return self.fileobj_factory(fileobj, 'wb')
212 else:
213 return self.fileobj_factory(fileobj, 'wb',
214 compression_level=compresslevel)
215
216 def decompressor_file(self, fileobj):
217 """Returns an instance of a decompressor file object."""
218 self._check_versions()
219 return self.fileobj_factory(fileobj, 'rb')
220
221
222###############################################################################
223# base file compression/decompression object definition
224_MODE_CLOSED = 0
225_MODE_READ = 1
226_MODE_READ_EOF = 2
227_MODE_WRITE = 3
228_BUFFER_SIZE = 8192
229
230
231class BinaryZlibFile(io.BufferedIOBase):
232 """A file object providing transparent zlib (de)compression.
233
234 TODO python2_drop: is it still needed since we dropped Python 2 support A
235 BinaryZlibFile can act as a wrapper for an existing file object, or refer
236 directly to a named file on disk.
237
238 Note that BinaryZlibFile provides only a *binary* file interface: data read
239 is returned as bytes, and data to be written should be given as bytes.
240
241 This object is an adaptation of the BZ2File object and is compatible with
242 versions of python >= 2.7.
243
244 If filename is a str or bytes object, it gives the name
245 of the file to be opened. Otherwise, it should be a file object,
246 which will be used to read or write the compressed data.
247
248 mode can be 'rb' for reading (default) or 'wb' for (over)writing
249
250 If mode is 'wb', compresslevel can be a number between 1
251 and 9 specifying the level of compression: 1 produces the least
252 compression, and 9 produces the most compression. 3 is the default.
253 """
254
255 wbits = zlib.MAX_WBITS
256
257 def __init__(self, filename, mode="rb", compresslevel=3):
258 # This lock must be recursive, so that BufferedIOBase's
259 # readline(), readlines() and writelines() don't deadlock.
260 self._lock = RLock()
261 self._fp = None
262 self._closefp = False
263 self._mode = _MODE_CLOSED
264 self._pos = 0
265 self._size = -1
266 self.compresslevel = compresslevel
267
268 if not isinstance(compresslevel, int) or not (1 <= compresslevel <= 9):
269 raise ValueError("'compresslevel' must be an integer "
270 "between 1 and 9. You provided 'compresslevel={}'"
271 .format(compresslevel))
272
273 if mode == "rb":
274 self._mode = _MODE_READ
275 self._decompressor = zlib.decompressobj(self.wbits)
276 self._buffer = b""
277 self._buffer_offset = 0
278 elif mode == "wb":
279 self._mode = _MODE_WRITE
280 self._compressor = zlib.compressobj(self.compresslevel,
281 zlib.DEFLATED, self.wbits,
282 zlib.DEF_MEM_LEVEL, 0)
283 else:
284 raise ValueError("Invalid mode: %r" % (mode,))
285
286 if isinstance(filename, str):
287 self._fp = io.open(filename, mode)
288 self._closefp = True
289 elif hasattr(filename, "read") or hasattr(filename, "write"):
290 self._fp = filename
291 else:
292 raise TypeError("filename must be a str or bytes object, "
293 "or a file")
294
295 def close(self):
296 """Flush and close the file.
297
298 May be called more than once without error. Once the file is
299 closed, any other operation on it will raise a ValueError.
300 """
301 with self._lock:
302 if self._mode == _MODE_CLOSED:
303 return
304 try:
305 if self._mode in (_MODE_READ, _MODE_READ_EOF):
306 self._decompressor = None
307 elif self._mode == _MODE_WRITE:
308 self._fp.write(self._compressor.flush())
309 self._compressor = None
310 finally:
311 try:
312 if self._closefp:
313 self._fp.close()
314 finally:
315 self._fp = None
316 self._closefp = False
317 self._mode = _MODE_CLOSED
318 self._buffer = b""
319 self._buffer_offset = 0
320
321 @property
322 def closed(self):
323 """True if this file is closed."""
324 return self._mode == _MODE_CLOSED
325
326 def fileno(self):
327 """Return the file descriptor for the underlying file."""
328 self._check_not_closed()
329 return self._fp.fileno()
330
331 def seekable(self):
332 """Return whether the file supports seeking."""
333 return self.readable() and self._fp.seekable()
334
335 def readable(self):
336 """Return whether the file was opened for reading."""
337 self._check_not_closed()
338 return self._mode in (_MODE_READ, _MODE_READ_EOF)
339
340 def writable(self):
341 """Return whether the file was opened for writing."""
342 self._check_not_closed()
343 return self._mode == _MODE_WRITE
344
345 # Mode-checking helper functions.
346
347 def _check_not_closed(self):
348 if self.closed:
349 fname = getattr(self._fp, 'name', None)
350 msg = "I/O operation on closed file"
351 if fname is not None:
352 msg += " {}".format(fname)
353 msg += "."
354 raise ValueError(msg)
355
356 def _check_can_read(self):
357 if self._mode not in (_MODE_READ, _MODE_READ_EOF):
358 self._check_not_closed()
359 raise io.UnsupportedOperation("File not open for reading")
360
361 def _check_can_write(self):
362 if self._mode != _MODE_WRITE:
363 self._check_not_closed()
364 raise io.UnsupportedOperation("File not open for writing")
365
366 def _check_can_seek(self):
367 if self._mode not in (_MODE_READ, _MODE_READ_EOF):
368 self._check_not_closed()
369 raise io.UnsupportedOperation("Seeking is only supported "
370 "on files open for reading")
371 if not self._fp.seekable():
372 raise io.UnsupportedOperation("The underlying file object "
373 "does not support seeking")
374
375 # Fill the readahead buffer if it is empty. Returns False on EOF.
376 def _fill_buffer(self):
377 if self._mode == _MODE_READ_EOF:
378 return False
379 # Depending on the input data, our call to the decompressor may not
380 # return any data. In this case, try again after reading another block.
381 while self._buffer_offset == len(self._buffer):
382 try:
383 rawblock = (self._decompressor.unused_data or
384 self._fp.read(_BUFFER_SIZE))
385 if not rawblock:
386 raise EOFError
387 except EOFError:
388 # End-of-stream marker and end of file. We're good.
389 self._mode = _MODE_READ_EOF
390 self._size = self._pos
391 return False
392 else:
393 self._buffer = self._decompressor.decompress(rawblock)
394 self._buffer_offset = 0
395 return True
396
397 # Read data until EOF.
398 # If return_data is false, consume the data without returning it.
399 def _read_all(self, return_data=True):
400 # The loop assumes that _buffer_offset is 0. Ensure that this is true.
401 self._buffer = self._buffer[self._buffer_offset:]
402 self._buffer_offset = 0
403
404 blocks = []
405 while self._fill_buffer():
406 if return_data:
407 blocks.append(self._buffer)
408 self._pos += len(self._buffer)
409 self._buffer = b""
410 if return_data:
411 return b"".join(blocks)
412
413 # Read a block of up to n bytes.
414 # If return_data is false, consume the data without returning it.
415 def _read_block(self, n_bytes, return_data=True):
416 # If we have enough data buffered, return immediately.
417 end = self._buffer_offset + n_bytes
418 if end <= len(self._buffer):
419 data = self._buffer[self._buffer_offset: end]
420 self._buffer_offset = end
421 self._pos += len(data)
422 return data if return_data else None
423
424 # The loop assumes that _buffer_offset is 0. Ensure that this is true.
425 self._buffer = self._buffer[self._buffer_offset:]
426 self._buffer_offset = 0
427
428 blocks = []
429 while n_bytes > 0 and self._fill_buffer():
430 if n_bytes < len(self._buffer):
431 data = self._buffer[:n_bytes]
432 self._buffer_offset = n_bytes
433 else:
434 data = self._buffer
435 self._buffer = b""
436 if return_data:
437 blocks.append(data)
438 self._pos += len(data)
439 n_bytes -= len(data)
440 if return_data:
441 return b"".join(blocks)
442
443 def read(self, size=-1):
444 """Read up to size uncompressed bytes from the file.
445
446 If size is negative or omitted, read until EOF is reached.
447 Returns b'' if the file is already at EOF.
448 """
449 with self._lock:
450 self._check_can_read()
451 if size == 0:
452 return b""
453 elif size < 0:
454 return self._read_all()
455 else:
456 return self._read_block(size)
457
458 def readinto(self, b):
459 """Read up to len(b) bytes into b.
460
461 Returns the number of bytes read (0 for EOF).
462 """
463 with self._lock:
464 return io.BufferedIOBase.readinto(self, b)
465
466 def write(self, data):
467 """Write a byte string to the file.
468
469 Returns the number of uncompressed bytes written, which is
470 always len(data). Note that due to buffering, the file on disk
471 may not reflect the data written until close() is called.
472 """
473 with self._lock:
474 self._check_can_write()
475 # Convert data type if called by io.BufferedWriter.
476 if isinstance(data, memoryview):
477 data = data.tobytes()
478
479 compressed = self._compressor.compress(data)
480 self._fp.write(compressed)
481 self._pos += len(data)
482 return len(data)
483
484 # Rewind the file to the beginning of the data stream.
485 def _rewind(self):
486 self._fp.seek(0, 0)
487 self._mode = _MODE_READ
488 self._pos = 0
489 self._decompressor = zlib.decompressobj(self.wbits)
490 self._buffer = b""
491 self._buffer_offset = 0
492
493 def seek(self, offset, whence=0):
494 """Change the file position.
495
496 The new position is specified by offset, relative to the
497 position indicated by whence. Values for whence are:
498
499 0: start of stream (default); offset must not be negative
500 1: current stream position
501 2: end of stream; offset must not be positive
502
503 Returns the new file position.
504
505 Note that seeking is emulated, so depending on the parameters,
506 this operation may be extremely slow.
507 """
508 with self._lock:
509 self._check_can_seek()
510
511 # Recalculate offset as an absolute file position.
512 if whence == 0:
513 pass
514 elif whence == 1:
515 offset = self._pos + offset
516 elif whence == 2:
517 # Seeking relative to EOF - we need to know the file's size.
518 if self._size < 0:
519 self._read_all(return_data=False)
520 offset = self._size + offset
521 else:
522 raise ValueError("Invalid value for whence: %s" % (whence,))
523
524 # Make it so that offset is the number of bytes to skip forward.
525 if offset < self._pos:
526 self._rewind()
527 else:
528 offset -= self._pos
529
530 # Read and discard data until we reach the desired position.
531 self._read_block(offset, return_data=False)
532
533 return self._pos
534
535 def tell(self):
536 """Return the current file position."""
537 with self._lock:
538 self._check_not_closed()
539 return self._pos
540
541
542class ZlibCompressorWrapper(CompressorWrapper):
543
544 def __init__(self):
545 CompressorWrapper.__init__(self, obj=BinaryZlibFile,
546 prefix=_ZLIB_PREFIX, extension='.z')
547
548
549class BinaryGzipFile(BinaryZlibFile):
550 """A file object providing transparent gzip (de)compression.
551
552 If filename is a str or bytes object, it gives the name
553 of the file to be opened. Otherwise, it should be a file object,
554 which will be used to read or write the compressed data.
555
556 mode can be 'rb' for reading (default) or 'wb' for (over)writing
557
558 If mode is 'wb', compresslevel can be a number between 1
559 and 9 specifying the level of compression: 1 produces the least
560 compression, and 9 produces the most compression. 3 is the default.
561 """
562
563 wbits = 31 # zlib compressor/decompressor wbits value for gzip format.
564
565
566class GzipCompressorWrapper(CompressorWrapper):
567
568 def __init__(self):
569 CompressorWrapper.__init__(self, obj=BinaryGzipFile,
570 prefix=_GZIP_PREFIX, extension='.gz')