Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/coordinator.py: 23%

141 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Coordinator to help multiple threads stop when requested.""" 

16import contextlib 

17import sys 

18import threading 

19import time 

20 

21from tensorflow.python.framework import errors 

22from tensorflow.python.platform import tf_logging as logging 

23from tensorflow.python.util import compat 

24from tensorflow.python.util.tf_export import tf_export 

25 

26 

27@tf_export("train.Coordinator") 

28class Coordinator: 

29 """A coordinator for threads. 

30 

31 This class implements a simple mechanism to coordinate the termination of a 

32 set of threads. 

33 

34 #### Usage: 

35 

36 ```python 

37 # Create a coordinator. 

38 coord = Coordinator() 

39 # Start a number of threads, passing the coordinator to each of them. 

40 ...start thread 1...(coord, ...) 

41 ...start thread N...(coord, ...) 

42 # Wait for all the threads to terminate. 

43 coord.join(threads) 

44 ``` 

45 

46 Any of the threads can call `coord.request_stop()` to ask for all the threads 

47 to stop. To cooperate with the requests, each thread must check for 

48 `coord.should_stop()` on a regular basis. `coord.should_stop()` returns 

49 `True` as soon as `coord.request_stop()` has been called. 

50 

51 A typical thread running with a coordinator will do something like: 

52 

53 ```python 

54 while not coord.should_stop(): 

55 ...do some work... 

56 ``` 

57 

58 #### Exception handling: 

59 

60 A thread can report an exception to the coordinator as part of the 

61 `request_stop()` call. The exception will be re-raised from the 

62 `coord.join()` call. 

63 

64 Thread code: 

65 

66 ```python 

67 try: 

68 while not coord.should_stop(): 

69 ...do some work... 

70 except Exception as e: 

71 coord.request_stop(e) 

72 ``` 

73 

74 Main code: 

75 

76 ```python 

77 try: 

78 ... 

79 coord = Coordinator() 

80 # Start a number of threads, passing the coordinator to each of them. 

81 ...start thread 1...(coord, ...) 

82 ...start thread N...(coord, ...) 

83 # Wait for all the threads to terminate. 

84 coord.join(threads) 

85 except Exception as e: 

86 ...exception that was passed to coord.request_stop() 

87 ``` 

88 

89 To simplify the thread implementation, the Coordinator provides a 

90 context handler `stop_on_exception()` that automatically requests a stop if 

91 an exception is raised. Using the context handler the thread code above 

92 can be written as: 

93 

94 ```python 

95 with coord.stop_on_exception(): 

96 while not coord.should_stop(): 

97 ...do some work... 

98 ``` 

99 

100 #### Grace period for stopping: 

101 

102 After a thread has called `coord.request_stop()` the other threads have a 

103 fixed time to stop, this is called the 'stop grace period' and defaults to 2 

104 minutes. If any of the threads is still alive after the grace period expires 

105 `coord.join()` raises a RuntimeError reporting the laggards. 

106 

107 ```python 

108 try: 

109 ... 

110 coord = Coordinator() 

111 # Start a number of threads, passing the coordinator to each of them. 

112 ...start thread 1...(coord, ...) 

113 ...start thread N...(coord, ...) 

114 # Wait for all the threads to terminate, give them 10s grace period 

115 coord.join(threads, stop_grace_period_secs=10) 

116 except RuntimeError: 

117 ...one of the threads took more than 10s to stop after request_stop() 

118 ...was called. 

119 except Exception: 

120 ...exception that was passed to coord.request_stop() 

121 ``` 

122 """ 

123 

124 def __init__(self, clean_stop_exception_types=None): 

125 """Create a new Coordinator. 

126 

127 Args: 

128 clean_stop_exception_types: Optional tuple of Exception types that should 

129 cause a clean stop of the coordinator. If an exception of one of these 

130 types is reported to `request_stop(ex)` the coordinator will behave as 

131 if `request_stop(None)` was called. Defaults to 

132 `(tf.errors.OutOfRangeError,)` which is used by input queues to signal 

133 the end of input. When feeding training data from a Python iterator it 

134 is common to add `StopIteration` to this list. 

