Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/coordinator.py: 23%
141 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Coordinator to help multiple threads stop when requested."""
16import contextlib
17import sys
18import threading
19import time
21from tensorflow.python.framework import errors
22from tensorflow.python.platform import tf_logging as logging
23from tensorflow.python.util import compat
24from tensorflow.python.util.tf_export import tf_export
27@tf_export("train.Coordinator")
28class Coordinator:
29 """A coordinator for threads.
31 This class implements a simple mechanism to coordinate the termination of a
32 set of threads.
34 #### Usage:
36 ```python
37 # Create a coordinator.
38 coord = Coordinator()
39 # Start a number of threads, passing the coordinator to each of them.
40 ...start thread 1...(coord, ...)
41 ...start thread N...(coord, ...)
42 # Wait for all the threads to terminate.
43 coord.join(threads)
44 ```
46 Any of the threads can call `coord.request_stop()` to ask for all the threads
47 to stop. To cooperate with the requests, each thread must check for
48 `coord.should_stop()` on a regular basis. `coord.should_stop()` returns
49 `True` as soon as `coord.request_stop()` has been called.
51 A typical thread running with a coordinator will do something like:
53 ```python
54 while not coord.should_stop():
55 ...do some work...
56 ```
58 #### Exception handling:
60 A thread can report an exception to the coordinator as part of the
61 `request_stop()` call. The exception will be re-raised from the
62 `coord.join()` call.
64 Thread code:
66 ```python
67 try:
68 while not coord.should_stop():
69 ...do some work...
70 except Exception as e:
71 coord.request_stop(e)
72 ```
74 Main code:
76 ```python
77 try:
78 ...
79 coord = Coordinator()
80 # Start a number of threads, passing the coordinator to each of them.
81 ...start thread 1...(coord, ...)
82 ...start thread N...(coord, ...)
83 # Wait for all the threads to terminate.
84 coord.join(threads)
85 except Exception as e:
86 ...exception that was passed to coord.request_stop()
87 ```
89 To simplify the thread implementation, the Coordinator provides a
90 context handler `stop_on_exception()` that automatically requests a stop if
91 an exception is raised. Using the context handler the thread code above
92 can be written as:
94 ```python
95 with coord.stop_on_exception():
96 while not coord.should_stop():
97 ...do some work...
98 ```
100 #### Grace period for stopping:
102 After a thread has called `coord.request_stop()` the other threads have a
103 fixed time to stop, this is called the 'stop grace period' and defaults to 2
104 minutes. If any of the threads is still alive after the grace period expires
105 `coord.join()` raises a RuntimeError reporting the laggards.
107 ```python
108 try:
109 ...
110 coord = Coordinator()
111 # Start a number of threads, passing the coordinator to each of them.
112 ...start thread 1...(coord, ...)
113 ...start thread N...(coord, ...)
114 # Wait for all the threads to terminate, give them 10s grace period
115 coord.join(threads, stop_grace_period_secs=10)
116 except RuntimeError:
117 ...one of the threads took more than 10s to stop after request_stop()
118 ...was called.
119 except Exception:
120 ...exception that was passed to coord.request_stop()
121 ```
122 """
124 def __init__(self, clean_stop_exception_types=None):
125 """Create a new Coordinator.
127 Args:
128 clean_stop_exception_types: Optional tuple of Exception types that should
129 cause a clean stop of the coordinator. If an exception of one of these
130 types is reported to `request_stop(ex)` the coordinator will behave as
131 if `request_stop(None)` was called. Defaults to
132 `(tf.errors.OutOfRangeError,)` which is used by input queues to signal
133 the end of input. When feeding training data from a Python iterator it
134 is common to add `StopIteration` to this list.
135 """
136 if clean_stop_exception_types is None:
137 clean_stop_exception_types = (errors.OutOfRangeError,)
138 self._clean_stop_exception_types = tuple(clean_stop_exception_types)
139 # Protects all attributes.
140 self._lock = threading.Lock()
141 # Event set when threads must stop.
142 self._stop_event = threading.Event()
143 # Python exc_info to report.
144 # If not None, it should hold the returned value of sys.exc_info(), which is
145 # a tuple containing exception (type, value, traceback).
146 self._exc_info_to_raise = None
147 # True if we have called join() already.
148 self._joined = False
149 # Set of threads registered for joining when join() is called. These
150 # threads will be joined in addition to the threads passed to the join()
151 # call. It's ok if threads are both registered and passed to the join()
152 # call.
153 self._registered_threads = set()
155 def _filter_exception(self, ex):
156 """Check if the exception indicated in 'ex' should be ignored.
158 This method examines `ex` to check if it is an exception that should be
159 reported to the users. If yes, it returns `ex` as is, otherwise it returns
160 None.
162 The code returns None for exception types listed in
163 `_clean_stop_exception_types`.
165 Args:
166 ex: None, an `Exception`, or a Python `exc_info` tuple as returned by
167 `sys.exc_info()`.
169 Returns:
170 ex or None.
171 """
172 if isinstance(ex, tuple):
173 ex2 = ex[1]
174 else:
175 ex2 = ex
176 if isinstance(ex2, self._clean_stop_exception_types):
177 # Ignore the exception.
178 ex = None
179 return ex
181 def request_stop(self, ex=None):
182 """Request that the threads stop.
184 After this is called, calls to `should_stop()` will return `True`.
186 Note: If an exception is being passed in, in must be in the context of
187 handling the exception (i.e. `try: ... except Exception as ex: ...`) and not
188 a newly created one.
190 Args:
191 ex: Optional `Exception`, or Python `exc_info` tuple as returned by
192 `sys.exc_info()`. If this is the first call to `request_stop()` the
193 corresponding exception is recorded and re-raised from `join()`.
194 """
195 with self._lock:
196 ex = self._filter_exception(ex)
197 # If we have already joined the coordinator the exception will not have a
198 # chance to be reported, so just raise it normally. This can happen if
199 # you continue to use a session have having stopped and joined the
200 # coordinator threads.
201 if self._joined:
202 if isinstance(ex, tuple):
203 _, ex_instance, _ = ex
204 raise ex_instance
205 elif ex is not None:
206 # NOTE(touts): This is bogus if request_stop() is not called
207 # from the exception handler that raised ex.
208 _, ex_instance, _ = sys.exc_info()
209 raise ex_instance
210 if not self._stop_event.is_set():
211 if ex and self._exc_info_to_raise is None:
212 if isinstance(ex, tuple):
213 logging.info("Error reported to Coordinator: %s",
214 compat.as_str_any(ex[1]),
215 exc_info=ex)
216 self._exc_info_to_raise = ex
217 else:
218 logging.info("Error reported to Coordinator: %s, %s",
219 type(ex),
220 compat.as_str_any(ex))
221 self._exc_info_to_raise = sys.exc_info()
222 # self._exc_info_to_raise should contain a tuple containing exception
223 # (type, value, traceback)
224 if (len(self._exc_info_to_raise) != 3 or
225 not self._exc_info_to_raise[0] or
226 not self._exc_info_to_raise[1]):
227 # Raise, catch and record the exception here so that error happens
228 # where expected.
229 try:
230 raise ValueError(
231 "ex must be a tuple or sys.exc_info must return the current "
232 "exception: %s"
233 % self._exc_info_to_raise)
234 except ValueError:
235 # Record this error so it kills the coordinator properly.
236 # NOTE(touts): As above, this is bogus if request_stop() is not
237 # called from the exception handler that raised ex.
238 self._exc_info_to_raise = sys.exc_info()
240 self._stop_event.set()
242 def clear_stop(self):
243 """Clears the stop flag.
245 After this is called, calls to `should_stop()` will return `False`.
246 """
247 with self._lock:
248 self._joined = False
249 self._exc_info_to_raise = None
250 if self._stop_event.is_set():
251 self._stop_event.clear()
253 def should_stop(self):
254 """Check if stop was requested.
256 Returns:
257 True if a stop was requested.
258 """
259 return self._stop_event.is_set()
261 @contextlib.contextmanager
262 def stop_on_exception(self):
263 """Context manager to request stop when an Exception is raised.
265 Code that uses a coordinator must catch exceptions and pass
266 them to the `request_stop()` method to stop the other threads
267 managed by the coordinator.
269 This context handler simplifies the exception handling.
270 Use it as follows:
272 ```python
273 with coord.stop_on_exception():
274 # Any exception raised in the body of the with
275 # clause is reported to the coordinator before terminating
276 # the execution of the body.
277 ...body...
278 ```
280 This is completely equivalent to the slightly longer code:
282 ```python
283 try:
284 ...body...
285 except:
286 coord.request_stop(sys.exc_info())
287 ```
289 Yields:
290 nothing.
291 """
292 try:
293 yield
294 except: # pylint: disable=bare-except
295 self.request_stop(ex=sys.exc_info())
297 def wait_for_stop(self, timeout=None):
298 """Wait till the Coordinator is told to stop.
300 Args:
301 timeout: Float. Sleep for up to that many seconds waiting for
302 should_stop() to become True.
304 Returns:
305 True if the Coordinator is told stop, False if the timeout expired.
306 """
307 return self._stop_event.wait(timeout)
309 def register_thread(self, thread):
310 """Register a thread to join.
312 Args:
313 thread: A Python thread to join.
314 """
315 with self._lock:
316 self._registered_threads.add(thread)
318 def join(self, threads=None, stop_grace_period_secs=120,
319 ignore_live_threads=False):
320 """Wait for threads to terminate.
322 This call blocks until a set of threads have terminated. The set of thread
323 is the union of the threads passed in the `threads` argument and the list
324 of threads that registered with the coordinator by calling
325 `Coordinator.register_thread()`.
327 After the threads stop, if an `exc_info` was passed to `request_stop`, that
328 exception is re-raised.
330 Grace period handling: When `request_stop()` is called, threads are given
331 'stop_grace_period_secs' seconds to terminate. If any of them is still
332 alive after that period expires, a `RuntimeError` is raised. Note that if
333 an `exc_info` was passed to `request_stop()` then it is raised instead of
334 that `RuntimeError`.
336 Args:
337 threads: List of `threading.Threads`. The started threads to join in
338 addition to the registered threads.
339 stop_grace_period_secs: Number of seconds given to threads to stop after
340 `request_stop()` has been called.
341 ignore_live_threads: If `False`, raises an error if any of the threads are
342 still alive after `stop_grace_period_secs`.
344 Raises:
345 RuntimeError: If any thread is still alive after `request_stop()`
346 is called and the grace period expires.
347 """
348 # Threads registered after this call will not be joined.
349 with self._lock:
350 if threads is None:
351 threads = self._registered_threads
352 else:
353 threads = self._registered_threads.union(set(threads))
354 # Copy the set into a list to avoid race conditions where a new thread
355 # is added while we are waiting.
356 threads = list(threads)
358 # Wait for all threads to stop or for request_stop() to be called.
359 while any(t.is_alive() for t in threads) and not self.wait_for_stop(1.0):
360 pass
362 # If any thread is still alive, wait for the grace period to expire.
363 # By the time this check is executed, threads may still be shutting down,
364 # so we add a sleep of increasing duration to give them a chance to shut
365 # down without losing too many cycles.
366 # The sleep duration is limited to the remaining grace duration.
367 stop_wait_secs = 0.001
368 while any(t.is_alive() for t in threads) and stop_grace_period_secs >= 0.0:
369 time.sleep(stop_wait_secs)
370 stop_grace_period_secs -= stop_wait_secs
371 stop_wait_secs = 2 * stop_wait_secs
372 # Keep the waiting period within sane bounds.
373 # The minimum value is to avoid decreasing stop_wait_secs to a value
374 # that could cause stop_grace_period_secs to remain unchanged.
375 stop_wait_secs = max(min(stop_wait_secs, stop_grace_period_secs), 0.001)
377 # List the threads still alive after the grace period.
378 stragglers = [t.name for t in threads if t.is_alive()]
380 # Terminate with an exception if appropriate.
381 with self._lock:
382 self._joined = True
383 self._registered_threads = set()
384 if self._exc_info_to_raise:
385 _, ex_instance, _ = self._exc_info_to_raise
386 raise ex_instance
387 elif stragglers:
388 if ignore_live_threads:
389 logging.info("Coordinator stopped with threads still running: %s",
390 " ".join(stragglers))
391 else:
392 raise RuntimeError(
393 "Coordinator stopped with threads still running: %s" %
394 " ".join(stragglers))
396 @property
397 def joined(self):
398 return self._joined
400 def raise_requested_exception(self):
401 """If an exception has been passed to `request_stop`, this raises it."""
402 with self._lock:
403 if self._exc_info_to_raise:
404 _, ex_instance, _ = self._exc_info_to_raise
405 raise ex_instance
408# Threads for the standard services.
409@tf_export(v1=["train.LooperThread"])
410class LooperThread(threading.Thread):
411 """A thread that runs code repeatedly, optionally on a timer.
413 This thread class is intended to be used with a `Coordinator`. It repeatedly
414 runs code specified either as `target` and `args` or by the `run_loop()`
415 method.
417 Before each run the thread checks if the coordinator has requested stop. In
418 that case the looper thread terminates immediately.
420 If the code being run raises an exception, that exception is reported to the
421 coordinator and the thread terminates. The coordinator will then request all
422 the other threads it coordinates to stop.
424 You typically pass looper threads to the supervisor `Join()` method.
425 """
427 def __init__(self, coord, timer_interval_secs, target=None, args=None,
428 kwargs=None):
429 """Create a LooperThread.
431 Args:
432 coord: A Coordinator.
433 timer_interval_secs: Time boundaries at which to call Run(), or None
434 if it should be called back to back.
435 target: Optional callable object that will be executed in the thread.
436 args: Optional arguments to pass to `target` when calling it.
437 kwargs: Optional keyword arguments to pass to `target` when calling it.
439 Raises:
440 ValueError: If one of the arguments is invalid.
441 """
442 if not isinstance(coord, Coordinator):
443 raise ValueError("'coord' argument must be a Coordinator: %s" % coord)
444 super(LooperThread, self).__init__()
445 self.daemon = True
446 self._coord = coord
447 self._timer_interval_secs = timer_interval_secs
448 self._target = target
449 if self._target:
450 self._args = args or ()
451 self._kwargs = kwargs or {}
452 elif args or kwargs:
453 raise ValueError("'args' and 'kwargs' argument require that you also "
454 "pass 'target'")
455 self._coord.register_thread(self)
457 @staticmethod
458 def loop(coord, timer_interval_secs, target, args=None, kwargs=None):
459 """Start a LooperThread that calls a function periodically.
461 If `timer_interval_secs` is None the thread calls `target(args)`
462 repeatedly. Otherwise `target(args)` is called every `timer_interval_secs`
463 seconds. The thread terminates when a stop of the coordinator is
464 requested.
466 Args:
467 coord: A Coordinator.
468 timer_interval_secs: Number. Time boundaries at which to call `target`.
469 target: A callable object.
470 args: Optional arguments to pass to `target` when calling it.
471 kwargs: Optional keyword arguments to pass to `target` when calling it.
473 Returns:
474 The started thread.
475 """
476 looper = LooperThread(coord, timer_interval_secs, target=target, args=args,
477 kwargs=kwargs)
478 looper.start()
479 return looper
481 def run(self):
482 with self._coord.stop_on_exception():
483 self.start_loop()
484 if self._timer_interval_secs is None:
485 # Call back-to-back.
486 while not self._coord.should_stop():
487 self.run_loop()
488 else:
489 # Next time at which to call run_loop(), starts as 'now'.
490 next_timer_time = time.time()
491 while not self._coord.wait_for_stop(next_timer_time - time.time()):
492 next_timer_time += self._timer_interval_secs
493 self.run_loop()
494 self.stop_loop()
496 def start_loop(self):
497 """Called when the thread starts."""
498 pass
500 def stop_loop(self):
501 """Called when the thread stops."""
502 pass
504 def run_loop(self):
505 """Called at 'timer_interval_secs' boundaries."""
506 if self._target:
507 self._target(*self._args, **self._kwargs)