1import inspect
2from functools import partial
3from joblib.externals.cloudpickle import dumps, loads
4
5
6WRAP_CACHE = {}
7
8
9class CloudpickledObjectWrapper:
10 def __init__(self, obj, keep_wrapper=False):
11 self._obj = obj
12 self._keep_wrapper = keep_wrapper
13
14 def __reduce__(self):
15 _pickled_object = dumps(self._obj)
16 if not self._keep_wrapper:
17 return loads, (_pickled_object,)
18
19 return _reconstruct_wrapper, (_pickled_object, self._keep_wrapper)
20
21 def __getattr__(self, attr):
22 # Ensure that the wrapped object can be used seemlessly as the
23 # previous object.
24 if attr not in ["_obj", "_keep_wrapper"]:
25 return getattr(self._obj, attr)
26 return getattr(self, attr)
27
28
29# Make sure the wrapped object conserves the callable property
30class CallableObjectWrapper(CloudpickledObjectWrapper):
31 def __call__(self, *args, **kwargs):
32 return self._obj(*args, **kwargs)
33
34
35def _wrap_non_picklable_objects(obj, keep_wrapper):
36 if callable(obj):
37 return CallableObjectWrapper(obj, keep_wrapper=keep_wrapper)
38 return CloudpickledObjectWrapper(obj, keep_wrapper=keep_wrapper)
39
40
41def _reconstruct_wrapper(_pickled_object, keep_wrapper):
42 obj = loads(_pickled_object)
43 return _wrap_non_picklable_objects(obj, keep_wrapper)
44
45
46def _wrap_objects_when_needed(obj):
47 # Function to introspect an object and decide if it should be wrapped or
48 # not.
49 need_wrap = "__main__" in getattr(obj, "__module__", "")
50 if isinstance(obj, partial):
51 return partial(
52 _wrap_objects_when_needed(obj.func),
53 *[_wrap_objects_when_needed(a) for a in obj.args],
54 **{
55 k: _wrap_objects_when_needed(v)
56 for k, v in obj.keywords.items()
57 },
58 )
59 if callable(obj):
60 # Need wrap if the object is a function defined in a local scope of
61 # another function.
62 func_code = getattr(obj, "__code__", "")
63 need_wrap |= getattr(func_code, "co_flags", 0) & inspect.CO_NESTED
64
65 # Need wrap if the obj is a lambda expression
66 func_name = getattr(obj, "__name__", "")
67 need_wrap |= "<lambda>" in func_name
68
69 if not need_wrap:
70 return obj
71
72 wrapped_obj = WRAP_CACHE.get(obj)
73 if wrapped_obj is None:
74 wrapped_obj = _wrap_non_picklable_objects(obj, keep_wrapper=False)
75 WRAP_CACHE[obj] = wrapped_obj
76 return wrapped_obj
77
78
79def wrap_non_picklable_objects(obj, keep_wrapper=True):
80 """Wrapper for non-picklable object to use cloudpickle to serialize them.
81
82 Note that this wrapper tends to slow down the serialization process as it
83 is done with cloudpickle which is typically slower compared to pickle. The
84 proper way to solve serialization issues is to avoid defining functions and
85 objects in the main scripts and to implement __reduce__ functions for
86 complex classes.
87 """
88 # If obj is a class, create a CloudpickledClassWrapper which instantiates
89 # the object internally and wrap it directly in a CloudpickledObjectWrapper
90 if inspect.isclass(obj):
91
92 class CloudpickledClassWrapper(CloudpickledObjectWrapper):
93 def __init__(self, *args, **kwargs):
94 self._obj = obj(*args, **kwargs)
95 self._keep_wrapper = keep_wrapper
96
97 CloudpickledClassWrapper.__name__ = obj.__name__
98 return CloudpickledClassWrapper
99
100 # If obj is an instance of a class, just wrap it in a regular
101 # CloudpickledObjectWrapper
102 return _wrap_non_picklable_objects(obj, keep_wrapper=keep_wrapper)