Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/unblob/pool.py: 37%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

135 statements  

1import abc 

2import contextlib 

3import multiprocessing as mp 

4import os 

5import queue 

6import signal 

7import sys 

8import threading 

9from multiprocessing.queues import JoinableQueue 

10from typing import Any, Callable, Union 

11 

12from .logging import multiprocessing_breakpoint 

13 

14mp.set_start_method("fork") 

15 

16 

17class PoolBase(abc.ABC): 

18 def __init__(self): 

19 with pools_lock: 

20 pools.add(self) 

21 

22 @abc.abstractmethod 

23 def submit(self, args): 

24 pass 

25 

26 @abc.abstractmethod 

27 def process_until_done(self): 

28 pass 

29 

30 def start(self): 

31 pass 

32 

33 def close(self, *, immediate=False): # noqa: ARG002 

34 with pools_lock: 

35 pools.remove(self) 

36 

37 def __enter__(self): 

38 self.start() 

39 return self 

40 

41 def __exit__(self, exc_type, _exc_value, _tb): 

42 self.close(immediate=exc_type is not None) 

43 

44 

45pools_lock = threading.Lock() 

46pools: set[PoolBase] = set() 

47 

48 

49class Queue(JoinableQueue): 

50 def is_empty(self) -> bool: 

51 """Check if all ``task_done`` has been called for all items. 

52 

53 Based on ``multiprocessing.JoinableQueue.join``. 

54 """ 

55 with self._cond: # type: ignore 

56 return self._unfinished_tasks._semlock._is_zero() # type: ignore # noqa: SLF001 

57 

58 

59class _Sentinel: 

60 pass 

61 

62 

63_SENTINEL = _Sentinel 

64 

65 

66def _worker_process(handler, input_, output): 

67 # Creates a new process group, making sure no signals are 

68 # propagated from the main process to the worker processes. 

69 os.setpgrp() 

70 

71 # Restore default signal handlers, otherwise workers would inherit 

72 # them from main process 

73 signal.signal(signal.SIGTERM, signal.SIG_DFL) 

74 signal.signal(signal.SIGINT, signal.SIG_DFL) 

75 

76 sys.breakpointhook = multiprocessing_breakpoint 

77 while (args := input_.get()) is not _SENTINEL: 

78 result = handler(args) 

79 output.put(result) 

80 output.put(_SENTINEL) 

81 

82 

83class MultiPool(PoolBase): 

84 def __init__( 

85 self, 

86 process_num: int, 

87 handler: Callable[[Any], Any], 

88 *, 

89 result_callback: Callable[["MultiPool", Any], Any], 

90 ): 

91 super().__init__() 

92 if process_num <= 0: 

93 raise ValueError("At process_num must be greater than 0") 

94 

95 self._running = False 

96 self._result_callback = result_callback 

97 self._input = Queue(ctx=mp.get_context()) 

98 self._input.cancel_join_thread() 

99 self._output = mp.SimpleQueue() 

100 self._procs = [ 

101 mp.Process( 

102 target=_worker_process, 

103 args=(handler, self._input, self._output), 

104 ) 

105 for _ in range(process_num) 

106 ] 

107 self._tid = threading.get_native_id() 

108 

109 def start(self): 

110 self._running = True 

111 for p in self._procs: 

112 p.start() 

113 

114 def close(self, *, immediate=False): 

115 if not self._running: 

116 return 

117 self._running = False 

118 

119 if immediate: 

120 self._terminate_workers() 

121 else: 

122 self._clear_input_queue() 

123 self._request_workers_to_quit() 

124 self._clear_output_queue() 

125 

126 self._wait_for_workers_to_quit() 

127 super().close(immediate=immediate) 

128 

129 def _terminate_workers(self): 

130 for proc in self._procs: 

131 proc.terminate() 

132 

133 self._input.close() 

134 self._output.close() 

135 

136 def _clear_input_queue(self): 

137 try: 

138 while True: 

139 self._input.get_nowait() 

140 except queue.Empty: 

141 pass 

142 

143 def _request_workers_to_quit(self): 

144 for _ in self._procs: 

145 self._input.put(_SENTINEL) 

146 self._input.close() 

147 

148 def _clear_output_queue(self): 

149 process_quit_count = 0 

150 process_num = len(self._procs) 

151 while process_quit_count < process_num: 

152 result = self._output.get() 

153 if result is _SENTINEL: 

154 process_quit_count += 1 

155 

156 def _wait_for_workers_to_quit(self): 

157 for p in self._procs: 

158 p.join() 

159 

160 def submit(self, args): 

161 if threading.get_native_id() != self._tid: 

162 raise RuntimeError( 

163 "Submit can only be called from the same " 

164 "thread/process where the pool is created" 

165 ) 

166 self._input.put(args) 

167 

168 def process_until_done(self): 

169 with contextlib.suppress(EOFError): 

170 while not self._input.is_empty(): 

171 result = self._output.get() 

172 self._result_callback(self, result) 

173 self._input.task_done() 

174 

175 

176class SinglePool(PoolBase): 

177 def __init__(self, handler, *, result_callback): 

178 super().__init__() 

179 self._handler = handler 

180 self._result_callback = result_callback 

181 

182 def submit(self, args): 

183 result = self._handler(args) 

184 self._result_callback(self, result) 

185 

186 def process_until_done(self): 

187 pass 

188 

189 

190def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiPool]: 

191 if process_num == 1: 

192 return SinglePool(handler=handler, result_callback=result_callback) 

193 

194 return MultiPool( 

195 process_num=process_num, 

196 handler=handler, 

197 result_callback=result_callback, 

198 ) 

199 

200 

201orig_signal_handlers = {} 

202 

203 

204def _on_terminate(signum, frame): 

205 pools_snapshot = list(pools) 

206 for pool in pools_snapshot: 

207 pool.close(immediate=True) 

208 

209 if callable(orig_signal_handlers[signum]): 

210 orig_signal_handlers[signum](signum, frame) 

211 

212 

213orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _on_terminate) 

214orig_signal_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, _on_terminate)