Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/basic_session_run_hooks.py: 26%
487 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Some common SessionRunHook classes.
17Note that the symbols that are exported to v1 tf.train namespace are also
18exported to v2 in tf.estimator namespace. See
19https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py
20"""
22import os
23import time
25import numpy as np
27from tensorflow.core.framework.summary_pb2 import Summary
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.core.util.event_pb2 import SessionLog
30from tensorflow.python.client import timeline
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import errors
33from tensorflow.python.framework import meta_graph
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import init_ops
36from tensorflow.python.ops import variable_scope
37from tensorflow.python.platform import gfile
38from tensorflow.python.platform import tf_logging as logging
39from tensorflow.python.training import session_run_hook
40from tensorflow.python.training import training_util
41from tensorflow.python.training.session_run_hook import SessionRunArgs
42from tensorflow.python.training.summary_io import SummaryWriterCache
43from tensorflow.python.util.tf_export import tf_export
45_HOOKS = "hooks"
46_STEPS_PER_RUN_VAR = "steps_per_run"
49class _HookTimer:
50 """Base timer for determining when Hooks should trigger.
52 Should not be instantiated directly.
53 """
55 def __init__(self):
56 pass
58 def reset(self):
59 """Resets the timer."""
60 pass
62 def should_trigger_for_step(self, step):
63 """Return true if the timer should trigger for the specified step."""
64 raise NotImplementedError
66 def update_last_triggered_step(self, step):
67 """Update the last triggered time and step number.
69 Args:
70 step: The current step.
72 Returns:
73 A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number
74 of seconds between the current trigger and the last one (a float), and
75 `elapsed_steps` is the number of steps between the current trigger and
76 the last one. Both values will be set to `None` on the first trigger.
77 """
78 raise NotImplementedError
80 def last_triggered_step(self):
81 """Returns the last triggered time step or None if never triggered."""
82 raise NotImplementedError
85@tf_export(v1=["train.SecondOrStepTimer"])
86class SecondOrStepTimer(_HookTimer):
87 """Timer that triggers at most once every N seconds or once every N steps.
89 This symbol is also exported to v2 in tf.estimator namespace. See
90 https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py
91 """
93 def __init__(self, every_secs=None, every_steps=None):
94 self.reset()
95 self._every_secs = every_secs
96 self._every_steps = every_steps
98 if self._every_secs is None and self._every_steps is None:
99 raise ValueError("Either every_secs or every_steps should be provided.")
100 if (self._every_secs is not None) and (self._every_steps is not None):
101 raise ValueError("Can not provide both every_secs and every_steps.")
103 super(SecondOrStepTimer, self).__init__()
105 def reset(self):
106 self._last_triggered_step = None
107 self._last_triggered_time = None
109 def should_trigger_for_step(self, step):
110 """Return true if the timer should trigger for the specified step.
112 Args:
113 step: Training step to trigger on.
115 Returns:
116 True if the difference between the current time and the time of the last
117 trigger exceeds `every_secs`, or if the difference between the current
118 step and the last triggered step exceeds `every_steps`. False otherwise.
119 """
120 if self._last_triggered_step is None:
121 return True
123 if self._last_triggered_step == step:
124 return False
126 if self._every_secs is not None:
127 if time.time() >= self._last_triggered_time + self._every_secs:
128 return True
130 if self._every_steps is not None:
131 if step >= self._last_triggered_step + self._every_steps:
132 return True
134 return False
136 def update_last_triggered_step(self, step):
137 current_time = time.time()
138 if self._last_triggered_time is None:
139 elapsed_secs = None
140 elapsed_steps = None
141 else:
142 elapsed_secs = current_time - self._last_triggered_time
143 elapsed_steps = step - self._last_triggered_step
145 self._last_triggered_time = current_time
146 self._last_triggered_step = step
147 return (elapsed_secs, elapsed_steps)
149 def last_triggered_step(self):
150 return self._last_triggered_step
153class NeverTriggerTimer(_HookTimer):
154 """Timer that never triggers."""
156 def should_trigger_for_step(self, step):
157 _ = step
158 return False
160 def update_last_triggered_step(self, step):
161 _ = step
162 return (None, None)
164 def last_triggered_step(self):
165 return None
168@tf_export(v1=["train.LoggingTensorHook"])
169class LoggingTensorHook(session_run_hook.SessionRunHook):
170 """Prints the given tensors every N local steps, every N seconds, or at end.
172 The tensors will be printed to the log, with `INFO` severity. If you are not
173 seeing the logs, you might want to add the following line after your imports:
175 ```python
176 tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
177 ```
179 Note that if `at_end` is True, `tensors` should not include any tensor
180 whose evaluation produces a side effect such as consuming additional inputs.
182 @compatibility(TF2)
183 Please check this [notebook][notebook] on how to migrate the API to TF2.
185 [notebook]:https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb
187 @end_compatibility
189 """
191 def __init__(self,
192 tensors,
193 every_n_iter=None,
194 every_n_secs=None,
195 at_end=False,
196 formatter=None):
197 """Initializes a `LoggingTensorHook`.
199 Args:
200 tensors: `dict` that maps string-valued tags to tensors/tensor names, or
201 `iterable` of tensors/tensor names.
202 every_n_iter: `int`, print the values of `tensors` once every N local
203 steps taken on the current worker.
204 every_n_secs: `int` or `float`, print the values of `tensors` once every N
205 seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
206 provided.
207 at_end: `bool` specifying whether to print the values of `tensors` at the
208 end of the run.
209 formatter: function, takes dict of `tag`->`Tensor` and returns a string.
210 If `None` uses default printing all tensors.
212 Raises:
213 ValueError: if `every_n_iter` is non-positive.
214 """
215 only_log_at_end = (
216 at_end and (every_n_iter is None) and (every_n_secs is None))
217 if (not only_log_at_end and
218 (every_n_iter is None) == (every_n_secs is None)):
219 raise ValueError(
220 "either at_end and/or exactly one of every_n_iter and every_n_secs "
221 "must be provided.")
222 if every_n_iter is not None and every_n_iter <= 0:
223 raise ValueError("invalid every_n_iter=%s." % every_n_iter)
224 if not isinstance(tensors, dict):
225 self._tag_order = tensors
226 tensors = {item: item for item in tensors}
227 else:
228 self._tag_order = sorted(tensors.keys())
229 self._tensors = tensors
230 self._formatter = formatter
231 self._timer = (
232 NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer(
233 every_secs=every_n_secs, every_steps=every_n_iter))
234 self._log_at_end = at_end
236 def begin(self):
237 self._timer.reset()
238 self._iter_count = 0
239 # Convert names to tensors if given
240 self._current_tensors = {
241 tag: _as_graph_element(tensor)
242 for (tag, tensor) in self._tensors.items()
243 }
245 def before_run(self, run_context): # pylint: disable=unused-argument
246 self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
247 if self._should_trigger:
248 return SessionRunArgs(self._current_tensors)
249 else:
250 return None
252 def _log_tensors(self, tensor_values):
253 original = np.get_printoptions()
254 np.set_printoptions(suppress=True)
255 elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count)
256 if self._formatter:
257 logging.info(self._formatter(tensor_values))
258 else:
259 stats = []
260 for tag in self._tag_order:
261 stats.append("%s = %s" % (tag, tensor_values[tag]))
262 if elapsed_secs is not None:
263 logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs)
264 else:
265 logging.info("%s", ", ".join(stats))
266 np.set_printoptions(**original)
268 def after_run(self, run_context, run_values):
269 _ = run_context
270 if self._should_trigger:
271 self._log_tensors(run_values.results)
273 self._iter_count += 1
275 def end(self, session):
276 if self._log_at_end:
277 values = session.run(self._current_tensors)
278 self._log_tensors(values)
281def get_or_create_steps_per_run_variable():
282 """Gets or creates the steps_per_run variable.
284 In Estimator, the user provided computation, the model_fn, is wrapped
285 inside a tf.while_loop for peak performance. The iterations of the loop are
286 specified by this variable, which adjusts its value on the CPU after each
287 device program execution and before the next execution.
289 The purpose of using a variable, rather than a constant, is to allow
290 Estimator adapt the device training iterations according to the final steps
291 specified by users. For example, if the user sets the steps_per_run as
292 4 and steps as 10 in Estimator.train(), the steps_per_run
293 variable will have the following value before each training run.
295 - 1-st execution: steps_per_run = 4
296 - 2-nd execution: steps_per_run = 4
297 - 3-rd execution: steps_per_run = 2
299 As model_fn increases the global step once per train_op invocation, the global
300 step is 10 after all executions, matching the steps=10 inputs passed in by
301 users.
303 Returns:
304 A TF non-trainable resource variable.
306 Raises:
307 RuntimeError: If multi steps_per_run variables were found.
308 """
309 graph = ops.get_default_graph()
310 collection_name = "{}_{}".format(_HOOKS, _STEPS_PER_RUN_VAR)
311 steps_per_run_vars = graph.get_collection(collection_name)
312 if len(steps_per_run_vars) == 1:
313 return steps_per_run_vars[0]
314 elif len(steps_per_run_vars) > 1:
315 raise RuntimeError("Multiple steps_per_run_var in collection.")
317 with variable_scope.variable_scope(_HOOKS, reuse=variable_scope.AUTO_REUSE):
318 return variable_scope.get_variable(
319 _STEPS_PER_RUN_VAR,
320 initializer=init_ops.ones_initializer(),
321 shape=[],
322 dtype=dtypes.int32,
323 trainable=False,
324 collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES],
325 use_resource=True)
328class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
329 """Hook that requests stop at a specified step."""
331 def __init__(self, num_steps=None, last_step=None, steps_per_run=1):
332 """Initializes a `MultiStepStopAtStepHook`.
334 This hook requests stop after either a number of steps have been
335 executed or a last step has been reached. Only one of the two options can be
336 specified.
338 if `num_steps` is specified, it indicates the number of steps to execute
339 after `begin()` is called. If instead `last_step` is specified, it
340 indicates the last step we want to execute, as passed to the `after_run()`
341 call.
343 In Estimator, the user provided computation, the model_fn, is wrapped
344 inside a tf.while_loop for peak performance. The steps_per_run variable
345 determines the number of iterations of the loop before returning to the CPU.
347 Args:
348 num_steps: Number of steps to execute.
349 last_step: Step after which to stop.
350 steps_per_run: Number of steps executed per run call.
352 Raises:
353 ValueError: If one of the arguments is invalid.
354 """
355 if num_steps is None and last_step is None:
356 raise ValueError("One of num_steps or last_step must be specified.")
357 if num_steps is not None and last_step is not None:
358 raise ValueError("Only one of num_steps or last_step can be specified.")
359 if steps_per_run is None or steps_per_run < 1:
360 raise ValueError("steps_per_run should be greater than 0")
361 self._num_steps = num_steps
362 self._last_step = last_step
363 self._steps_per_run_initial_value = steps_per_run
365 def begin(self):
366 self._global_step_tensor = training_util.get_global_step()
367 if self._global_step_tensor is None:
368 raise RuntimeError("Global step should be created to use StopAtStepHook.")
369 self._steps_per_run_variable = get_or_create_steps_per_run_variable()
371 def _update_steps_per_run_variable(self, global_step, session):
372 steps = min(self._last_step - global_step,
373 self._steps_per_run_initial_value)
374 self._steps_per_run_variable.load(steps, session=session)
376 def after_create_session(self, session, coord):
377 global_step = session.run(self._global_step_tensor)
378 if self._last_step is None:
379 self._last_step = global_step + self._num_steps
380 self._update_steps_per_run_variable(global_step, session)
382 def after_run(self, run_context, run_values):
383 # Global step cannot be retrieved via SessionRunArgs and before_run due to
384 # race condition in hook execution.
385 global_step = run_context.session.run(self._global_step_tensor)
386 if global_step >= self._last_step:
387 run_context.request_stop()
388 else:
389 self._update_steps_per_run_variable(global_step, run_context.session)
392@tf_export(v1=["train.StopAtStepHook"])
393class StopAtStepHook(session_run_hook.SessionRunHook):
394 """Hook that requests stop at a specified step.
396 @compatibility(TF2)
397 Please check this [notebook][notebook] on how to migrate the API to TF2.
399 [notebook]:https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb
401 @end_compatibility
402 """
404 def __init__(self, num_steps=None, last_step=None):
405 """Initializes a `StopAtStepHook`.
407 This hook requests stop after either a number of steps have been
408 executed or a last step has been reached. Only one of the two options can be
409 specified.
411 if `num_steps` is specified, it indicates the number of steps to execute
412 after `begin()` is called. If instead `last_step` is specified, it
413 indicates the last step we want to execute, as passed to the `after_run()`
414 call.
416 Args:
417 num_steps: Number of steps to execute.
418 last_step: Step after which to stop.
420 Raises:
421 ValueError: If one of the arguments is invalid.
422 """
423 if num_steps is None and last_step is None:
424 raise ValueError("One of num_steps or last_step must be specified.")
425 if num_steps is not None and last_step is not None:
426 raise ValueError("Only one of num_steps or last_step can be specified.")
427 self._num_steps = num_steps
428 self._last_step = last_step
430 def begin(self):
431 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
432 if self._global_step_tensor is None:
433 raise RuntimeError("Global step should be created to use StopAtStepHook.")
435 def after_create_session(self, session, coord):
436 if self._last_step is None:
437 global_step = session.run(self._global_step_tensor)
438 self._last_step = global_step + self._num_steps
440 def before_run(self, run_context): # pylint: disable=unused-argument
441 return SessionRunArgs(self._global_step_tensor)
443 def after_run(self, run_context, run_values):
444 global_step = run_values.results + 1
445 if global_step >= self._last_step:
446 # Check latest global step to ensure that the targeted last step is
447 # reached. global_step read tensor is the value of global step
448 # before running the operation. We're not sure whether current session.run
449 # incremented the global_step or not. Here we're checking it.
451 step = run_context.session.run(self._global_step_tensor)
452 if step >= self._last_step:
453 run_context.request_stop()
456@tf_export(v1=["train.CheckpointSaverListener"])
457class CheckpointSaverListener:
458 """Interface for listeners that take action before or after checkpoint save.
460 `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is
461 triggered, and provides callbacks at the following points:
462 - before using the session
463 - before each call to `Saver.save()`
464 - after each call to `Saver.save()`
465 - at the end of session
467 To use a listener, implement a class and pass the listener to a
468 `CheckpointSaverHook`, as in this example:
470 ```python
471 class ExampleCheckpointSaverListener(CheckpointSaverListener):
472 def begin(self):
473 # You can add ops to the graph here.
474 print('Starting the session.')
475 self.your_tensor = ...
477 def before_save(self, session, global_step_value):
478 print('About to write a checkpoint')
480 def after_save(self, session, global_step_value):
481 print('Done writing checkpoint.')
482 if decided_to_stop_training():
483 return True
485 def end(self, session, global_step_value):
486 print('Done with the session.')
488 ...
489 listener = ExampleCheckpointSaverListener()
490 saver_hook = tf.estimator.CheckpointSaverHook(
491 checkpoint_dir, listeners=[listener])
492 with
493 tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
494 ...
495 ```
497 A `CheckpointSaverListener` may simply take some action after every
498 checkpoint save. It is also possible for the listener to use its own schedule
499 to act less frequently, e.g. based on global_step_value. In this case,
500 implementors should implement the `end()` method to handle actions related to
501 the last checkpoint save. But the listener should not act twice if
502 `after_save()` already handled this last checkpoint save.
504 A `CheckpointSaverListener` can request training to be stopped, by returning
505 True in `after_save`. Please note that, in replicated distributed training
506 setting, only `chief` should use this behavior. Otherwise each worker will do
507 their own evaluation, which may be wasteful of resources.
508 """
510 def begin(self):
511 pass
513 def before_save(self, session, global_step_value):
514 pass
516 def after_save(self, session, global_step_value):
517 pass
519 def end(self, session, global_step_value):
520 pass
523@tf_export(v1=["train.CheckpointSaverHook"])
524class CheckpointSaverHook(session_run_hook.SessionRunHook):
525 """Saves checkpoints every N steps or seconds."""
527 def __init__(self,
528 checkpoint_dir,
529 save_secs=None,
530 save_steps=None,
531 saver=None,
532 checkpoint_basename="model.ckpt",
533 scaffold=None,
534 listeners=None,
535 save_graph_def=True):
536 """Initializes a `CheckpointSaverHook`.
538 Args:
539 checkpoint_dir: `str`, base directory for the checkpoint files.
540 save_secs: `int`, save every N secs.
541 save_steps: `int`, save every N steps.
542 saver: `Saver` object, used for saving.
543 checkpoint_basename: `str`, base name for the checkpoint files.
544 scaffold: `Scaffold`, use to get saver object.
545 listeners: List of `CheckpointSaverListener` subclass instances. Used for
546 callbacks that run immediately before or after this hook saves the
547 checkpoint.
548 save_graph_def: Whether to save the GraphDef and MetaGraphDef to
549 `checkpoint_dir`. The GraphDef is saved after the session is created as
550 `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as
551 `model.ckpt-*.meta`.
553 Raises:
554 ValueError: One of `save_steps` or `save_secs` should be set.
555 ValueError: At most one of `saver` or `scaffold` should be set.
556 """
557 logging.info("Create CheckpointSaverHook.")
558 if saver is not None and scaffold is not None:
559 raise ValueError("You cannot provide both saver and scaffold.")
560 self._saver = saver
561 self._checkpoint_dir = checkpoint_dir
562 self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
563 self._scaffold = scaffold
564 self._timer = SecondOrStepTimer(
565 every_secs=save_secs, every_steps=save_steps)
566 self._listeners = listeners or []
567 # Set sufficiently high default that it never skips checking the actual
568 # global step counter -- unless the user overrides it with the right value
569 # for the steps_per_run.
570 self._steps_per_run = 1000000
571 self._save_graph_def = save_graph_def
573 def _set_steps_per_run(self, steps_per_run):
574 self._steps_per_run = steps_per_run
576 def begin(self):
577 self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
578 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
579 if self._global_step_tensor is None:
580 raise RuntimeError(
581 "Global step should be created to use CheckpointSaverHook.")
582 for l in self._listeners:
583 l.begin()
585 def after_create_session(self, session, coord):
586 global_step = session.run(self._global_step_tensor)
587 if self._save_graph_def:
588 # We do write graph and saver_def at the first call of before_run.
589 # We cannot do this in begin, since we let other hooks to change graph and
590 # add variables in begin. Graph is finalized after all begin calls.
591 training_util.write_graph(
592 ops.get_default_graph().as_graph_def(add_shapes=True),
593 self._checkpoint_dir, "graph.pbtxt")
594 saver_def = self._get_saver().saver_def if self._get_saver() else None
595 graph = ops.get_default_graph()
596 meta_graph_def = meta_graph.create_meta_graph_def(
597 graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
598 self._summary_writer.add_graph(graph)
599 self._summary_writer.add_meta_graph(meta_graph_def)
600 # The checkpoint saved here is the state at step "global_step".
601 self._save(session, global_step)
602 self._timer.update_last_triggered_step(global_step)
604 def before_run(self, run_context): # pylint: disable=unused-argument
605 return SessionRunArgs(self._global_step_tensor)
607 def after_run(self, run_context, run_values):
608 stale_global_step = run_values.results
609 if self._timer.should_trigger_for_step(stale_global_step +
610 self._steps_per_run):
611 # get the real value after train op.
612 global_step = run_context.session.run(self._global_step_tensor)
613 if self._timer.should_trigger_for_step(global_step):
614 self._timer.update_last_triggered_step(global_step)
615 if self._save(run_context.session, global_step):
616 run_context.request_stop()
618 def end(self, session):
619 last_step = session.run(self._global_step_tensor)
620 if last_step != self._timer.last_triggered_step():
621 self._save(session, last_step)
622 for l in self._listeners:
623 l.end(session, last_step)
625 def _save(self, session, step):
626 """Saves the latest checkpoint, returns should_stop."""
627 logging.info("Calling checkpoint listeners before saving checkpoint %d...",
628 step)
629 for l in self._listeners:
630 l.before_save(session, step)
632 logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
633 self._get_saver().save(session, self._save_path, global_step=step,
634 write_meta_graph=self._save_graph_def)
635 self._summary_writer.add_session_log(
636 SessionLog(
637 status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
638 step)
639 logging.info("Calling checkpoint listeners after saving checkpoint %d...",
640 step)
641 should_stop = False
642 for l in self._listeners:
643 if l.after_save(session, step):
644 logging.info(
645 "A CheckpointSaverListener requested that training be stopped. "
646 "listener: {}".format(l))
647 should_stop = True
648 return should_stop
650 def _get_saver(self):
651 if self._saver is not None:
652 return self._saver
653 elif self._scaffold is not None:
654 return self._scaffold.saver
656 # Get saver from the SAVERS collection if present.
657 collection_key = ops.GraphKeys.SAVERS
658 savers = ops.get_collection(collection_key)
659 if not savers:
660 raise RuntimeError(
661 "No items in collection {}. Please add a saver to the collection "
662 "or provide a saver or scaffold.".format(collection_key))
663 elif len(savers) > 1:
664 raise RuntimeError(
665 "More than one item in collection {}. "
666 "Please indicate which one to use by passing it to the constructor."
667 .format(collection_key))
669 self._saver = savers[0]
670 return savers[0]
673@tf_export(v1=["train.StepCounterHook"])
674class StepCounterHook(session_run_hook.SessionRunHook):
675 """Hook that counts steps per second."""
677 def __init__(self,
678 every_n_steps=100,
679 every_n_secs=None,
680 output_dir=None,
681 summary_writer=None):
683 if (every_n_steps is None) == (every_n_secs is None):
684 raise ValueError(
685 "exactly one of every_n_steps and every_n_secs should be provided.")
686 self._timer = SecondOrStepTimer(
687 every_steps=every_n_steps, every_secs=every_n_secs)
689 self._summary_writer = summary_writer
690 self._output_dir = output_dir
691 self._last_global_step = None
692 self._steps_per_run = 1
694 def _set_steps_per_run(self, steps_per_run):
695 self._steps_per_run = steps_per_run
697 def begin(self):
698 if self._summary_writer is None and self._output_dir:
699 self._summary_writer = SummaryWriterCache.get(self._output_dir)
700 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
701 if self._global_step_tensor is None:
702 raise RuntimeError(
703 "Global step should be created to use StepCounterHook.")
704 self._summary_tag = training_util.get_global_step().op.name + "/sec"
706 def before_run(self, run_context): # pylint: disable=unused-argument
707 return SessionRunArgs(self._global_step_tensor)
709 def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
710 steps_per_sec = elapsed_steps / elapsed_time
711 if self._summary_writer is not None:
712 summary = Summary(value=[
713 Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
714 ])
715 self._summary_writer.add_summary(summary, global_step)
716 logging.info("%s: %g", self._summary_tag, steps_per_sec)
718 def after_run(self, run_context, run_values):
719 _ = run_context
721 stale_global_step = run_values.results
722 if self._timer.should_trigger_for_step(stale_global_step +
723 self._steps_per_run):
724 # get the real value after train op.
725 global_step = run_context.session.run(self._global_step_tensor)
726 if self._timer.should_trigger_for_step(global_step):
727 elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
728 global_step)
729 if elapsed_time is not None:
730 self._log_and_record(elapsed_steps, elapsed_time, global_step)
732 # Check whether the global step has been increased. Here, we do not use the
733 # timer.last_triggered_step as the timer might record a different global
734 # step value such that the comparison could be unreliable. For simplicity,
735 # we just compare the stale_global_step with previously recorded version.
736 if stale_global_step == self._last_global_step:
737 # Here, we give a warning in the first 5 times if we have observed that
738 # the global step has not been increased. For some Optimizers, the global
739 # step is not increased each time by design. For example,
740 # SyncReplicaOptimizer doesn't increase the global step in worker's main
741 # train step.
742 logging.log_first_n(
743 logging.WARN,
744 "It seems that global step (tf.train.get_global_step) has not "
745 "been increased. Current value (could be stable): %s vs previous "
746 "value: %s. You could increase the global step by passing "
747 "tf.train.get_global_step() to Optimizer.apply_gradients or "
748 "Optimizer.minimize.", 5, stale_global_step, self._last_global_step)
750 self._last_global_step = stale_global_step
753@tf_export(v1=["train.NanLossDuringTrainingError"])
754class NanLossDuringTrainingError(RuntimeError):
756 def __str__(self):
757 return "NaN loss during training."
760@tf_export(v1=["train.NanTensorHook"])
761class NanTensorHook(session_run_hook.SessionRunHook):
762 """Monitors the loss tensor and stops training if loss is NaN.
764 Can either fail with exception or just stop training.
765 """
767 def __init__(self, loss_tensor, fail_on_nan_loss=True):
768 """Initializes a `NanTensorHook`.
770 Args:
771 loss_tensor: `Tensor`, the loss tensor.
772 fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
773 """
774 self._loss_tensor = loss_tensor
775 self._fail_on_nan_loss = fail_on_nan_loss
777 def before_run(self, run_context): # pylint: disable=unused-argument
778 return SessionRunArgs(self._loss_tensor)
780 def after_run(self, run_context, run_values):
781 if np.isnan(run_values.results):
782 failure_message = "Model diverged with loss = NaN."
783 if self._fail_on_nan_loss:
784 logging.error(failure_message)
785 raise NanLossDuringTrainingError
786 else:
787 logging.warning(failure_message)
788 # We don't raise an error but we request stop without an exception.
789 run_context.request_stop()
792@tf_export(v1=["train.SummarySaverHook"])
793class SummarySaverHook(session_run_hook.SessionRunHook):
794 """Saves summaries every N steps."""
796 def __init__(self,
797 save_steps=None,
798 save_secs=None,
799 output_dir=None,
800 summary_writer=None,
801 scaffold=None,
802 summary_op=None):
803 """Initializes a `SummarySaverHook`.
805 Args:
806 save_steps: `int`, save summaries every N steps. Exactly one of
807 `save_secs` and `save_steps` should be set.
808 save_secs: `int`, save summaries every N seconds.
809 output_dir: `string`, the directory to save the summaries to. Only used if
810 no `summary_writer` is supplied.
811 summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
812 one will be created accordingly.
813 scaffold: `Scaffold` to get summary_op if it's not provided.
814 summary_op: `Tensor` of type `string` containing the serialized `Summary`
815 protocol buffer or a list of `Tensor`. They are most likely an output by
816 TF summary methods like `tf.compat.v1.summary.scalar` or
817 `tf.compat.v1.summary.merge_all`. It can be passed in as one tensor; if
818 more than one, they must be passed in as a list.
820 Raises:
821 ValueError: Exactly one of scaffold or summary_op should be set.
822 """
823 if ((scaffold is None and summary_op is None) or
824 (scaffold is not None and summary_op is not None)):
825 raise ValueError(
826 "Exactly one of scaffold or summary_op must be provided.")
827 self._summary_op = summary_op
828 self._summary_writer = summary_writer
829 self._output_dir = output_dir
830 self._scaffold = scaffold
831 self._timer = SecondOrStepTimer(
832 every_secs=save_secs, every_steps=save_steps)
833 # TODO(mdan): Throw an error if output_dir and summary_writer are None.
835 def begin(self):
836 if self._summary_writer is None and self._output_dir:
837 self._summary_writer = SummaryWriterCache.get(self._output_dir)
838 self._next_step = None
839 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
840 if self._global_step_tensor is None:
841 raise RuntimeError(
842 "Global step should be created to use SummarySaverHook.")
844 def before_run(self, run_context): # pylint: disable=unused-argument
845 self._request_summary = (
846 self._next_step is None or
847 self._timer.should_trigger_for_step(self._next_step))
848 requests = {"global_step": self._global_step_tensor}
849 if self._request_summary:
850 if self._get_summary_op() is not None:
851 requests["summary"] = self._get_summary_op()
853 return SessionRunArgs(requests)
855 def after_run(self, run_context, run_values):
856 _ = run_context
857 if not self._summary_writer:
858 return
860 stale_global_step = run_values.results["global_step"]
861 global_step = stale_global_step + 1
862 if self._next_step is None or self._request_summary:
863 global_step = run_context.session.run(self._global_step_tensor)
865 if self._next_step is None:
866 self._summary_writer.add_session_log(
867 SessionLog(status=SessionLog.START), global_step)
869 if self._request_summary:
870 self._timer.update_last_triggered_step(global_step)
871 if "summary" in run_values.results:
872 for summary in run_values.results["summary"]:
873 self._summary_writer.add_summary(summary, global_step)
875 self._next_step = global_step + 1
877 def end(self, session=None):
878 if self._summary_writer:
879 self._summary_writer.flush()
881 def _get_summary_op(self):
882 """Fetches the summary op either from self._summary_op or self._scaffold.
884 Returns:
885 Returns a list of summary `Tensor`.
886 """
887 summary_op = None
888 if self._summary_op is not None:
889 summary_op = self._summary_op
890 elif self._scaffold.summary_op is not None:
891 summary_op = self._scaffold.summary_op
893 if summary_op is None:
894 return None
896 if not isinstance(summary_op, list):
897 return [summary_op]
898 return summary_op
901@tf_export(v1=["train.GlobalStepWaiterHook"])
902class GlobalStepWaiterHook(session_run_hook.SessionRunHook):
903 """Delays execution until global step reaches `wait_until_step`.
905 This hook delays execution until global step reaches to `wait_until_step`. It
906 is used to gradually start workers in distributed settings. One example usage
907 would be setting `wait_until_step=int(K*log(task_id+1))` assuming that
908 task_id=0 is the chief.
909 """
911 def __init__(self, wait_until_step):
912 """Initializes a `GlobalStepWaiterHook`.
914 Args:
915 wait_until_step: an `int` shows until which global step should we wait.
916 """
917 self._wait_until_step = wait_until_step
919 def begin(self):
920 self._worker_is_started = False
921 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
922 if self._global_step_tensor is None:
923 raise RuntimeError(
924 "Global step should be created to use _GlobalStepWaiterHook.")
926 def before_run(self, run_context):
927 if self._worker_is_started:
928 return None
930 if self._wait_until_step <= 0:
931 self._worker_is_started = True
932 return None
934 logging.info("Waiting for global step %d before starting training.",
935 self._wait_until_step)
936 last_logged_step = 0
937 while True:
938 current_step = run_context.session.run(self._global_step_tensor)
939 if current_step >= self._wait_until_step:
940 self._worker_is_started = True
941 return None
942 if current_step - last_logged_step > 1000:
943 logging.info(
944 "Waiting for global step %d before starting training. "
945 "Current step is %d.", self._wait_until_step, current_step)
946 last_logged_step = current_step
947 time.sleep(0.5)
950@tf_export(v1=["train.FinalOpsHook"])
951class FinalOpsHook(session_run_hook.SessionRunHook):
952 """A hook which evaluates `Tensors` at the end of a session."""
954 def __init__(self, final_ops, final_ops_feed_dict=None):
955 """Initializes `FinalOpHook` with ops to run at the end of the session.
957 Args:
958 final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
959 to `Tensors`.
960 final_ops_feed_dict: A feed dictionary to use when running
961 `final_ops_dict`.
962 """
963 self._final_ops = final_ops
964 self._final_ops_feed_dict = final_ops_feed_dict
965 self._final_ops_values = None
967 @property
968 def final_ops_values(self):
969 return self._final_ops_values
971 def end(self, session):
972 if self._final_ops is not None:
973 try:
974 self._final_ops_values = session.run(
975 self._final_ops, feed_dict=self._final_ops_feed_dict)
976 except (errors.OutOfRangeError, StopIteration) as e:
977 logging.warning(
978 "An OutOfRangeError or StopIteration exception is raised by the "
979 "code in FinalOpsHook. This typically means the Ops running by the "
980 "FinalOpsHook have a dependency back to some input source, which "
981 "should not happen. For example, for metrics in "
982 "tf.estimator.Estimator, all metrics functions return two Ops: "
983 "`value_op` and `update_op`. Estimator.evaluate calls the "
984 "`update_op` for each batch of the data in input source and, once "
985 "it is exhausted, it call the `value_op` to get the metric values. "
986 "The `value_op` here should have dependency back to variables "
987 "reading only, rather than reading another batch from input. "
988 "Otherwise, the `value_op`, executed by `FinalOpsHook`, triggers "
989 "another data reading, which ends OutOfRangeError/StopIteration. "
990 "Please fix that.")
991 raise e
994@tf_export(v1=["train.FeedFnHook"])
995class FeedFnHook(session_run_hook.SessionRunHook):
996 """Runs `feed_fn` and sets the `feed_dict` accordingly."""
998 def __init__(self, feed_fn):
999 """Initializes a `FeedFnHook`.
1001 Args:
1002 feed_fn: function that takes no arguments and returns `dict` of `Tensor`
1003 to feed.
1004 """
1005 self.feed_fn = feed_fn
1007 def before_run(self, run_context): # pylint: disable=unused-argument
1008 return session_run_hook.SessionRunArgs(
1009 fetches=None, feed_dict=self.feed_fn())
1012@tf_export(v1=["train.ProfilerHook"])
1013class ProfilerHook(session_run_hook.SessionRunHook):
1014 """Captures CPU/GPU profiling information every N steps or seconds.
1016 This produces files called "timeline-<step>.json", which are in Chrome
1017 Trace format.
1019 For more information see:
1020 https://github.com/catapult-project/catapult/blob/master/tracing/README.md
1021 """
1023 def __init__(self,
1024 save_steps=None,
1025 save_secs=None,
1026 output_dir="",
1027 show_dataflow=True,
1028 show_memory=False):
1029 """Initializes a hook that takes periodic profiling snapshots.
1031 `options.run_metadata` argument of `tf.Session.Run` is used to collect
1032 metadata about execution. This hook sets the metadata and dumps it in Chrome
1033 Trace format.
1036 Args:
1037 save_steps: `int`, save profile traces every N steps. Exactly one of
1038 `save_secs` and `save_steps` should be set.
1039 save_secs: `int` or `float`, save profile traces every N seconds.
1040 output_dir: `string`, the directory to save the profile traces to.
1041 Defaults to the current directory.
1042 show_dataflow: `bool`, if True, add flow events to the trace connecting
1043 producers and consumers of tensors.
1044 show_memory: `bool`, if True, add object snapshot events to the trace
1045 showing the sizes and lifetimes of tensors.
1046 """
1047 self._output_file = os.path.join(output_dir, "timeline-{}.json")
1048 self._file_writer = SummaryWriterCache.get(output_dir)
1049 self._show_dataflow = show_dataflow
1050 self._show_memory = show_memory
1051 self._timer = SecondOrStepTimer(
1052 every_secs=save_secs, every_steps=save_steps)
1054 def begin(self):
1055 self._next_step = None
1056 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
1057 if self._global_step_tensor is None:
1058 raise RuntimeError("Global step should be created to use ProfilerHook.")
1060 def before_run(self, run_context):
1061 self._request_summary = (
1062 self._next_step is not None and
1063 self._timer.should_trigger_for_step(self._next_step))
1064 requests = {"global_step": self._global_step_tensor}
1065 opts = (
1066 config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
1067 if self._request_summary else None)
1069 return SessionRunArgs(requests, options=opts)
1071 def after_run(self, run_context, run_values):
1072 stale_global_step = run_values.results["global_step"]
1073 if self._next_step is None:
1074 # Update the timer so that it does not activate until N steps or seconds
1075 # have passed.
1076 self._timer.update_last_triggered_step(stale_global_step)
1077 global_step = stale_global_step + 1
1078 if self._request_summary:
1079 global_step = run_context.session.run(self._global_step_tensor)
1080 self._timer.update_last_triggered_step(global_step)
1081 self._save(global_step, self._output_file.format(global_step),
1082 run_values.run_metadata.step_stats)
1083 self._file_writer.add_run_metadata(run_values.run_metadata,
1084 "step_%d" % global_step)
1086 self._next_step = global_step + 1
1088 def _save(self, step, save_path, step_stats):
1089 logging.info("Saving timeline for %d into '%s'.", step, save_path)
1090 with gfile.Open(save_path, "w") as f:
1091 trace = timeline.Timeline(step_stats)
1092 f.write(
1093 trace.generate_chrome_trace_format(
1094 show_dataflow=self._show_dataflow, show_memory=self._show_memory))
1097def _as_graph_element(obj):
1098 """Retrieves Graph element."""
1099 graph = ops.get_default_graph()
1100 if not isinstance(obj, str):
1101 if not hasattr(obj, "graph") or obj.graph != graph:
1102 raise ValueError("Passed %s should have graph attribute that is equal "
1103 "to current graph %s." % (obj, graph))
1104 return obj
1105 if ":" in obj:
1106 element = graph.as_graph_element(obj)
1107 else:
1108 element = graph.as_graph_element(obj + ":0")
1109 # Check that there is no :1 (e.g. it's single output).
1110 try:
1111 graph.as_graph_element(obj + ":1")
1112 except (KeyError, ValueError):
1113 pass
1114 else:
1115 raise ValueError("Name %s is ambiguous, "
1116 "as this `Operation` has multiple outputs "
1117 "(at least 2)." % obj)
1118 return element