135 """ 

136 if clean_stop_exception_types is None: 

137 clean_stop_exception_types = (errors.OutOfRangeError,) 

138 self._clean_stop_exception_types = tuple(clean_stop_exception_types) 

139 # Protects all attributes. 

140 self._lock = threading.Lock() 

141 # Event set when threads must stop. 

142 self._stop_event = threading.Event() 

143 # Python exc_info to report. 

144 # If not None, it should hold the returned value of sys.exc_info(), which is 

145 # a tuple containing exception (type, value, traceback). 

146 self._exc_info_to_raise = None 

147 # True if we have called join() already. 

148 self._joined = False 

149 # Set of threads registered for joining when join() is called. These 

150 # threads will be joined in addition to the threads passed to the join() 

151 # call. It's ok if threads are both registered and passed to the join() 

152 # call. 

153 self._registered_threads = set() 

154 

155 def _filter_exception(self, ex): 

156 """Check if the exception indicated in 'ex' should be ignored. 

157 

158 This method examines `ex` to check if it is an exception that should be 

159 reported to the users. If yes, it returns `ex` as is, otherwise it returns 

160 None. 

161 

162 The code returns None for exception types listed in 

163 `_clean_stop_exception_types`. 

164 

165 Args: 

166 ex: None, an `Exception`, or a Python `exc_info` tuple as returned by 

167 `sys.exc_info()`. 

168 

169 Returns: 

170 ex or None. 

171 """ 

172 if isinstance(ex, tuple): 

173 ex2 = ex[1] 

174 else: 

175 ex2 = ex 

176 if isinstance(ex2, self._clean_stop_exception_types): 

177 # Ignore the exception. 

178 ex = None 

179 return ex 

180 

181 def request_stop(self, ex=None): 

182 """Request that the threads stop. 

183 

184 After this is called, calls to `should_stop()` will return `True`. 

185 

186 Note: If an exception is being passed in, in must be in the context of 

187 handling the exception (i.e. `try: ... except Exception as ex: ...`) and not 

188 a newly created one. 

189 

190 Args: 

191 ex: Optional `Exception`, or Python `exc_info` tuple as returned by 

192 `sys.exc_info()`. If this is the first call to `request_stop()` the 

193 corresponding exception is recorded and re-raised from `join()`. 

194 """ 

195 with self._lock: 

196 ex = self._filter_exception(ex) 

197 # If we have already joined the coordinator the exception will not have a 

198 # chance to be reported, so just raise it normally. This can happen if 

199 # you continue to use a session have having stopped and joined the 

200 # coordinator threads. 

201 if self._joined: 

202 if isinstance(ex, tuple): 

203 _, ex_instance, _ = ex 

204 raise ex_instance 

205 elif ex is not None: 

206 # NOTE(touts): This is bogus if request_stop() is not called 

207 # from the exception handler that raised ex. 

208 _, ex_instance, _ = sys.exc_info() 

209 raise ex_instance 

210 if not self._stop_event.is_set(): 

211 if ex and self._exc_info_to_raise is None: 

212 if isinstance(ex, tuple): 

213 logging.info("Error reported to Coordinator: %s", 

214 compat.as_str_any(ex[1]), 

215 exc_info=ex) 

216 self._exc_info_to_raise = ex 

217 else: 

218 logging.info("Error reported to Coordinator: %s, %s", 

219 type(ex), 

220 compat.as_str_any(ex)) 

221 self._exc_info_to_raise = sys.exc_info() 

222 # self._exc_info_to_raise should contain a tuple containing exception 

223 # (type, value, traceback) 

224 if (len(self._exc_info_to_raise) != 3 or 

225 not self._exc_info_to_raise[0] or 

226 not self._exc_info_to_raise[1]): 

227 # Raise, catch and record the exception here so that error happens 

228 # where expected. 

229 try: 

230 raise ValueError( 

231 "ex must be a tuple or sys.exc_info must return the current " 

232 "exception: %s" 

233 % self._exc_info_to_raise) 

234 except ValueError: 

235 # Record this error so it kills the coordinator properly. 

236 # NOTE(touts): As above, this is bogus if request_stop() is not 

237 # called from the exception handler that raised ex. 

238 self._exc_info_to_raise = sys.exc_info() 

239 

240 self._stop_event.set() 

241 

242 def clear_stop(self): 

243 """Clears the stop flag. 

244 

245 After this is called, calls to `should_stop()` will return `False`. 

246 """ 

247 with self._lock: 

248 self._joined = False 

249 self._exc_info_to_raise = None 

250 if self._stop_event.is_set(): 

251 self._stop_event.clear() 

252 

253 def should_stop(self): 

254 """Check if stop was requested. 

255 

256 Returns: 

257 True if a stop was requested. 

258 """ 

259 return self._stop_event.is_set() 

260 

261 @contextlib.contextmanager 

262 def stop_on_exception(self): 

263 """Context manager to request stop when an Exception is raised. 

264 

