Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/dask/multiprocessing.py: 36%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

118 statements  

1from __future__ import annotations 

2 

3import copyreg 

4import multiprocessing 

5import multiprocessing.pool 

6import os 

7import pickle 

8import sys 

9import traceback 

10from collections.abc import Mapping, Sequence 

11from concurrent.futures import ProcessPoolExecutor 

12from functools import partial 

13from warnings import warn 

14 

15import cloudpickle 

16 

17from dask import config 

18from dask.local import MultiprocessingPoolExecutor, get_async, reraise 

19from dask.optimization import cull, fuse 

20from dask.system import CPU_COUNT 

21from dask.typing import Key 

22from dask.utils import ensure_dict 

23 

24 

25def _reduce_method_descriptor(m): 

26 return getattr, (m.__objclass__, m.__name__) 

27 

28 

29# type(set.union) is used as a proxy to <class 'method_descriptor'> 

30copyreg.pickle(type(set.union), _reduce_method_descriptor) 

31 

32_dumps = partial(cloudpickle.dumps, protocol=pickle.HIGHEST_PROTOCOL) 

33_loads = cloudpickle.loads 

34 

35 

36def _process_get_id(): 

37 return multiprocessing.current_process().ident 

38 

39 

40# -- Remote Exception Handling -- 

41# By default, tracebacks can't be serialized using pickle. However, the 

42# `tblib` library can enable support for this. Since we don't mandate 

43# that tblib is installed, we do the following: 

44# 

45# - If tblib is installed, use it to serialize the traceback and reraise 

46# in the scheduler process 

47# - Otherwise, use a ``RemoteException`` class to contain a serialized 

48# version of the formatted traceback, which will then print in the 

49# scheduler process. 

50# 

51# To enable testing of the ``RemoteException`` class even when tblib is 

52# installed, we don't wrap the class in the try block below 

53class RemoteException(Exception): 

54 """Remote Exception 

55 

56 Contains the exception and traceback from a remotely run task 

57 """ 

58 

59 def __init__(self, exception, traceback): 

60 self.exception = exception 

61 self.traceback = traceback 

62 

63 def __str__(self): 

64 return str(self.exception) + "\n\nTraceback\n---------\n" + self.traceback 

65 

66 def __dir__(self): 

67 return sorted(set(dir(type(self)) + list(self.__dict__) + dir(self.exception))) 

68 

69 def __getattr__(self, key): 

70 try: 

71 return object.__getattribute__(self, key) 

72 except AttributeError: 

73 return getattr(self.exception, key) 

74 

75 

76exceptions: dict[type[Exception], type[Exception]] = {} 

77 

78 

79def remote_exception(exc: Exception, tb) -> Exception: 

80 """Metaclass that wraps exception type in RemoteException""" 

81 if type(exc) in exceptions: 

82 typ = exceptions[type(exc)] 

83 return typ(exc, tb) 

84 else: 

85 try: 

86 typ = type( 

87 exc.__class__.__name__, 

88 (RemoteException, type(exc)), 

89 {"exception_type": type(exc)}, 

90 ) 

91 exceptions[type(exc)] = typ 

92 return typ(exc, tb) 

93 except TypeError: 

94 return exc 

95 

96 

97try: 

98 import tblib.pickling_support 

99 

100 tblib.pickling_support.install() 

101 

102 def _pack_traceback(tb): 

103 return tb 

104 

105except ImportError: 

106 

107 def _pack_traceback(tb): 

108 return "".join(traceback.format_tb(tb)) 

109 

110 def reraise(exc, tb=None): 

111 exc = remote_exception(exc, tb) 

112 raise exc 

113 

114 

115def pack_exception(e, dumps): 

116 exc_type, exc_value, exc_traceback = sys.exc_info() 

117 tb = _pack_traceback(exc_traceback) 

118 try: 

119 result = dumps((e, tb)) 

120 except Exception as e: 

121 exc_type, exc_value, exc_traceback = sys.exc_info() 

122 tb = _pack_traceback(exc_traceback) 

123 result = dumps((e, tb)) 

124 return result 

125 

126 

127_CONTEXT_UNSUPPORTED = """\ 

128The 'multiprocessing.context' configuration option will be ignored on Python 2 

129and on Windows, because they each only support a single context. 

130""" 

131 

132 

133def get_context(): 

134 """Return the current multiprocessing context.""" 

135 # fork context does fork()-without-exec(), which can lead to deadlocks, 

136 # so default to "spawn". 

137 context_name = config.get("multiprocessing.context", "spawn") 

138 if sys.platform == "win32": 

139 if context_name != "spawn": 

140 # Only spawn is supported on Win32, can't change it: 

141 warn(_CONTEXT_UNSUPPORTED, UserWarning) 

