Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/supervisor.py: 29%
336 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training helper that checkpoints models and computes summaries."""
16import contextlib
17import os
18import time
20from tensorflow.core.framework.summary_pb2 import Summary
21from tensorflow.core.util.event_pb2 import SessionLog
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import meta_graph
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import lookup_ops
28from tensorflow.python.ops import variables
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.summary import summary as _summary
31from tensorflow.python.training import coordinator
32from tensorflow.python.training import saver as saver_mod
33from tensorflow.python.training import session_manager as session_manager_mod
34from tensorflow.python.training import training_util
35from tensorflow.python.util import deprecation
36from tensorflow.python.util.tf_export import tf_export
39@tf_export(v1=["train.Supervisor"])
40class Supervisor:
41 """A training helper that checkpoints models and computes summaries.
43 This class is deprecated. Please use
44 `tf.compat.v1.train.MonitoredTrainingSession` instead.
46 The Supervisor is a small wrapper around a `Coordinator`, a `Saver`,
47 and a `SessionManager` that takes care of common needs of TensorFlow
48 training programs.
50 #### Use for a single program
52 ```python
53 with tf.Graph().as_default():
54 ...add operations to the graph...
55 # Create a Supervisor that will checkpoint the model in '/tmp/mydir'.
56 sv = Supervisor(logdir='/tmp/mydir')
57 # Get a TensorFlow session managed by the supervisor.
58 with sv.managed_session(FLAGS.master) as sess:
59 # Use the session to train the graph.
60 while not sv.should_stop():
61 sess.run(<my_train_op>)
62 ```
64 Within the `with sv.managed_session()` block all variables in the graph have
65 been initialized. In addition, a few services have been started to
66 checkpoint the model and add summaries to the event log.
68 If the program crashes and is restarted, the managed session automatically
69 reinitialize variables from the most recent checkpoint.
71 The supervisor is notified of any exception raised by one of the services.
72 After an exception is raised, `should_stop()` returns `True`. In that case
73 the training loop should also stop. This is why the training loop has to
74 check for `sv.should_stop()`.
76 Exceptions that indicate that the training inputs have been exhausted,
77 `tf.errors.OutOfRangeError`, also cause `sv.should_stop()` to return `True`
78 but are not re-raised from the `with` block: they indicate a normal
79 termination.
81 #### Use for multiple replicas
83 To train with replicas you deploy the same program in a `Cluster`.
84 One of the tasks must be identified as the *chief*: the task that handles
85 initialization, checkpoints, summaries, and recovery. The other tasks
86 depend on the *chief* for these services.
88 The only change you have to do to the single program code is to indicate
89 if the program is running as the *chief*.
91 ```python
92 # Choose a task as the chief. This could be based on server_def.task_index,
93 # or job_def.name, or job_def.tasks. It's entirely up to the end user.
94 # But there can be only one *chief*.
95 is_chief = (server_def.task_index == 0)
96 server = tf.distribute.Server(server_def)
98 with tf.Graph().as_default():
99 ...add operations to the graph...
100 # Create a Supervisor that uses log directory on a shared file system.
101 # Indicate if you are the 'chief'
102 sv = Supervisor(logdir='/shared_directory/...', is_chief=is_chief)
103 # Get a Session in a TensorFlow server on the cluster.
104 with sv.managed_session(server.target) as sess:
105 # Use the session to train the graph.
106 while not sv.should_stop():
107 sess.run(<my_train_op>)
108 ```
110 In the *chief* task, the `Supervisor` works exactly as in the first example
111 above. In the other tasks `sv.managed_session()` waits for the Model to have
112 been initialized before returning a session to the training code. The
113 non-chief tasks depend on the chief task for initializing the model.
115 If one of the tasks crashes and restarts, `managed_session()`
116 checks if the Model is initialized. If yes, it just creates a session and
117 returns it to the training code that proceeds normally. If the model needs
118 to be initialized, the chief task takes care of reinitializing it; the other
119 tasks just wait for the model to have been initialized.
121 NOTE: This modified program still works fine as a single program.
122 The single program marks itself as the chief.
124 #### What `master` string to use
126 Whether you are running on your machine or in the cluster you can use the
127 following values for the --master flag:
129 * Specifying `''` requests an in-process session that does not use RPC.
131 * Specifying `'local'` requests a session that uses the RPC-based
132 "Master interface" to run TensorFlow programs. See
133 `tf.train.Server.create_local_server` for
134 details.
136 * Specifying `'grpc://hostname:port'` requests a session that uses
137 the RPC interface to a specific host, and also allows the in-process
138 master to access remote tensorflow workers. Often, it is
139 appropriate to pass `server.target` (for some `tf.distribute.Server`
140 named `server).
142 #### Advanced use
144 ##### Launching additional services
146 `managed_session()` launches the Checkpoint and Summary services (threads).
147 If you need more services to run you can simply launch them in the block
148 controlled by `managed_session()`.
150 Example: Start a thread to print losses. We want this thread to run
151 every 60 seconds, so we launch it with `sv.loop()`.
153 ```python
154 ...
155 sv = Supervisor(logdir='/tmp/mydir')
156 with sv.managed_session(FLAGS.master) as sess:
157 sv.loop(60, print_loss, (sess, ))
158 while not sv.should_stop():
159 sess.run(my_train_op)
160 ```
162 ##### Launching fewer services
164 `managed_session()` launches the "summary" and "checkpoint" threads which use
165 either the optionally `summary_op` and `saver` passed to the constructor, or
166 default ones created automatically by the supervisor. If you want to run
167 your own summary and checkpointing logic, disable these services by passing
168 `None` to the `summary_op` and `saver` parameters.
170 Example: Create summaries manually every 100 steps in the chief.
172 ```python
173 # Create a Supervisor with no automatic summaries.
174 sv = Supervisor(logdir='/tmp/mydir', is_chief=is_chief, summary_op=None)
175 # As summary_op was None, managed_session() does not start the
176 # summary thread.
177 with sv.managed_session(FLAGS.master) as sess:
178 for step in range(1000000):
179 if sv.should_stop():
180 break
181 if is_chief and step % 100 == 0:
182 # Create the summary every 100 chief steps.
183 sv.summary_computed(sess, sess.run(my_summary_op))
184 else:
185 # Train normally
186 sess.run(my_train_op)
187 ```
189 ##### Custom model initialization
191 `managed_session()` only supports initializing the model by running an
192 `init_op` or restoring from the latest checkpoint. If you have special
193 initialization needs, see how to specify a `local_init_op` when creating the
194 supervisor. You can also use the `SessionManager` directly to create a
195 session and check if it could be initialized automatically.
196 """
198 # Value to pass for the 'ready_op', 'init_op', 'summary_op', 'saver',
199 # and 'global_step' parameters of Supervisor.__init__() to indicate that
200 # the default behavior should be used.
201 USE_DEFAULT = 0
203 @deprecation.deprecated(None,
204 "Please switch to tf.train.MonitoredTrainingSession")
205 def __init__(self,
206 graph=None,
207 ready_op=USE_DEFAULT,
208 ready_for_local_init_op=USE_DEFAULT,
209 is_chief=True,
210 init_op=USE_DEFAULT,
211 init_feed_dict=None,
212 local_init_op=USE_DEFAULT,
213 logdir=None,
214 summary_op=USE_DEFAULT,
215 saver=USE_DEFAULT,
216 global_step=USE_DEFAULT,
217 save_summaries_secs=120,
218 save_model_secs=600,
219 recovery_wait_secs=30,
220 stop_grace_secs=120,
221 checkpoint_basename="model.ckpt",
222 session_manager=None,
223 summary_writer=USE_DEFAULT,
224 init_fn=None,
225 local_init_run_options=None):
226 """Create a `Supervisor`.
228 Args:
229 graph: A `Graph`. The graph that the model will use. Defaults to the
230 default `Graph`. The supervisor may add operations to the graph before
231 creating a session, but the graph should not be modified by the caller
232 after passing it to the supervisor.
233 ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in
234 `prepare_or_wait_for_session()` to check if the model is ready to use.
235 The model is considered ready if it returns an empty array. Defaults to
236 the tensor returned from `tf.compat.v1.report_uninitialized_variables()`
237 If `None`, the model is not checked for readiness.
238 ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by
239 supervisors in `prepare_or_wait_for_session()` to check if the model is
240 ready to run the local_init_op. The model is considered ready if it
241 returns an empty array. Defaults to `None`. If `None`, the model is not
242 checked for readiness before running local_init_op.
243 is_chief: If True, create a chief supervisor in charge of initializing and
244 restoring the model. If False, create a supervisor that relies on a
245 chief supervisor for inits and restore.
246 init_op: `Operation`. Used by chief supervisors to initialize the model
247 when it can not be recovered. Defaults to an `Operation` that
248 initializes all global variables. If `None`, no initialization is done
249 automatically unless you pass a value for `init_fn`, see below.
250 init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
251 This feed dictionary will be used when `init_op` is evaluated.
252 local_init_op: `Operation`. Used by all supervisors to run initializations
253 that should run for every new supervisor instance. By default these are
254 table initializers and initializers for local variables. If `None`, no
255 further per supervisor-instance initialization is done automatically.
256 logdir: A string. Optional path to a directory where to checkpoint the
257 model and log events for the visualizer. Used by chief supervisors. The
258 directory will be created if it does not exist.
259 summary_op: An `Operation` that returns a Summary for the event logs. Used
260 by chief supervisors if a `logdir` was specified. Defaults to the
261 operation returned from summary.merge_all(). If `None`, summaries are
262 not computed automatically.
263 saver: A Saver object. Used by chief supervisors if a `logdir` was
264 specified. Defaults to the saved returned by Saver(). If `None`, the
265 model is not saved automatically.
266 global_step: An integer Tensor of size 1 that counts steps. The value
267 from 'global_step' is used in summaries and checkpoint filenames.
268 Default to the op named 'global_step' in the graph if it exists, is of
269 rank 1, size 1, and of type tf.int32 or tf.int64. If `None` the global
270 step is not recorded in summaries and checkpoint files. Used by chief
271 supervisors if a `logdir` was specified.
272 save_summaries_secs: Number of seconds between the computation of
273 summaries for the event log. Defaults to 120 seconds. Pass 0 to
274 disable summaries.
275 save_model_secs: Number of seconds between the creation of model
276 checkpoints. Defaults to 600 seconds. Pass 0 to disable checkpoints.
277 recovery_wait_secs: Number of seconds between checks that the model is
278 ready. Used by supervisors when waiting for a chief supervisor to
279 initialize or restore the model. Defaults to 30 seconds.
280 stop_grace_secs: Grace period, in seconds, given to running threads to
281 stop when `stop()` is called. Defaults to 120 seconds.
282 checkpoint_basename: The basename for checkpoint saving.
283 session_manager: `SessionManager`, which manages Session creation and
284 recovery. If it is `None`, a default `SessionManager` will be created
285 with the set of arguments passed in for backwards compatibility.
286 summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None` to
287 indicate that no summaries should be written.
288 init_fn: Optional callable used to initialize the model. Called after the
289 optional `init_op` is called. The callable must accept one argument,
290 the session being initialized.
291 local_init_run_options: RunOptions to be passed as the SessionManager
292 local_init_run_options parameter.
294 Returns:
295 A `Supervisor`.
297 Raises:
298 RuntimeError: If called with eager execution enabled.
300 @compatibility(eager)
301 `Supervisor`s are not supported when eager execution is enabled.
302 @end_compatibility
303 """
304 if context.executing_eagerly():
305 raise RuntimeError("Supervisors are incompatible with eager execution.")
306 # Set default values of arguments.
307 if graph is None:
308 graph = ops.get_default_graph()
309 with graph.as_default():
310 self._init_ready_op(
311 ready_op=ready_op, ready_for_local_init_op=ready_for_local_init_op)
312 self._init_init_op(init_op=init_op, init_feed_dict=init_feed_dict)
313 self._init_local_init_op(local_init_op=local_init_op)
314 self._init_saver(saver=saver)
315 self._init_summary_op(summary_op=summary_op)
316 self._init_global_step(global_step=global_step)
317 self._graph = graph
318 self._meta_graph_def = meta_graph.create_meta_graph_def(
319 graph_def=graph.as_graph_def(add_shapes=True),
320 saver_def=self._saver.saver_def if self._saver else None)
321 self._is_chief = is_chief
322 self._coord = coordinator.Coordinator()
323 self._recovery_wait_secs = recovery_wait_secs
324 self._stop_grace_secs = stop_grace_secs
325 self._init_fn = init_fn
326 self._local_init_run_options = local_init_run_options
328 # Set all attributes related to checkpointing and writing events to None.
329 # Afterwards, set them appropriately for chief supervisors, as these are
330 # the only supervisors that can write checkpoints and events.
331 self._logdir = None
332 self._save_summaries_secs = None
333 self._save_model_secs = None
334 self._save_path = None
335 self._summary_writer = None
337 if self._is_chief:
338 self._logdir = logdir
339 self._save_summaries_secs = save_summaries_secs
340 self._save_model_secs = save_model_secs
341 if self._logdir:
342 self._save_path = os.path.join(self._logdir, checkpoint_basename)
343 if summary_writer is Supervisor.USE_DEFAULT:
344 if self._logdir:
345 self._summary_writer = _summary.FileWriter(self._logdir)
346 else:
347 self._summary_writer = summary_writer
348 self._graph_added_to_summary = False
350 self._init_session_manager(session_manager=session_manager)
351 self._verify_setup()
352 # The graph is not allowed to change anymore.
353 graph.finalize()
355 def _init_session_manager(self, session_manager=None):
356 if session_manager is None:
357 self._session_manager = session_manager_mod.SessionManager(
358 local_init_op=self._local_init_op,
359 ready_op=self._ready_op,
360 ready_for_local_init_op=self._ready_for_local_init_op,
361 graph=self._graph,
362 recovery_wait_secs=self._recovery_wait_secs,
363 local_init_run_options=self._local_init_run_options)
364 else:
365 self._session_manager = session_manager
367 def _get_first_op_from_collection(self, key):
368 """Returns the first `Operation` from a collection.
370 Args:
371 key: A string collection key.
373 Returns:
374 The first Op found in a collection, or `None` if the collection is empty.
375 """
376 try:
377 op_list = ops.get_collection(key)
378 if len(op_list) > 1:
379 logging.info("Found %d %s operations. Returning the first one.",
380 len(op_list), key)
381 if op_list:
382 return op_list[0]
383 except LookupError:
384 pass
386 return None
388 def _init_ready_op(self,
389 ready_op=USE_DEFAULT,
390 ready_for_local_init_op=USE_DEFAULT):
391 """Initializes ready_op.
393 Args:
394 ready_op: `Tensor` to check if the model is initialized. If it's set to
395 USE_DEFAULT, creates an op that checks all the variables are
396 initialized.
397 ready_for_local_init_op: `Tensor` to check if the model is ready to run
398 local_init_op. If it's set to USE_DEFAULT, creates an op that checks all
399 the global variables are initialized.
400 """
401 if ready_op is Supervisor.USE_DEFAULT:
402 ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP)
403 if ready_op is None:
404 ready_op = variables.report_uninitialized_variables()
405 ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op)
406 self._ready_op = ready_op
408 # ready_for_local_init_op defaults to None for backward compatibility
409 if ready_for_local_init_op is Supervisor.USE_DEFAULT:
410 ready_for_local_init_op = self._get_first_op_from_collection(
411 ops.GraphKeys.READY_FOR_LOCAL_INIT_OP)
412 self._ready_for_local_init_op = ready_for_local_init_op
414 def _init_init_op(self, init_op=USE_DEFAULT, init_feed_dict=None):
415 """Initializes init_op.
417 Args:
418 init_op: `Operation` to initialize the variables. If set to USE_DEFAULT,
419 create an op that initializes all variables and tables.
420 init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
421 This feed dictionary will be used when `init_op` is evaluated.
422 """
423 if init_op is Supervisor.USE_DEFAULT:
424 init_op = self._get_first_op_from_collection(ops.GraphKeys.INIT_OP)
425 if init_op is None:
426 init_op = variables.global_variables_initializer()
427 ops.add_to_collection(ops.GraphKeys.INIT_OP, init_op)
428 self._init_op = init_op
429 self._init_feed_dict = init_feed_dict
431 def _init_local_init_op(self, local_init_op=USE_DEFAULT):
432 """Initializes local_init_op.
434 Args:
435 local_init_op: `Operation` run for every new supervisor instance. If set
436 to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
437 collection. If the collection is empty, create an op that initializes
438 all local variables and all tables.
439 """
440 if local_init_op is Supervisor.USE_DEFAULT:
441 local_init_op = self._get_first_op_from_collection(
442 ops.GraphKeys.LOCAL_INIT_OP)
443 if local_init_op is None:
444 op_list = [
445 variables.local_variables_initializer(),
446 lookup_ops.tables_initializer()
447 ]
448 if op_list:
449 local_init_op = control_flow_ops.group(*op_list)
450 ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
451 self._local_init_op = local_init_op
453 def _init_saver(self, saver=USE_DEFAULT):
454 """Initializes saver.
456 Args:
457 saver: A `Saver` object. If set to USE_DEFAULT, create one that saves all
458 the variables.
459 """
460 if saver is Supervisor.USE_DEFAULT:
461 saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS)
462 if saver is None and variables.global_variables():
463 saver = saver_mod.Saver()
464 ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
465 self._saver = saver
467 def _init_summary_op(self, summary_op=USE_DEFAULT):
468 """Initializes summary_op.
470 Args:
471 summary_op: An Operation that returns a Summary for the event logs. If set
472 to USE_DEFAULT, create an op that merges all the summaries.
473 """
474 if summary_op is Supervisor.USE_DEFAULT:
475 summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP)
476 if summary_op is None:
477 summary_op = _summary.merge_all()
478 if summary_op is not None:
479 ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op)
480 self._summary_op = summary_op
482 def _init_global_step(self, global_step=USE_DEFAULT):
483 """Initializes global_step.
485 Args:
486 global_step: An integer Tensor of size 1 that counts steps. If set to
487 USE_DEFAULT, creates global_step tensor.
488 """
489 if global_step is Supervisor.USE_DEFAULT:
490 global_step = self._get_first_op_from_collection(
491 ops.GraphKeys.GLOBAL_STEP)
492 if global_step is None:
493 global_step = self._default_global_step_tensor()
494 if global_step is not None:
495 ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, global_step)
496 self._global_step = global_step
498 @property
499 def is_chief(self):
500 """Return True if this is a chief supervisor.
502 Returns:
503 A bool.
504 """
505 return self._is_chief
507 @property
508 def session_manager(self):
509 """Return the SessionManager used by the Supervisor.
511 Returns:
512 A SessionManager object.
513 """
514 return self._session_manager
516 @property
517 def coord(self):
518 """Return the Coordinator used by the Supervisor.
520 The Coordinator can be useful if you want to run multiple threads
521 during your training.
523 Returns:
524 A Coordinator object.
525 """
526 return self._coord
528 @property
529 def init_op(self):
530 """Return the Init Op used by the supervisor.
532 Returns:
533 An Op or `None`.
534 """
535 return self._init_op
537 @property
538 def init_feed_dict(self):
539 """Return the feed dictionary used when evaluating the `init_op`.
541 Returns:
542 A feed dictionary or `None`.
543 """
544 return self._init_feed_dict
546 @property
547 def ready_op(self):
548 """Return the Ready Op used by the supervisor.
550 Returns:
551 An Op or `None`.
552 """
553 return self._ready_op
555 @property
556 def ready_for_local_init_op(self):
557 return self._ready_for_local_init_op
559 @property
560 def summary_writer(self):
561 """Return the SummaryWriter used by the chief supervisor.
563 Returns:
564 A SummaryWriter.
565 """
566 return self._summary_writer
568 @property
569 def summary_op(self):
570 """Return the Summary Tensor used by the chief supervisor.
572 Returns:
573 A string Tensor for the summary or `None`.
574 """
575 return self._summary_op
577 @property
578 def save_summaries_secs(self):
579 """Return the delay between summary computations.
581 Returns:
582 A timestamp.
583 """
584 return self._save_summaries_secs
586 @property
587 def global_step(self):
588 """Return the global_step Tensor used by the supervisor.
590 Returns:
591 An integer Tensor for the global_step.
592 """
593 return self._global_step
595 @property
596 def saver(self):
597 """Return the Saver used by the supervisor.
599 Returns:
600 A Saver object.
601 """
602 return self._saver
604 @property
605 def save_model_secs(self):
606 """Return the delay between checkpoints.
608 Returns:
609 A timestamp.
610 """
611 return self._save_model_secs
613 @property
614 def save_path(self):
615 """Return the save path used by the supervisor.
617 Returns:
618 A string.
619 """
620 return self._save_path
622 def _write_graph(self):
623 """Writes graph_def to `logdir` and adds it to summary if applicable."""
624 assert self._is_chief
625 if self._logdir:
626 training_util.write_graph(
627 self._graph.as_graph_def(add_shapes=True), self._logdir,
628 "graph.pbtxt")
629 if self._summary_writer and not self._graph_added_to_summary:
630 self._summary_writer.add_graph(self._graph)
631 self._summary_writer.add_meta_graph(self._meta_graph_def)
632 self._graph_added_to_summary = True
634 def start_standard_services(self, sess):
635 """Start the standard services for 'sess'.
637 This starts services in the background. The services started depend
638 on the parameters to the constructor and may include:
640 - A Summary thread computing summaries every save_summaries_secs.
641 - A Checkpoint thread saving the model every save_model_secs.
642 - A StepCounter thread measure step time.
644 Args:
645 sess: A Session.
647 Returns:
648 A list of threads that are running the standard services. You can use
649 the Supervisor's Coordinator to join these threads with:
650 sv.coord.Join(<list of threads>)
652 Raises:
653 RuntimeError: If called with a non-chief Supervisor.
654 ValueError: If not `logdir` was passed to the constructor as the
655 services need a log directory.
656 """
657 if not self._is_chief:
658 raise RuntimeError("Only chief supervisor can start standard services. "
659 "Because only chief supervisors can write events.")
661 if not self._logdir:
662 logging.warning("Standard services need a 'logdir' "
663 "passed to the SessionManager")
664 return
666 if self._global_step is not None and self._summary_writer:
667 # Only add the session log if we keep track of global step.
668 # TensorBoard cannot use START message for purging expired events
669 # if there is no step value.
670 current_step = training_util.global_step(sess, self._global_step)
671 self._summary_writer.add_session_log(
672 SessionLog(status=SessionLog.START), current_step)
674 threads = []
675 if self._save_summaries_secs and self._summary_writer:
676 if self._summary_op is not None:
677 threads.append(SVSummaryThread(self, sess))
678 if self._global_step is not None:
679 threads.append(SVStepCounterThread(self, sess))
680 if self.saver and self._save_model_secs:
681 threads.append(SVTimerCheckpointThread(self, sess))
682 for t in threads:
683 t.start()
684 return threads
686 def prepare_or_wait_for_session(self,
687 master="",
688 config=None,
689 wait_for_checkpoint=False,
690 max_wait_secs=7200,
691 start_standard_services=True):
692 """Make sure the model is ready to be used.
694 Create a session on 'master', recovering or initializing the model as
695 needed, or wait for a session to be ready. If running as the chief
696 and `start_standard_service` is set to True, also call the session
697 manager to start the standard services.
699 Args:
700 master: name of the TensorFlow master to use. See the
701 `tf.compat.v1.Session` constructor for how this is interpreted.
702 config: Optional ConfigProto proto used to configure the session, which is
703 passed as-is to create the session.
704 wait_for_checkpoint: Whether we should wait for the availability of a
705 checkpoint before creating Session. Defaults to False.
706 max_wait_secs: Maximum time to wait for the session to become available.
707 start_standard_services: Whether to start the standard services and the
708 queue runners.
710 Returns:
711 A Session object that can be used to drive the model.
712 """
713 # For users who recreate the session with prepare_or_wait_for_session(), we
714 # need to clear the coordinator's stop_event so that threads managed by the
715 # coordinator can run.
716 self._coord.clear_stop()
717 if self._summary_writer:
718 self._summary_writer.reopen()
720 if self._is_chief:
721 sess = self._session_manager.prepare_session(
722 master,
723 init_op=self.init_op,
724 saver=self.saver,
725 checkpoint_dir=self._logdir,
726 wait_for_checkpoint=wait_for_checkpoint,
727 max_wait_secs=max_wait_secs,
728 config=config,
729 init_feed_dict=self._init_feed_dict,
730 init_fn=self._init_fn)
731 self._write_graph()
732 if start_standard_services:
733 logging.info("Starting standard services.")
734 self.start_standard_services(sess)
735 else:
736 sess = self._session_manager.wait_for_session(
737 master, config=config, max_wait_secs=max_wait_secs)
738 if start_standard_services:
739 logging.info("Starting queue runners.")
740 self.start_queue_runners(sess)
741 return sess
743 def start_queue_runners(self, sess, queue_runners=None):
744 """Start threads for `QueueRunners`.
746 Note that the queue runners collected in the graph key `QUEUE_RUNNERS`
747 are already started automatically when you create a session with the
748 supervisor, so unless you have non-collected queue runners to start
749 you do not need to call this explicitly.
751 Args:
752 sess: A `Session`.
753 queue_runners: A list of `QueueRunners`. If not specified, we'll use the
754 list of queue runners gathered in the graph under the key
755 `GraphKeys.QUEUE_RUNNERS`.
757 Returns:
758 The list of threads started for the `QueueRunners`.
760 Raises:
761 RuntimeError: If called with eager execution enabled.
763 @compatibility(eager)
764 Queues are not compatible with eager execution. To ingest data when eager
765 execution is enabled, use the `tf.data` API.
766 @end_compatibility
767 """
768 if context.executing_eagerly():
769 raise RuntimeError("Queues are not compatible with eager execution.")
770 if queue_runners is None:
771 queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
772 threads = []
773 for qr in queue_runners:
774 threads.extend(
775 qr.create_threads(sess, coord=self._coord, daemon=True, start=True))
776 return threads
778 def loop(self, timer_interval_secs, target, args=None, kwargs=None):
779 """Start a LooperThread that calls a function periodically.
781 If `timer_interval_secs` is None the thread calls `target(*args, **kwargs)`
782 repeatedly. Otherwise it calls it every `timer_interval_secs`
783 seconds. The thread terminates when a stop is requested.
785 The started thread is added to the list of threads managed by the supervisor
786 so it does not need to be passed to the `stop()` method.
788 Args:
789 timer_interval_secs: Number. Time boundaries at which to call `target`.
790 target: A callable object.
791 args: Optional arguments to pass to `target` when calling it.
792 kwargs: Optional keyword arguments to pass to `target` when calling it.
794 Returns:
795 The started thread.
796 """
797 looper = coordinator.LooperThread(
798 self._coord,
799 timer_interval_secs,
800 target=target,
801 args=args,
802 kwargs=kwargs)
803 looper.start()
804 return looper
806 def stop(self,
807 threads=None,
808 close_summary_writer=True,
809 ignore_live_threads=False):
810 """Stop the services and the coordinator.
812 This does not close the session.
814 Args:
815 threads: Optional list of threads to join with the coordinator. If
816 `None`, defaults to the threads running the standard services, the
817 threads started for `QueueRunners`, and the threads started by the
818 `loop()` method. To wait on additional threads, pass the list in this
819 parameter.
820 close_summary_writer: Whether to close the `summary_writer`. Defaults to
821 `True` if the summary writer was created by the supervisor, `False`
822 otherwise.
823 ignore_live_threads: If `True` ignores threads that remain running after a
824 grace period when joining threads via the coordinator, instead of
825 raising a RuntimeError.
826 """
827 self._coord.request_stop()
828 try:
829 # coord.join() re-raises the first reported exception; the "finally"
830 # block ensures that we clean up whether or not an exception was
831 # reported.
832 self._coord.join(
833 threads,
834 stop_grace_period_secs=self._stop_grace_secs,
835 ignore_live_threads=ignore_live_threads)
836 finally:
837 # Close the writer last, in case one of the running threads was using it.
838 if close_summary_writer and self._summary_writer:
839 # Stop messages are not logged with event.step,
840 # since the session may have already terminated.
841 self._summary_writer.add_session_log(SessionLog(status=SessionLog.STOP))
842 self._summary_writer.close()
843 self._graph_added_to_summary = False
845 def request_stop(self, ex=None):
846 """Request that the coordinator stop the threads.
848 See `Coordinator.request_stop()`.
850 Args:
851 ex: Optional `Exception`, or Python `exc_info` tuple as returned by
852 `sys.exc_info()`. If this is the first call to `request_stop()` the
853 corresponding exception is recorded and re-raised from `join()`.
854 """
855 self._coord.request_stop(ex=ex)
857 def should_stop(self):
858 """Check if the coordinator was told to stop.
860 See `Coordinator.should_stop()`.
862 Returns:
863 True if the coordinator was told to stop, False otherwise.
864 """
865 return self._coord.should_stop()
867 def stop_on_exception(self):
868 """Context handler to stop the supervisor when an exception is raised.
870 See `Coordinator.stop_on_exception()`.
872 Returns:
873 A context handler.
874 """
875 return self._coord.stop_on_exception()
877 def wait_for_stop(self):
878 """Block waiting for the coordinator to stop."""
879 self._coord.wait_for_stop()
881 def summary_computed(self, sess, summary, global_step=None):
882 """Indicate that a summary was computed.
884 Args:
885 sess: A `Session` object.
886 summary: A Summary proto, or a string holding a serialized summary proto.
887 global_step: Int. global step this summary is associated with. If `None`,
888 it will try to fetch the current step.
890 Raises:
891 TypeError: if 'summary' is not a Summary proto or a string.
892 RuntimeError: if the Supervisor was created without a `logdir`.
893 """
894 if not self._summary_writer:
895 raise RuntimeError("Writing a summary requires a summary writer.")
896 if global_step is None and self.global_step is not None:
897 global_step = training_util.global_step(sess, self.global_step)
898 self._summary_writer.add_summary(summary, global_step)
900 def _default_global_step_tensor(self):
901 """Returns the global_step from the default graph.
903 Returns:
904 The global step `Tensor` or `None`.
905 """
906 try:
907 gs = ops.get_default_graph().get_tensor_by_name("global_step:0")
908 if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]:
909 return gs
910 else:
911 logging.warning("Found 'global_step' is not an int type: %s", gs.dtype)
912 return None
913 except KeyError:
914 return None
916 def _verify_setup(self):
917 """Check that all is good.
919 Raises:
920 ValueError: If something is not good.
921 """
922 # Not running as chief means that replicas are used.
923 # In that case all Variables must have their device set.
924 if not self._is_chief:
925 for op in self._graph.get_operations():
926 if op.type in ["Variable", "VariableV2"] and not op.device:
927 raise ValueError("When using replicas, all Variables must have "
928 "their device set: %s" % op)
930 # pylint: disable=g-doc-return-or-yield,broad-except
931 @contextlib.contextmanager
932 def managed_session(self,
933 master="",
934 config=None,
935 start_standard_services=True,
936 close_summary_writer=True):
937 """Returns a context manager for a managed session.
939 This context manager creates and automatically recovers a session. It
940 optionally starts the standard services that handle checkpoints and
941 summaries. It monitors exceptions raised from the `with` block or from the
942 services and stops the supervisor as needed.
944 The context manager is typically used as follows:
946 ```python
947 def train():
948 sv = tf.compat.v1.train.Supervisor(...)
949 with sv.managed_session(<master>) as sess:
950 for step in range(..):
951 if sv.should_stop():
952 break
953 sess.run(<my training op>)
954 ...do other things needed at each training step...
955 ```
957 An exception raised from the `with` block or one of the service threads is
958 raised again when the block exits. This is done after stopping all threads
959 and closing the session. For example, an `AbortedError` exception, raised
960 in case of preemption of one of the workers in a distributed model, is
961 raised again when the block exits.
963 If you want to retry the training loop in case of preemption you can do it
964 as follows:
966 ```python
967 def main(...):
968 while True
969 try:
970 train()
971 except tf.errors.Aborted:
972 pass
973 ```
975 As a special case, exceptions used for control flow, such as
976 `OutOfRangeError` which reports that input queues are exhausted, are not
977 raised again from the `with` block: they indicate a clean termination of
978 the training loop and are considered normal termination.
980 Args:
981 master: name of the TensorFlow master to use. See the
982 `tf.compat.v1.Session` constructor for how this is interpreted.
983 config: Optional `ConfigProto` proto used to configure the session. Passed
984 as-is to create the session.
985 start_standard_services: Whether to start the standard services, such as
986 checkpoint, summary and step counter.
987 close_summary_writer: Whether to close the summary writer when closing the
988 session. Defaults to True.
990 Returns:
991 A context manager that yields a `Session` restored from the latest
992 checkpoint or initialized from scratch if not checkpoint exists. The
993 session is closed when the `with` block exits.
994 """
995 try:
996 sess = self.prepare_or_wait_for_session(
997 master=master,
998 config=config,
999 start_standard_services=start_standard_services)
1000 yield sess
1001 except Exception as e:
1002 self.request_stop(e)
1003 finally:
1004 try:
1005 # Request all the threads to stop and wait for them to do so. Any
1006 # exception raised by the threads is raised again from stop().
1007 # Passing stop_grace_period_secs is for blocked enqueue/dequeue
1008 # threads which are not checking for `should_stop()`. They
1009 # will be stopped when we close the session further down.
1010 self.stop(close_summary_writer=close_summary_writer)
1011 finally:
1012 # Close the session to finish up all pending calls. We do not care
1013 # about exceptions raised when closing. This takes care of
1014 # blocked enqueue/dequeue calls.
1015 try:
1016 sess.close()
1017 except Exception:
1018 # Silently ignore exceptions raised by close().
1019 pass
1021 # pylint: enable=g-doc-return-or-yield,broad-except
1024class SVSummaryThread(coordinator.LooperThread):
1025 """A thread to save summaries on a timer."""
1027 def __init__(self, sv, sess):
1028 """Create a SVSummaryThread.
1030 Args:
1031 sv: A `Supervisor`.
1032 sess: A `Session`.
1033 """
1034 super(SVSummaryThread, self).__init__(sv.coord, sv.save_summaries_secs)
1035 self._sv = sv
1036 self._sess = sess
1038 def run_loop(self):
1039 if self._sv.global_step is not None:
1040 summary_strs, global_step = self._sess.run(
1041 [self._sv.summary_op, self._sv.global_step])
1042 else:
1043 summary_strs = self._sess.run(self._sv.summary_op)
1044 global_step = None
1045 if self._sv.summary_writer:
1046 logging.info("Recording summary at step %s.", global_step)
1047 self._sv.summary_writer.add_summary(summary_strs, global_step)
1050class SVStepCounterThread(coordinator.LooperThread):
1051 """Threads to count steps and measure their duration."""
1053 def __init__(self, sv, sess, step_counter=None):
1054 """Create a `SVStepCounterThread`.
1056 Args:
1057 sv: A `Supervisor`.
1058 sess: A `Session`.
1059 step_counter: A `Tensor` holding the step counter. By defaults, it uses
1060 sv.global_step.
1061 """
1062 super(SVStepCounterThread, self).__init__(sv.coord, sv.save_summaries_secs)
1063 self._sv = sv
1064 self._sess = sess
1065 self._last_time = 0.0
1066 self._last_step = 0
1067 step_counter = sv.global_step if step_counter is None else step_counter
1068 self._step_counter = step_counter
1069 self._summary_tag = "%s/sec" % self._step_counter.op.name
1071 def start_loop(self):
1072 self._last_time = time.time()
1073 self._last_step = training_util.global_step(self._sess, self._step_counter)
1075 def run_loop(self):
1076 # Count the steps.
1077 current_step = training_util.global_step(self._sess, self._step_counter)
1078 added_steps = current_step - self._last_step
1079 self._last_step = current_step
1080 # Measure the elapsed time.
1081 current_time = time.time()
1082 elapsed_time = current_time - self._last_time
1083 self._last_time = current_time
1084 # Reports the number of steps done per second
1085 if elapsed_time > 0.:
1086 steps_per_sec = added_steps / elapsed_time
1087 else:
1088 steps_per_sec = float("inf")
1089 summary = Summary(value=[
1090 Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
1091 ])
1092 if self._sv.summary_writer:
1093 self._sv.summary_writer.add_summary(summary, current_step)
1094 logging.log_first_n(logging.INFO, "%s: %g", 10, self._summary_tag,
1095 steps_per_sec)
1098class SVTimerCheckpointThread(coordinator.LooperThread):
1099 """A thread to checkpoint on a timer."""
1101 def __init__(self, sv, sess):
1102 """Create a `SVTimerCheckpointThread`.
1104 Args:
1105 sv: A `Supervisor`.
1106 sess: A `Session`.
1107 """
1108 super(SVTimerCheckpointThread, self).__init__(sv.coord, sv.save_model_secs)
1109 self._sv = sv
1110 self._sess = sess
1112 def run_loop(self):
1113 logging.info("Saving checkpoint to path %s", self._sv.save_path)
1114 self._sv.saver.save(
1115 self._sess, self._sv.save_path, global_step=self._sv.global_step)
1116 if self._sv.summary_writer and self._sv.global_step is not None:
1117 current_step = training_util.global_step(self._sess, self._sv.global_step)
1118 self._sv.summary_writer.add_session_log(
1119 SessionLog(
1120 status=SessionLog.CHECKPOINT, checkpoint_path=self._sv.save_path),
1121 current_step)
1124# TODO(sherrym): All non-PEP8 compliant names will be deprecated shortly.
1125setattr(Supervisor, "PrepareSession", Supervisor.prepare_or_wait_for_session)
1126setattr(Supervisor, "StartQueueRunners", Supervisor.start_queue_runners)
1127setattr(Supervisor, "StartStandardServices", Supervisor.start_standard_services)
1128setattr(Supervisor, "Stop", Supervisor.stop)
1129setattr(Supervisor, "RequestStop", Supervisor.request_stop)
1130setattr(Supervisor, "Loop", Supervisor.loop)
1131setattr(Supervisor, "ShouldStop", Supervisor.should_stop)
1132setattr(Supervisor, "StopOnException", Supervisor.stop_on_exception)
1133setattr(Supervisor, "WaitForStop", Supervisor.wait_for_stop)
1134setattr(Supervisor, "SummaryComputed", Supervisor.summary_computed)