1"""Utilities for fast persistence of big data, with optional compression."""
2
3# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
4# Copyright (c) 2009 Gael Varoquaux
5# License: BSD Style, 3 clauses.
6
7import contextlib
8import io
9import pickle
10import sys
11import warnings
12
13from .compressor import _COMPRESSORS, _ZFILE_PREFIX
14
15try:
16 import numpy as np
17except ImportError:
18 np = None
19
20Unpickler = pickle._Unpickler
21Pickler = pickle._Pickler
22xrange = range
23
24
25try:
26 # The python standard library can be built without bz2 so we make bz2
27 # usage optional.
28 # see https://github.com/scikit-learn/scikit-learn/issues/7526 for more
29 # details.
30 import bz2
31except ImportError:
32 bz2 = None
33
34# Buffer size used in io.BufferedReader and io.BufferedWriter
35_IO_BUFFER_SIZE = 1024**2
36
37
38def _is_raw_file(fileobj):
39 """Check if fileobj is a raw file object, e.g created with open."""
40 fileobj = getattr(fileobj, "raw", fileobj)
41 return isinstance(fileobj, io.FileIO)
42
43
44def _get_prefixes_max_len():
45 # Compute the max prefix len of registered compressors.
46 prefixes = [len(compressor.prefix) for compressor in _COMPRESSORS.values()]
47 prefixes += [len(_ZFILE_PREFIX)]
48 return max(prefixes)
49
50
51def _is_numpy_array_byte_order_mismatch(array):
52 """Check if numpy array is having byte order mismatch"""
53 return (
54 sys.byteorder == "big"
55 and (
56 array.dtype.byteorder == "<"
57 or (
58 array.dtype.byteorder == "|"
59 and array.dtype.fields
60 and all(e[0].byteorder == "<" for e in array.dtype.fields.values())
61 )
62 )
63 ) or (
64 sys.byteorder == "little"
65 and (
66 array.dtype.byteorder == ">"
67 or (
68 array.dtype.byteorder == "|"
69 and array.dtype.fields
70 and all(e[0].byteorder == ">" for e in array.dtype.fields.values())
71 )
72 )
73 )
74
75
76def _ensure_native_byte_order(array):
77 """Use the byte order of the host while preserving values
78
79 Does nothing if array already uses the system byte order.
80 """
81 if _is_numpy_array_byte_order_mismatch(array):
82 array = array.byteswap().view(array.dtype.newbyteorder("="))
83 return array
84
85
86###############################################################################
87# Cache file utilities
88def _detect_compressor(fileobj):
89 """Return the compressor matching fileobj.
90
91 Parameters
92 ----------
93 fileobj: file object
94
95 Returns
96 -------
97 str in {'zlib', 'gzip', 'bz2', 'lzma', 'xz', 'compat', 'not-compressed'}
98 """
99 # Read the magic number in the first bytes of the file.
100 max_prefix_len = _get_prefixes_max_len()
101 if hasattr(fileobj, "peek"):
102 # Peek allows to read those bytes without moving the cursor in the
103 # file which.
104 first_bytes = fileobj.peek(max_prefix_len)
105 else:
106 # Fallback to seek if the fileobject is not peekable.
107 first_bytes = fileobj.read(max_prefix_len)
108 fileobj.seek(0)
109
110 if first_bytes.startswith(_ZFILE_PREFIX):
111 return "compat"
112 else:
113 for name, compressor in _COMPRESSORS.items():
114 if first_bytes.startswith(compressor.prefix):
115 return name
116
117 return "not-compressed"
118
119
120def _buffered_read_file(fobj):
121 """Return a buffered version of a read file object."""
122 return io.BufferedReader(fobj, buffer_size=_IO_BUFFER_SIZE)
123
124
125def _buffered_write_file(fobj):
126 """Return a buffered version of a write file object."""
127 return io.BufferedWriter(fobj, buffer_size=_IO_BUFFER_SIZE)
128
129
130@contextlib.contextmanager
131def _validate_fileobject_and_memmap(fileobj, filename, mmap_mode=None):
132 """Utility function opening the right fileobject from a filename.
133
134 The magic number is used to choose between the type of file object to open:
135 * regular file object (default)
136 * zlib file object
137 * gzip file object
138 * bz2 file object
139 * lzma file object (for xz and lzma compressor)
140
141 Parameters
142 ----------
143 fileobj: file object
144 filename: str
145 filename path corresponding to the fileobj parameter.
146 mmap_mode: str
147 memory map mode that should be used to open the pickle file. This
148 parameter is useful to verify that the user is not trying to one with
149 compression. Default: None.
150
151 Returns
152 -------
153 a tuple with a file like object, and the validated mmap_mode.
154
155 """
156 # Detect if the fileobj contains compressed data.
157 compressor = _detect_compressor(fileobj)
158 validated_mmap_mode = mmap_mode
159
160 if compressor == "compat":
161 # Compatibility with old pickle mode: simply return the input
162 # filename "as-is" and let the compatibility function be called by the
163 # caller.
164 warnings.warn(
165 "The file '%s' has been generated with a joblib "
166 "version less than 0.10. "
167 "Please regenerate this pickle file." % filename,
168 DeprecationWarning,
169 stacklevel=2,
170 )
171 yield filename, validated_mmap_mode
172 else:
173 if compressor in _COMPRESSORS:
174 # based on the compressor detected in the file, we open the
175 # correct decompressor file object, wrapped in a buffer.
176 compressor_wrapper = _COMPRESSORS[compressor]
177 inst = compressor_wrapper.decompressor_file(fileobj)
178 fileobj = _buffered_read_file(inst)
179
180 # Checking if incompatible load parameters with the type of file:
181 # mmap_mode cannot be used with compressed file or in memory buffers
182 # such as io.BytesIO.
183 if mmap_mode is not None:
184 validated_mmap_mode = None
185 if isinstance(fileobj, io.BytesIO):
186 warnings.warn(
187 "In memory persistence is not compatible with "
188 'mmap_mode "%(mmap_mode)s" flag passed. '
189 "mmap_mode option will be ignored." % locals(),
190 stacklevel=2,
191 )
192 elif compressor != "not-compressed":
193 warnings.warn(
194 'mmap_mode "%(mmap_mode)s" is not compatible '
195 "with compressed file %(filename)s. "
196 '"%(mmap_mode)s" flag will be ignored.' % locals(),
197 stacklevel=2,
198 )
199 elif not _is_raw_file(fileobj):
200 warnings.warn(
201 '"%(fileobj)r" is not a raw file, mmap_mode '
202 '"%(mmap_mode)s" flag will be ignored.' % locals(),
203 stacklevel=2,
204 )
205 else:
206 validated_mmap_mode = mmap_mode
207
208 yield fileobj, validated_mmap_mode
209
210
211def _write_fileobject(filename, compress=("zlib", 3)):
212 """Return the right compressor file object in write mode."""
213 compressmethod = compress[0]
214 compresslevel = compress[1]
215
216 if compressmethod in _COMPRESSORS.keys():
217 file_instance = _COMPRESSORS[compressmethod].compressor_file(
218 filename, compresslevel=compresslevel
219 )
220 return _buffered_write_file(file_instance)
221 else:
222 file_instance = _COMPRESSORS["zlib"].compressor_file(
223 filename, compresslevel=compresslevel
224 )
225 return _buffered_write_file(file_instance)
226
227
228# Utility functions/variables from numpy required for writing arrays.
229# We need at least the functions introduced in version 1.9 of numpy. Here,
230# we use the ones from numpy 1.10.2.
231BUFFER_SIZE = 2**18 # size of buffer for reading npz files in bytes
232
233
234def _read_bytes(fp, size, error_template="ran out of data"):
235 """Read from file-like object until size bytes are read.
236
237 TODO python2_drop: is it still needed? The docstring mentions python 2.6
238 and it looks like this can be at least simplified ...
239
240 Raises ValueError if not EOF is encountered before size bytes are read.
241 Non-blocking objects only supported if they derive from io objects.
242
243 Required as e.g. ZipExtFile in python 2.6 can return less data than
244 requested.
245
246 This function was taken from numpy/lib/format.py in version 1.10.2.
247
248 Parameters
249 ----------
250 fp: file-like object
251 size: int
252 error_template: str
253
254 Returns
255 -------
256 a bytes object
257 The data read in bytes.
258
259 """
260 data = bytes()
261 while True:
262 # io files (default in python3) return None or raise on
263 # would-block, python2 file will truncate, probably nothing can be
264 # done about that. note that regular files can't be non-blocking
265 try:
266 r = fp.read(size - len(data))
267 data += r
268 if len(r) == 0 or len(data) == size:
269 break
270 except io.BlockingIOError:
271 pass
272 if len(data) != size:
273 msg = "EOF: reading %s, expected %d bytes got %d"
274 raise ValueError(msg % (error_template, size, len(data)))
275 else:
276 return data
277
278
279def _reconstruct(*args, **kwargs):
280 # Wrapper for numpy._core.multiarray._reconstruct with backward compat
281 # for numpy 1.X
282 #
283 # XXX: Remove this function when numpy 1.X is not supported anymore
284
285 np_major_version = np.__version__[:2]
286 if np_major_version == "1.":
287 from numpy.core.multiarray import _reconstruct as np_reconstruct
288 elif np_major_version == "2.":
289 from numpy._core.multiarray import _reconstruct as np_reconstruct
290
291 return np_reconstruct(*args, **kwargs)