Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/session_manager.py: 18%
152 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training helper that checkpoints models and creates session."""
16import time
18import numpy as np
19from tensorflow.python.checkpoint import checkpoint_management
20from tensorflow.python.client import session
21from tensorflow.python.distribute import distribute_lib
22from tensorflow.python.framework import errors
23from tensorflow.python.framework import ops
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.util.tf_export import tf_export
28def _maybe_name(obj):
29 """Returns object name if it has one, or a message otherwise.
31 This is useful for names that apper in error messages.
32 Args:
33 obj: Object to get the name of.
34 Returns:
35 name, "None", or a "no name" message.
36 """
37 if obj is None:
38 return "None"
39 elif hasattr(obj, "name"):
40 return obj.name
41 else:
42 return "<no name for %s>" % type(obj)
45def _restore_checkpoint_and_maybe_run_saved_model_initializers(
46 sess, saver, path):
47 """Restores checkpoint values and SavedModel initializers if found."""
48 # NOTE: All references to SavedModel refer to SavedModels loaded from the
49 # load_v2 API (which does not require the `sess` argument).
51 # If the graph contains resources loaded from a SavedModel, they are not
52 # restored when calling `saver.restore`. Thus, the SavedModel initializer must
53 # be called with `saver.restore` to properly initialize the model.
55 # The SavedModel init is stored in the "saved_model_initializers" collection.
56 # This collection is part of the MetaGraph's default_init_op, so it is already
57 # called by MonitoredSession as long as the saver doesn't restore any
58 # checkpoints from the working dir.
59 saved_model_init_ops = ops.get_collection("saved_model_initializers")
60 if saved_model_init_ops:
61 sess.run(saved_model_init_ops)
63 # The saver must be called *after* the SavedModel init, because the SavedModel
64 # init will restore the variables from the SavedModel variables directory.
65 # Initializing/restoring twice is not ideal but there's no other way to do it.
66 saver.restore(sess, path)
69@tf_export(v1=["train.SessionManager"])
70class SessionManager:
71 """Training helper that restores from checkpoint and creates session.
73 This class is a small wrapper that takes care of session creation and
74 checkpoint recovery. It also provides functions that to facilitate
75 coordination among multiple training threads or processes.
77 * Checkpointing trained variables as the training progresses.
78 * Initializing variables on startup, restoring them from the most recent
79 checkpoint after a crash, or wait for checkpoints to become available.
81 ### Usage:
83 ```python
84 with tf.Graph().as_default():
85 ...add operations to the graph...
86 # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.
87 sm = SessionManager()
88 sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
89 # Use the session to train the graph.
90 while True:
91 sess.run(<my_train_op>)
92 ```
94 `prepare_session()` initializes or restores a model. It requires `init_op`
95 and `saver` as an argument.
97 A second process could wait for the model to be ready by doing the following:
99 ```python
100 with tf.Graph().as_default():
101 ...add operations to the graph...
102 # Create a SessionManager that will wait for the model to become ready.
103 sm = SessionManager()
104 sess = sm.wait_for_session(master)
105 # Use the session to train the graph.
106 while True:
107 sess.run(<my_train_op>)
108 ```
110 `wait_for_session()` waits for a model to be initialized by other processes.
112 """
114 def __init__(self,
115 local_init_op=None,
116 ready_op=None,
117 ready_for_local_init_op=None,
118 graph=None,
119 recovery_wait_secs=30,
120 local_init_run_options=None,
121 local_init_feed_dict=None):
122 """Creates a SessionManager.
124 The `local_init_op` is an `Operation` that is run always after a new session
125 was created. If `None`, this step is skipped.
127 The `ready_op` is an `Operation` used to check if the model is ready. The
128 model is considered ready if that operation returns an empty 1D string
129 tensor. If the operation returns a non empty 1D string tensor, the elements
130 are concatenated and used to indicate to the user why the model is not
131 ready.
133 The `ready_for_local_init_op` is an `Operation` used to check if the model
134 is ready to run local_init_op. The model is considered ready if that
135 operation returns an empty 1D string tensor. If the operation returns a non
136 empty 1D string tensor, the elements are concatenated and used to indicate
137 to the user why the model is not ready.
139 If `ready_op` is `None`, the model is not checked for readiness.
141 `recovery_wait_secs` is the number of seconds between checks that
142 the model is ready. It is used by processes to wait for a model to
143 be initialized or restored. Defaults to 30 seconds.
145 Args:
146 local_init_op: An `Operation` run immediately after session creation.
147 Usually used to initialize tables and local variables.
148 ready_op: An `Operation` to check if the model is initialized.
149 ready_for_local_init_op: An `Operation` to check if the model is ready
150 to run local_init_op.
151 graph: The `Graph` that the model will use.
152 recovery_wait_secs: Seconds between checks for the model to be ready.
153 local_init_run_options: RunOptions to be passed to session.run when
154 executing the local_init_op.
155 local_init_feed_dict: Optional session feed dictionary to use when running
156 the local_init_op.
158 Raises:
159 ValueError: If ready_for_local_init_op is not None but local_init_op is
160 None
161 """
162 # Sets default values of arguments.
163 if graph is None:
164 graph = ops.get_default_graph()
165 self._local_init_op = local_init_op
166 self._ready_op = ready_op
167 self._ready_for_local_init_op = ready_for_local_init_op
168 self._graph = graph
169 self._recovery_wait_secs = recovery_wait_secs
170 self._target = None
171 self._local_init_run_options = local_init_run_options
172 self._local_init_feed_dict = local_init_feed_dict
173 if ready_for_local_init_op is not None and local_init_op is None:
174 raise ValueError("If you pass a ready_for_local_init_op "
175 "you must also pass a local_init_op "
176 ", ready_for_local_init_op [%s]" %
177 ready_for_local_init_op)
179 def _restore_checkpoint(self,
180 master,
181 saver=None,
182 checkpoint_dir=None,
183 checkpoint_filename_with_path=None,
184 wait_for_checkpoint=False,
185 max_wait_secs=7200,
186 config=None):
187 """Creates a `Session`, and tries to restore a checkpoint.
190 Args:
191 master: `String` representation of the TensorFlow master to use.
192 saver: A `Saver` object used to restore a model.
193 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
194 dir will be used to restore.
195 checkpoint_filename_with_path: Full file name path to the checkpoint file.
196 wait_for_checkpoint: Whether to wait for checkpoint to become available.
197 max_wait_secs: Maximum time to wait for checkpoints to become available.
198 config: Optional `ConfigProto` proto used to configure the session.
200 Returns:
201 A pair (sess, is_restored) where 'is_restored' is `True` if
202 the session could be restored, `False` otherwise.
204 Raises:
205 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
206 set.
207 """
208 self._target = master
210 # This is required to so that we initialize the TPU device before
211 # restoring from checkpoint since we'll be placing variables on the device
212 # and TPUInitialize wipes out the memory of the device.
213 strategy = distribute_lib.get_strategy()
214 if strategy and hasattr(strategy.extended,
215 "_experimental_initialize_system"):
216 strategy.extended._experimental_initialize_system() # pylint: disable=protected-access
218 sess = session.Session(self._target, graph=self._graph, config=config)
219 if checkpoint_dir and checkpoint_filename_with_path:
220 raise ValueError("Can not provide both checkpoint_dir and "
221 "checkpoint_filename_with_path.")
222 # If either saver or checkpoint_* is not specified, cannot restore. Just
223 # return.
224 if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
225 return sess, False
227 if checkpoint_filename_with_path:
228 _restore_checkpoint_and_maybe_run_saved_model_initializers(
229 sess, saver, checkpoint_filename_with_path)
230 return sess, True
232 # Waits up until max_wait_secs for checkpoint to become available.
233 wait_time = 0
234 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
235 while not ckpt or not ckpt.model_checkpoint_path:
236 if wait_for_checkpoint and wait_time < max_wait_secs:
237 logging.info("Waiting for checkpoint to be available.")
238 time.sleep(self._recovery_wait_secs)
239 wait_time += self._recovery_wait_secs
240 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
241 else:
242 return sess, False
244 # Loads the checkpoint.
245 _restore_checkpoint_and_maybe_run_saved_model_initializers(
246 sess, saver, ckpt.model_checkpoint_path)
247 saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
248 return sess, True
250 def prepare_session(self,
251 master,
252 init_op=None,
253 saver=None,
254 checkpoint_dir=None,
255 checkpoint_filename_with_path=None,
256 wait_for_checkpoint=False,
257 max_wait_secs=7200,
258 config=None,
259 init_feed_dict=None,
260 init_fn=None):
261 """Creates a `Session`. Makes sure the model is ready to be used.
263 Creates a `Session` on 'master'. If a `saver` object is passed in, and
264 `checkpoint_dir` points to a directory containing valid checkpoint
265 files, then it will try to recover the model from checkpoint. If
266 no checkpoint files are available, and `wait_for_checkpoint` is
267 `True`, then the process would check every `recovery_wait_secs`,
268 up to `max_wait_secs`, for recovery to succeed.
270 If the model cannot be recovered successfully then it is initialized by
271 running the `init_op` and calling `init_fn` if they are provided.
272 The `local_init_op` is also run after init_op and init_fn, regardless of
273 whether the model was recovered successfully, but only if
274 `ready_for_local_init_op` passes.
276 If the model is recovered from a checkpoint it is assumed that all
277 global variables have been initialized, in particular neither `init_op`
278 nor `init_fn` will be executed.
280 It is an error if the model cannot be recovered and no `init_op`
281 or `init_fn` or `local_init_op` are passed.
283 Args:
284 master: `String` representation of the TensorFlow master to use.
285 init_op: Optional `Operation` used to initialize the model.
286 saver: A `Saver` object used to restore a model.
287 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
288 dir will be used to restore.
289 checkpoint_filename_with_path: Full file name path to the checkpoint file.
290 wait_for_checkpoint: Whether to wait for checkpoint to become available.
291 max_wait_secs: Maximum time to wait for checkpoints to become available.
292 config: Optional `ConfigProto` proto used to configure the session.
293 init_feed_dict: Optional dictionary that maps `Tensor` objects to feed
294 values. This feed dictionary is passed to the session `run()` call when
295 running the init op.
296 init_fn: Optional callable used to initialize the model. Called after the
297 optional `init_op` is called. The callable must accept one argument,
298 the session being initialized.
300 Returns:
301 A `Session` object that can be used to drive the model.
303 Raises:
304 RuntimeError: If the model cannot be initialized or recovered.
305 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
306 set.
307 """
309 sess, is_loaded_from_checkpoint = self._restore_checkpoint(
310 master,
311 saver,
312 checkpoint_dir=checkpoint_dir,
313 checkpoint_filename_with_path=checkpoint_filename_with_path,
314 wait_for_checkpoint=wait_for_checkpoint,
315 max_wait_secs=max_wait_secs,
316 config=config)
317 if not is_loaded_from_checkpoint:
318 if init_op is None and not init_fn and self._local_init_op is None:
319 raise RuntimeError("Model is not initialized and no init_op or "
320 "init_fn or local_init_op was given")
321 if init_op is not None:
322 sess.run(init_op, feed_dict=init_feed_dict)
323 if init_fn:
324 init_fn(sess)
326 local_init_success, msg = self._try_run_local_init_op(sess)
327 if not local_init_success:
328 raise RuntimeError(
329 "Init operations did not make model ready for local_init. "
330 "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op),
331 init_fn,
332 msg))
334 is_ready, msg = self._model_ready(sess)
335 if not is_ready:
336 raise RuntimeError(
337 "Init operations did not make model ready. "
338 "Init op: %s, init fn: %s, local_init_op: %s, error: %s" %
339 (_maybe_name(init_op), init_fn, self._local_init_op, msg))
340 return sess
342 def recover_session(self,
343 master,
344 saver=None,
345 checkpoint_dir=None,
346 checkpoint_filename_with_path=None,
347 wait_for_checkpoint=False,
348 max_wait_secs=7200,
349 config=None):
350 """Creates a `Session`, recovering if possible.
352 Creates a new session on 'master'. If the session is not initialized
353 and can be recovered from a checkpoint, recover it.
355 Args:
356 master: `String` representation of the TensorFlow master to use.
357 saver: A `Saver` object used to restore a model.
358 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
359 dir will be used to restore.
360 checkpoint_filename_with_path: Full file name path to the checkpoint file.
361 wait_for_checkpoint: Whether to wait for checkpoint to become available.
362 max_wait_secs: Maximum time to wait for checkpoints to become available.
363 config: Optional `ConfigProto` proto used to configure the session.
365 Returns:
366 A pair (sess, initialized) where 'initialized' is `True` if
367 the session could be recovered and initialized, `False` otherwise.
369 Raises:
370 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
371 set.
372 """
374 sess, is_loaded_from_checkpoint = self._restore_checkpoint(
375 master,
376 saver,
377 checkpoint_dir=checkpoint_dir,
378 checkpoint_filename_with_path=checkpoint_filename_with_path,
379 wait_for_checkpoint=wait_for_checkpoint,
380 max_wait_secs=max_wait_secs,
381 config=config)
383 # Always try to run local_init_op
384 local_init_success, msg = self._try_run_local_init_op(sess)
386 if not is_loaded_from_checkpoint:
387 # Do not need to run checks for readiness
388 return sess, False
390 restoring_file = checkpoint_dir or checkpoint_filename_with_path
391 if not local_init_success:
392 logging.info(
393 "Restoring model from %s did not make model ready for local init:"
394 " %s", restoring_file, msg)
395 return sess, False
397 is_ready, msg = self._model_ready(sess)
398 if not is_ready:
399 logging.info("Restoring model from %s did not make model ready: %s",
400 restoring_file, msg)
401 return sess, False
403 logging.info("Restored model from %s", restoring_file)
404 return sess, is_loaded_from_checkpoint
406 def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")):
407 """Creates a new `Session` and waits for model to be ready.
409 Creates a new `Session` on 'master'. Waits for the model to be
410 initialized or recovered from a checkpoint. It's expected that
411 another thread or process will make the model ready, and that this
412 is intended to be used by threads/processes that participate in a
413 distributed training configuration where a different thread/process
414 is responsible for initializing or recovering the model being trained.
416 NB: The amount of time this method waits for the session is bounded
417 by max_wait_secs. By default, this function will wait indefinitely.
419 Args:
420 master: `String` representation of the TensorFlow master to use.
421 config: Optional ConfigProto proto used to configure the session.
422 max_wait_secs: Maximum time to wait for the session to become available.
424 Returns:
425 A `Session`. May be None if the operation exceeds the timeout
426 specified by config.operation_timeout_in_ms.
428 Raises:
429 tf.DeadlineExceededError: if the session is not available after
430 max_wait_secs.
431 """
432 self._target = master
434 if max_wait_secs is None:
435 max_wait_secs = float("Inf")
436 timer = _CountDownTimer(max_wait_secs)
438 while True:
439 sess = session.Session(self._target, graph=self._graph, config=config)
440 not_ready_msg = None
441 not_ready_local_msg = None
442 local_init_success, not_ready_local_msg = self._try_run_local_init_op(
443 sess)
444 if local_init_success:
445 # Successful if local_init_op is None, or ready_for_local_init_op passes
446 is_ready, not_ready_msg = self._model_ready(sess)
447 if is_ready:
448 return sess
450 self._safe_close(sess)
452 # Do we have enough time left to try again?
453 remaining_ms_after_wait = (
454 timer.secs_remaining() - self._recovery_wait_secs)
455 if remaining_ms_after_wait < 0:
456 raise errors.DeadlineExceededError(
457 None, None,
458 "Session was not ready after waiting %d secs." % (max_wait_secs,))
460 logging.info("Waiting for model to be ready. "
461 "Ready_for_local_init_op: %s, ready: %s",
462 not_ready_local_msg, not_ready_msg)
463 time.sleep(self._recovery_wait_secs)
465 def _safe_close(self, sess):
466 """Closes a session without raising an exception.
468 Just like sess.close() but ignores exceptions.
470 Args:
471 sess: A `Session`.
472 """
473 # pylint: disable=broad-except
474 try:
475 sess.close()
476 except Exception:
477 # Intentionally not logging to avoid user complaints that
478 # they get cryptic errors. We really do not care that Close
479 # fails.
480 pass
481 # pylint: enable=broad-except
483 def _model_ready(self, sess):
484 """Checks if the model is ready or not.
486 Args:
487 sess: A `Session`.
489 Returns:
490 A tuple (is_ready, msg), where is_ready is True if ready and False
491 otherwise, and msg is `None` if the model is ready, a `String` with the
492 reason why it is not ready otherwise.
493 """
494 return _ready(self._ready_op, sess, "Model not ready")
496 def _model_ready_for_local_init(self, sess):
497 """Checks if the model is ready to run local_init_op.
499 Args:
500 sess: A `Session`.
502 Returns:
503 A tuple (is_ready, msg), where is_ready is True if ready to run
504 local_init_op and False otherwise, and msg is `None` if the model is
505 ready to run local_init_op, a `String` with the reason why it is not ready
506 otherwise.
507 """
508 return _ready(self._ready_for_local_init_op, sess,
509 "Model not ready for local init")
511 def _try_run_local_init_op(self, sess):
512 """Tries to run _local_init_op, if not None, and is ready for local init.
514 Args:
515 sess: A `Session`.
517 Returns:
518 A tuple (is_successful, msg), where is_successful is True if
519 _local_init_op is None, or we ran _local_init_op, and False otherwise;
520 and msg is a `String` with the reason why the model was not ready to run
521 local init.
522 """
523 if self._local_init_op is not None:
524 is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
525 if is_ready_for_local_init:
526 logging.info("Running local_init_op.")
527 sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict,
528 options=self._local_init_run_options)
529 logging.info("Done running local_init_op.")
530 return True, None
531 else:
532 return False, msg
533 return True, None
536def _ready(op, sess, msg):
537 """Checks if the model is ready or not, as determined by op.
539 Args:
540 op: An op, either _ready_op or _ready_for_local_init_op, which defines the
541 readiness of the model.
542 sess: A `Session`.
543 msg: A message to log to warning if not ready
545 Returns:
546 A tuple (is_ready, msg), where is_ready is True if ready and False
547 otherwise, and msg is `None` if the model is ready, a `String` with the
548 reason why it is not ready otherwise.
549 """
550 if op is None:
551 return True, None
552 else:
553 try:
554 ready_value = sess.run(op)
555 # The model is considered ready if ready_op returns an empty 1-D tensor.
556 # Also compare to `None` and dtype being int32 for backward
557 # compatibility.
558 if (ready_value is None or ready_value.dtype == np.int32 or
559 ready_value.size == 0):
560 return True, None
561 else:
562 # TODO(sherrym): If a custom ready_op returns other types of tensor,
563 # or strings other than variable names, this message could be
564 # confusing.
565 non_initialized_varnames = ", ".join(
566 [i.decode("utf-8") for i in ready_value])
567 return False, "Variables not initialized: " + non_initialized_varnames
568 except errors.FailedPreconditionError as e:
569 if "uninitialized" not in str(e):
570 logging.warning("%s : error [%s]", msg, str(e))
571 raise e
572 return False, str(e)
575class _CountDownTimer:
576 """A timer that tracks a duration since creation."""
578 __slots__ = ["_start_time_secs", "_duration_secs"]
580 def __init__(self, duration_secs):
581 self._start_time_secs = time.time()
582 self._duration_secs = duration_secs
584 def secs_remaining(self):
585 diff = self._duration_secs - (time.time() - self._start_time_secs)
586 return max(0, diff)