Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/unblob/pool.py: 33%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import abc
2import atexit
3import contextlib
4import fcntl
5import multiprocessing as mp
6import multiprocessing.connection
7import os
8import queue
9import signal
10import sys
11import threading
12from collections.abc import Callable
13from multiprocessing.queues import JoinableQueue, SimpleQueue
14from typing import Any
16from .logging import multiprocessing_breakpoint
18mp.set_start_method("fork")
21class PoolBase(abc.ABC):
22 @abc.abstractmethod
23 def submit(self, args):
24 pass
26 @abc.abstractmethod
27 def process_until_done(self):
28 pass
30 def start(self):
31 pass
33 def close(self, *, immediate=False):
34 pass
36 def __enter__(self):
37 self.start()
38 return self
40 def __exit__(self, exc_type, _exc_value, _tb):
41 self.close(immediate=exc_type is not None)
44class Queue(JoinableQueue):
45 def is_empty(self) -> bool:
46 """Check if all ``task_done`` has been called for all items.
48 Based on ``multiprocessing.JoinableQueue.join``.
49 """
50 with self._cond: # type: ignore
51 return self._unfinished_tasks._semlock._is_zero() # type: ignore # noqa: SLF001
54class ResultQueue(SimpleQueue):
55 @property
56 def reader(self) -> multiprocessing.connection.Connection:
57 return self._reader # type: ignore
60class _Sentinel:
61 pass
64_SENTINEL = _Sentinel
67class WorkerDiedError(RuntimeError):
68 pass
71def _worker_process(handler, input_, output, lifeline_worker_side, lifeline_host_side):
72 # Creates a new process group, making sure no signals are
73 # propagated from the main process to the worker processes.
74 os.setpgrp()
76 # Restore default signal handlers, otherwise workers would inherit
77 # them from main process. When used as a library, the hosting app
78 # is free to set-up its own signal handlers.
79 signal.signal(signal.SIGTERM, signal.SIG_DFL)
80 signal.signal(signal.SIGINT, signal.SIG_DFL)
82 os.close( # forked processes inherit open files, we don't need the host FD
83 lifeline_host_side
84 )
86 def _exit_on_parent_death():
87 os.read(lifeline_worker_side, 1)
88 # We cannot really do anything about this, best to reliably
89 # abort the process
90 os._exit(1)
92 parent_liveness_monitor = threading.Thread(
93 target=_exit_on_parent_death, daemon=True
94 )
95 parent_liveness_monitor.start()
97 sys.breakpointhook = multiprocessing_breakpoint
98 while (args := input_.get()) is not _SENTINEL:
99 result = handler(args)
100 output.put(result)
101 output.put(_SENTINEL)
104class MultiPool(PoolBase):
105 def __init__(
106 self,
107 process_num: int,
108 handler: Callable[[Any], Any],
109 *,
110 result_callback: Callable[["MultiPool", Any], Any],
111 ):
112 if process_num <= 0:
113 raise ValueError("At process_num must be greater than 0")
115 self._running = False
116 self._result_callback = result_callback
117 self._input = Queue(ctx=mp.get_context())
118 self._input.cancel_join_thread()
119 self._output = ResultQueue(ctx=mp.get_context())
120 # see search results for "death-pipe" or "forkfd concept"
121 (self._lifeline_worker_side, self._lifeline_host_side) = os.pipe()
122 fcntl.fcntl(self._lifeline_host_side, fcntl.F_SETFD, fcntl.FD_CLOEXEC)
123 fcntl.fcntl(self._lifeline_worker_side, fcntl.F_SETFD, fcntl.FD_CLOEXEC)
125 self._procs = [
126 mp.Process(
127 target=_worker_process,
128 args=(
129 handler,
130 self._input,
131 self._output,
132 self._lifeline_worker_side,
133 self._lifeline_host_side,
134 ),
135 )
136 for _ in range(process_num)
137 ]
138 self._tid = threading.get_native_id()
140 def start(self):
141 self._running = True
142 for p in self._procs:
143 p.start()
144 # We are the host process, we don't need this anymore.
145 # Had to keep the file alive until inherited by the forked subprocess
146 os.close(self._lifeline_worker_side)
147 atexit.register(self._close_immediate)
149 def _any_worker_exited(self) -> bool:
150 sentinels = [p.sentinel for p in self._procs]
151 return bool(multiprocessing.connection.wait(sentinels, timeout=0))
153 def close(self, *, immediate=False):
154 if not self._running:
155 return
156 self._running = False
157 atexit.unregister(self._close_immediate)
158 immediate = immediate or self._any_worker_exited()
160 termination_exception = None
161 if not immediate:
162 try:
163 self._clear_input_queue()
164 self._request_workers_to_quit()
165 self._clear_output_queue()
166 except BaseException as exc:
167 termination_exception = exc
168 immediate = True
170 if immediate:
171 self._terminate_workers()
173 self._wait_for_workers_to_quit()
175 # closing this FD any sooner would cause workers to abort
176 # immediately, should close only after workers quit
177 os.close(self._lifeline_host_side)
179 if termination_exception:
180 raise termination_exception
182 def _close_immediate(self):
183 self.close(immediate=True)
185 def _terminate_workers(self):
186 for proc in self._procs:
187 proc.terminate()
189 self._input.close()
190 self._output.close()
192 def _clear_input_queue(self):
193 with contextlib.suppress(queue.Empty):
194 while True:
195 self._input.get_nowait()
197 def _request_workers_to_quit(self):
198 for _ in self._procs:
199 self._input.put(_SENTINEL)
200 self._input.close()
202 def _clear_output_queue(self):
203 alive = {p.sentinel: p for p in self._procs if p.exitcode is None}
204 while alive:
205 ready = multiprocessing.connection.wait([self._output.reader, *alive])
206 for fd in ready:
207 alive.pop(fd, None) # type: ignore[arg-type]
208 if self._output.reader in ready:
209 self._output.get()
211 def _wait_for_workers_to_quit(self):
212 for p in self._procs:
213 p.join()
215 def submit(self, args):
216 if threading.get_native_id() != self._tid:
217 raise RuntimeError(
218 "Submit can only be called from the same "
219 "thread/process where the pool is created"
220 )
221 self._input.put(args)
223 def _check_worker_deaths(self, sentinels, ready):
224 for fd in ready:
225 if fd not in sentinels:
226 continue
227 proc = sentinels.pop(fd)
228 if proc.exitcode != 0:
229 exitcode = proc.exitcode
230 if exitcode is not None and exitcode < 0:
231 reason = f"killed by signal {-exitcode}"
232 else:
233 reason = f"exited with code {exitcode}"
234 raise WorkerDiedError(
235 f"Worker process {proc.pid} exited unexpectedly ({reason})"
236 )
238 def process_until_done(self):
239 sentinels = {p.sentinel: p for p in self._procs}
240 with contextlib.suppress(EOFError):
241 while not self._input.is_empty():
242 ready = multiprocessing.connection.wait(
243 [self._output.reader, *sentinels]
244 )
245 self._check_worker_deaths(sentinels, ready)
246 if self._output.reader in ready:
247 result = self._output.get()
248 self._result_callback(self, result)
249 self._input.task_done()
252class SinglePool(PoolBase):
253 def __init__(self, handler, *, result_callback):
254 self._handler = handler
255 self._result_callback = result_callback
257 def submit(self, args):
258 result = self._handler(args)
259 self._result_callback(self, result)
261 def process_until_done(self):
262 pass
265def make_pool(process_num, handler, result_callback) -> SinglePool | MultiPool:
266 if process_num == 1:
267 return SinglePool(handler=handler, result_callback=result_callback)
269 return MultiPool(
270 process_num=process_num,
271 handler=handler,
272 result_callback=result_callback,
273 )