1"""Pickle-related utilities.."""
2
3# Copyright (c) IPython Development Team.
4# Distributed under the terms of the Modified BSD License.
5import copy
6import functools
7import pickle
8import sys
9from types import FunctionType
10
11from traitlets import import_item
12from traitlets.log import get_logger
13
14from . import codeutil # noqa This registers a hook when it's imported
15
16
17def _get_cell_type(a=None):
18 """the type of a closure cell doesn't seem to be importable,
19 so just create one
20 """
21
22 def inner():
23 return a
24
25 return type(inner.__closure__[0])
26
27
28cell_type = _get_cell_type()
29
30# -------------------------------------------------------------------------------
31# Functions
32# -------------------------------------------------------------------------------
33
34
35def interactive(f):
36 """decorator for making functions appear as interactively defined.
37 This results in the function being linked to the user_ns as globals()
38 instead of the module globals().
39 """
40
41 # build new FunctionType, so it can have the right globals
42 # interactive functions never have closures, that's kind of the point
43 if isinstance(f, FunctionType):
44 mainmod = __import__('__main__')
45 f = FunctionType(
46 f.__code__,
47 mainmod.__dict__,
48 f.__name__,
49 f.__defaults__,
50 )
51 # associate with __main__ for uncanning
52 f.__module__ = '__main__'
53 return f
54
55
56def use_dill():
57 """use dill to expand serialization support
58
59 adds support for object methods and closures to serialization.
60 """
61 import dill
62
63 from . import serialize
64
65 serialize.pickle = dill
66
67 # disable special function handling, let dill take care of it
68 can_map.pop(FunctionType, None)
69
70
71def use_cloudpickle():
72 """use cloudpickle to expand serialization support
73
74 adds support for object methods and closures to serialization.
75 """
76 import cloudpickle
77
78 from . import serialize
79
80 serialize.pickle = cloudpickle
81
82 # disable special function handling, let cloudpickle take care of it
83 can_map.pop(FunctionType, None)
84
85
86def use_pickle():
87 """revert to using stdlib pickle
88
89 Reverts custom serialization enabled by use_dill|cloudpickle.
90 """
91
92 from . import serialize
93
94 serialize.pickle = pickle
95
96 # restore special function handling
97 can_map[FunctionType] = _original_can_map[FunctionType]
98
99
100# -------------------------------------------------------------------------------
101# Classes
102# -------------------------------------------------------------------------------
103
104
105class CannedObject:
106 def __init__(self, obj, keys=[], hook=None):
107 """can an object for safe pickling
108
109 Parameters
110 ----------
111 obj
112 The object to be canned
113 keys : list (optional)
114 list of attribute names that will be explicitly canned / uncanned
115 hook : callable (optional)
116 An optional extra callable,
117 which can do additional processing of the uncanned object.
118 large data may be offloaded into the buffers list,
119 used for zero-copy transfers.
120 """
121 self.keys = keys
122 self.obj = copy.copy(obj)
123 self.hook = can(hook)
124 for key in keys:
125 setattr(self.obj, key, can(getattr(obj, key)))
126
127 self.buffers = []
128
129 def get_object(self, g=None):
130 if g is None:
131 g = {}
132 obj = self.obj
133 for key in self.keys:
134 setattr(obj, key, uncan(getattr(obj, key), g))
135
136 if self.hook:
137 self.hook = uncan(self.hook, g)
138 self.hook(obj, g)
139 return self.obj
140
141
142class Reference(CannedObject):
143 """object for wrapping a remote reference by name."""
144
145 def __init__(self, name):
146 if not isinstance(name, str):
147 raise TypeError(f"illegal name: {name!r}")
148 self.name = name
149 self.buffers = []
150
151 def __repr__(self):
152 return f"<Reference: {self.name!r}>"
153
154 def get_object(self, g=None):
155 if g is None:
156 g = {}
157
158 return eval(self.name, g)
159
160
161class CannedCell(CannedObject):
162 """Can a closure cell"""
163
164 def __init__(self, cell):
165 self.cell_contents = can(cell.cell_contents)
166
167 def get_object(self, g=None):
168 cell_contents = uncan(self.cell_contents, g)
169
170 def inner():
171 return cell_contents
172
173 return inner.__closure__[0]
174
175
176class CannedFunction(CannedObject):
177 def __init__(self, f):
178 self._check_type(f)
179 self.code = f.__code__
180 if f.__defaults__:
181 self.defaults = [can(fd) for fd in f.__defaults__]
182 else:
183 self.defaults = None
184
185 if f.__kwdefaults__:
186 self.kwdefaults = can_dict(f.__kwdefaults__)
187 else:
188 self.kwdefaults = None
189
190 if f.__annotations__:
191 self.annotations = can_dict(f.__annotations__)
192 else:
193 self.annotations = None
194
195 closure = f.__closure__
196 if closure:
197 self.closure = tuple(can(cell) for cell in closure)
198 else:
199 self.closure = None
200
201 self.module = f.__module__ or '__main__'
202 self.__name__ = f.__name__
203 self.buffers = []
204
205 def _check_type(self, obj):
206 assert isinstance(obj, FunctionType), "Not a function type"
207
208 def get_object(self, g=None):
209 # try to load function back into its module:
210 if not self.module.startswith('__'):
211 __import__(self.module)
212 g = sys.modules[self.module].__dict__
213
214 if g is None:
215 g = {}
216 if self.defaults:
217 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
218 else:
219 defaults = None
220
221 if self.kwdefaults:
222 kwdefaults = uncan_dict(self.kwdefaults)
223 else:
224 kwdefaults = None
225 if self.annotations:
226 annotations = uncan_dict(self.annotations)
227 else:
228 annotations = {}
229
230 if self.closure:
231 closure = tuple(uncan(cell, g) for cell in self.closure)
232 else:
233 closure = None
234 newFunc = FunctionType(self.code, g, self.__name__, defaults, closure)
235 if kwdefaults:
236 newFunc.__kwdefaults__ = kwdefaults
237 if annotations:
238 newFunc.__annotations__ = annotations
239 return newFunc
240
241
242class CannedPartial(CannedObject):
243 def __init__(self, f):
244 self._check_type(f)
245 self.func = can(f.func)
246 self.args = [can(a) for a in f.args]
247 self.keywords = {k: can(v) for k, v in f.keywords.items()}
248 self.buffers = []
249 self.arg_buffer_counts = []
250 self.keyword_buffer_counts = {}
251 # consolidate buffers
252 for canned_arg in self.args:
253 if not isinstance(canned_arg, CannedObject):
254 self.arg_buffer_counts.append(0)
255 continue
256 self.arg_buffer_counts.append(len(canned_arg.buffers))
257 self.buffers.extend(canned_arg.buffers)
258 canned_arg.buffers = []
259 for key in sorted(self.keywords):
260 canned_kwarg = self.keywords[key]
261 if not isinstance(canned_kwarg, CannedObject):
262 continue
263 self.keyword_buffer_counts[key] = len(canned_kwarg.buffers)
264 self.buffers.extend(canned_kwarg.buffers)
265 canned_kwarg.buffers = []
266
267 def _check_type(self, obj):
268 if not isinstance(obj, functools.partial):
269 raise ValueError(f"Not a functools.partial: {obj!r}")
270
271 def get_object(self, g=None):
272 if g is None:
273 g = {}
274 if self.buffers:
275 # reconstitute buffers
276 for canned_arg, buf_count in zip(self.args, self.arg_buffer_counts):
277 if not buf_count:
278 continue
279 canned_arg.buffers = self.buffers[:buf_count]
280 self.buffers = self.buffers[buf_count:]
281 for key in sorted(self.keyword_buffer_counts):
282 buf_count = self.keyword_buffer_counts[key]
283 canned_kwarg = self.keywords[key]
284 canned_kwarg.buffers = self.buffers[:buf_count]
285 self.buffers = self.buffers[buf_count:]
286 assert len(self.buffers) == 0
287
288 args = [uncan(a, g) for a in self.args]
289 keywords = {k: uncan(v, g) for k, v in self.keywords.items()}
290 func = uncan(self.func, g)
291 return functools.partial(func, *args, **keywords)
292
293
294class CannedClass(CannedObject):
295 def __init__(self, cls):
296 self._check_type(cls)
297 self.name = cls.__name__
298 self.old_style = not isinstance(cls, type)
299 self._canned_dict = {}
300 for k, v in cls.__dict__.items():
301 if k not in ('__weakref__', '__dict__'):
302 self._canned_dict[k] = can(v)
303 if self.old_style:
304 mro = []
305 else:
306 mro = cls.mro()
307
308 self.parents = [can(c) for c in mro[1:]]
309 self.buffers = []
310
311 def _check_type(self, obj):
312 assert isinstance(obj, type), "Not a class type"
313
314 def get_object(self, g=None):
315 parents = tuple(uncan(p, g) for p in self.parents)
316 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
317
318
319class CannedArray(CannedObject):
320 def __init__(self, obj):
321 from numpy import ascontiguousarray
322
323 self.shape = obj.shape
324 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
325 self.pickled = False
326 if sum(obj.shape) == 0:
327 self.pickled = True
328 elif obj.dtype == 'O':
329 # can't handle object dtype with buffer approach
330 self.pickled = True
331 elif obj.dtype.fields and any(
332 dt == 'O' for dt, sz in obj.dtype.fields.values()
333 ):
334 self.pickled = True
335 if self.pickled:
336 # just pickle it
337 from . import serialize
338
339 self.buffers = [serialize.pickle.dumps(obj, serialize.PICKLE_PROTOCOL)]
340 else:
341 # ensure contiguous
342 obj = ascontiguousarray(obj, dtype=None)
343 self.buffers = [memoryview(obj)]
344
345 def get_object(self, g=None):
346 from numpy import frombuffer
347
348 data = self.buffers[0]
349 if self.pickled:
350 from . import serialize
351
352 # we just pickled it
353 return serialize.pickle.loads(data)
354 else:
355 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
356
357
358class CannedBytes(CannedObject):
359 def __init__(self, obj):
360 self.buffers = [obj]
361
362 def get_object(self, g=None):
363 data = self.buffers[0]
364 return self.wrap(data)
365
366 @staticmethod
367 def wrap(data):
368 if isinstance(data, bytes):
369 return data
370 else:
371 return memoryview(data).tobytes()
372
373
374class CannedMemoryView(CannedBytes):
375 wrap = memoryview
376
377
378CannedBuffer = CannedMemoryView
379# -------------------------------------------------------------------------------
380# Functions
381# -------------------------------------------------------------------------------
382
383
384def _import_mapping(mapping, original=None):
385 """import any string-keys in a type mapping"""
386 log = get_logger()
387 log.debug("Importing canning map")
388 for key, value in list(mapping.items()):
389 if isinstance(key, str):
390 try:
391 cls = import_item(key)
392 except Exception:
393 if original and key not in original:
394 # only message on user-added classes
395 log.error("canning class not importable: %r", key, exc_info=True)
396 mapping.pop(key)
397 else:
398 mapping[cls] = mapping.pop(key)
399
400
401def istype(obj, check):
402 """like isinstance(obj, check), but strict
403
404 This won't catch subclasses.
405 """
406 if isinstance(check, tuple):
407 for cls in check:
408 if type(obj) is cls:
409 return True
410 return False
411 else:
412 return type(obj) is check
413
414
415def can(obj):
416 """prepare an object for pickling"""
417
418 import_needed = False
419
420 for cls, canner in can_map.items():
421 if isinstance(cls, str):
422 import_needed = True
423 break
424 elif istype(obj, cls):
425 return canner(obj)
426
427 if import_needed:
428 # perform can_map imports, then try again
429 # this will usually only happen once
430 _import_mapping(can_map, _original_can_map)
431 return can(obj)
432
433 return obj
434
435
436def can_class(obj):
437 if isinstance(obj, type) and obj.__module__ == '__main__':
438 return CannedClass(obj)
439 else:
440 return obj
441
442
443def can_dict(obj):
444 """can the *values* of a dict"""
445 if istype(obj, dict):
446 newobj = {}
447 for k, v in obj.items():
448 newobj[k] = can(v)
449 return newobj
450 else:
451 return obj
452
453
454sequence_types = (list, tuple, set)
455
456
457def can_sequence(obj):
458 """can the elements of a sequence"""
459 if istype(obj, sequence_types):
460 t = type(obj)
461 return t([can(i) for i in obj])
462 else:
463 return obj
464
465
466def uncan(obj, g=None):
467 """invert canning"""
468
469 import_needed = False
470 for cls, uncanner in uncan_map.items():
471 if isinstance(cls, str):
472 import_needed = True
473 break
474 elif isinstance(obj, cls):
475 return uncanner(obj, g)
476
477 if import_needed:
478 # perform uncan_map imports, then try again
479 # this will usually only happen once
480 _import_mapping(uncan_map, _original_uncan_map)
481 return uncan(obj, g)
482
483 return obj
484
485
486def uncan_dict(obj, g=None):
487 if istype(obj, dict):
488 newobj = {}
489 for k, v in obj.items():
490 newobj[k] = uncan(v, g)
491 return newobj
492 else:
493 return obj
494
495
496def uncan_sequence(obj, g=None):
497 if istype(obj, sequence_types):
498 t = type(obj)
499 return t([uncan(i, g) for i in obj])
500 else:
501 return obj
502
503
504def _uncan_dependent_hook(dep, g=None):
505 dep.check_dependency()
506
507
508def can_dependent(obj):
509 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
510
511
512# -------------------------------------------------------------------------------
513# API dictionaries
514# -------------------------------------------------------------------------------
515
516# These dicts can be extended for custom serialization of new objects
517
518can_map = {
519 'numpy.ndarray': CannedArray,
520 FunctionType: CannedFunction,
521 functools.partial: CannedPartial,
522 bytes: CannedBytes,
523 memoryview: CannedMemoryView,
524 cell_type: CannedCell,
525 type: can_class,
526 'ipyparallel.dependent': can_dependent,
527}
528
529uncan_map = {
530 CannedObject: lambda obj, g: obj.get_object(g),
531 dict: uncan_dict,
532}
533
534# for use in _import_mapping:
535_original_can_map = can_map.copy()
536_original_uncan_map = uncan_map.copy()