1"""Helper functions for a standard streaming compression API"""
2
3from zipfile import ZipFile
4
5import fsspec.utils
6from fsspec.spec import AbstractBufferedFile
7
8
9def noop_file(file, mode, **kwargs):
10 return file
11
12
13# TODO: files should also be available as contexts
14# should be functions of the form func(infile, mode=, **kwargs) -> file-like
15compr = {None: noop_file}
16
17
18def register_compression(name, callback, extensions, force=False):
19 """Register an "inferable" file compression type.
20
21 Registers transparent file compression type for use with fsspec.open.
22 Compression can be specified by name in open, or "infer"-ed for any files
23 ending with the given extensions.
24
25 Args:
26 name: (str) The compression type name. Eg. "gzip".
27 callback: A callable of form (infile, mode, **kwargs) -> file-like.
28 Accepts an input file-like object, the target mode and kwargs.
29 Returns a wrapped file-like object.
30 extensions: (str, Iterable[str]) A file extension, or list of file
31 extensions for which to infer this compression scheme. Eg. "gz".
32 force: (bool) Force re-registration of compression type or extensions.
33
34 Raises:
35 ValueError: If name or extensions already registered, and not force.
36
37 """
38 if isinstance(extensions, str):
39 extensions = [extensions]
40
41 # Validate registration
42 if name in compr and not force:
43 raise ValueError(f"Duplicate compression registration: {name}")
44
45 for ext in extensions:
46 if ext in fsspec.utils.compressions and not force:
47 raise ValueError(f"Duplicate compression file extension: {ext} ({name})")
48
49 compr[name] = callback
50
51 for ext in extensions:
52 fsspec.utils.compressions[ext] = name
53
54
55def unzip(infile, mode="rb", filename=None, **kwargs):
56 if "r" not in mode:
57 filename = filename or "file"
58 z = ZipFile(infile, mode="w", **kwargs)
59 fo = z.open(filename, mode="w")
60 fo.close = lambda closer=fo.close: closer() or z.close()
61 return fo
62 z = ZipFile(infile)
63 if filename is None:
64 filename = z.namelist()[0]
65 return z.open(filename, mode="r", **kwargs)
66
67
68register_compression("zip", unzip, "zip")
69
70try:
71 from bz2 import BZ2File
72except ImportError:
73 pass
74else:
75 register_compression("bz2", BZ2File, "bz2")
76
77try: # pragma: no cover
78 from isal import igzip
79
80 def isal(infile, mode="rb", **kwargs):
81 return igzip.IGzipFile(fileobj=infile, mode=mode, **kwargs)
82
83 register_compression("gzip", isal, "gz")
84except ImportError:
85 from gzip import GzipFile
86
87 register_compression(
88 "gzip", lambda f, **kwargs: GzipFile(fileobj=f, **kwargs), "gz"
89 )
90
91try:
92 from lzma import LZMAFile
93
94 register_compression("lzma", LZMAFile, "lzma")
95 register_compression("xz", LZMAFile, "xz")
96except ImportError:
97 pass
98
99try:
100 import lzmaffi
101
102 register_compression("lzma", lzmaffi.LZMAFile, "lzma", force=True)
103 register_compression("xz", lzmaffi.LZMAFile, "xz", force=True)
104except ImportError:
105 pass
106
107
108class SnappyFile(AbstractBufferedFile):
109 def __init__(self, infile, mode, **kwargs):
110 import snappy
111
112 super().__init__(
113 fs=None, path="snappy", mode=mode.strip("b") + "b", size=999999999, **kwargs
114 )
115 self.infile = infile
116 if "r" in mode:
117 self.codec = snappy.StreamDecompressor()
118 else:
119 self.codec = snappy.StreamCompressor()
120
121 def _upload_chunk(self, final=False):
122 self.buffer.seek(0)
123 out = self.codec.add_chunk(self.buffer.read())
124 self.infile.write(out)
125 return True
126
127 def seek(self, loc, whence=0):
128 raise NotImplementedError("SnappyFile is not seekable")
129
130 def seekable(self):
131 return False
132
133 def _fetch_range(self, start, end):
134 """Get the specified set of bytes from remote"""
135 data = self.infile.read(end - start)
136 return self.codec.decompress(data)
137
138
139try:
140 import snappy
141
142 snappy.compress(b"")
143 # Snappy may use the .sz file extension, but this is not part of the
144 # standard implementation.
145 register_compression("snappy", SnappyFile, [])
146
147except (ImportError, NameError, AttributeError):
148 pass
149
150try:
151 import lz4.frame
152
153 register_compression("lz4", lz4.frame.open, "lz4")
154except ImportError:
155 pass
156
157try:
158 import zstandard as zstd
159
160 def zstandard_file(infile, mode="rb"):
161 if "r" in mode:
162 cctx = zstd.ZstdDecompressor()
163 return cctx.stream_reader(infile)
164 else:
165 cctx = zstd.ZstdCompressor(level=10)
166 return cctx.stream_writer(infile)
167
168 register_compression("zstd", zstandard_file, "zst")
169except ImportError:
170 pass
171
172
173def available_compressions():
174 """Return a list of the implemented compressions."""
175 return list(compr)