1###############################################################################
2# Customizable Pickler with some basic reducers
3#
4# author: Thomas Moreau
5#
6# adapted from multiprocessing/reduction.py (17/02/2017)
7# * Replace the ForkingPickler with a similar _LokyPickler,
8# * Add CustomizableLokyPickler to allow customizing pickling process
9# on the fly.
10#
11import copyreg
12import io
13import functools
14import types
15import sys
16import os
17
18from multiprocessing import util
19from pickle import loads, HIGHEST_PROTOCOL
20
21###############################################################################
22# Enable custom pickling in Loky.
23
24_dispatch_table = {}
25
26
27def register(type_, reduce_function):
28 _dispatch_table[type_] = reduce_function
29
30
31###############################################################################
32# Registers extra pickling routines to improve picklization for loky
33
34
35# make methods picklable
36def _reduce_method(m):
37 if m.__self__ is None:
38 return getattr, (m.__class__, m.__func__.__name__)
39 else:
40 return getattr, (m.__self__, m.__func__.__name__)
41
42
43class _C:
44 def f(self):
45 pass
46
47 @classmethod
48 def h(cls):
49 pass
50
51
52register(type(_C().f), _reduce_method)
53register(type(_C.h), _reduce_method)
54
55
56if not hasattr(sys, "pypy_version_info"):
57 # PyPy uses functions instead of method_descriptors and wrapper_descriptors
58 def _reduce_method_descriptor(m):
59 return getattr, (m.__objclass__, m.__name__)
60
61 register(type(list.append), _reduce_method_descriptor)
62 register(type(int.__add__), _reduce_method_descriptor)
63
64
65# Make partial func pickable
66def _reduce_partial(p):
67 return _rebuild_partial, (p.func, p.args, p.keywords or {})
68
69
70def _rebuild_partial(func, args, keywords):
71 return functools.partial(func, *args, **keywords)
72
73
74register(functools.partial, _reduce_partial)
75
76if sys.platform != "win32":
77 from ._posix_reduction import _mk_inheritable # noqa: F401
78else:
79 from . import _win_reduction # noqa: F401
80
81# global variable to change the pickler behavior
82try:
83 from joblib.externals import cloudpickle # noqa: F401
84
85 DEFAULT_ENV = "cloudpickle"
86except ImportError:
87 # If cloudpickle is not present, fallback to pickle
88 DEFAULT_ENV = "pickle"
89
90ENV_LOKY_PICKLER = os.environ.get("LOKY_PICKLER", DEFAULT_ENV)
91_LokyPickler = None
92_loky_pickler_name = None
93
94
95def set_loky_pickler(loky_pickler=None):
96 global _LokyPickler, _loky_pickler_name
97
98 if loky_pickler is None:
99 loky_pickler = ENV_LOKY_PICKLER
100
101 loky_pickler_cls = None
102
103 # The default loky_pickler is cloudpickle
104 if loky_pickler in ["", None]:
105 loky_pickler = "cloudpickle"
106
107 if loky_pickler == _loky_pickler_name:
108 return
109
110 if loky_pickler == "cloudpickle":
111 from joblib.externals.cloudpickle import CloudPickler as loky_pickler_cls
112 else:
113 try:
114 from importlib import import_module
115
116 module_pickle = import_module(loky_pickler)
117 loky_pickler_cls = module_pickle.Pickler
118 except (ImportError, AttributeError) as e:
119 extra_info = (
120 "\nThis error occurred while setting loky_pickler to"
121 f" '{loky_pickler}', as required by the env variable "
122 "LOKY_PICKLER or the function set_loky_pickler."
123 )
124 e.args = (e.args[0] + extra_info,) + e.args[1:]
125 e.msg = e.args[0]
126 raise e
127
128 util.debug(
129 f"Using '{loky_pickler if loky_pickler else 'cloudpickle'}' for "
130 "serialization."
131 )
132
133 class CustomizablePickler(loky_pickler_cls):
134 _loky_pickler_cls = loky_pickler_cls
135
136 def _set_dispatch_table(self, dispatch_table):
137 for ancestor_class in self._loky_pickler_cls.mro():
138 dt_attribute = getattr(ancestor_class, "dispatch_table", None)
139 if isinstance(dt_attribute, types.MemberDescriptorType):
140 # Ancestor class (typically _pickle.Pickler) has a
141 # member_descriptor for its "dispatch_table" attribute. Use
142 # it to set the dispatch_table as a member instead of a
143 # dynamic attribute in the __dict__ of the instance,
144 # otherwise it will not be taken into account by the C
145 # implementation of the dump method if a subclass defines a
146 # class-level dispatch_table attribute as was done in
147 # cloudpickle 1.6.0:
148 # https://github.com/joblib/loky/pull/260
149 dt_attribute.__set__(self, dispatch_table)
150 break
151
152 # On top of member descriptor set, also use setattr such that code
153 # that directly access self.dispatch_table gets a consistent view
154 # of the same table.
155 self.dispatch_table = dispatch_table
156
157 def __init__(self, writer, reducers=None, protocol=HIGHEST_PROTOCOL):
158 loky_pickler_cls.__init__(self, writer, protocol=protocol)
159 if reducers is None:
160 reducers = {}
161
162 if hasattr(self, "dispatch_table"):
163 # Force a copy that we will update without mutating the
164 # any class level defined dispatch_table.
165 loky_dt = dict(self.dispatch_table)
166 else:
167 # Use standard reducers as bases
168 loky_dt = copyreg.dispatch_table.copy()
169
170 # Register loky specific reducers
171 loky_dt.update(_dispatch_table)
172
173 # Set the new dispatch table, taking care of the fact that we
174 # need to use the member_descriptor when we inherit from a
175 # subclass of the C implementation of the Pickler base class
176 # with an class level dispatch_table attribute.
177 self._set_dispatch_table(loky_dt)
178
179 # Register the reducers
180 for type, reduce_func in reducers.items():
181 self.register(type, reduce_func)
182
183 def register(self, type, reduce_func):
184 """Attach a reducer function to a given type in the dispatch table."""
185 self.dispatch_table[type] = reduce_func
186
187 _LokyPickler = CustomizablePickler
188 _loky_pickler_name = loky_pickler
189
190
191def get_loky_pickler_name():
192 global _loky_pickler_name
193 return _loky_pickler_name
194
195
196def get_loky_pickler():
197 global _LokyPickler
198 return _LokyPickler
199
200
201# Set it to its default value
202set_loky_pickler()
203
204
205def dump(obj, file, reducers=None, protocol=None):
206 """Replacement for pickle.dump() using _LokyPickler."""
207 global _LokyPickler
208 _LokyPickler(file, reducers=reducers, protocol=protocol).dump(obj)
209
210
211def dumps(obj, reducers=None, protocol=None):
212 global _LokyPickler
213
214 buf = io.BytesIO()
215 dump(obj, buf, reducers=reducers, protocol=protocol)
216 return buf.getbuffer()
217
218
219__all__ = ["dump", "dumps", "loads", "register", "set_loky_pickler"]
220
221if sys.platform == "win32":
222 from multiprocessing.reduction import duplicate
223
224 __all__ += ["duplicate"]