1"""Numpy pickle compatibility functions."""
2
3import inspect
4import os
5import pickle
6import zlib
7from io import BytesIO
8
9from .numpy_pickle_utils import (
10 _ZFILE_PREFIX,
11 Unpickler,
12 _ensure_native_byte_order,
13 _reconstruct,
14)
15
16
17def hex_str(an_int):
18 """Convert an int to an hexadecimal string."""
19 return "{:#x}".format(an_int)
20
21
22def asbytes(s):
23 if isinstance(s, bytes):
24 return s
25 return s.encode("latin1")
26
27
28_MAX_LEN = len(hex_str(2**64))
29_CHUNK_SIZE = 64 * 1024
30
31
32def read_zfile(file_handle):
33 """Read the z-file and return the content as a string.
34
35 Z-files are raw data compressed with zlib used internally by joblib
36 for persistence. Backward compatibility is not guaranteed. Do not
37 use for external purposes.
38 """
39 file_handle.seek(0)
40 header_length = len(_ZFILE_PREFIX) + _MAX_LEN
41 length = file_handle.read(header_length)
42 length = length[len(_ZFILE_PREFIX) :]
43 length = int(length, 16)
44
45 # With python2 and joblib version <= 0.8.4 compressed pickle header is one
46 # character wider so we need to ignore an additional space if present.
47 # Note: the first byte of the zlib data is guaranteed not to be a
48 # space according to
49 # https://tools.ietf.org/html/rfc6713#section-2.1
50 next_byte = file_handle.read(1)
51 if next_byte != b" ":
52 # The zlib compressed data has started and we need to go back
53 # one byte
54 file_handle.seek(header_length)
55
56 # We use the known length of the data to tell Zlib the size of the
57 # buffer to allocate.
58 data = zlib.decompress(file_handle.read(), 15, length)
59 assert len(data) == length, (
60 "Incorrect data length while decompressing %s."
61 "The file could be corrupted." % file_handle
62 )
63 return data
64
65
66def write_zfile(file_handle, data, compress=1):
67 """Write the data in the given file as a Z-file.
68
69 Z-files are raw data compressed with zlib used internally by joblib
70 for persistence. Backward compatibility is not guaranteed. Do not
71 use for external purposes.
72 """
73 file_handle.write(_ZFILE_PREFIX)
74 length = hex_str(len(data))
75 # Store the length of the data
76 file_handle.write(asbytes(length.ljust(_MAX_LEN)))
77 file_handle.write(zlib.compress(asbytes(data), compress))
78
79
80###############################################################################
81# Utility objects for persistence.
82
83
84class NDArrayWrapper(object):
85 """An object to be persisted instead of numpy arrays.
86
87 The only thing this object does, is to carry the filename in which
88 the array has been persisted, and the array subclass.
89 """
90
91 def __init__(self, filename, subclass, allow_mmap=True):
92 """Constructor. Store the useful information for later."""
93 self.filename = filename
94 self.subclass = subclass
95 self.allow_mmap = allow_mmap
96
97 def read(self, unpickler):
98 """Reconstruct the array."""
99 filename = os.path.join(unpickler._dirname, self.filename)
100 # Load the array from the disk
101 # use getattr instead of self.allow_mmap to ensure backward compat
102 # with NDArrayWrapper instances pickled with joblib < 0.9.0
103 allow_mmap = getattr(self, "allow_mmap", True)
104 kwargs = {}
105 if allow_mmap:
106 kwargs["mmap_mode"] = unpickler.mmap_mode
107 if "allow_pickle" in inspect.signature(unpickler.np.load).parameters:
108 # Required in numpy 1.16.3 and later to acknowledge the security
109 # risk.
110 kwargs["allow_pickle"] = True
111 array = unpickler.np.load(filename, **kwargs)
112
113 # Detect byte order mismatch and swap as needed.
114 array = _ensure_native_byte_order(array)
115
116 # Reconstruct subclasses. This does not work with old
117 # versions of numpy
118 if hasattr(array, "__array_prepare__") and self.subclass not in (
119 unpickler.np.ndarray,
120 unpickler.np.memmap,
121 ):
122 # We need to reconstruct another subclass
123 new_array = _reconstruct(self.subclass, (0,), "b")
124 return new_array.__array_prepare__(array)
125 else:
126 return array
127
128
129class ZNDArrayWrapper(NDArrayWrapper):
130 """An object to be persisted instead of numpy arrays.
131
132 This object store the Zfile filename in which
133 the data array has been persisted, and the meta information to
134 retrieve it.
135 The reason that we store the raw buffer data of the array and
136 the meta information, rather than array representation routine
137 (tobytes) is that it enables us to use completely the strided
138 model to avoid memory copies (a and a.T store as fast). In
139 addition saving the heavy information separately can avoid
140 creating large temporary buffers when unpickling data with
141 large arrays.
142 """
143
144 def __init__(self, filename, init_args, state):
145 """Constructor. Store the useful information for later."""
146 self.filename = filename
147 self.state = state
148 self.init_args = init_args
149
150 def read(self, unpickler):
151 """Reconstruct the array from the meta-information and the z-file."""
152 # Here we a simply reproducing the unpickling mechanism for numpy
153 # arrays
154 filename = os.path.join(unpickler._dirname, self.filename)
155 array = _reconstruct(*self.init_args)
156 with open(filename, "rb") as f:
157 data = read_zfile(f)
158 state = self.state + (data,)
159 array.__setstate__(state)
160 return array
161
162
163class ZipNumpyUnpickler(Unpickler):
164 """A subclass of the Unpickler to unpickle our numpy pickles."""
165
166 dispatch = Unpickler.dispatch.copy()
167
168 def __init__(self, filename, file_handle, mmap_mode=None):
169 """Constructor."""
170 self._filename = os.path.basename(filename)
171 self._dirname = os.path.dirname(filename)
172 self.mmap_mode = mmap_mode
173 self.file_handle = self._open_pickle(file_handle)
174 Unpickler.__init__(self, self.file_handle)
175 try:
176 import numpy as np
177 except ImportError:
178 np = None
179 self.np = np
180
181 def _open_pickle(self, file_handle):
182 return BytesIO(read_zfile(file_handle))
183
184 def load_build(self):
185 """Set the state of a newly created object.
186
187 We capture it to replace our place-holder objects,
188 NDArrayWrapper, by the array we are interested in. We
189 replace them directly in the stack of pickler.
190 """
191 Unpickler.load_build(self)
192 if isinstance(self.stack[-1], NDArrayWrapper):
193 if self.np is None:
194 raise ImportError(
195 "Trying to unpickle an ndarray, but numpy didn't import correctly"
196 )
197 nd_array_wrapper = self.stack.pop()
198 array = nd_array_wrapper.read(self)
199 self.stack.append(array)
200
201 dispatch[pickle.BUILD[0]] = load_build
202
203
204def load_compatibility(filename):
205 """Reconstruct a Python object from a file persisted with joblib.dump.
206
207 This function ensures the compatibility with joblib old persistence format
208 (<= 0.9.3).
209
210 Parameters
211 ----------
212 filename: string
213 The name of the file from which to load the object
214
215 Returns
216 -------
217 result: any Python object
218 The object stored in the file.
219
220 See Also
221 --------
222 joblib.dump : function to save an object
223
224 Notes
225 -----
226
227 This function can load numpy array files saved separately during the
228 dump.
229 """
230 with open(filename, "rb") as file_handle:
231 # We are careful to open the file handle early and keep it open to
232 # avoid race-conditions on renames. That said, if data is stored in
233 # companion files, moving the directory will create a race when
234 # joblib tries to access the companion files.
235 unpickler = ZipNumpyUnpickler(filename, file_handle=file_handle)
236 try:
237 obj = unpickler.load()
238 except UnicodeDecodeError as exc:
239 # More user-friendly error message
240 new_exc = ValueError(
241 "You may be trying to read with "
242 "python 3 a joblib pickle generated with python 2. "
243 "This feature is not supported by joblib."
244 )
245 new_exc.__cause__ = exc
246 raise new_exc
247 finally:
248 if hasattr(unpickler, "file_handle"):
249 unpickler.file_handle.close()
250 return obj