Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/unblob/pool.py: 37%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import abc
2import contextlib
3import multiprocessing as mp
4import os
5import queue
6import signal
7import sys
8import threading
9from multiprocessing.queues import JoinableQueue
10from typing import Any, Callable, Union
12from .logging import multiprocessing_breakpoint
14mp.set_start_method("fork")
17class PoolBase(abc.ABC):
18 def __init__(self):
19 with pools_lock:
20 pools.add(self)
22 @abc.abstractmethod
23 def submit(self, args):
24 pass
26 @abc.abstractmethod
27 def process_until_done(self):
28 pass
30 def start(self):
31 pass
33 def close(self, *, immediate=False): # noqa: ARG002
34 with pools_lock:
35 pools.remove(self)
37 def __enter__(self):
38 self.start()
39 return self
41 def __exit__(self, exc_type, _exc_value, _tb):
42 self.close(immediate=exc_type is not None)
45pools_lock = threading.Lock()
46pools: set[PoolBase] = set()
49class Queue(JoinableQueue):
50 def is_empty(self) -> bool:
51 """Check if all ``task_done`` has been called for all items.
53 Based on ``multiprocessing.JoinableQueue.join``.
54 """
55 with self._cond: # type: ignore
56 return self._unfinished_tasks._semlock._is_zero() # type: ignore # noqa: SLF001
59class _Sentinel:
60 pass
63_SENTINEL = _Sentinel
66def _worker_process(handler, input_, output):
67 # Creates a new process group, making sure no signals are
68 # propagated from the main process to the worker processes.
69 os.setpgrp()
71 # Restore default signal handlers, otherwise workers would inherit
72 # them from main process
73 signal.signal(signal.SIGTERM, signal.SIG_DFL)
74 signal.signal(signal.SIGINT, signal.SIG_DFL)
76 sys.breakpointhook = multiprocessing_breakpoint
77 while (args := input_.get()) is not _SENTINEL:
78 result = handler(args)
79 output.put(result)
80 output.put(_SENTINEL)
83class MultiPool(PoolBase):
84 def __init__(
85 self,
86 process_num: int,
87 handler: Callable[[Any], Any],
88 *,
89 result_callback: Callable[["MultiPool", Any], Any],
90 ):
91 super().__init__()
92 if process_num <= 0:
93 raise ValueError("At process_num must be greater than 0")
95 self._running = False
96 self._result_callback = result_callback
97 self._input = Queue(ctx=mp.get_context())
98 self._input.cancel_join_thread()
99 self._output = mp.SimpleQueue()
100 self._procs = [
101 mp.Process(
102 target=_worker_process,
103 args=(handler, self._input, self._output),
104 )
105 for _ in range(process_num)
106 ]
107 self._tid = threading.get_native_id()
109 def start(self):
110 self._running = True
111 for p in self._procs:
112 p.start()
114 def close(self, *, immediate=False):
115 if not self._running:
116 return
117 self._running = False
119 if immediate:
120 self._terminate_workers()
121 else:
122 self._clear_input_queue()
123 self._request_workers_to_quit()
124 self._clear_output_queue()
126 self._wait_for_workers_to_quit()
127 super().close(immediate=immediate)
129 def _terminate_workers(self):
130 for proc in self._procs:
131 proc.terminate()
133 self._input.close()
134 self._output.close()
136 def _clear_input_queue(self):
137 try:
138 while True:
139 self._input.get_nowait()
140 except queue.Empty:
141 pass
143 def _request_workers_to_quit(self):
144 for _ in self._procs:
145 self._input.put(_SENTINEL)
146 self._input.close()
148 def _clear_output_queue(self):
149 process_quit_count = 0
150 process_num = len(self._procs)
151 while process_quit_count < process_num:
152 result = self._output.get()
153 if result is _SENTINEL:
154 process_quit_count += 1
156 def _wait_for_workers_to_quit(self):
157 for p in self._procs:
158 p.join()
160 def submit(self, args):
161 if threading.get_native_id() != self._tid:
162 raise RuntimeError(
163 "Submit can only be called from the same "
164 "thread/process where the pool is created"
165 )
166 self._input.put(args)
168 def process_until_done(self):
169 with contextlib.suppress(EOFError):
170 while not self._input.is_empty():
171 result = self._output.get()
172 self._result_callback(self, result)
173 self._input.task_done()
176class SinglePool(PoolBase):
177 def __init__(self, handler, *, result_callback):
178 super().__init__()
179 self._handler = handler
180 self._result_callback = result_callback
182 def submit(self, args):
183 result = self._handler(args)
184 self._result_callback(self, result)
186 def process_until_done(self):
187 pass
190def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiPool]:
191 if process_num == 1:
192 return SinglePool(handler=handler, result_callback=result_callback)
194 return MultiPool(
195 process_num=process_num,
196 handler=handler,
197 result_callback=result_callback,
198 )
201orig_signal_handlers = {}
204def _on_terminate(signum, frame):
205 pools_snapshot = list(pools)
206 for pool in pools_snapshot:
207 pool.close(immediate=True)
209 if callable(orig_signal_handlers[signum]):
210 orig_signal_handlers[signum](signum, frame)
213orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _on_terminate)
214orig_signal_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, _on_terminate)