142 return multiprocessing 

143 else: 

144 return multiprocessing.get_context(context_name) 

145 

146 

147def get( 

148 dsk: Mapping, 

149 keys: Sequence[Key] | Key, 

150 num_workers=None, 

151 func_loads=None, 

152 func_dumps=None, 

153 optimize_graph=True, 

154 pool=None, 

155 initializer=None, 

156 chunksize=None, 

157 **kwargs, 

158): 

159 """Multiprocessed get function appropriate for Bags 

160 

161 Parameters 

162 ---------- 

163 dsk : dict 

164 dask graph 

165 keys : object or list 

166 Desired results from graph 

167 num_workers : int 

168 Number of worker processes (defaults to number of cores) 

169 func_dumps : function 

170 Function to use for function serialization (defaults to cloudpickle.dumps) 

171 func_loads : function 

172 Function to use for function deserialization (defaults to cloudpickle.loads) 

173 optimize_graph : bool 

174 If True [default], `fuse` is applied to the graph before computation. 

175 pool : Executor or Pool 

176 Some sort of `Executor` or `Pool` to use 

177 initializer: function 

178 Ignored if ``pool`` has been set. 

179 Function to initialize a worker process before running any tasks in it. 

180 chunksize: int, optional 

181 Size of chunks to use when dispatching work. 

182 Defaults to 6 as some batching is helpful. 

183 If -1, will be computed to evenly divide ready work across workers. 

184 """ 

185 chunksize = chunksize or config.get("chunksize", 6) 

186 pool = pool or config.get("pool", None) 

187 initializer = initializer or config.get("multiprocessing.initializer", None) 

188 num_workers = num_workers or config.get("num_workers", None) or CPU_COUNT 

189 if pool is None: 

190 # In order to get consistent hashing in subprocesses, we need to set a 

191 # consistent seed for the Python hash algorithm. Unfortunately, there 

192 # is no way to specify environment variables only for the Pool 

193 # processes, so we have to rely on environment variables being 

194 # inherited. 

195 if os.environ.get("PYTHONHASHSEED") in (None, "0"): 

196 # This number is arbitrary; it was chosen to commemorate 

197 # https://github.com/dask/dask/issues/6640. 

198 os.environ["PYTHONHASHSEED"] = "6640" 

199 context = get_context() 

200 initializer = partial(initialize_worker_process, user_initializer=initializer) 

201 pool = ProcessPoolExecutor( 

202 num_workers, mp_context=context, initializer=initializer 

203 ) 

204 cleanup = True 

205 else: 

206 if initializer is not None: 

207 warn( 

208 "The ``initializer`` argument is ignored when ``pool`` is provided. " 

209 "The user should configure ``pool`` with the needed ``initializer`` " 

210 "on creation." 

211 ) 

212 if isinstance(pool, multiprocessing.pool.Pool): 

213 pool = MultiprocessingPoolExecutor(pool) 

214 cleanup = False 

215 

216 if hasattr(dsk, "__dask_graph__"): 

217 dsk = dsk.__dask_graph__() 

218 

219 dsk = ensure_dict(dsk) 

220 dsk2, dependencies = cull(dsk, keys) 

221 if optimize_graph: 

222 dsk3, dependencies = fuse(dsk2, keys, dependencies) 

223 else: 

224 dsk3 = dsk2 

225 

226 # We specify marshalling functions in order to catch serialization 

227 # errors and report them to the user. 

228 loads = func_loads or config.get("func_loads", None) or _loads 

229 dumps = func_dumps or config.get("func_dumps", None) or _dumps 

230 

231 # Note former versions used a multiprocessing Manager to share 

232 # a Queue between parent and workers, but this is fragile on Windows 

233 # (issue #1652). 

234 try: 

235 # Run 

236 result = get_async( 

237 pool.submit, 

238 pool._max_workers, 

239 dsk3, 

240 keys, 

241 get_id=_process_get_id, 

242 dumps=dumps, 

243 loads=loads, 

244 pack_exception=pack_exception, 

245 raise_exception=reraise, 

246 chunksize=chunksize, 

247 **kwargs, 

248 ) 

249 finally: 

250 if cleanup: 

251 pool.shutdown() 

252 return result 

253 

254 

255def default_initializer(): 

256 # If Numpy is already imported, presumably its random state was 

257 # inherited from the parent => re-seed it. 

258 np = sys.modules.get("numpy") 

259 if np is not None: 

260 np.random.seed() 

261 

262 

263def initialize_worker_process(user_initializer=None): 

264 """ 

265 Initialize a worker process before running any tasks in it. 

266 """ 

267 default_initializer() 

268 

269 if user_initializer is not None: 

270 user_initializer()