265 Code that uses a coordinator must catch exceptions and pass 

266 them to the `request_stop()` method to stop the other threads 

267 managed by the coordinator. 

268 

269 This context handler simplifies the exception handling. 

270 Use it as follows: 

271 

272 ```python 

273 with coord.stop_on_exception(): 

274 # Any exception raised in the body of the with 

275 # clause is reported to the coordinator before terminating 

276 # the execution of the body. 

277 ...body... 

278 ``` 

279 

280 This is completely equivalent to the slightly longer code: 

281 

282 ```python 

283 try: 

284 ...body... 

285 except: 

286 coord.request_stop(sys.exc_info()) 

287 ``` 

288 

289 Yields: 

290 nothing. 

291 """ 

292 try: 

293 yield 

294 except: # pylint: disable=bare-except 

295 self.request_stop(ex=sys.exc_info()) 

296 

297 def wait_for_stop(self, timeout=None): 

298 """Wait till the Coordinator is told to stop. 

299 

300 Args: 

301 timeout: Float. Sleep for up to that many seconds waiting for 

302 should_stop() to become True. 

303 

304 Returns: 

305 True if the Coordinator is told stop, False if the timeout expired. 

306 """ 

307 return self._stop_event.wait(timeout) 

308 

309 def register_thread(self, thread): 

310 """Register a thread to join. 

311 

312 Args: 

313 thread: A Python thread to join. 

314 """ 

315 with self._lock: 

316 self._registered_threads.add(thread) 

317 

318 def join(self, threads=None, stop_grace_period_secs=120, 

319 ignore_live_threads=False): 

320 """Wait for threads to terminate. 

321 

322 This call blocks until a set of threads have terminated. The set of thread 

323 is the union of the threads passed in the `threads` argument and the list 

324 of threads that registered with the coordinator by calling 

325 `Coordinator.register_thread()`. 

326 

327 After the threads stop, if an `exc_info` was passed to `request_stop`, that 

328 exception is re-raised. 

329 

330 Grace period handling: When `request_stop()` is called, threads are given 

331 'stop_grace_period_secs' seconds to terminate. If any of them is still 

332 alive after that period expires, a `RuntimeError` is raised. Note that if 

333 an `exc_info` was passed to `request_stop()` then it is raised instead of 

334 that `RuntimeError`. 

335 

336 Args: 

337 threads: List of `threading.Threads`. The started threads to join in 

338 addition to the registered threads. 

339 stop_grace_period_secs: Number of seconds given to threads to stop after 

340 `request_stop()` has been called. 

341 ignore_live_threads: If `False`, raises an error if any of the threads are 

342 still alive after `stop_grace_period_secs`. 

343 

344 Raises: 

345 RuntimeError: If any thread is still alive after `request_stop()` 

346 is called and the grace period expires. 

347 """ 

348 # Threads registered after this call will not be joined. 

349 with self._lock: 

350 if threads is None: 

351 threads = self._registered_threads 

352 else: 

353 threads = self._registered_threads.union(set(threads)) 

354 # Copy the set into a list to avoid race conditions where a new thread 

355 # is added while we are waiting. 

356 threads = list(threads) 

357 

358 # Wait for all threads to stop or for request_stop() to be called. 

359 while any(t.is_alive() for t in threads) and not self.wait_for_stop(1.0): 

360 pass 

361 

362 # If any thread is still alive, wait for the grace period to expire. 

363 # By the time this check is executed, threads may still be shutting down, 

364 # so we add a sleep of increasing duration to give them a chance to shut 

365 # down without losing too many cycles. 

366 # The sleep duration is limited to the remaining grace duration. 

367 stop_wait_secs = 0.001 

368 while any(t.is_alive() for t in threads) and stop_grace_period_secs >= 0.0: 

369 time.sleep(stop_wait_secs) 

370 stop_grace_period_secs -= stop_wait_secs 

371 stop_wait_secs = 2 * stop_wait_secs 

372 # Keep the waiting period within sane bounds. 

373 # The minimum value is to avoid decreasing stop_wait_secs to a value 

374 # that could cause stop_grace_period_secs to remain unchanged. 

375 stop_wait_secs = max(min(stop_wait_secs, stop_grace_period_secs), 0.001) 

376 

377 # List the threads still alive after the grace period. 

378 stragglers = [t.name for t in threads if t.is_alive()] 

379 

380 # Terminate with an exception if appropriate. 

381 with self._lock: 

382 self._joined = True 

383 self._registered_threads = set() 

384 if self._exc_info_to_raise: 

385 _, ex_instance, _ = self._exc_info_to_raise 

386 raise ex_instance 

387 elif stragglers: 

388 if ignore_live_threads: 

389 logging.info("Coordinator stopped with threads still running: %s", 

390 " ".join(stragglers)) 

391 else: 

392 raise RuntimeError( 

393 "Coordinator stopped with threads still running: %s" % 

394 " ".join(stragglers)) 

395 

396 @property 

397 def joined(self): 

398 return self._joined 

399 

400 def raise_requested_exception(self): 

401 """If an exception has been passed to `request_stop`, this raises it.""" 

402 with self._lock: 

403 if self._exc_info_to_raise: 

404 _, ex_instance, _ = self._exc_info_to_raise 

405 raise ex_instance 

406 

407 

408# Threads for the standard services. 

409@tf_export(v1=["train.LooperThread"]) 

410class LooperThread(threading.Thread): 

411 """A thread that runs code repeatedly, optionally on a timer. 

