Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/unblob/pool.py: 38%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import abc
2import contextlib
3import multiprocessing as mp
4import os
5import queue
6import signal
7import sys
8import threading
9from collections.abc import Callable
10from multiprocessing.queues import JoinableQueue
11from typing import Any
13from .logging import multiprocessing_breakpoint
15mp.set_start_method("fork")
18class PoolBase(abc.ABC):
19 def __init__(self):
20 with pools_lock:
21 pools.add(self)
23 @abc.abstractmethod
24 def submit(self, args):
25 pass
27 @abc.abstractmethod
28 def process_until_done(self):
29 pass
31 def start(self):
32 pass
34 def close(self, *, immediate=False): # noqa: ARG002
35 with pools_lock:
36 pools.remove(self)
38 def __enter__(self):
39 self.start()
40 return self
42 def __exit__(self, exc_type, _exc_value, _tb):
43 self.close(immediate=exc_type is not None)
46pools_lock = threading.Lock()
47pools: set[PoolBase] = set()
50class Queue(JoinableQueue):
51 def is_empty(self) -> bool:
52 """Check if all ``task_done`` has been called for all items.
54 Based on ``multiprocessing.JoinableQueue.join``.
55 """
56 with self._cond: # type: ignore
57 return self._unfinished_tasks._semlock._is_zero() # type: ignore # noqa: SLF001
60class _Sentinel:
61 pass
64_SENTINEL = _Sentinel
67def _worker_process(handler, input_, output):
68 # Creates a new process group, making sure no signals are
69 # propagated from the main process to the worker processes.
70 os.setpgrp()
72 # Restore default signal handlers, otherwise workers would inherit
73 # them from main process
74 signal.signal(signal.SIGTERM, signal.SIG_DFL)
75 signal.signal(signal.SIGINT, signal.SIG_DFL)
77 sys.breakpointhook = multiprocessing_breakpoint
78 while (args := input_.get()) is not _SENTINEL:
79 result = handler(args)
80 output.put(result)
81 output.put(_SENTINEL)
84class MultiPool(PoolBase):
85 def __init__(
86 self,
87 process_num: int,
88 handler: Callable[[Any], Any],
89 *,
90 result_callback: Callable[["MultiPool", Any], Any],
91 ):
92 super().__init__()
93 if process_num <= 0:
94 raise ValueError("At process_num must be greater than 0")
96 self._running = False
97 self._result_callback = result_callback
98 self._input = Queue(ctx=mp.get_context())
99 self._input.cancel_join_thread()
100 self._output = mp.SimpleQueue()
101 self._procs = [
102 mp.Process(
103 target=_worker_process,
104 args=(handler, self._input, self._output),
105 )
106 for _ in range(process_num)
107 ]
108 self._tid = threading.get_native_id()
110 def start(self):
111 self._running = True
112 for p in self._procs:
113 p.start()
115 def close(self, *, immediate=False):
116 if not self._running:
117 return
118 self._running = False
120 if immediate:
121 self._terminate_workers()
122 else:
123 self._clear_input_queue()
124 self._request_workers_to_quit()
125 self._clear_output_queue()
127 self._wait_for_workers_to_quit()
128 super().close(immediate=immediate)
130 def _terminate_workers(self):
131 for proc in self._procs:
132 proc.terminate()
134 self._input.close()
135 self._output.close()
137 def _clear_input_queue(self):
138 try:
139 while True:
140 self._input.get_nowait()
141 except queue.Empty:
142 pass
144 def _request_workers_to_quit(self):
145 for _ in self._procs:
146 self._input.put(_SENTINEL)
147 self._input.close()
149 def _clear_output_queue(self):
150 process_quit_count = 0
151 process_num = len(self._procs)
152 while process_quit_count < process_num:
153 result = self._output.get()
154 if result is _SENTINEL:
155 process_quit_count += 1
157 def _wait_for_workers_to_quit(self):
158 for p in self._procs:
159 p.join()
161 def submit(self, args):
162 if threading.get_native_id() != self._tid:
163 raise RuntimeError(
164 "Submit can only be called from the same "
165 "thread/process where the pool is created"
166 )
167 self._input.put(args)
169 def process_until_done(self):
170 with contextlib.suppress(EOFError):
171 while not self._input.is_empty():
172 result = self._output.get()
173 self._result_callback(self, result)
174 self._input.task_done()
177class SinglePool(PoolBase):
178 def __init__(self, handler, *, result_callback):
179 super().__init__()
180 self._handler = handler
181 self._result_callback = result_callback
183 def submit(self, args):
184 result = self._handler(args)
185 self._result_callback(self, result)
187 def process_until_done(self):
188 pass
191def make_pool(process_num, handler, result_callback) -> SinglePool | MultiPool:
192 if process_num == 1:
193 return SinglePool(handler=handler, result_callback=result_callback)
195 return MultiPool(
196 process_num=process_num,
197 handler=handler,
198 result_callback=result_callback,
199 )
202orig_signal_handlers = {}
205def _on_terminate(signum, frame):
206 pools_snapshot = list(pools)
207 for pool in pools_snapshot:
208 pool.close(immediate=True)
210 if callable(orig_signal_handlers[signum]):
211 orig_signal_handlers[signum](signum, frame)
214orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _on_terminate)
215orig_signal_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, _on_terminate)