1"""Numpy pickle compatibility functions."""
2
3import pickle
4import os
5import zlib
6import inspect
7
8from io import BytesIO
9
10from .numpy_pickle_utils import _ZFILE_PREFIX
11from .numpy_pickle_utils import Unpickler
12from .numpy_pickle_utils import _ensure_native_byte_order
13
14
15def hex_str(an_int):
16 """Convert an int to an hexadecimal string."""
17 return '{:#x}'.format(an_int)
18
19
20def asbytes(s):
21 if isinstance(s, bytes):
22 return s
23 return s.encode('latin1')
24
25
26_MAX_LEN = len(hex_str(2 ** 64))
27_CHUNK_SIZE = 64 * 1024
28
29
30def read_zfile(file_handle):
31 """Read the z-file and return the content as a string.
32
33 Z-files are raw data compressed with zlib used internally by joblib
34 for persistence. Backward compatibility is not guaranteed. Do not
35 use for external purposes.
36 """
37 file_handle.seek(0)
38 header_length = len(_ZFILE_PREFIX) + _MAX_LEN
39 length = file_handle.read(header_length)
40 length = length[len(_ZFILE_PREFIX):]
41 length = int(length, 16)
42
43 # With python2 and joblib version <= 0.8.4 compressed pickle header is one
44 # character wider so we need to ignore an additional space if present.
45 # Note: the first byte of the zlib data is guaranteed not to be a
46 # space according to
47 # https://tools.ietf.org/html/rfc6713#section-2.1
48 next_byte = file_handle.read(1)
49 if next_byte != b' ':
50 # The zlib compressed data has started and we need to go back
51 # one byte
52 file_handle.seek(header_length)
53
54 # We use the known length of the data to tell Zlib the size of the
55 # buffer to allocate.
56 data = zlib.decompress(file_handle.read(), 15, length)
57 assert len(data) == length, (
58 "Incorrect data length while decompressing %s."
59 "The file could be corrupted." % file_handle)
60 return data
61
62
63def write_zfile(file_handle, data, compress=1):
64 """Write the data in the given file as a Z-file.
65
66 Z-files are raw data compressed with zlib used internally by joblib
67 for persistence. Backward compatibility is not guaranteed. Do not
68 use for external purposes.
69 """
70 file_handle.write(_ZFILE_PREFIX)
71 length = hex_str(len(data))
72 # Store the length of the data
73 file_handle.write(asbytes(length.ljust(_MAX_LEN)))
74 file_handle.write(zlib.compress(asbytes(data), compress))
75
76###############################################################################
77# Utility objects for persistence.
78
79
80class NDArrayWrapper(object):
81 """An object to be persisted instead of numpy arrays.
82
83 The only thing this object does, is to carry the filename in which
84 the array has been persisted, and the array subclass.
85 """
86
87 def __init__(self, filename, subclass, allow_mmap=True):
88 """Constructor. Store the useful information for later."""
89 self.filename = filename
90 self.subclass = subclass
91 self.allow_mmap = allow_mmap
92
93 def read(self, unpickler):
94 """Reconstruct the array."""
95 filename = os.path.join(unpickler._dirname, self.filename)
96 # Load the array from the disk
97 # use getattr instead of self.allow_mmap to ensure backward compat
98 # with NDArrayWrapper instances pickled with joblib < 0.9.0
99 allow_mmap = getattr(self, 'allow_mmap', True)
100 kwargs = {}
101 if allow_mmap:
102 kwargs['mmap_mode'] = unpickler.mmap_mode
103 if "allow_pickle" in inspect.signature(unpickler.np.load).parameters:
104 # Required in numpy 1.16.3 and later to aknowledge the security
105 # risk.
106 kwargs["allow_pickle"] = True
107 array = unpickler.np.load(filename, **kwargs)
108
109 # Detect byte order mismatch and swap as needed.
110 array = _ensure_native_byte_order(array)
111
112 # Reconstruct subclasses. This does not work with old
113 # versions of numpy
114 if (hasattr(array, '__array_prepare__') and
115 self.subclass not in (unpickler.np.ndarray,
116 unpickler.np.memmap)):
117 # We need to reconstruct another subclass
118 new_array = unpickler.np.core.multiarray._reconstruct(
119 self.subclass, (0,), 'b')
120 return new_array.__array_prepare__(array)
121 else:
122 return array
123
124
125class ZNDArrayWrapper(NDArrayWrapper):
126 """An object to be persisted instead of numpy arrays.
127
128 This object store the Zfile filename in which
129 the data array has been persisted, and the meta information to
130 retrieve it.
131 The reason that we store the raw buffer data of the array and
132 the meta information, rather than array representation routine
133 (tobytes) is that it enables us to use completely the strided
134 model to avoid memory copies (a and a.T store as fast). In
135 addition saving the heavy information separately can avoid
136 creating large temporary buffers when unpickling data with
137 large arrays.
138 """
139
140 def __init__(self, filename, init_args, state):
141 """Constructor. Store the useful information for later."""
142 self.filename = filename
143 self.state = state
144 self.init_args = init_args
145
146 def read(self, unpickler):
147 """Reconstruct the array from the meta-information and the z-file."""
148 # Here we a simply reproducing the unpickling mechanism for numpy
149 # arrays
150 filename = os.path.join(unpickler._dirname, self.filename)
151 array = unpickler.np.core.multiarray._reconstruct(*self.init_args)
152 with open(filename, 'rb') as f:
153 data = read_zfile(f)
154 state = self.state + (data,)
155 array.__setstate__(state)
156 return array
157
158
159class ZipNumpyUnpickler(Unpickler):
160 """A subclass of the Unpickler to unpickle our numpy pickles."""
161
162 dispatch = Unpickler.dispatch.copy()
163
164 def __init__(self, filename, file_handle, mmap_mode=None):
165 """Constructor."""
166 self._filename = os.path.basename(filename)
167 self._dirname = os.path.dirname(filename)
168 self.mmap_mode = mmap_mode
169 self.file_handle = self._open_pickle(file_handle)
170 Unpickler.__init__(self, self.file_handle)
171 try:
172 import numpy as np
173 except ImportError:
174 np = None
175 self.np = np
176
177 def _open_pickle(self, file_handle):
178 return BytesIO(read_zfile(file_handle))
179
180 def load_build(self):
181 """Set the state of a newly created object.
182
183 We capture it to replace our place-holder objects,
184 NDArrayWrapper, by the array we are interested in. We
185 replace them directly in the stack of pickler.
186 """
187 Unpickler.load_build(self)
188 if isinstance(self.stack[-1], NDArrayWrapper):
189 if self.np is None:
190 raise ImportError("Trying to unpickle an ndarray, "
191 "but numpy didn't import correctly")
192 nd_array_wrapper = self.stack.pop()
193 array = nd_array_wrapper.read(self)
194 self.stack.append(array)
195
196 dispatch[pickle.BUILD[0]] = load_build
197
198
199def load_compatibility(filename):
200 """Reconstruct a Python object from a file persisted with joblib.dump.
201
202 This function ensures the compatibility with joblib old persistence format
203 (<= 0.9.3).
204
205 Parameters
206 ----------
207 filename: string
208 The name of the file from which to load the object
209
210 Returns
211 -------
212 result: any Python object
213 The object stored in the file.
214
215 See Also
216 --------
217 joblib.dump : function to save an object
218
219 Notes
220 -----
221
222 This function can load numpy array files saved separately during the
223 dump.
224 """
225 with open(filename, 'rb') as file_handle:
226 # We are careful to open the file handle early and keep it open to
227 # avoid race-conditions on renames. That said, if data is stored in
228 # companion files, moving the directory will create a race when
229 # joblib tries to access the companion files.
230 unpickler = ZipNumpyUnpickler(filename, file_handle=file_handle)
231 try:
232 obj = unpickler.load()
233 except UnicodeDecodeError as exc:
234 # More user-friendly error message
235 new_exc = ValueError(
236 'You may be trying to read with '
237 'python 3 a joblib pickle generated with python 2. '
238 'This feature is not supported by joblib.')
239 new_exc.__cause__ = exc
240 raise new_exc
241 finally:
242 if hasattr(unpickler, 'file_handle'):
243 unpickler.file_handle.close()
244 return obj