1from __future__ import annotations
2
3import copyreg
4import multiprocessing
5import multiprocessing.pool
6import os
7import pickle
8import sys
9import traceback
10from collections.abc import Mapping, Sequence
11from concurrent.futures import ProcessPoolExecutor
12from functools import partial
13from warnings import warn
14
15import cloudpickle
16
17from dask import config
18from dask.local import MultiprocessingPoolExecutor, get_async, reraise
19from dask.optimization import cull, fuse
20from dask.system import CPU_COUNT
21from dask.typing import Key
22from dask.utils import ensure_dict
23
24
25def _reduce_method_descriptor(m):
26 return getattr, (m.__objclass__, m.__name__)
27
28
29# type(set.union) is used as a proxy to <class 'method_descriptor'>
30copyreg.pickle(type(set.union), _reduce_method_descriptor)
31
32_dumps = partial(cloudpickle.dumps, protocol=pickle.HIGHEST_PROTOCOL)
33_loads = cloudpickle.loads
34
35
36def _process_get_id():
37 return multiprocessing.current_process().ident
38
39
40# -- Remote Exception Handling --
41# By default, tracebacks can't be serialized using pickle. However, the
42# `tblib` library can enable support for this. Since we don't mandate
43# that tblib is installed, we do the following:
44#
45# - If tblib is installed, use it to serialize the traceback and reraise
46# in the scheduler process
47# - Otherwise, use a ``RemoteException`` class to contain a serialized
48# version of the formatted traceback, which will then print in the
49# scheduler process.
50#
51# To enable testing of the ``RemoteException`` class even when tblib is
52# installed, we don't wrap the class in the try block below
53class RemoteException(Exception):
54 """Remote Exception
55
56 Contains the exception and traceback from a remotely run task
57 """
58
59 def __init__(self, exception, traceback):
60 self.exception = exception
61 self.traceback = traceback
62
63 def __str__(self):
64 return str(self.exception) + "\n\nTraceback\n---------\n" + self.traceback
65
66 def __dir__(self):
67 return sorted(set(dir(type(self)) + list(self.__dict__) + dir(self.exception)))
68
69 def __getattr__(self, key):
70 try:
71 return object.__getattribute__(self, key)
72 except AttributeError:
73 return getattr(self.exception, key)
74
75
76exceptions: dict[type[Exception], type[Exception]] = {}
77
78
79def remote_exception(exc: Exception, tb) -> Exception:
80 """Metaclass that wraps exception type in RemoteException"""
81 if type(exc) in exceptions:
82 typ = exceptions[type(exc)]
83 return typ(exc, tb)
84 else:
85 try:
86 typ = type(
87 exc.__class__.__name__,
88 (RemoteException, type(exc)),
89 {"exception_type": type(exc)},
90 )
91 exceptions[type(exc)] = typ
92 return typ(exc, tb)
93 except TypeError:
94 return exc
95
96
97try:
98 import tblib.pickling_support
99
100 tblib.pickling_support.install()
101
102 def _pack_traceback(tb):
103 return tb
104
105except ImportError:
106
107 def _pack_traceback(tb):
108 return "".join(traceback.format_tb(tb))
109
110 def reraise(exc, tb=None):
111 exc = remote_exception(exc, tb)
112 raise exc
113
114
115def pack_exception(e, dumps):
116 exc_type, exc_value, exc_traceback = sys.exc_info()
117 tb = _pack_traceback(exc_traceback)
118 try:
119 result = dumps((e, tb))
120 except Exception as e:
121 exc_type, exc_value, exc_traceback = sys.exc_info()
122 tb = _pack_traceback(exc_traceback)
123 result = dumps((e, tb))
124 return result
125
126
127_CONTEXT_UNSUPPORTED = """\
128The 'multiprocessing.context' configuration option will be ignored on Python 2
129and on Windows, because they each only support a single context.
130"""
131
132
133def get_context():
134 """Return the current multiprocessing context."""
135 # fork context does fork()-without-exec(), which can lead to deadlocks,
136 # so default to "spawn".
137 context_name = config.get("multiprocessing.context", "spawn")
138 if sys.platform == "win32":
139 if context_name != "spawn":
140 # Only spawn is supported on Win32, can't change it:
141 warn(_CONTEXT_UNSUPPORTED, UserWarning)
142 return multiprocessing
143 else:
144 return multiprocessing.get_context(context_name)
145
146
147def get(
148 dsk: Mapping,
149 keys: Sequence[Key] | Key,
150 num_workers=None,
151 func_loads=None,
152 func_dumps=None,
153 optimize_graph=True,
154 pool=None,
155 initializer=None,
156 chunksize=None,
157 **kwargs,
158):
159 """Multiprocessed get function appropriate for Bags
160
161 Parameters
162 ----------
163 dsk : dict
164 dask graph
165 keys : object or list
166 Desired results from graph
167 num_workers : int
168 Number of worker processes (defaults to number of cores)
169 func_dumps : function
170 Function to use for function serialization (defaults to cloudpickle.dumps)
171 func_loads : function
172 Function to use for function deserialization (defaults to cloudpickle.loads)
173 optimize_graph : bool
174 If True [default], `fuse` is applied to the graph before computation.
175 pool : Executor or Pool
176 Some sort of `Executor` or `Pool` to use
177 initializer: function
178 Ignored if ``pool`` has been set.
179 Function to initialize a worker process before running any tasks in it.
180 chunksize: int, optional
181 Size of chunks to use when dispatching work.
182 Defaults to 6 as some batching is helpful.
183 If -1, will be computed to evenly divide ready work across workers.
184 """
185 chunksize = chunksize or config.get("chunksize", 6)
186 pool = pool or config.get("pool", None)
187 initializer = initializer or config.get("multiprocessing.initializer", None)
188 num_workers = num_workers or config.get("num_workers", None) or CPU_COUNT
189 if pool is None:
190 # In order to get consistent hashing in subprocesses, we need to set a
191 # consistent seed for the Python hash algorithm. Unfortunately, there
192 # is no way to specify environment variables only for the Pool
193 # processes, so we have to rely on environment variables being
194 # inherited.
195 if os.environ.get("PYTHONHASHSEED") in (None, "0"):
196 # This number is arbitrary; it was chosen to commemorate
197 # https://github.com/dask/dask/issues/6640.
198 os.environ["PYTHONHASHSEED"] = "6640"
199 context = get_context()
200 initializer = partial(initialize_worker_process, user_initializer=initializer)
201 pool = ProcessPoolExecutor(
202 num_workers, mp_context=context, initializer=initializer
203 )
204 cleanup = True
205 else:
206 if initializer is not None:
207 warn(
208 "The ``initializer`` argument is ignored when ``pool`` is provided. "
209 "The user should configure ``pool`` with the needed ``initializer`` "
210 "on creation."
211 )
212 if isinstance(pool, multiprocessing.pool.Pool):
213 pool = MultiprocessingPoolExecutor(pool)
214 cleanup = False
215
216 if hasattr(dsk, "__dask_graph__"):
217 dsk = dsk.__dask_graph__()
218
219 dsk = ensure_dict(dsk)
220 dsk2, dependencies = cull(dsk, keys)
221 if optimize_graph:
222 dsk3, dependencies = fuse(dsk2, keys, dependencies)
223 else:
224 dsk3 = dsk2
225
226 # We specify marshalling functions in order to catch serialization
227 # errors and report them to the user.
228 loads = func_loads or config.get("func_loads", None) or _loads
229 dumps = func_dumps or config.get("func_dumps", None) or _dumps
230
231 # Note former versions used a multiprocessing Manager to share
232 # a Queue between parent and workers, but this is fragile on Windows
233 # (issue #1652).
234 try:
235 # Run
236 result = get_async(
237 pool.submit,
238 pool._max_workers,
239 dsk3,
240 keys,
241 get_id=_process_get_id,
242 dumps=dumps,
243 loads=loads,
244 pack_exception=pack_exception,
245 raise_exception=reraise,
246 chunksize=chunksize,
247 **kwargs,
248 )
249 finally:
250 if cleanup:
251 pool.shutdown()
252 return result
253
254
255def default_initializer():
256 # If Numpy is already imported, presumably its random state was
257 # inherited from the parent => re-seed it.
258 np = sys.modules.get("numpy")
259 if np is not None:
260 np.random.seed()
261
262
263def initialize_worker_process(user_initializer=None):
264 """
265 Initialize a worker process before running any tasks in it.
266 """
267 default_initializer()
268
269 if user_initializer is not None:
270 user_initializer()