1###############################################################################
2# Customizable Pickler with some basic reducers
3#
4# author: Thomas Moreau
5#
6# adapted from multiprocessing/reduction.py (17/02/2017)
7# * Replace the ForkingPickler with a similar _LokyPickler,
8# * Add CustomizableLokyPickler to allow customizing pickling process
9# on the fly.
10#
11import copyreg
12import io
13import functools
14import types
15import sys
16import os
17
18from multiprocessing import util
19from pickle import loads, HIGHEST_PROTOCOL
20
21###############################################################################
22# Enable custom pickling in Loky.
23
24_dispatch_table = {}
25
26
27def register(type_, reduce_function):
28 _dispatch_table[type_] = reduce_function
29
30
31###############################################################################
32# Registers extra pickling routines to improve picklization for loky
33
34
35# make methods picklable
36def _reduce_method(m):
37 if m.__self__ is None:
38 return getattr, (m.__class__, m.__func__.__name__)
39 else:
40 return getattr, (m.__self__, m.__func__.__name__)
41
42
43class _C:
44 def f(self):
45 pass
46
47 @classmethod
48 def h(cls):
49 pass
50
51
52register(type(_C().f), _reduce_method)
53register(type(_C.h), _reduce_method)
54
55
56def _reduce_method_descriptor(m):
57 return getattr, (m.__objclass__, m.__name__)
58
59
60register(type(list.append), _reduce_method_descriptor)
61register(type(int.__add__), _reduce_method_descriptor)
62
63
64# Make partial func pickable
65def _reduce_partial(p):
66 return _rebuild_partial, (p.func, p.args, p.keywords or {})
67
68
69def _rebuild_partial(func, args, keywords):
70 return functools.partial(func, *args, **keywords)
71
72
73register(functools.partial, _reduce_partial)
74
75if sys.platform != "win32":
76 from ._posix_reduction import _mk_inheritable # noqa: F401
77else:
78 from . import _win_reduction # noqa: F401
79
80# global variable to change the pickler behavior
81try:
82 from joblib.externals import cloudpickle # noqa: F401
83
84 DEFAULT_ENV = "cloudpickle"
85except ImportError:
86 # If cloudpickle is not present, fallback to pickle
87 DEFAULT_ENV = "pickle"
88
89ENV_LOKY_PICKLER = os.environ.get("LOKY_PICKLER", DEFAULT_ENV)
90_LokyPickler = None
91_loky_pickler_name = None
92
93
94def set_loky_pickler(loky_pickler=None):
95 global _LokyPickler, _loky_pickler_name
96
97 if loky_pickler is None:
98 loky_pickler = ENV_LOKY_PICKLER
99
100 loky_pickler_cls = None
101
102 # The default loky_pickler is cloudpickle
103 if loky_pickler in ["", None]:
104 loky_pickler = "cloudpickle"
105
106 if loky_pickler == _loky_pickler_name:
107 return
108
109 if loky_pickler == "cloudpickle":
110 from joblib.externals.cloudpickle import CloudPickler as loky_pickler_cls
111 else:
112 try:
113 from importlib import import_module
114
115 module_pickle = import_module(loky_pickler)
116 loky_pickler_cls = module_pickle.Pickler
117 except (ImportError, AttributeError) as e:
118 extra_info = (
119 "\nThis error occurred while setting loky_pickler to"
120 f" '{loky_pickler}', as required by the env variable "
121 "LOKY_PICKLER or the function set_loky_pickler."
122 )
123 e.args = (e.args[0] + extra_info,) + e.args[1:]
124 e.msg = e.args[0]
125 raise e
126
127 util.debug(
128 f"Using '{loky_pickler if loky_pickler else 'cloudpickle'}' for "
129 "serialization."
130 )
131
132 class CustomizablePickler(loky_pickler_cls):
133 _loky_pickler_cls = loky_pickler_cls
134
135 def _set_dispatch_table(self, dispatch_table):
136 for ancestor_class in self._loky_pickler_cls.mro():
137 dt_attribute = getattr(ancestor_class, "dispatch_table", None)
138 if isinstance(dt_attribute, types.MemberDescriptorType):
139 # Ancestor class (typically _pickle.Pickler) has a
140 # member_descriptor for its "dispatch_table" attribute. Use
141 # it to set the dispatch_table as a member instead of a
142 # dynamic attribute in the __dict__ of the instance,
143 # otherwise it will not be taken into account by the C
144 # implementation of the dump method if a subclass defines a
145 # class-level dispatch_table attribute as was done in
146 # cloudpickle 1.6.0:
147 # https://github.com/joblib/loky/pull/260
148 dt_attribute.__set__(self, dispatch_table)
149 break
150
151 # On top of member descriptor set, also use setattr such that code
152 # that directly access self.dispatch_table gets a consistent view
153 # of the same table.
154 self.dispatch_table = dispatch_table
155
156 def __init__(self, writer, reducers=None, protocol=HIGHEST_PROTOCOL):
157 loky_pickler_cls.__init__(self, writer, protocol=protocol)
158 if reducers is None:
159 reducers = {}
160
161 if hasattr(self, "dispatch_table"):
162 # Force a copy that we will update without mutating the
163 # any class level defined dispatch_table.
164 loky_dt = dict(self.dispatch_table)
165 else:
166 # Use standard reducers as bases
167 loky_dt = copyreg.dispatch_table.copy()
168
169 # Register loky specific reducers
170 loky_dt.update(_dispatch_table)
171
172 # Set the new dispatch table, taking care of the fact that we
173 # need to use the member_descriptor when we inherit from a
174 # subclass of the C implementation of the Pickler base class
175 # with an class level dispatch_table attribute.
176 self._set_dispatch_table(loky_dt)
177
178 # Register the reducers
179 for type, reduce_func in reducers.items():
180 self.register(type, reduce_func)
181
182 def register(self, type, reduce_func):
183 """Attach a reducer function to a given type in the dispatch table."""
184 self.dispatch_table[type] = reduce_func
185
186 _LokyPickler = CustomizablePickler
187 _loky_pickler_name = loky_pickler
188
189
190def get_loky_pickler_name():
191 global _loky_pickler_name
192 return _loky_pickler_name
193
194
195def get_loky_pickler():
196 global _LokyPickler
197 return _LokyPickler
198
199
200# Set it to its default value
201set_loky_pickler()
202
203
204def dump(obj, file, reducers=None, protocol=None):
205 """Replacement for pickle.dump() using _LokyPickler."""
206 global _LokyPickler
207 _LokyPickler(file, reducers=reducers, protocol=protocol).dump(obj)
208
209
210def dumps(obj, reducers=None, protocol=None):
211 global _LokyPickler
212
213 buf = io.BytesIO()
214 dump(obj, buf, reducers=reducers, protocol=protocol)
215 return buf.getbuffer()
216
217
218__all__ = ["dump", "dumps", "loads", "register", "set_loky_pickler"]
219
220if sys.platform == "win32":
221 from multiprocessing.reduction import duplicate
222
223 __all__ += ["duplicate"]