412 

413 This thread class is intended to be used with a `Coordinator`. It repeatedly 

414 runs code specified either as `target` and `args` or by the `run_loop()` 

415 method. 

416 

417 Before each run the thread checks if the coordinator has requested stop. In 

418 that case the looper thread terminates immediately. 

419 

420 If the code being run raises an exception, that exception is reported to the 

421 coordinator and the thread terminates. The coordinator will then request all 

422 the other threads it coordinates to stop. 

423 

424 You typically pass looper threads to the supervisor `Join()` method. 

425 """ 

426 

427 def __init__(self, coord, timer_interval_secs, target=None, args=None, 

428 kwargs=None): 

429 """Create a LooperThread. 

430 

431 Args: 

432 coord: A Coordinator. 

433 timer_interval_secs: Time boundaries at which to call Run(), or None 

434 if it should be called back to back. 

435 target: Optional callable object that will be executed in the thread. 

436 args: Optional arguments to pass to `target` when calling it. 

437 kwargs: Optional keyword arguments to pass to `target` when calling it. 

438 

439 Raises: 

440 ValueError: If one of the arguments is invalid. 

441 """ 

442 if not isinstance(coord, Coordinator): 

443 raise ValueError("'coord' argument must be a Coordinator: %s" % coord) 

444 super(LooperThread, self).__init__() 

445 self.daemon = True 

446 self._coord = coord 

447 self._timer_interval_secs = timer_interval_secs 

448 self._target = target 

449 if self._target: 

450 self._args = args or () 

451 self._kwargs = kwargs or {} 

452 elif args or kwargs: 

453 raise ValueError("'args' and 'kwargs' argument require that you also " 

454 "pass 'target'") 

455 self._coord.register_thread(self) 

456 

457 @staticmethod 

458 def loop(coord, timer_interval_secs, target, args=None, kwargs=None): 

459 """Start a LooperThread that calls a function periodically. 

460 

461 If `timer_interval_secs` is None the thread calls `target(args)` 

462 repeatedly. Otherwise `target(args)` is called every `timer_interval_secs` 

463 seconds. The thread terminates when a stop of the coordinator is 

464 requested. 

465 

466 Args: 

467 coord: A Coordinator. 

468 timer_interval_secs: Number. Time boundaries at which to call `target`. 

469 target: A callable object. 

470 args: Optional arguments to pass to `target` when calling it. 

471 kwargs: Optional keyword arguments to pass to `target` when calling it. 

472 

473 Returns: 

474 The started thread. 

475 """ 

476 looper = LooperThread(coord, timer_interval_secs, target=target, args=args, 

477 kwargs=kwargs) 

478 looper.start() 

479 return looper 

480 

481 def run(self): 

482 with self._coord.stop_on_exception(): 

483 self.start_loop() 

484 if self._timer_interval_secs is None: 

485 # Call back-to-back. 

486 while not self._coord.should_stop(): 

487 self.run_loop() 

488 else: 

489 # Next time at which to call run_loop(), starts as 'now'. 

490 next_timer_time = time.time() 

491 while not self._coord.wait_for_stop(next_timer_time - time.time()): 

492 next_timer_time += self._timer_interval_secs 

493 self.run_loop() 

494 self.stop_loop() 

495 

496 def start_loop(self): 

497 """Called when the thread starts.""" 

498 pass 

499 

500 def stop_loop(self): 

501 """Called when the thread stops.""" 

502 pass 

503 

504 def run_loop(self): 

505 """Called at 'timer_interval_secs' boundaries.""" 

506 if self._target: 

507 self._target(*self._args, **self._kwargs)