1import os
2import shutil
3from contextlib import suppress
4from shutil import (
5 ReadError,
6 _ensure_directory,
7 _get_gid,
8 _get_uid,
9 copyfileobj,
10 register_archive_format,
11 register_unpack_format,
12 unregister_archive_format,
13 unregister_unpack_format,
14)
15
16try:
17 import zlib
18 del zlib
19 _ZLIB_SUPPORTED = True
20except ImportError:
21 _ZLIB_SUPPORTED = False
22
23try:
24 import bz2
25 del bz2
26 _BZ2_SUPPORTED = True
27except ImportError:
28 _BZ2_SUPPORTED = False
29
30try:
31 import lzma
32 del lzma
33 _LZMA_SUPPORTED = True
34except ImportError:
35 _LZMA_SUPPORTED = False
36
37_ZSTD_SUPPORTED = True
38
39
40def _make_tarball(base_name, base_dir, compress="gzip", verbose=0, dry_run=0,
41 owner=None, group=None, logger=None, root_dir=None):
42 """Create a (possibly compressed) tar file from all the files under
43 'base_dir'.
44
45 'compress' must be "gzip" (the default), "bzip2", "xz", "zst", or None.
46
47 'owner' and 'group' can be used to define an owner and a group for the
48 archive that is being built. If not provided, the current owner and group
49 will be used.
50
51 The output tar file will be named 'base_name' + ".tar", possibly plus
52 the appropriate compression extension (".gz", ".bz2", ".xz", or ".zst").
53
54 Returns the output filename.
55 """
56 if compress is None:
57 tar_compression = ''
58 elif _ZLIB_SUPPORTED and compress == 'gzip':
59 tar_compression = 'gz'
60 elif _BZ2_SUPPORTED and compress == 'bzip2':
61 tar_compression = 'bz2'
62 elif _LZMA_SUPPORTED and compress == 'xz':
63 tar_compression = 'xz'
64 elif _ZSTD_SUPPORTED and compress == 'zst':
65 tar_compression = 'zst'
66 else:
67 raise ValueError("bad value for 'compress', or compression format not "
68 "supported : {0}".format(compress))
69
70 compress_ext = '.' + tar_compression if compress else ''
71 archive_name = base_name + '.tar' + compress_ext
72 archive_dir = os.path.dirname(archive_name)
73
74 if archive_dir and not os.path.exists(archive_dir):
75 if logger is not None:
76 logger.info("creating %s", archive_dir)
77 if not dry_run:
78 os.makedirs(archive_dir)
79
80 # creating the tarball
81 if logger is not None:
82 logger.info('Creating tar archive')
83
84 uid = _get_uid(owner)
85 gid = _get_gid(group)
86
87 def _set_uid_gid(tarinfo):
88 if gid is not None:
89 tarinfo.gid = gid
90 tarinfo.gname = group
91 if uid is not None:
92 tarinfo.uid = uid
93 tarinfo.uname = owner
94 return tarinfo
95
96 if not dry_run:
97 from backports.zstd import tarfile
98
99 tar = tarfile.open(archive_name, 'w|%s' % tar_compression)
100 arcname = base_dir
101 if root_dir is not None:
102 base_dir = os.path.join(root_dir, base_dir)
103 try:
104 tar.add(base_dir, arcname, filter=_set_uid_gid)
105 finally:
106 tar.close()
107
108 if root_dir is not None:
109 archive_name = os.path.abspath(archive_name)
110 return archive_name
111
112_make_tarball.supports_root_dir = True
113
114def _unpack_zipfile(filename, extract_dir):
115 """Unpack zip `filename` to `extract_dir`
116 """
117 from backports.zstd import zipfile
118
119 if not zipfile.is_zipfile(filename):
120 raise ReadError("%s is not a zip file" % filename)
121
122 zip = zipfile.ZipFile(filename)
123 try:
124 for info in zip.infolist():
125 name = info.filename
126
127 # don't extract absolute paths or ones with .. in them
128 if name.startswith('/') or '..' in name:
129 continue
130
131 targetpath = os.path.join(extract_dir, *name.split('/'))
132 if not targetpath:
133 continue
134
135 _ensure_directory(targetpath)
136 if not name.endswith('/'):
137 # file
138 with zip.open(name, 'r') as source, \
139 open(targetpath, 'wb') as target:
140 copyfileobj(source, target)
141 finally:
142 zip.close()
143
144def _unpack_tarfile(filename, extract_dir, *, filter=None):
145 """Unpack tar/tar.gz/tar.bz2/tar.xz/tar.zst `filename` to `extract_dir`
146 """
147 from backports.zstd import tarfile
148
149 try:
150 tarobj = tarfile.open(filename)
151 except tarfile.TarError:
152 raise ReadError(
153 "%s is not a compressed or uncompressed tar file" % filename)
154 try:
155 tarobj.extractall(extract_dir, filter=filter)
156 finally:
157 tarobj.close()
158
159
160
161def register_shutil(*, tar=True, zip=True):
162 """Register support for Zstandard in shutil's archiving operations.
163
164 tar
165 Register support for zstdtar archive format and .tar.zst/.tzst unpacking extensions.
166 Defaults to True.
167 zip
168 Register support for .zip unpacking extension.
169 Defaults to True.
170 """
171 if tar:
172 name = "zstdtar"
173 description = "zstd'ed tar-file"
174 with suppress(KeyError):
175 unregister_archive_format(name)
176 with suppress(KeyError):
177 unregister_unpack_format(name)
178 register_archive_format(name, _make_tarball, [("compress", "zst")], description)
179 register_unpack_format(name, [".tar.zst", ".tzst"], _unpack_tarfile, [], description)
180 if zip:
181 name = "zip"
182 description = "ZIP file"
183 with suppress(KeyError):
184 unregister_unpack_format(name)
185 register_unpack_format(name, [".zip"], _unpack_zipfile, [], description)