1import os
2import shutil
3from contextlib import suppress
4from shutil import (
5 ReadError,
6 _ensure_directory,
7 _get_gid,
8 _get_uid,
9 copyfileobj,
10 register_archive_format,
11 register_unpack_format,
12 unregister_archive_format,
13 unregister_unpack_format,
14)
15
16try:
17 import zlib
18 del zlib
19 _ZLIB_SUPPORTED = True
20except ImportError:
21 _ZLIB_SUPPORTED = False
22
23try:
24 import bz2
25 del bz2
26 _BZ2_SUPPORTED = True
27except ImportError:
28 _BZ2_SUPPORTED = False
29
30try:
31 import lzma
32 del lzma
33 _LZMA_SUPPORTED = True
34except ImportError:
35 _LZMA_SUPPORTED = False
36
37_ZSTD_SUPPORTED = True
38
39
40def _make_tarball(base_name, base_dir, compress="gzip", verbose=0, dry_run=0,
41 owner=None, group=None, logger=None, root_dir=None):
42 """Create a (possibly compressed) tar file from all the files under
43 'base_dir'.
44
45 'compress' must be "gzip" (the default), "bzip2", "xz", "zst", or None.
46
47 'owner' and 'group' can be used to define an owner and a group for the
48 archive that is being built. If not provided, the current owner and group
49 will be used.
50
51 The output tar file will be named 'base_name' + ".tar", possibly plus
52 the appropriate compression extension (".gz", ".bz2", ".xz", or ".zst").
53
54 Returns the output filename.
55 """
56 if compress is None:
57 tar_compression = ''
58 elif _ZLIB_SUPPORTED and compress == 'gzip':
59 tar_compression = 'gz'
60 elif _BZ2_SUPPORTED and compress == 'bzip2':
61 tar_compression = 'bz2'
62 elif _LZMA_SUPPORTED and compress == 'xz':
63 tar_compression = 'xz'
64 elif _ZSTD_SUPPORTED and compress == 'zst':
65 tar_compression = 'zst'
66 else:
67 raise ValueError("bad value for 'compress', or compression format not "
68 "supported : {0}".format(compress))
69
70 compress_ext = '.' + tar_compression if compress else ''
71 archive_name = base_name + '.tar' + compress_ext
72 archive_dir = os.path.dirname(archive_name)
73
74 if archive_dir and not os.path.exists(archive_dir):
75 if logger is not None:
76 logger.info("creating %s", archive_dir)
77 if not dry_run:
78 os.makedirs(archive_dir)
79
80 # creating the tarball
81 if logger is not None:
82 logger.info('Creating tar archive')
83
84 uid = _get_uid(owner)
85 gid = _get_gid(group)
86
87 def _set_uid_gid(tarinfo):
88 if gid is not None:
89 tarinfo.gid = gid
90 tarinfo.gname = group
91 if uid is not None:
92 tarinfo.uid = uid
93 tarinfo.uname = owner
94 return tarinfo
95
96 if not dry_run:
97 from backports.zstd import tarfile
98
99 tar = tarfile.open(archive_name, 'w|%s' % tar_compression)
100 arcname = base_dir
101 if root_dir is not None:
102 base_dir = os.path.join(root_dir, base_dir)
103 try:
104 tar.add(base_dir, arcname, filter=_set_uid_gid)
105 finally:
106 tar.close()
107
108 if root_dir is not None:
109 archive_name = os.path.abspath(archive_name)
110 return archive_name
111
112_make_tarball.supports_root_dir = True
113
114def _unpack_zipfile(filename, extract_dir):
115 """Unpack zip `filename` to `extract_dir`
116 """
117 from backports.zstd import zipfile
118
119 if not zipfile.is_zipfile(filename):
120 raise ReadError("%s is not a zip file" % filename)
121
122 with zipfile.ZipFile(filename) as zip:
123 zip._ignore_invalid_names = True
124 zip.extractall(extract_dir)
125
126def _unpack_tarfile(filename, extract_dir, *, filter=None):
127 """Unpack tar/tar.gz/tar.bz2/tar.xz/tar.zst `filename` to `extract_dir`
128 """
129 from backports.zstd import tarfile
130
131 try:
132 tarobj = tarfile.open(filename)
133 except tarfile.TarError:
134 raise ReadError(
135 "%s is not a compressed or uncompressed tar file" % filename)
136 try:
137 tarobj.extractall(extract_dir, filter=filter)
138 finally:
139 tarobj.close()
140
141
142
143def register_shutil(*, tar=True, zip=True):
144 """Register support for Zstandard in shutil's archiving operations.
145
146 tar
147 Register support for zstdtar archive format and .tar.zst/.tzst unpacking extensions.
148 Defaults to True.
149 zip
150 Register support for .zip unpacking extension.
151 Defaults to True.
152 """
153 if tar:
154 name = "zstdtar"
155 description = "zstd'ed tar-file"
156 with suppress(KeyError):
157 unregister_archive_format(name)
158 with suppress(KeyError):
159 unregister_unpack_format(name)
160 register_archive_format(name, _make_tarball, [("compress", "zst")], description)
161 register_unpack_format(name, [".tar.zst", ".tzst"], _unpack_tarfile, [], description)
162 if zip:
163 name = "zip"
164 description = "ZIP file"
165 with suppress(KeyError):
166 unregister_unpack_format(name)
167 register_unpack_format(name, [".zip"], _unpack_zipfile, [], description)