Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/unblob/pool.py: 38%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

136 statements  

1import abc 

2import contextlib 

3import multiprocessing as mp 

4import os 

5import queue 

6import signal 

7import sys 

8import threading 

9from collections.abc import Callable 

10from multiprocessing.queues import JoinableQueue 

11from typing import Any 

12 

13from .logging import multiprocessing_breakpoint 

14 

15mp.set_start_method("fork") 

16 

17 

18class PoolBase(abc.ABC): 

19 def __init__(self): 

20 with pools_lock: 

21 pools.add(self) 

22 

23 @abc.abstractmethod 

24 def submit(self, args): 

25 pass 

26 

27 @abc.abstractmethod 

28 def process_until_done(self): 

29 pass 

30 

31 def start(self): 

32 pass 

33 

34 def close(self, *, immediate=False): # noqa: ARG002 

35 with pools_lock: 

36 pools.remove(self) 

37 

38 def __enter__(self): 

39 self.start() 

40 return self 

41 

42 def __exit__(self, exc_type, _exc_value, _tb): 

43 self.close(immediate=exc_type is not None) 

44 

45 

46pools_lock = threading.Lock() 

47pools: set[PoolBase] = set() 

48 

49 

50class Queue(JoinableQueue): 

51 def is_empty(self) -> bool: 

52 """Check if all ``task_done`` has been called for all items. 

53 

54 Based on ``multiprocessing.JoinableQueue.join``. 

55 """ 

56 with self._cond: # type: ignore 

57 return self._unfinished_tasks._semlock._is_zero() # type: ignore # noqa: SLF001 

58 

59 

60class _Sentinel: 

61 pass 

62 

63 

64_SENTINEL = _Sentinel 

65 

66 

67def _worker_process(handler, input_, output): 

68 # Creates a new process group, making sure no signals are 

69 # propagated from the main process to the worker processes. 

70 os.setpgrp() 

71 

72 # Restore default signal handlers, otherwise workers would inherit 

73 # them from main process 

74 signal.signal(signal.SIGTERM, signal.SIG_DFL) 

75 signal.signal(signal.SIGINT, signal.SIG_DFL) 

76 

77 sys.breakpointhook = multiprocessing_breakpoint 

78 while (args := input_.get()) is not _SENTINEL: 

79 result = handler(args) 

80 output.put(result) 

81 output.put(_SENTINEL) 

82 

83 

84class MultiPool(PoolBase): 

85 def __init__( 

86 self, 

87 process_num: int, 

88 handler: Callable[[Any], Any], 

89 *, 

90 result_callback: Callable[["MultiPool", Any], Any], 

91 ): 

92 super().__init__() 

93 if process_num <= 0: 

94 raise ValueError("At process_num must be greater than 0") 

95 

96 self._running = False 

97 self._result_callback = result_callback 

98 self._input = Queue(ctx=mp.get_context()) 

99 self._input.cancel_join_thread() 

100 self._output = mp.SimpleQueue() 

101 self._procs = [ 

102 mp.Process( 

103 target=_worker_process, 

104 args=(handler, self._input, self._output), 

105 ) 

106 for _ in range(process_num) 

107 ] 

108 self._tid = threading.get_native_id() 

109 

110 def start(self): 

111 self._running = True 

112 for p in self._procs: 

113 p.start() 

114 

115 def close(self, *, immediate=False): 

116 if not self._running: 

117 return 

118 self._running = False 

119 

120 if immediate: 

121 self._terminate_workers() 

122 else: 

123 self._clear_input_queue() 

124 self._request_workers_to_quit() 

125 self._clear_output_queue() 

126 

127 self._wait_for_workers_to_quit() 

128 super().close(immediate=immediate) 

129 

130 def _terminate_workers(self): 

131 for proc in self._procs: 

132 proc.terminate() 

133 

134 self._input.close() 

135 self._output.close() 

136 

137 def _clear_input_queue(self): 

138 try: 

139 while True: 

140 self._input.get_nowait() 

141 except queue.Empty: 

142 pass 

143 

144 def _request_workers_to_quit(self): 

145 for _ in self._procs: 

146 self._input.put(_SENTINEL) 

147 self._input.close() 

148 

149 def _clear_output_queue(self): 

150 process_quit_count = 0 

151 process_num = len(self._procs) 

152 while process_quit_count < process_num: 

153 result = self._output.get() 

154 if result is _SENTINEL: 

155 process_quit_count += 1 

156 

157 def _wait_for_workers_to_quit(self): 

158 for p in self._procs: 

159 p.join() 

160 

161 def submit(self, args): 

162 if threading.get_native_id() != self._tid: 

163 raise RuntimeError( 

164 "Submit can only be called from the same " 

165 "thread/process where the pool is created" 

166 ) 

167 self._input.put(args) 

168 

169 def process_until_done(self): 

170 with contextlib.suppress(EOFError): 

171 while not self._input.is_empty(): 

172 result = self._output.get() 

173 self._result_callback(self, result) 

174 self._input.task_done() 

175 

176 

177class SinglePool(PoolBase): 

178 def __init__(self, handler, *, result_callback): 

179 super().__init__() 

180 self._handler = handler 

181 self._result_callback = result_callback 

182 

183 def submit(self, args): 

184 result = self._handler(args) 

185 self._result_callback(self, result) 

186 

187 def process_until_done(self): 

188 pass 

189 

190 

191def make_pool(process_num, handler, result_callback) -> SinglePool | MultiPool: 

192 if process_num == 1: 

193 return SinglePool(handler=handler, result_callback=result_callback) 

194 

195 return MultiPool( 

196 process_num=process_num, 

197 handler=handler, 

198 result_callback=result_callback, 

199 ) 

200 

201 

202orig_signal_handlers = {} 

203 

204 

205def _on_terminate(signum, frame): 

206 pools_snapshot = list(pools) 

207 for pool in pools_snapshot: 

208 pool.close(immediate=True) 

209 

210 if callable(orig_signal_handlers[signum]): 

211 orig_signal_handlers[signum](signum, frame) 

212 

213 

214orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _on_terminate) 

215orig_signal_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, _on_terminate)