Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/saver.py: 20%
511 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16# pylint: disable=invalid-name
17"""Save and restore variables.
19Symbols in this file are deprecated. See replacements in
20tensorflow/python/training/trackable and tensorflow/python/training/saving.
21"""
22import collections
23import glob
24import os.path
25import threading
26import time
28import numpy as np
29from tensorflow.core.protobuf import meta_graph_pb2
30from tensorflow.core.protobuf import saver_pb2
31from tensorflow.core.protobuf import trackable_object_graph_pb2
32from tensorflow.python.checkpoint import checkpoint_management
33from tensorflow.python.client import session
34from tensorflow.python.eager import context
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import device as pydev
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import meta_graph
39from tensorflow.python.framework import ops
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import control_flow_ops
42from tensorflow.python.ops import gen_io_ops
43from tensorflow.python.ops import io_ops
44from tensorflow.python.ops import string_ops
45from tensorflow.python.ops import variables
46from tensorflow.python.platform import gfile
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.saved_model.pywrap_saved_model import metrics
49from tensorflow.python.trackable import base as trackable
50from tensorflow.python.training import py_checkpoint_reader
51from tensorflow.python.training import training_util
52from tensorflow.python.training.saving import saveable_object
53from tensorflow.python.training.saving import saveable_object_util
54from tensorflow.python.util import compat
55from tensorflow.python.util.tf_export import tf_export
57# TODO(allenl): Remove these aliases once all users are migrated off.
58get_checkpoint_state = checkpoint_management.get_checkpoint_state
59update_checkpoint_state = checkpoint_management.update_checkpoint_state
60generate_checkpoint_state_proto = (
61 checkpoint_management.generate_checkpoint_state_proto)
62latest_checkpoint = checkpoint_management.latest_checkpoint
63checkpoint_exists = checkpoint_management.checkpoint_exists
64get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes
65remove_checkpoint = checkpoint_management.remove_checkpoint
67# Captures the timestamp of the first Saver object instantiation or end of a
68# save operation. Can be accessed by multiple Saver instances.
69_END_TIME_OF_LAST_WRITE = None
70_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock()
72# API label for cell name used in checkpoint metrics.
73_SAVER_LABEL = "saver_v1"
76def _get_duration_microseconds(start_time_seconds, end_time_seconds):
77 if end_time_seconds < start_time_seconds:
78 # Avoid returning negative value in case of clock skew.
79 return 0
80 return round((end_time_seconds - start_time_seconds) * 1000000)
83def _get_checkpoint_size(prefix):
84 """Calculates filesize of checkpoint based on prefix."""
85 size = 0
86 # Gather all files beginning with prefix (.index plus sharded data files).
87 files = glob.glob("{}*".format(prefix))
88 for file in files:
89 # Use TensorFlow's C++ FileSystem API.
90 size += metrics.CalculateFileSize(file)
91 return size
94class BaseSaverBuilder:
95 """Base class for Savers.
97 Can be extended to create different Ops.
98 """
100 SaveSpec = saveable_object.SaveSpec
101 SaveableObject = saveable_object.SaveableObject
103 # Aliases for code which was moved but still has lots of users.
104 VariableSaveable = saveable_object_util.ReferenceVariableSaveable
105 ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable
107 def __init__(self, write_version=saver_pb2.SaverDef.V2):
108 self._write_version = write_version
110 def save_op(self, filename_tensor, saveables):
111 """Create an Op to save 'saveables'.
113 This is intended to be overridden by subclasses that want to generate
114 different Ops.
116 Args:
117 filename_tensor: String Tensor.
118 saveables: A list of BaseSaverBuilder.SaveableObject objects.
120 Returns:
121 An Operation that save the variables.
123 Raises:
124 RuntimeError: (implementation detail) if "self._write_version" is an
125 unexpected value.
126 """
127 # pylint: disable=protected-access
128 tensor_names = []
129 tensors = []
130 tensor_slices = []
131 for saveable in saveables:
132 for spec in saveable.specs:
133 tensor_names.append(spec.name)
134 tensors.append(spec.tensor)
135 tensor_slices.append(spec.slice_spec)
136 if self._write_version == saver_pb2.SaverDef.V1:
137 return io_ops._save(
138 filename=filename_tensor,
139 tensor_names=tensor_names,
140 tensors=tensors,
141 tensor_slices=tensor_slices)
142 elif self._write_version == saver_pb2.SaverDef.V2:
143 # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
144 # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
145 return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
146 tensors)
147 else:
148 raise RuntimeError("Unexpected write_version: " + self._write_version)
150 def bulk_restore(self, filename_tensor, saveables, preferred_shard,
151 restore_sequentially):
152 """Restore all tensors contained in saveables.
154 By default, this issues separate calls to `restore_op` for each saveable.
155 Subclasses may override to load multiple saveables in a single call.
157 Args:
158 filename_tensor: String Tensor.
159 saveables: List of BaseSaverBuilder.SaveableObject objects.
160 preferred_shard: Int. Shard to open first when loading a sharded file.
161 restore_sequentially: Unused. Bool. If true, each restore is sequential.
163 Returns:
164 A list of Tensors resulting from reading 'saveable' from
165 'filename'.
167 """
168 del restore_sequentially
169 all_tensors = []
170 for saveable in saveables:
171 if saveable.device:
172 device = saveable_object_util.set_cpu0(saveable.device)
173 else:
174 device = None
175 with ops.device(device):
176 all_tensors.extend(
177 self.restore_op(filename_tensor, saveable, preferred_shard))
178 return all_tensors
180 # pylint: disable=unused-argument
181 def restore_op(self, filename_tensor, saveable, preferred_shard):
182 """Create ops to restore 'saveable'.
184 This is intended to be overridden by subclasses that want to generate
185 different Ops.
187 Args:
188 filename_tensor: String Tensor.
189 saveable: A BaseSaverBuilder.SaveableObject object.
190 preferred_shard: Int. Shard to open first when loading a sharded file.
192 Returns:
193 A list of Tensors resulting from reading 'saveable' from
194 'filename'.
195 """
196 # pylint: disable=protected-access
197 tensors = []
198 for spec in saveable.specs:
199 tensors.append(
200 io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec],
201 [spec.dtype])[0])
203 return tensors
205 # pylint: enable=unused-argument
207 def sharded_filename(self, filename_tensor, shard, num_shards):
208 """Append sharding information to a filename.
210 Args:
211 filename_tensor: A string tensor.
212 shard: Integer. The shard for the filename.
213 num_shards: An int Tensor for the number of shards.
215 Returns:
216 A string tensor.
217 """
218 return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
220 def _AddSaveOps(self, filename_tensor, saveables):
221 """Add ops to save variables that are on the same shard.
223 Args:
224 filename_tensor: String Tensor.
225 saveables: A list of SaveableObject objects.
227 Returns:
228 A tensor with the filename used to save.
229 """
230 save = self.save_op(filename_tensor, saveables)
231 return control_flow_ops.with_dependencies([save], filename_tensor)
233 def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device):
234 """Add ops to save the params per shard, for the V2 format.
236 Note that the sharded save procedure for the V2 format is different from
237 V1: there is a special "merge" step that merges the small metadata produced
238 from each device.
240 Args:
241 checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A FILENAME*,
242 but as a prefix of a V2 checkpoint;
243 per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
244 returned by _GroupByDevices().
246 Returns:
247 An op to save the variables, which, when evaluated, returns the prefix
248 "<user-fed prefix>" only and does not include the sharded spec suffix.
249 """
250 # IMPLEMENTATION DETAILS: most clients should skip.
251 #
252 # Suffix for any well-formed "checkpoint_prefix", when sharded.
253 # Transformations:
254 # * Users pass in "save_path" in save() and restore(). Say "myckpt".
255 # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>.
256 # * If checkpoint_prefix is a S3 bucket path ".part" is appended to it
257 # * Otherwise _temp/part is appended which is normalized relative to the OS
258 # Example:
259 # During runtime, a temporary directory is first created, which contains
260 # files
261 #
262 # <train dir>/myckpt_temp/
263 # part-?????-of-?????{.index, .data-00000-of-00001}
264 #
265 # Before .save() finishes, they will be (hopefully, atomically) renamed to
266 #
267 # <train dir>/
268 # myckpt{.index, .data-?????-of-?????}
269 #
270 # Filesystems with eventual consistency (such as S3), don't need a
271 # temporary location. Using a temporary directory in those cases might
272 # cause situations where files are not available during copy.
273 #
274 # Users only need to interact with the user-specified prefix, which is
275 # "<train dir>/myckpt" in this case. Save() and Restore() work with the
276 # prefix directly, instead of any physical pathname. (On failure and
277 # subsequent restore, an outdated and orphaned temporary directory can be
278 # safely removed.)
279 with ops.device("CPU"):
280 _SHARDED_SUFFIX = array_ops.where(
281 string_ops.regex_full_match(checkpoint_prefix, "^s3://.*"),
282 constant_op.constant(".part"),
283 constant_op.constant(os.path.normpath("_temp/part")))
284 tmp_checkpoint_prefix = string_ops.string_join(
285 [checkpoint_prefix, _SHARDED_SUFFIX])
287 num_shards = len(per_device)
288 sharded_saves = []
289 sharded_prefixes = []
290 num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
291 last_device = None
292 for shard, (device, saveables) in enumerate(per_device):
293 last_device = device
294 with ops.device(saveable_object_util.set_cpu0(device)):
295 sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
296 num_shards_tensor)
297 sharded_prefixes.append(sharded_filename)
298 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
300 with ops.control_dependencies([x.op for x in sharded_saves]):
301 # Co-locates the merge step with the last device.
302 with ops.device(saveable_object_util.set_cpu0(last_device)):
303 # V2 format write path consists of a metadata merge step. Once merged,
304 # attempts to delete the temporary directory, "<user-fed prefix>_temp".
305 merge_step = gen_io_ops.merge_v2_checkpoints(
306 sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
307 with ops.control_dependencies([merge_step]):
308 # Returns the prefix "<user-fed prefix>" only. DOES NOT include the
309 # sharded spec suffix.
310 return array_ops.identity(checkpoint_prefix)
312 def _AddShardedSaveOps(self, filename_tensor, per_device):
313 """Add ops to save the params per shard.
315 Args:
316 filename_tensor: a scalar String Tensor.
317 per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as
318 returned by _GroupByDevices().
320 Returns:
321 An op to save the variables.
322 """
323 if self._write_version == saver_pb2.SaverDef.V2:
324 return self._AddShardedSaveOpsForV2(filename_tensor, per_device)
326 num_shards = len(per_device)
327 sharded_saves = []
328 num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
329 for shard, (device, saveables) in enumerate(per_device):
330 with ops.device(device):
331 sharded_filename = self.sharded_filename(filename_tensor, shard,
332 num_shards_tensor)
333 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
334 # Return the sharded name for the save path.
335 with ops.control_dependencies([x.op for x in sharded_saves]):
336 return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor)
338 def _AddRestoreOps(self,
339 filename_tensor,
340 saveables,
341 restore_sequentially,
342 reshape,
343 preferred_shard=-1,
344 name="restore_all"):
345 """Add operations to restore saveables.
347 Args:
348 filename_tensor: Tensor for the path of the file to load.
349 saveables: A list of SaveableObject objects.
350 restore_sequentially: True if we want to restore variables sequentially
351 within a shard.
352 reshape: True if we want to reshape loaded tensors to the shape of the
353 corresponding variable.
354 preferred_shard: Shard to open first when loading a sharded file.
355 name: Name for the returned op.
357 Returns:
358 An Operation that restores the variables.
359 """
360 all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
361 restore_sequentially)
363 assign_ops = []
364 idx = 0
365 # Load and optionally reshape on the CPU, as string tensors are not
366 # available on the GPU.
367 # TODO(touts): Re-enable restore on GPU when we can support annotating
368 # string tensors as "HostMemory" inputs.
369 for saveable in saveables:
370 shapes = None
371 if reshape:
372 # Compute the shapes, let the restore op decide if and how to do
373 # the reshape.
374 shapes = []
375 for spec in saveable.specs:
376 v = spec.tensor
377 shape = v.get_shape()
378 if not shape.is_fully_defined():
379 shape = array_ops.shape(v)
380 shapes.append(shape)
381 saveable_tensors = all_tensors[idx:idx + len(saveable.specs)]
382 idx += len(saveable.specs)
383 assign_ops.append(saveable.restore(saveable_tensors, shapes))
385 # Create a Noop that has control dependencies from all the updates.
386 return control_flow_ops.group(*assign_ops, name=name)
388 def _AddShardedRestoreOps(self, filename_tensor, per_device,
389 restore_sequentially, reshape):
390 """Add Ops to restore variables from multiple devices.
392 Args:
393 filename_tensor: Tensor for the path of the file to load.
394 per_device: A list of (device, SaveableObject) pairs, as returned by
395 _GroupByDevices().
396 restore_sequentially: True if we want to restore variables sequentially
397 within a shard.
398 reshape: True if we want to reshape loaded tensors to the shape of the
399 corresponding variable.
401 Returns:
402 An Operation that restores the variables.
403 """
404 sharded_restores = []
405 for shard, (device, saveables) in enumerate(per_device):
406 with ops.device(device):
407 sharded_restores.append(
408 self._AddRestoreOps(
409 filename_tensor,
410 saveables,
411 restore_sequentially,
412 reshape,
413 preferred_shard=shard,
414 name="restore_shard"))
415 return control_flow_ops.group(*sharded_restores, name="restore_all")
417 def _GroupByDevices(self, saveables):
418 """Group Variable tensor slices per device.
420 TODO(touts): Make sure that all the devices found are on different
421 job/replica/task/cpu|gpu. It would be bad if 2 were on the same device.
422 It can happen if the devices are unspecified.
424 Args:
425 saveables: A list of BaseSaverBuilder.SaveableObject objects.
427 Returns:
428 A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples.
429 The list is sorted by ascending device_name.
431 Raises:
432 ValueError: If the tensors of a saveable are on different devices.
433 """
434 per_device = collections.defaultdict(lambda: [])
435 for saveable in saveables:
436 canonical_device = set(
437 pydev.canonical_name(spec.device) for spec in saveable.specs)
438 if len(canonical_device) != 1:
439 raise ValueError("All tensors of a saveable object must be "
440 "on the same device: %s" % saveable.name)
441 per_device[canonical_device.pop()].append(saveable)
442 return sorted(per_device.items(), key=lambda t: t[0])
444 def build(self,
445 names_to_saveables,
446 reshape=False,
447 sharded=False,
448 max_to_keep=5,
449 keep_checkpoint_every_n_hours=10000.0,
450 name=None,
451 restore_sequentially=False,
452 filename="model"):
453 """Builds save/restore graph nodes or runs save/restore in eager mode.
455 Args:
456 names_to_saveables: A dictionary mapping name to a Variable or
457 SaveableObject. Each name will be associated with the corresponding
458 variable in the checkpoint.
459 reshape: If True, allow restoring parameters from a checkpoint that where
460 the parameters have a different shape. This is only needed when you try
461 to restore from a Dist-Belief checkpoint, and only some times.
462 sharded: If True, shard the checkpoints, one per device that has Variable
463 nodes.
464 max_to_keep: Maximum number of checkpoints to keep. As new checkpoints
465 are created, old ones are deleted. If None or 0, no checkpoints are
466 deleted from the filesystem but only the last one is kept in the
467 `checkpoint` file. Presently the number is only roughly enforced. For
468 example in case of restarts more than max_to_keep checkpoints may be
469 kept.
470 keep_checkpoint_every_n_hours: How often checkpoints should be kept.
471 Defaults to 10,000 hours.
472 name: String. Optional name to use as a prefix when adding operations.
473 restore_sequentially: A Bool, which if true, causes restore of different
474 variables to happen sequentially within each device.
475 filename: If known at graph construction time, filename used for variable
476 loading/saving. If None, then the default name "model" will be used.
478 Returns:
479 A SaverDef proto.
481 Raises:
482 TypeError: If 'names_to_saveables' is not a dictionary mapping string
483 keys to variable Tensors.
484 ValueError: If any of the keys or values in 'names_to_saveables' is not
485 unique.
486 """
487 return self._build_internal(
488 names_to_saveables=names_to_saveables,
489 reshape=reshape,
490 sharded=sharded,
491 max_to_keep=max_to_keep,
492 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
493 name=name,
494 restore_sequentially=restore_sequentially,
495 filename=filename)
497 def _build_internal(self,
498 names_to_saveables,
499 reshape=False,
500 sharded=False,
501 max_to_keep=5,
502 keep_checkpoint_every_n_hours=10000.0,
503 name=None,
504 restore_sequentially=False,
505 filename="model",
506 build_save=True,
507 build_restore=True):
508 """build() with option to only perform save and restore."""
509 if not context.executing_eagerly() and (not build_save or
510 not build_restore):
511 raise ValueError("save and restore operations need to be built together "
512 " when eager execution is not enabled.")
514 if not isinstance(names_to_saveables, dict):
515 names_to_saveables = saveable_object_util.op_list_to_dict(
516 names_to_saveables)
517 saveables = saveable_object_util.validate_and_slice_inputs(
518 names_to_saveables)
519 if max_to_keep is None:
520 max_to_keep = 0
522 with ops.name_scope(name, "save",
523 [saveable.op for saveable in saveables]) as name:
524 # Add a placeholder string tensor for the filename.
525 filename_tensor = array_ops.placeholder_with_default(
526 filename or "model", shape=(), name="filename")
527 # Keep the name "Const" for backwards compatibility.
528 filename_tensor = array_ops.placeholder_with_default(
529 filename_tensor, shape=(), name="Const")
531 # Add the save ops.
532 if sharded:
533 per_device = self._GroupByDevices(saveables)
534 if build_save:
535 save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
536 if build_restore:
537 restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
538 restore_sequentially, reshape)
539 else:
540 if build_save:
541 save_tensor = self._AddSaveOps(filename_tensor, saveables)
542 if build_restore:
543 restore_op = self._AddRestoreOps(filename_tensor, saveables,
544 restore_sequentially, reshape)
546 # In the following use case, it's possible to have restore_ops be called
547 # something else:
548 # - Build inference graph and export a meta_graph.
549 # - Import the inference meta_graph
550 # - Extend the inference graph to a train graph.
551 # - Export a new meta_graph.
552 # Now the second restore_op will be called "restore_all_1".
553 # As such, comment out the assert for now until we know whether supporting
554 # such usage model makes sense.
555 #
556 # assert restore_op.name.endswith("restore_all"), restore_op.name
557 if context.executing_eagerly():
558 # Store the tensor values to the tensor_names.
559 save_tensor_name = save_tensor.numpy() if build_save else ""
560 return saver_pb2.SaverDef(
561 filename_tensor_name=filename_tensor.numpy(),
562 save_tensor_name=save_tensor_name,
563 restore_op_name="",
564 max_to_keep=max_to_keep,
565 sharded=sharded,
566 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
567 version=self._write_version)
568 else:
569 graph = ops.get_default_graph()
570 # Do some sanity checking on collections containing
571 # PartitionedVariables. If a saved collection has a PartitionedVariable,
572 # the GraphDef needs to include concat ops to get the value (or there'll
573 # be a lookup error on load).
574 check_collection_list = graph.get_all_collection_keys()
575 for collection_type in check_collection_list:
576 for element in graph.get_collection(collection_type):
577 if isinstance(element, variables.PartitionedVariable):
578 try:
579 graph.get_operation_by_name(element.name)
580 except KeyError:
581 # Create a concat op for this PartitionedVariable. The user may
582 # not need it, but we'll try looking it up on MetaGraph restore
583 # since it's in a collection.
584 element.as_tensor()
585 return saver_pb2.SaverDef(
586 filename_tensor_name=filename_tensor.name,
587 save_tensor_name=save_tensor.name,
588 restore_op_name=restore_op.name,
589 max_to_keep=max_to_keep,
590 sharded=sharded,
591 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
592 version=self._write_version)
595class BulkSaverBuilder(BaseSaverBuilder):
596 """SaverBuilder with support for bulk restoring multiple saveables."""
598 def bulk_restore(self, filename_tensor, saveables, preferred_shard,
599 restore_sequentially):
601 # Ignored: bulk restore is internally sequential.
602 del restore_sequentially
603 restore_specs = []
604 for saveable in saveables:
605 for spec in saveable.specs:
606 restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
608 names, slices, dtypes = zip(*restore_specs)
609 # Load all tensors onto CPU 0 for compatibility with existing code.
610 with ops.device("cpu:0"):
611 return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
614def _get_saver_or_default():
615 """Returns the saver from SAVERS collection, or creates a default one.
617 This method is used by other members of the training module, such as
618 `Scaffold`, or `CheckpointSaverHook`.
620 Returns:
621 `Saver`.
623 Raises:
624 RuntimeError: If the SAVERS collection already has more than one items.
625 """
626 collection_key = ops.GraphKeys.SAVERS
627 savers = ops.get_collection(collection_key)
628 if savers:
629 if len(savers) > 1:
630 raise RuntimeError(
631 "More than one item in collection {}. "
632 "Please indicate which one to use by passing it to the constructor."
633 .format(collection_key))
634 return savers[0]
635 saver = Saver(sharded=True, allow_empty=True)
636 if saver is not None:
637 ops.add_to_collection(collection_key, saver)
638 return saver
641@tf_export(v1=["train.Saver"])
642class Saver:
643 # pylint: disable=line-too-long
644 """Saves and restores variables.
646 @compatibility(TF2)
647 `tf.compat.v1.train.Saver` is not supported for saving and restoring
648 checkpoints in TF2. Please switch to `tf.train.Checkpoint` or
649 `tf.keras.Model.save_weights`, which perform a more robust [object-based
650 saving](https://www.tensorflow.org/guide/checkpoint#loading_mechanics).
652 ### How to Rewrite Checkpoints
654 Please rewrite your checkpoints immediately using the object-based checkpoint
655 APIs.
657 You can load a name-based checkpoint written by `tf.compat.v1.train.Saver`
658 using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However,
659 you may have to change the names of the variables in your model to match the
660 variable names in the name-based checkpoint, which can be viewed with
661 `tf.train.list_variables(path)`.
663 Another option is to create an `assignment_map` that maps the name of the
664 variables in the name-based checkpoint to the variables in your model, eg:
665 ```
666 {
667 'sequential/dense/bias': model.variables[0],
668 'sequential/dense/kernel': model.variables[1]
669 }
670 ```
671 and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to
672 restore the name-based checkpoint.
674 After restoring, re-encode your checkpoint
675 using `tf.train.Checkpoint.save` or `tf.keras.Model.save_weights`.
677 See the [Checkpoint compatibility](
678 https://www.tensorflow.org/guide/migrate#checkpoint_compatibility)
679 section of the migration guide for more details.
682 ### Checkpoint Management in TF2
684 Use `tf.train.CheckpointManager` to manage checkpoints in TF2.
685 `tf.train.CheckpointManager` offers equivalent `keep_checkpoint_every_n_hours`
686 and `max_to_keep` parameters.
688 To recover the latest checkpoint,
690 ```
691 checkpoint = tf.train.Checkpoint(model)
692 manager = tf.train.CheckpointManager(checkpoint)
693 status = checkpoint.restore(manager.latest_checkpoint)
694 ```
696 `tf.train.CheckpointManager` also writes a [`CheckpointState` proto]
697 (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/checkpoint_state.proto)
698 which contains the timestamp when each checkpoint was created.
700 ### Writing `MetaGraphDef`s in TF2
702 To replace, `tf.compat.v1.train.Saver.save(write_meta_graph=True)`, use
703 `tf.saved_model.save` to write the `MetaGraphDef` (which is contained in
704 `saved_model.pb`).
706 @end_compatibility
708 See [Variables](https://tensorflow.org/guide/variables)
709 for an overview of variables, saving and restoring.
711 The `Saver` class adds ops to save and restore variables to and from
712 *checkpoints*. It also provides convenience methods to run these ops.
714 Checkpoints are binary files in a proprietary format which map variable names
715 to tensor values. The best way to examine the contents of a checkpoint is to
716 load it using a `Saver`.
718 Savers can automatically number checkpoint filenames with a provided counter.
719 This lets you keep multiple checkpoints at different steps while training a
720 model. For example you can number the checkpoint filenames with the training
721 step number. To avoid filling up disks, savers manage checkpoint files
722 automatically. For example, they can keep only the N most recent files, or
723 one checkpoint for every N hours of training.
725 You number checkpoint filenames by passing a value to the optional
726 `global_step` argument to `save()`:
728 ```python
729 saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
730 ...
731 saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
732 ```
734 Additionally, optional arguments to the `Saver()` constructor let you control
735 the proliferation of checkpoint files on disk:
737 * `max_to_keep` indicates the maximum number of recent checkpoint files to
738 keep. As new files are created, older files are deleted. If None or 0,
739 no checkpoints are deleted from the filesystem but only the last one is
740 kept in the `checkpoint` file. Defaults to 5 (that is, the 5 most recent
741 checkpoint files are kept.)
743 * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent
744 `max_to_keep` checkpoint files, you might want to keep one checkpoint file
745 for every N hours of training. This can be useful if you want to later
746 analyze how a model progressed during a long training session. For
747 example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
748 one checkpoint file for every 2 hours of training. The default value of
749 10,000 hours effectively disables the feature.
751 Note that you still have to call the `save()` method to save the model.
752 Passing these arguments to the constructor will not save variables
753 automatically for you.
755 A training program that saves regularly looks like:
757 ```python
758 ...
759 # Create a saver.
760 saver = tf.compat.v1.train.Saver(...variables...)
761 # Launch the graph and train, saving the model every 1,000 steps.
762 sess = tf.compat.v1.Session()
763 for step in range(1000000):
764 sess.run(..training_op..)
765 if step % 1000 == 0:
766 # Append the step number to the checkpoint name:
767 saver.save(sess, 'my-model', global_step=step)
768 ```
770 In addition to checkpoint files, savers keep a protocol buffer on disk with
771 the list of recent checkpoints. This is used to manage numbered checkpoint
772 files and by `latest_checkpoint()`, which makes it easy to discover the path
773 to the most recent checkpoint. That protocol buffer is stored in a file named
774 'checkpoint' next to the checkpoint files.
776 If you create several savers, you can specify a different filename for the
777 protocol buffer file in the call to `save()`.
778 """
780 # pylint: enable=line-too-long
782 def __init__(self,
783 var_list=None,
784 reshape=False,
785 sharded=False,
786 max_to_keep=5,
787 keep_checkpoint_every_n_hours=10000.0,
788 name=None,
789 restore_sequentially=False,
790 saver_def=None,
791 builder=None,
792 defer_build=False,
793 allow_empty=False,
794 write_version=saver_pb2.SaverDef.V2,
795 pad_step_number=False,
796 save_relative_paths=False,
797 filename=None):
798 """Creates a `Saver`.
800 The constructor adds ops to save and restore variables.
802 `var_list` specifies the variables that will be saved and restored. It can
803 be passed as a `dict` or a list:
805 * A `dict` of names to variables: The keys are the names that will be
806 used to save or restore the variables in the checkpoint files.
807 * A list of variables: The variables will be keyed with their op name in
808 the checkpoint files.
810 For example:
812 ```python
813 v1 = tf.Variable(..., name='v1')
814 v2 = tf.Variable(..., name='v2')
816 # Pass the variables as a dict:
817 saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2})
819 # Or pass them as a list.
820 saver = tf.compat.v1.train.Saver([v1, v2])
821 # Passing a list is equivalent to passing a dict with the variable op names
822 # as keys:
823 saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]})
824 ```
826 Note: the newer `AutoTrackable` API is not supported by `Saver`. In this
827 case, the `tf.train.Checkpoint` class should be used.
829 The optional `reshape` argument, if `True`, allows restoring a variable from
830 a save file where the variable had a different shape, but the same number
831 of elements and type. This is useful if you have reshaped a variable and
832 want to reload it from an older checkpoint.
834 The optional `sharded` argument, if `True`, instructs the saver to shard
835 checkpoints per device.
837 Args:
838 var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
839 names to `SaveableObject`s. If `None`, defaults to the list of all
840 saveable objects.
841 reshape: If `True`, allows restoring parameters from a checkpoint where
842 the variables have a different shape.
843 sharded: If `True`, shard the checkpoints, one per device.
844 max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5.
845 keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to
846 10,000 hours.
847 name: String. Optional name to use as a prefix when adding operations.
848 restore_sequentially: A `Bool`, which if true, causes restore of different
849 variables to happen sequentially within each device. This can lower
850 memory usage when restoring very large models.
851 saver_def: Optional `SaverDef` proto to use instead of running the
852 builder. This is only useful for specialty code that wants to recreate a
853 `Saver` object for a previously built `Graph` that had a `Saver`. The
854 `saver_def` proto should be the one returned by the `as_saver_def()`
855 call of the `Saver` that was created for that `Graph`.
856 builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
857 Defaults to `BulkSaverBuilder()`.
858 defer_build: If `True`, defer adding the save and restore ops to the
859 `build()` call. In that case `build()` should be called before
860 finalizing the graph or using the saver.
861 allow_empty: If `False` (default) raise an error if there are no variables
862 in the graph. Otherwise, construct the saver anyway and make it a no-op.
863 write_version: controls what format to use when saving checkpoints. It
864 also affects certain filepath matching logic. The V2 format is the
865 recommended choice: it is much more optimized than V1 in terms of memory
866 required and latency incurred during restore. Regardless of this flag,
867 the Saver is able to restore from both V2 and V1 checkpoints.
868 pad_step_number: if True, pads the global step number in the checkpoint
869 filepaths to some fixed width (8 by default). This is turned off by
870 default.
871 save_relative_paths: If `True`, will write relative paths to the
872 checkpoint state file. This is needed if the user wants to copy the
873 checkpoint directory and reload from the copied directory.
874 filename: If known at graph construction time, filename used for variable
875 loading/saving.
877 Raises:
878 TypeError: If `var_list` is invalid.
879 ValueError: If any of the keys or values in `var_list` are not unique.
880 RuntimeError: If eager execution is enabled and`var_list` does not specify
881 a list of variables to save.
883 @compatibility(eager)
884 When eager execution is enabled, `var_list` must specify a `list` or `dict`
885 of variables to save. Otherwise, a `RuntimeError` will be raised.
887 Although Saver works in some cases when executing eagerly, it is
888 fragile. Please switch to `tf.train.Checkpoint` or
889 `tf.keras.Model.save_weights`, which perform a more robust object-based
890 saving. These APIs will load checkpoints written by `Saver`.
891 @end_compatibility
892 """
893 global _END_TIME_OF_LAST_WRITE
894 with _END_TIME_OF_LAST_WRITE_LOCK:
895 if _END_TIME_OF_LAST_WRITE is None:
896 _END_TIME_OF_LAST_WRITE = time.time()
898 if defer_build and var_list:
899 raise ValueError(
900 "If `var_list` is provided then build cannot be deferred. "
901 "Either set defer_build=False or var_list=None.")
902 if context.executing_eagerly():
903 logging.warning(
904 "Saver is deprecated, please switch to tf.train.Checkpoint or "
905 "tf.keras.Model.save_weights for training checkpoints. When "
906 "executing eagerly variables do not necessarily have unique names, "
907 "and so the variable.name-based lookups Saver performs are "
908 "error-prone.")
909 if var_list is None:
910 raise RuntimeError(
911 "When eager execution is enabled, `var_list` must specify a list "
912 "or dict of variables to save")
913 self._var_list = var_list
914 self._reshape = reshape
915 self._sharded = sharded
916 self._max_to_keep = max_to_keep
917 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
918 self._name = name
919 self._restore_sequentially = restore_sequentially
920 self.saver_def = saver_def
921 self._builder = builder
922 self._is_built = False
923 self._allow_empty = allow_empty
924 self._is_empty = None
925 self._write_version = write_version
926 self._pad_step_number = pad_step_number
927 self._filename = filename
928 self._last_checkpoints = []
929 self._checkpoints_to_be_deleted = []
930 if context.executing_eagerly():
931 self._next_checkpoint_time = (
932 time.time() + self._keep_checkpoint_every_n_hours * 3600)
933 elif not defer_build:
934 self.build()
935 if self.saver_def:
936 self._check_saver_def()
937 self._write_version = self.saver_def.version
938 self._save_relative_paths = save_relative_paths
939 # For compatibility with object-based checkpoints, we may build a second
940 # Saver to read the renamed keys.
941 self._object_restore_saver = None
943 def build(self):
944 if context.executing_eagerly():
945 raise RuntimeError("Use save/restore instead of build in eager mode.")
946 self._build(self._filename, build_save=True, build_restore=True)
948 def _build_eager(self, checkpoint_path, build_save, build_restore):
949 self._build(
950 checkpoint_path, build_save=build_save, build_restore=build_restore)
952 def _build(self, checkpoint_path, build_save, build_restore):
953 """Builds saver_def."""
954 if not context.executing_eagerly():
955 if self._is_built:
956 return
957 self._is_built = True
959 if not self.saver_def or context.executing_eagerly():
960 if self._builder is None:
961 self._builder = BulkSaverBuilder(self._write_version)
963 if self._var_list is None:
964 # pylint: disable=protected-access
965 self._var_list = variables._all_saveable_objects()
966 if not self._var_list:
967 if self._allow_empty:
968 self._is_empty = True
969 return
970 else:
971 raise ValueError("No variables to save")
972 self._is_empty = False
974 self.saver_def = self._builder._build_internal( # pylint: disable=protected-access
975 self._var_list,
976 reshape=self._reshape,
977 sharded=self._sharded,
978 max_to_keep=self._max_to_keep,
979 keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
980 name=self._name,
981 restore_sequentially=self._restore_sequentially,
982 filename=checkpoint_path,
983 build_save=build_save,
984 build_restore=build_restore)
985 elif self.saver_def and self._name:
986 # Since self._name is used as a name_scope by builder(), we are
987 # overloading the use of this field to represent the "import_scope" as
988 # well.
989 self.saver_def.filename_tensor_name = ops.prepend_name_scope(
990 self.saver_def.filename_tensor_name, self._name)
991 self.saver_def.save_tensor_name = ops.prepend_name_scope(
992 self.saver_def.save_tensor_name, self._name)
993 self.saver_def.restore_op_name = ops.prepend_name_scope(
994 self.saver_def.restore_op_name, self._name)
996 self._check_saver_def()
997 if not context.executing_eagerly():
998 # Updates next checkpoint time.
999 # Set in __init__ when executing eagerly.
1000 self._next_checkpoint_time = (
1001 time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600)
1003 def _check_saver_def(self):
1004 if not isinstance(self.saver_def, saver_pb2.SaverDef):
1005 raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" %
1006 self.saver_def)
1007 if not context.executing_eagerly():
1008 if not self.saver_def.save_tensor_name:
1009 raise ValueError("saver_def must specify the save_tensor_name: %s" %
1010 str(self.saver_def))
1011 if not self.saver_def.restore_op_name:
1012 raise ValueError("saver_def must specify the restore_op_name: %s" %
1013 str(self.saver_def))
1015 def _CheckpointFilename(self, p):
1016 """Returns the checkpoint filename given a `(filename, time)` pair.
1018 Args:
1019 p: (filename, time) pair.
1021 Returns:
1022 Checkpoint file name.
1023 """
1024 name, _ = p
1025 return name
1027 def _RecordLastCheckpoint(self, latest_save_path):
1028 """Manages the list of the latest checkpoints."""
1029 if not self.saver_def.max_to_keep:
1030 return
1031 # Remove first from list if the same name was used before.
1032 for p in self._last_checkpoints:
1033 if latest_save_path == self._CheckpointFilename(p):
1034 self._last_checkpoints.remove(p)
1035 # Append new path to list
1036 self._last_checkpoints.append((latest_save_path, time.time()))
1038 # If more than max_to_keep, remove oldest.
1039 if len(self._last_checkpoints) > self.saver_def.max_to_keep:
1040 self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0))
1042 def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"):
1043 """Deletes old checkpoints if necessary.
1045 `self._checkpoints_to_be_deleted` is going to contain checkpoints that are
1046 over `max_to_keep`. They are going to be deleted. If
1047 `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
1048 every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
1049 kept for every 0.5 hours of training; if `N` is 10, an additional
1050 checkpoint is kept for every 10 hours of training.
1052 Args:
1053 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
1054 """
1055 if self._checkpoints_to_be_deleted:
1056 p = self._checkpoints_to_be_deleted.pop(0)
1057 # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
1058 # have reached N hours of training.
1059 should_keep = p[1] > self._next_checkpoint_time
1060 if should_keep:
1061 self._next_checkpoint_time += (
1062 self.saver_def.keep_checkpoint_every_n_hours * 3600)
1063 return
1065 # Otherwise delete the files.
1066 try:
1067 checkpoint_management.remove_checkpoint(
1068 self._CheckpointFilename(p), self.saver_def.version,
1069 meta_graph_suffix)
1070 except Exception as e: # pylint: disable=broad-except
1071 logging.warning("Ignoring: %s", str(e))
1073 def as_saver_def(self):
1074 """Generates a `SaverDef` representation of this saver.
1076 Returns:
1077 A `SaverDef` proto.
1078 """
1079 return self.saver_def
1081 def to_proto(self, export_scope=None):
1082 """Converts this `Saver` to a `SaverDef` protocol buffer.
1084 Args:
1085 export_scope: Optional `string`. Name scope to remove.
1087 Returns:
1088 A `SaverDef` protocol buffer.
1089 """
1090 if export_scope is None:
1091 return self.saver_def
1093 if not (self.saver_def.filename_tensor_name.startswith(export_scope) and
1094 self.saver_def.save_tensor_name.startswith(export_scope) and
1095 self.saver_def.restore_op_name.startswith(export_scope)):
1096 return None
1098 saver_def = saver_pb2.SaverDef()
1099 saver_def.CopyFrom(self.saver_def)
1100 saver_def.filename_tensor_name = ops.strip_name_scope(
1101 saver_def.filename_tensor_name, export_scope)
1102 saver_def.save_tensor_name = ops.strip_name_scope(
1103 saver_def.save_tensor_name, export_scope)
1104 saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name,
1105 export_scope)
1106 return saver_def
1108 @staticmethod
1109 def from_proto(saver_def, import_scope=None):
1110 """Returns a `Saver` object created from `saver_def`.
1112 Args:
1113 saver_def: a `SaverDef` protocol buffer.
1114 import_scope: Optional `string`. Name scope to use.
1116 Returns:
1117 A `Saver` built from saver_def.
1118 """
1119 return Saver(saver_def=saver_def, name=import_scope)
1121 @property
1122 def last_checkpoints(self):
1123 """List of not-yet-deleted checkpoint filenames.
1125 You can pass any of the returned values to `restore()`.
1127 Returns:
1128 A list of checkpoint filenames, sorted from oldest to newest.
1129 """
1130 return list(self._CheckpointFilename(p) for p in self._last_checkpoints)
1132 def set_last_checkpoints(self, last_checkpoints):
1133 """DEPRECATED: Use set_last_checkpoints_with_time.
1135 Sets the list of old checkpoint filenames.
1137 Args:
1138 last_checkpoints: A list of checkpoint filenames.
1140 Raises:
1141 AssertionError: If last_checkpoints is not a list.
1142 """
1143 assert isinstance(last_checkpoints, list)
1144 # We use a timestamp of +inf so that this checkpoint will never be
1145 # deleted. This is both safe and backwards compatible to a previous
1146 # version of the code which used s[1] as the "timestamp".
1147 self._last_checkpoints = [(s, np.inf) for s in last_checkpoints]
1149 def set_last_checkpoints_with_time(self, last_checkpoints_with_time):
1150 """Sets the list of old checkpoint filenames and timestamps.
1152 Args:
1153 last_checkpoints_with_time: A list of tuples of checkpoint filenames and
1154 timestamps.
1156 Raises:
1157 AssertionError: If last_checkpoints_with_time is not a list.
1158 """
1159 assert isinstance(last_checkpoints_with_time, list)
1160 self._last_checkpoints = last_checkpoints_with_time
1162 def recover_last_checkpoints(self, checkpoint_paths):
1163 """Recovers the internal saver state after a crash.
1165 This method is useful for recovering the "self._last_checkpoints" state.
1167 Globs for the checkpoints pointed to by `checkpoint_paths`. If the files
1168 exist, use their mtime as the checkpoint timestamp.
1170 Args:
1171 checkpoint_paths: a list of checkpoint paths.
1172 """
1173 checkpoints_with_mtimes = []
1174 for checkpoint_path in checkpoint_paths:
1175 try:
1176 mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path])
1177 except errors.NotFoundError:
1178 # It's fine if some other thread/process is deleting some older
1179 # checkpoint concurrently.
1180 continue
1181 if mtime:
1182 checkpoints_with_mtimes.append((checkpoint_path, mtime[0]))
1183 self.set_last_checkpoints_with_time(checkpoints_with_mtimes)
1185 def save(self,
1186 sess,
1187 save_path,
1188 global_step=None,
1189 latest_filename=None,
1190 meta_graph_suffix="meta",
1191 write_meta_graph=True,
1192 write_state=True,
1193 strip_default_attrs=False,
1194 save_debug_info=False):
1195 # pylint: disable=line-too-long
1196 """Saves variables.
1198 This method runs the ops added by the constructor for saving variables.
1199 It requires a session in which the graph was launched. The variables to
1200 save must also have been initialized.
1202 The method returns the path prefix of the newly created checkpoint files.
1203 This string can be passed directly to a call to `restore()`.
1205 Args:
1206 sess: A Session to use to save the variables.
1207 save_path: String. Prefix of filenames created for the checkpoint.
1208 global_step: If provided the global step number is appended to `save_path`
1209 to create the checkpoint filenames. The optional argument can be a
1210 `Tensor`, a `Tensor` name or an integer.
1211 latest_filename: Optional name for the protocol buffer file that will
1212 contains the list of most recent checkpoints. That file, kept in the
1213 same directory as the checkpoint files, is automatically managed by the
1214 saver to keep track of recent checkpoints. Defaults to 'checkpoint'.
1215 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
1216 write_meta_graph: `Boolean` indicating whether or not to write the meta
1217 graph file.
1218 write_state: `Boolean` indicating whether or not to write the
1219 `CheckpointStateProto`.
1220 strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1221 removed from the NodeDefs. For a detailed guide, see [Stripping
1222 Default-Valued
1223 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1224 save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1225 which in the same directory of save_path and with `_debug` added before
1226 the file extension. This is only enabled when `write_meta_graph` is
1227 `True`
1229 Returns:
1230 A string: path prefix used for the checkpoint files. If the saver is
1231 sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
1232 is the number of shards created.
1233 If the saver is empty, returns None.
1235 Raises:
1236 TypeError: If `sess` is not a `Session`.
1237 ValueError: If `latest_filename` contains path components, or if it
1238 collides with `save_path`.
1239 RuntimeError: If save and restore ops weren't built.
1240 """
1241 # pylint: enable=line-too-long
1242 start_time = time.time()
1243 if not self._is_built and not context.executing_eagerly():
1244 raise RuntimeError(
1245 "`build()` should be called before save if defer_build==True")
1246 if latest_filename is None:
1247 latest_filename = "checkpoint"
1248 if self._write_version != saver_pb2.SaverDef.V2:
1249 logging.warning("*******************************************************")
1250 logging.warning("TensorFlow's V1 checkpoint format has been deprecated.")
1251 logging.warning("Consider switching to the more efficient V2 format:")
1252 logging.warning(" `tf.train.Saver(write_version=tf.train.SaverDef.V2)`")
1253 logging.warning("now on by default.")
1254 logging.warning("*******************************************************")
1256 if os.path.split(latest_filename)[0]:
1257 raise ValueError("'latest_filename' must not contain path components")
1259 save_path = compat.as_str(save_path)
1260 if global_step is not None:
1261 if not isinstance(global_step, compat.integral_types):
1262 global_step = training_util.global_step(sess, global_step)
1263 checkpoint_file = "%s-%d" % (save_path, global_step)
1264 if self._pad_step_number:
1265 # Zero-pads the step numbers, so that they are sorted when listed.
1266 checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
1267 else:
1268 checkpoint_file = save_path
1269 if os.path.basename(save_path) == latest_filename and not self._sharded:
1270 # Guard against collision between data file and checkpoint state file.
1271 raise ValueError(
1272 "'latest_filename' collides with 'save_path': '%s' and '%s'" %
1273 (latest_filename, save_path))
1275 if (not context.executing_eagerly() and
1276 not isinstance(sess, session.SessionInterface)):
1277 raise TypeError("'sess' must be a Session; %s" % sess)
1279 save_path_parent = os.path.dirname(save_path)
1280 if not self._is_empty:
1281 try:
1282 if context.executing_eagerly():
1283 self._build_eager(
1284 checkpoint_file, build_save=True, build_restore=False)
1285 model_checkpoint_path = self.saver_def.save_tensor_name
1286 else:
1287 model_checkpoint_path = sess.run(
1288 self.saver_def.save_tensor_name,
1289 {self.saver_def.filename_tensor_name: checkpoint_file})
1291 model_checkpoint_path = compat.as_str(model_checkpoint_path)
1292 if write_state:
1293 self._RecordLastCheckpoint(model_checkpoint_path)
1294 checkpoint_management.update_checkpoint_state_internal(
1295 save_dir=save_path_parent,
1296 model_checkpoint_path=model_checkpoint_path,
1297 all_model_checkpoint_paths=self.last_checkpoints,
1298 latest_filename=latest_filename,
1299 save_relative_paths=self._save_relative_paths)
1300 self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
1301 except (errors.FailedPreconditionError, errors.NotFoundError) as exc:
1302 if not gfile.IsDirectory(save_path_parent):
1303 exc = ValueError(
1304 "Parent directory of {} doesn't exist, can't save.".format(
1305 save_path))
1306 raise exc
1308 end_time = time.time()
1309 metrics.AddCheckpointWriteDuration(
1310 api_label=_SAVER_LABEL,
1311 microseconds=_get_duration_microseconds(start_time, end_time))
1312 global _END_TIME_OF_LAST_WRITE
1313 with _END_TIME_OF_LAST_WRITE_LOCK:
1314 metrics.AddTrainingTimeSaved(
1315 api_label=_SAVER_LABEL,
1316 microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE,
1317 end_time))
1318 _END_TIME_OF_LAST_WRITE = end_time
1320 if write_meta_graph:
1321 meta_graph_filename = checkpoint_management.meta_graph_filename(
1322 checkpoint_file, meta_graph_suffix=meta_graph_suffix)
1323 if not context.executing_eagerly():
1324 with sess.graph.as_default():
1325 self.export_meta_graph(
1326 meta_graph_filename,
1327 strip_default_attrs=strip_default_attrs,
1328 save_debug_info=save_debug_info)
1330 if self._is_empty:
1331 return None
1332 else:
1333 metrics.RecordCheckpointSize(
1334 api_label=_SAVER_LABEL,
1335 filesize=_get_checkpoint_size(model_checkpoint_path))
1336 return model_checkpoint_path
1338 def export_meta_graph(self,
1339 filename=None,
1340 collection_list=None,
1341 as_text=False,
1342 export_scope=None,
1343 clear_devices=False,
1344 clear_extraneous_savers=False,
1345 strip_default_attrs=False,
1346 save_debug_info=False):
1347 # pylint: disable=line-too-long
1348 """Writes `MetaGraphDef` to save_path/filename.
1350 Args:
1351 filename: Optional meta_graph filename including the path.
1352 collection_list: List of string keys to collect.
1353 as_text: If `True`, writes the meta_graph as an ASCII proto.
1354 export_scope: Optional `string`. Name scope to remove.
1355 clear_devices: Whether or not to clear the device field for an `Operation`
1356 or `Tensor` during export.
1357 clear_extraneous_savers: Remove any Saver-related information from the
1358 graph (both Save/Restore ops and SaverDefs) that are not associated with
1359 this Saver.
1360 strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1361 removed from the NodeDefs. For a detailed guide, see [Stripping
1362 Default-Valued
1363 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1364 save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1365 which in the same directory of filename and with `_debug` added before
1366 the file extension.
1368 Returns:
1369 A `MetaGraphDef` proto.
1370 """
1371 # pylint: enable=line-too-long
1372 return export_meta_graph(
1373 filename=filename,
1374 graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
1375 saver_def=self.saver_def,
1376 collection_list=collection_list,
1377 as_text=as_text,
1378 export_scope=export_scope,
1379 clear_devices=clear_devices,
1380 clear_extraneous_savers=clear_extraneous_savers,
1381 strip_default_attrs=strip_default_attrs,
1382 save_debug_info=save_debug_info)
1384 def restore(self, sess, save_path):
1385 """Restores previously saved variables.
1387 This method runs the ops added by the constructor for restoring variables.
1388 It requires a session in which the graph was launched. The variables to
1389 restore do not have to have been initialized, as restoring is itself a way
1390 to initialize variables.
1392 The `save_path` argument is typically a value previously returned from a
1393 `save()` call, or a call to `latest_checkpoint()`.
1395 Args:
1396 sess: A `Session` to use to restore the parameters. None in eager mode.
1397 save_path: Path where parameters were previously saved.
1399 Raises:
1400 ValueError: If save_path is None or not a valid checkpoint.
1401 """
1402 start_time = time.time()
1403 if self._is_empty:
1404 return
1405 if save_path is None:
1406 raise ValueError("Can't load save_path when it is None.")
1408 checkpoint_prefix = compat.as_text(save_path)
1409 if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix):
1410 raise ValueError("The passed save_path is not a valid checkpoint: " +
1411 checkpoint_prefix)
1413 logging.info("Restoring parameters from %s", checkpoint_prefix)
1414 try:
1415 if context.executing_eagerly():
1416 self._build_eager(save_path, build_save=False, build_restore=True)
1417 else:
1418 sess.run(self.saver_def.restore_op_name,
1419 {self.saver_def.filename_tensor_name: save_path})
1420 except errors.NotFoundError as err:
1421 # There are three common conditions that might cause this error:
1422 # 0. The file is missing. We ignore here, as this is checked above.
1423 # 1. This is an object-based checkpoint trying name-based loading.
1424 # 2. The graph has been altered and a variable or other name is missing.
1426 # 1. The checkpoint would not be loaded successfully as is. Try to parse
1427 # it as an object-based checkpoint.
1428 try:
1429 names_to_keys = object_graph_key_mapping(save_path)
1430 except errors.NotFoundError:
1431 # 2. This is not an object-based checkpoint, which likely means there
1432 # is a graph mismatch. Re-raise the original error with
1433 # a helpful message (b/110263146)
1434 raise _wrap_restore_error_with_msg(
1435 err, "a Variable name or other graph key that is missing")
1437 # This is an object-based checkpoint. We'll print a warning and then do
1438 # the restore.
1439 logging.warning(
1440 "Restoring an object-based checkpoint using a name-based saver. This "
1441 "may be somewhat fragile, and will re-build the Saver. Instead, "
1442 "consider loading object-based checkpoints using "
1443 "tf.train.Checkpoint().")
1444 self._object_restore_saver = saver_from_object_based_checkpoint(
1445 checkpoint_path=save_path,
1446 var_list=self._var_list,
1447 builder=self._builder,
1448 names_to_keys=names_to_keys,
1449 cached_saver=self._object_restore_saver)
1450 self._object_restore_saver.restore(sess=sess, save_path=save_path)
1451 except errors.InvalidArgumentError as err:
1452 # There is a mismatch between the graph and the checkpoint being loaded.
1453 # We add a more reasonable error message here to help users (b/110263146)
1454 raise _wrap_restore_error_with_msg(
1455 err, "a mismatch between the current graph and the graph")
1456 metrics.AddCheckpointReadDuration(
1457 api_label=_SAVER_LABEL,
1458 microseconds=_get_duration_microseconds(start_time, time.time()))
1460 @staticmethod
1461 def _add_collection_def(meta_graph_def, key, export_scope=None):
1462 """Adds a collection to MetaGraphDef protocol buffer.
1464 Args:
1465 meta_graph_def: MetaGraphDef protocol buffer.
1466 key: One of the GraphKeys or user-defined string.
1467 export_scope: Optional `string`. Name scope to remove.
1468 """
1469 meta_graph.add_collection_def(
1470 meta_graph_def, key, export_scope=export_scope)
1473@tf_export(v1=["train.import_meta_graph"])
1474def import_meta_graph(meta_graph_or_file,
1475 clear_devices=False,
1476 import_scope=None,
1477 **kwargs):
1478 """Recreates a Graph saved in a `MetaGraphDef` proto.
1480 This function takes a `MetaGraphDef` protocol buffer as input. If
1481 the argument is a file containing a `MetaGraphDef` protocol buffer ,
1482 it constructs a protocol buffer from the file content. The function
1483 then adds all the nodes from the `graph_def` field to the
1484 current graph, recreates all the collections, and returns a saver
1485 constructed from the `saver_def` field.
1487 In combination with `export_meta_graph()`, this function can be used to
1489 * Serialize a graph along with other Python objects such as `QueueRunner`,
1490 `Variable` into a `MetaGraphDef`.
1492 * Restart training from a saved graph and checkpoints.
1494 * Run inference from a saved graph and checkpoints.
1496 ```Python
1497 ...
1498 # Create a saver.
1499 saver = tf.compat.v1.train.Saver(...variables...)
1500 # Remember the training_op we want to run by adding it to a collection.
1501 tf.compat.v1.add_to_collection('train_op', train_op)
1502 sess = tf.compat.v1.Session()
1503 for step in range(1000000):
1504 sess.run(train_op)
1505 if step % 1000 == 0:
1506 # Saves checkpoint, which by default also exports a meta_graph
1507 # named 'my-model-global_step.meta'.
1508 saver.save(sess, 'my-model', global_step=step)
1509 ```
1511 Later we can continue training from this saved `meta_graph` without building
1512 the model from scratch.
1514 ```Python
1515 with tf.Session() as sess:
1516 new_saver =
1517 tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
1518 new_saver.restore(sess, 'my-save-dir/my-model-10000')
1519 # tf.get_collection() returns a list. In this example we only want
1520 # the first one.
1521 train_op = tf.get_collection('train_op')[0]
1522 for step in range(1000000):
1523 sess.run(train_op)
1524 ```
1526 NOTE: Restarting training from saved `meta_graph` only works if the
1527 device assignments have not changed.
1529 Example:
1530 Variables, placeholders, and independent operations can also be stored, as
1531 shown in the following example.
1533 ```Python
1534 # Saving contents and operations.
1535 v1 = tf.placeholder(tf.float32, name="v1")
1536 v2 = tf.placeholder(tf.float32, name="v2")
1537 v3 = tf.math.multiply(v1, v2)
1538 vx = tf.Variable(10.0, name="vx")
1539 v4 = tf.add(v3, vx, name="v4")
1540 saver = tf.train.Saver([vx])
1541 sess = tf.Session()
1542 sess.run(tf.global_variables_initializer())
1543 sess.run(vx.assign(tf.add(vx, vx)))
1544 result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
1545 print(result)
1546 saver.save(sess, "./model_ex1")
1547 ```
1549 Later this model can be restored and contents loaded.
1551 ```Python
1552 # Restoring variables and running operations.
1553 saver = tf.train.import_meta_graph("./model_ex1.meta")
1554 sess = tf.Session()
1555 saver.restore(sess, "./model_ex1")
1556 result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
1557 print(result)
1558 ```
1560 Args:
1561 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
1562 the path) containing a `MetaGraphDef`.
1563 clear_devices: Whether or not to clear the device field for an `Operation`
1564 or `Tensor` during import.
1565 import_scope: Optional `string`. Name scope to add. Only used when
1566 initializing from protocol buffer.
1567 **kwargs: Optional keyed arguments.
1569 Returns:
1570 A saver constructed from `saver_def` in `MetaGraphDef` or None.
1572 A None value is returned if no variables exist in the `MetaGraphDef`
1573 (i.e., there are no variables to restore).
1575 Raises:
1576 RuntimeError: If called with eager execution enabled.
1578 @compatibility(eager)
1579 Exporting/importing meta graphs is not supported. No graph exists when eager
1580 execution is enabled.
1581 @end_compatibility
1582 """ # pylint: disable=g-doc-exception
1583 return _import_meta_graph_with_return_elements(meta_graph_or_file,
1584 clear_devices, import_scope,
1585 **kwargs)[0]
1588def _import_meta_graph_with_return_elements(meta_graph_or_file,
1589 clear_devices=False,
1590 import_scope=None,
1591 return_elements=None,
1592 **kwargs):
1593 """Import MetaGraph, and return both a saver and returned elements."""
1594 if context.executing_eagerly():
1595 raise RuntimeError("Exporting/importing meta graphs is not supported when "
1596 "eager execution is enabled. No graph exists when eager "
1597 "execution is enabled.")
1598 if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
1599 meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file)
1600 else:
1601 meta_graph_def = meta_graph_or_file
1603 imported_vars, imported_return_elements = (
1604 meta_graph.import_scoped_meta_graph_with_return_elements(
1605 meta_graph_def,
1606 clear_devices=clear_devices,
1607 import_scope=import_scope,
1608 return_elements=return_elements,
1609 **kwargs))
1611 saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
1612 imported_vars)
1613 return saver, imported_return_elements
1616def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
1617 imported_vars):
1618 """Return a saver for restoring variable values to an imported MetaGraph."""
1619 if meta_graph_def.HasField("saver_def"):
1620 # Infer the scope that is prepended by `import_scoped_meta_graph`.
1621 scope = import_scope
1622 var_names = list(imported_vars.keys())
1623 if var_names:
1624 sample_key = var_names[0]
1625 sample_var = imported_vars[sample_key]
1626 scope = sample_var.name[:-len(sample_key)]
1628 return Saver(saver_def=meta_graph_def.saver_def, name=scope)
1629 else:
1630 if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access
1631 # Return the default saver instance for all graph variables.
1632 return Saver()
1633 else:
1634 # If no graph variables exist, then a Saver cannot be constructed.
1635 logging.info("Saver not created because there are no variables in the"
1636 " graph to restore")
1637 return None
1640@tf_export(v1=["train.export_meta_graph"])
1641def export_meta_graph(filename=None,
1642 meta_info_def=None,
1643 graph_def=None,
1644 saver_def=None,
1645 collection_list=None,
1646 as_text=False,
1647 graph=None,
1648 export_scope=None,
1649 clear_devices=False,
1650 clear_extraneous_savers=False,
1651 strip_default_attrs=False,
1652 save_debug_info=False,
1653 **kwargs):
1654 # pylint: disable=line-too-long
1655 """Returns `MetaGraphDef` proto.
1657 Optionally writes it to filename.
1659 This function exports the graph, saver, and collection objects into
1660 `MetaGraphDef` protocol buffer with the intention of it being imported
1661 at a later time or location to restart training, run inference, or be
1662 a subgraph.
1664 Args:
1665 filename: Optional filename including the path for writing the generated
1666 `MetaGraphDef` protocol buffer.
1667 meta_info_def: `MetaInfoDef` protocol buffer.
1668 graph_def: `GraphDef` protocol buffer.
1669 saver_def: `SaverDef` protocol buffer.
1670 collection_list: List of string keys to collect.
1671 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
1672 graph: The `Graph` to export. If `None`, use the default graph.
1673 export_scope: Optional `string`. Name scope under which to extract the
1674 subgraph. The scope name will be striped from the node definitions for
1675 easy import later into new name scopes. If `None`, the whole graph is
1676 exported. graph_def and export_scope cannot both be specified.
1677 clear_devices: Whether or not to clear the device field for an `Operation`
1678 or `Tensor` during export.
1679 clear_extraneous_savers: Remove any Saver-related information from the graph
1680 (both Save/Restore ops and SaverDefs) that are not associated with the
1681 provided SaverDef.
1682 strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1683 removed from the NodeDefs. For a detailed guide, see [Stripping
1684 Default-Valued
1685 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1686 save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
1687 which in the same directory of filename and with `_debug` added before the
1688 file extend.
1689 **kwargs: Optional keyed arguments.
1691 Returns:
1692 A `MetaGraphDef` proto.
1694 Raises:
1695 ValueError: When the `GraphDef` is larger than 2GB.
1696 RuntimeError: If called with eager execution enabled.
1698 @compatibility(eager)
1699 Exporting/importing meta graphs is not supported unless both `graph_def` and
1700 `graph` are provided. No graph exists when eager execution is enabled.
1701 @end_compatibility
1702 """
1703 # pylint: enable=line-too-long
1704 if context.executing_eagerly() and not (graph_def is not None and
1705 graph is not None):
1706 raise RuntimeError("Exporting/importing meta graphs is not supported when "
1707 "eager execution is enabled. No graph exists when eager "
1708 "execution is enabled.")
1709 meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
1710 filename=filename,
1711 meta_info_def=meta_info_def,
1712 graph_def=graph_def,
1713 saver_def=saver_def,
1714 collection_list=collection_list,
1715 as_text=as_text,
1716 graph=graph,
1717 export_scope=export_scope,
1718 clear_devices=clear_devices,
1719 clear_extraneous_savers=clear_extraneous_savers,
1720 strip_default_attrs=strip_default_attrs,
1721 save_debug_info=save_debug_info,
1722 **kwargs)
1723 return meta_graph_def
1726def _wrap_restore_error_with_msg(err, extra_verbiage):
1727 err_msg = ("Restoring from checkpoint failed. This is most likely "
1728 "due to {} from the checkpoint. Please ensure that you "
1729 "have not altered the graph expected based on the checkpoint. "
1730 "Original error:\n\n{}").format(extra_verbiage, err.message)
1731 return err.__class__(err.node_def, err.op, err_msg)
1734ops.register_proto_function(
1735 ops.GraphKeys.SAVERS,
1736 proto_type=saver_pb2.SaverDef,
1737 to_proto=Saver.to_proto,
1738 from_proto=Saver.from_proto)
1741def object_graph_key_mapping(checkpoint_path):
1742 """Return name to key mappings from the checkpoint.
1744 Args:
1745 checkpoint_path: string, path to object-based checkpoint
1747 Returns:
1748 Dictionary mapping tensor names to checkpoint keys.
1749 """
1750 reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
1751 object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
1752 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
1753 object_graph_proto.ParseFromString(object_graph_string)
1754 names_to_keys = {}
1755 for node in object_graph_proto.nodes:
1756 for attribute in node.attributes:
1757 names_to_keys[attribute.full_name] = attribute.checkpoint_key
1758 return names_to_keys
1761def saver_from_object_based_checkpoint(checkpoint_path,
1762 var_list=None,
1763 builder=None,
1764 names_to_keys=None,
1765 cached_saver=None):
1766 """Return a `Saver` which reads from an object-based checkpoint.
1768 This function validates that all variables in the variables list are remapped
1769 in the object-based checkpoint (or `names_to_keys` dict if provided). A
1770 saver will be created with the list of remapped variables.
1772 The `cached_saver` argument allows the user to pass in a previously created
1773 saver, so multiple `saver.restore()` calls don't pollute the graph when graph
1774 building. This assumes that keys are consistent, meaning that the
1775 1) `checkpoint_path` checkpoint, and
1776 2) checkpoint used to create the `cached_saver`
1777 are the same type of object-based checkpoint. If this argument is set, this
1778 function will simply validate that all variables have been remapped by the
1779 checkpoint at `checkpoint_path`.
1781 Note that in general, `tf.train.Checkpoint` should be used to restore/save an
1782 object-based checkpoint.
1784 Args:
1785 checkpoint_path: string, path to object-based checkpoint
1786 var_list: list of `Variables` that appear in the checkpoint. If `None`,
1787 `var_list` will be set to all saveable objects.
1788 builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder`
1789 will be created.
1790 names_to_keys: dict mapping string tensor names to checkpoint keys. If
1791 `None`, this dict will be generated from the checkpoint file.
1792 cached_saver: Cached `Saver` object with remapped variables.
1794 Returns:
1795 `Saver` with remapped variables for reading from an object-based checkpoint.
1797 Raises:
1798 ValueError if the checkpoint provided is not an object-based checkpoint.
1799 NotFoundError: If one of the variables in `var_list` can not be found in the
1800 checkpoint. This could mean the checkpoint or `names_to_keys` mapping is
1801 missing the variable.
1802 """
1803 if names_to_keys is None:
1804 try:
1805 names_to_keys = object_graph_key_mapping(checkpoint_path)
1806 except errors.NotFoundError:
1807 raise ValueError("Checkpoint in %s not an object-based checkpoint." %
1808 checkpoint_path)
1809 if var_list is None:
1810 var_list = variables._all_saveable_objects() # pylint: disable=protected-access
1811 if builder is None:
1812 builder = BulkSaverBuilder()
1814 if not isinstance(var_list, dict):
1815 var_list = saveable_object_util.op_list_to_dict(var_list)
1816 saveables = saveable_object_util.validate_and_slice_inputs(var_list)
1817 current_names = set()
1818 for saveable in saveables:
1819 for spec in saveable.specs:
1820 current_names.add(spec.name)
1821 previous_names = set(names_to_keys.keys())
1822 missing_names = current_names - previous_names
1823 if missing_names:
1824 extra_names = previous_names - current_names
1825 intersecting_names = previous_names.intersection(current_names)
1826 raise errors.NotFoundError(
1827 None,
1828 None,
1829 message=(
1830 "\n\nExisting variables not in the checkpoint: %s\n\n"
1831 "Variables names when this checkpoint was written which don't "
1832 "exist now: %s\n\n"
1833 "(%d variable name(s) did match)\n\n"
1834 "Could not find some variables in the checkpoint (see names "
1835 "above). Saver was attempting to load an object-based checkpoint "
1836 "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) "
1837 "using variable names. If the checkpoint was written with eager "
1838 "execution enabled, it's possible that variable names have "
1839 "changed (for example missing a '_1' suffix). It's also "
1840 "possible that there are new variables which did not exist "
1841 "when the checkpoint was written. You can construct a "
1842 "Saver(var_list=...) with only the variables which previously "
1843 "existed, and if variable names have changed you may need to "
1844 "make this a dictionary with the old names as keys. If you're "
1845 "using an Estimator, you'll need to return a tf.train.Saver "
1846 "inside a tf.train.Scaffold from your model_fn.") %
1847 (", ".join(sorted(missing_names)), ", ".join(
1848 sorted(extra_names)), len(intersecting_names)))
1849 for saveable in saveables:
1850 for spec in saveable.specs:
1851 spec.name = names_to_keys[spec.name]
1852 if cached_saver is None:
1853 return Saver(saveables)
1854 return cached_saver