Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/checkpoint.py: 21%
814 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""Utilities for saving/loading Trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
17import abc
18import collections
19import functools
20import glob
21import os
22import threading
23import time
24import weakref
26from tensorflow.core.protobuf import trackable_object_graph_pb2
27from tensorflow.python.checkpoint import async_checkpoint_helper
28from tensorflow.python.checkpoint import checkpoint_context
29from tensorflow.python.checkpoint import checkpoint_management
30from tensorflow.python.checkpoint import checkpoint_options
31from tensorflow.python.checkpoint import functional_saver
32from tensorflow.python.checkpoint import graph_view as graph_view_lib
33from tensorflow.python.checkpoint import restore as restore_lib
34from tensorflow.python.checkpoint import save_util
35from tensorflow.python.checkpoint import save_util_v1
36from tensorflow.python.checkpoint import util
37from tensorflow.python.client import session as session_lib
38from tensorflow.python.eager import context
39from tensorflow.python.eager import def_function
40from tensorflow.python.eager import monitoring
41from tensorflow.python.framework import constant_op
42from tensorflow.python.framework import dtypes
43from tensorflow.python.framework import errors_impl
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import tensor_shape
46from tensorflow.python.framework import tensor_util
47from tensorflow.python.lib.io import file_io
48from tensorflow.python.ops import array_ops
49from tensorflow.python.ops import gen_io_ops as io_ops
50from tensorflow.python.ops import init_ops
51from tensorflow.python.ops import variable_scope
52from tensorflow.python.ops import variable_v1
53from tensorflow.python.platform import gfile
54from tensorflow.python.platform import tf_logging as logging
55from tensorflow.python.saved_model import path_helpers
56from tensorflow.python.saved_model.pywrap_saved_model import metrics
57from tensorflow.python.trackable import autotrackable
58from tensorflow.python.trackable import base
59from tensorflow.python.trackable import data_structures
60from tensorflow.python.training import py_checkpoint_reader
61from tensorflow.python.training import saver as v1_saver_lib
62from tensorflow.python.training.saving import saveable_object as saveable_object_lib
63from tensorflow.python.training.saving import saveable_object_util
64from tensorflow.python.util import compat
65from tensorflow.python.util import deprecation
66from tensorflow.python.util import object_identity
67from tensorflow.python.util import tf_contextlib
68from tensorflow.python.util import tf_inspect
69from tensorflow.python.util.tf_export import tf_export
72# The callable that provide Keras default session that is needed for saving.
73_SESSION_PROVIDER = None
75# Captures the timestamp of the first Checkpoint instantiation or end of a write
76# operation. Can be accessed by multiple Checkpoint instances.
77_END_TIME_OF_LAST_WRITE = None
78_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock()
80# API labels for cell names used in checkpoint metrics.
81_CHECKPOINT_V1 = "checkpoint_v1"
82_CHECKPOINT_V2 = "checkpoint_v2"
84# Async thread used for asynchronous checkpoint.
85_ASYNC_CHECKPOINT_THREAD = None
88def _get_duration_microseconds(start_time_seconds, end_time_seconds):
89 if end_time_seconds < start_time_seconds:
90 # Avoid returning negative value in case of clock skew.
91 return 0
92 return round((end_time_seconds - start_time_seconds) * 1000000)
95@tf_export("__internal__.tracking.register_session_provider", v1=[])
96def register_session_provider(session_provider):
97 global _SESSION_PROVIDER
98 # TODO(scottzhu): Change it back to only allow one time setting for session
99 # provider once we finished the keras repo split.
100 # if _SESSION_PROVIDER is None:
101 _SESSION_PROVIDER = session_provider
104def get_session():
105 # Prefer TF's default session since get_session from Keras has side-effects.
106 session = ops.get_default_session()
107 if session is None:
108 global _SESSION_PROVIDER
109 if _SESSION_PROVIDER is not None:
110 session = _SESSION_PROVIDER() # pylint: disable=not-callable
111 return session
114def _get_checkpoint_size(prefix):
115 """Calculates filesize of checkpoint based on prefix."""
116 size = 0
117 # Gather all files beginning with prefix (.index plus sharded data files).
118 files = glob.glob("{}*".format(prefix))
119 for file in files:
120 # Use TensorFlow's C++ FileSystem API.
121 size += metrics.CalculateFileSize(file)
122 return size
125class ObjectGraphProtoPrettyPrinter:
126 """Lazily traverses an object graph proto to pretty print names.
128 If no calls to `node_names` are made this object has no performance
129 overhead. On the other hand, it will only traverse the object graph once, so
130 repeated naming is cheap after the first.
131 """
133 __slots__ = ["_object_graph_proto", "_node_name_cache"]
135 def __init__(self, object_graph_proto):
136 self._object_graph_proto = object_graph_proto
137 self._node_name_cache = None
139 @property
140 def node_names(self):
141 """Lazily creates a mapping from node id to ("path", "to", "root")."""
142 if self._node_name_cache is not None:
143 return self._node_name_cache
144 path_to_root = {}
145 path_to_root[0] = ("(root)",)
146 to_visit = collections.deque([0])
147 while to_visit:
148 node_id = to_visit.popleft()
149 obj = self._object_graph_proto.nodes[node_id]
150 for child in obj.children:
151 if child.node_id not in path_to_root:
152 path_to_root[child.node_id] = (
153 path_to_root[node_id] + (child.local_name,))
154 to_visit.append(child.node_id)
156 node_names = {}
157 for node_id, path_to_root in path_to_root.items():
158 node_names[node_id] = ".".join(path_to_root)
160 for node_id, node in enumerate(self._object_graph_proto.nodes):
161 for slot_reference in node.slot_variables:
162 node_names[slot_reference.slot_variable_node_id] = (
163 f"{node_names[node_id]}'s state '{slot_reference.slot_name}' for "
164 f"{node_names[slot_reference.original_variable_node_id]}")
165 self._node_name_cache = node_names
166 return node_names
169class _CheckpointRestoreCoordinatorDeleter:
170 """Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__()."""
172 __slots__ = [
173 "expect_partial", "object_graph_proto", "matched_proto_ids",
174 "unused_attributes"
175 ]
177 def __init__(self, expect_partial, object_graph_proto, matched_proto_ids,
178 unused_attributes):
179 self.expect_partial = expect_partial
180 self.object_graph_proto = object_graph_proto
181 self.matched_proto_ids = matched_proto_ids
182 self.unused_attributes = unused_attributes
184 def set_expect_partial(self, expect_partial):
185 self.expect_partial = expect_partial
187 def __del__(self):
188 if self.expect_partial:
189 return
190 if logging is None:
191 # The logging module may have been unloaded when __del__ is called.
192 log_fn = print
193 else:
194 log_fn = logging.warning
195 unused_nodes_in_checkpoint = []
196 unrestored_attributes_in_object = []
197 pretty_printer = ObjectGraphProtoPrettyPrinter(self.object_graph_proto)
198 for node_id, node in enumerate(self.object_graph_proto.nodes):
199 if not node.attributes:
200 continue
201 if node_id not in self.matched_proto_ids:
202 unused_nodes_in_checkpoint.append(pretty_printer.node_names[node_id])
203 for node_id, attribute_name in self.unused_attributes.items():
204 unrestored_attributes_in_object.append((
205 pretty_printer.node_names[node_id], attribute_name))
206 if unused_nodes_in_checkpoint or unrestored_attributes_in_object:
207 # pylint:disable=line-too-long
208 log_fn("Detecting that an object or model or tf.train.Checkpoint is being"
209 " deleted with unrestored values. See the following logs for the "
210 "specific values in question. To silence these warnings, use "
211 "`status.expect_partial()`. See "
212 "https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restore"
213 "for details about the status object returned by the restore "
214 "function.")
215 # pylint:enable=line-too-long
216 for node_path in unused_nodes_in_checkpoint:
217 log_fn("Value in checkpoint could not be found in the restored object: "
218 f"{node_path}")
219 for node_path, attr in unrestored_attributes_in_object:
220 log_fn("An attribute in the restored object could not be found in the "
221 f"checkpoint. Object: {node_path}, attribute: {attr}")
224class _CheckpointRestoreCoordinator:
225 """Holds the status of an object-based checkpoint load."""
227 def __init__(self, object_graph_proto, save_path, save_path_tensor, reader,
228 restore_op_cache, graph_view, options, saveables_cache):
229 """Specify the checkpoint being loaded.
231 Args:
232 object_graph_proto: The TrackableObjectGraph protocol buffer associated
233 with this checkpoint.
234 save_path: A string, the path to the checkpoint, as returned by
235 `tf.train.latest_checkpoint`.
236 save_path_tensor: A string `Tensor` which contains or will be fed the save
237 path.
238 reader: A `CheckpointReader` for `save_path`. If None,
239 `_CheckpointRestoreCoordinator` will initialize one itself.
240 restore_op_cache: A dictionary shared between
241 `_CheckpointRestoreCoordinator`s for the same Python objects, used to
242 look up restore ops by name to avoid re-creating them across multiple
243 `restore()` calls.
244 graph_view: A graph_view_lib.ObjectGraphView object for the restored
245 objects.
246 options: A CheckpointOptions object.
247 saveables_cache: An optional cache storing previously created
248 SaveableObjects created for each Trackable. Maps Trackables to a
249 dictionary of attribute names to Trackable.
250 """
251 self.options = options
252 self.object_graph_proto = object_graph_proto
253 self.restore_uid = ops.uid()
254 # Maps from proto ids to lists of attributes which were in the checkpoint
255 # but not loaded into any object, for error checking.
256 self.unused_attributes = {}
257 # Dictionary mapping from an id in the protocol buffer flat array to
258 # Trackable Python objects. This mapping may be deferred if a
259 # checkpoint is restored before all dependencies have been tracked. Uses
260 # weak references so that partial restorations don't create reference cycles
261 # (as objects with deferred dependencies will generally have references to
262 # this object).
263 self.object_by_proto_id = weakref.WeakValueDictionary()
264 self.matched_proto_ids = set()
265 # A set of all Python objects we've seen as dependencies, even if we didn't
266 # use them (for example because of inconsistent references when
267 # loading). Used to make status assertions fail when loading checkpoints
268 # that don't quite match.
269 self.all_python_objects = object_identity.ObjectIdentityWeakSet()
270 self.save_path_tensor = save_path_tensor
271 self.save_path_string = save_path
272 self.dtype_map = reader.get_variable_to_dtype_map()
273 self.shape_map = reader.get_variable_to_shape_map()
274 # A NewCheckpointReader for the most recent checkpoint, for streaming Python
275 # state restoration.
276 # When graph building, contains a list of ops to run to restore objects from
277 # this checkpoint.
278 self.restore_ops = []
279 self.restore_ops_by_name = restore_op_cache
280 self.graph_view = graph_view
281 self.new_restore_ops_callback = None
282 # A mapping from optimizer proto ids to lists of slot variables to be
283 # restored when the optimizer is tracked. Only includes slot variables whose
284 # regular variables have already been created, and only for optimizer
285 # objects which have not yet been created/tracked.
286 self.deferred_slot_restorations = {}
287 # A mapping from variable proto ids to lists of slot variables to be
288 # restored when the variable is created/tracked. These get shifted over to
289 # deferred_slot_restorations if the optimizer hasn't been created when that
290 # happens.
291 self.slot_restorations = {}
292 # Controls whether errors are printed in __del__ if some objects did not
293 # match.
294 self.expect_partial_attr = False
295 for node_index, node in enumerate(self.object_graph_proto.nodes):
296 for slot_reference in node.slot_variables:
297 # `node` refers to an `Optimizer`, since only these have slot variables.
298 self.slot_restorations.setdefault(
299 slot_reference.original_variable_node_id, []).append(
300 base._SlotVariableRestoration( # pylint: disable=protected-access
301 optimizer_id=node_index,
302 slot_variable_id=slot_reference.slot_variable_node_id,
303 slot_name=slot_reference.slot_name))
305 self._deleter = _CheckpointRestoreCoordinatorDeleter(
306 self.expect_partial_attr,
307 self.object_graph_proto,
308 self.matched_proto_ids,
309 self.unused_attributes)
311 self.saveables_cache = saveables_cache
313 @property
314 def expect_partial(self):
315 return self.expect_partial_attr
317 @expect_partial.setter
318 def expect_partial(self, expect_partial):
319 self.expect_partial_attr = expect_partial
320 self._deleter.set_expect_partial(expect_partial)
322 def new_restore_ops(self, new_ops):
323 self.restore_ops.extend(new_ops)
324 if self.new_restore_ops_callback:
325 self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable
327 def restore_saveables(
328 self,
329 tensor_saveables,
330 python_positions,
331 registered_savers=None,
332 reader=None,
333 ):
334 """Run or build restore operations for SaveableObjects.
336 Args:
337 tensor_saveables: `SaveableObject`s which correspond to Tensors.
338 python_positions: List of CheckpointPositions bound to `PythonState`
339 objects which must be restored eagerly.
340 registered_savers: a dict mapping saver names-> object name -> Trackable.
341 reader: A `CheckpointReader`. If None, a new instance will be created.
343 Returns:
344 When graph building, a list of restore operations, either cached or newly
345 created, to restore `tensor_saveables`.
346 """
347 if reader is None:
348 reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string)
350 restore_ops = []
351 # Eagerly run restorations for Python state.
352 for position in python_positions:
353 key = position.object_proto.attributes[0].checkpoint_key
354 position.trackable.deserialize(reader.get_tensor(key))
356 # If we have new SaveableObjects, extract and cache restore ops.
357 if tensor_saveables or registered_savers:
358 flat_saveables = saveable_object_util.validate_and_slice_inputs(
359 tensor_saveables)
360 new_restore_ops = functional_saver.MultiDeviceSaver.from_saveables(
361 flat_saveables,
362 registered_savers).restore(self.save_path_tensor, self.options)
363 if not context.executing_eagerly():
364 for name, restore_op in sorted(new_restore_ops.items()):
365 restore_ops.append(restore_op)
366 assert name not in self.restore_ops_by_name
367 self.restore_ops_by_name[name] = restore_op
368 return restore_ops
371class _NameBasedRestoreCoordinator:
372 """Keeps the status of a name-based checkpoint restore."""
374 def __init__(self, save_path, dtype_map=None):
375 self.save_path = save_path
376 self.dtype_map = dtype_map
377 # A map from trackable objects to unused attribute names. We don't have
378 # proto IDs when doing a name-based restore, so the map keys differ from
379 # those in _CheckpointRestoreCoordinator.
380 self.unused_attributes = object_identity.ObjectIdentityWeakKeyDictionary()
381 self.restore_uid = ops.uid()
383 def globally_named_object_attributes(self, trackable):
384 """Create globally named SaveableObjects from attributes.
386 If an object's attribute has no global name specified (default construction
387 for the SaveableObject factory), records the failure in
388 `self.unused_attributes` (which can then be used to make status assertions
389 fail; see `NameBasedSaverStatus`).
391 Args:
392 trackable: An object to save.
394 Yields:
395 SaveableObjects for `trackable`'s attributes.
396 """
397 for (
398 attribute_name,
399 saveable_factory,
400 ) in saveable_object_util.saveable_objects_from_trackable(
401 trackable, tf1_saver=True,
402 ).items():
403 if callable(saveable_factory):
404 try:
405 # This saveable object factory does not have a default name= argument,
406 # which means there's no way to save/restore it using a name-based
407 # checkpoint. Ignore the error now and make sure assert_consumed()
408 # fails.
409 saveable = saveable_factory()
410 except TypeError:
411 self.unused_attributes.setdefault(trackable,
412 []).append(attribute_name)
413 continue
414 else:
415 saveable = saveable_factory
416 names_to_saveables = saveable_object_util.op_list_to_dict(
417 [saveable], convert_variable_to_tensor=False)
418 for name, op in names_to_saveables.items():
419 for saveable_object in saveable_object_util.saveable_objects_for_op(
420 op=op, name=name):
421 yield saveable_object
423 def eager_restore(self, trackable):
424 """Runs restore ops for `trackable`'s attributes."""
425 # When graph building, we don't add any restore ops to the graph until
426 # run_restore_ops/initialize_or_restore on the status object for name-based
427 # checkpoints.
428 assert context.executing_eagerly()
429 for saveable in self.globally_named_object_attributes(trackable):
430 restored_tensors = []
431 tensor_missing = False
432 for spec in saveable.specs:
433 if spec.name in self.dtype_map:
434 with ops.device("cpu:0"):
435 restored, = io_ops.restore_v2(
436 prefix=self.save_path,
437 tensor_names=[spec.name],
438 shape_and_slices=[""],
439 dtypes=[self.dtype_map[spec.name]],
440 name="%s_checkpoint_read" % (spec.name,))
441 restored_tensors.append(array_ops.identity(restored))
442 else:
443 tensor_missing = True
445 if tensor_missing:
446 # Record that this variable didn't match so assertions will fail.
447 self.unused_attributes.setdefault(trackable, []).append(saveable.name)
448 else:
449 # Ignores values missing from the checkpoint, as with object-based
450 # restore. Status assertions can be used to check exact matches,
451 # although it's unlikely to ever happen for name-based checkpoints.
452 saveable.restore(
453 restored_tensors=restored_tensors, restored_shapes=None)
456# TODO(allenl): If this ends up in a public API, consider adding LINT.If Change
457# or consolidating the implementation with get_variable.
458def _default_getter(name,
459 shape,
460 dtype,
461 initializer=None,
462 partition_info=None,
463 **kwargs):
464 """A pared-down version of get_variable which does not reuse variables."""
465 dtype = dtypes.as_dtype(dtype)
466 shape_object = tensor_shape.as_shape(shape)
467 with ops.init_scope():
468 if initializer is None:
469 initializer, initializing_from_value = (
470 variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access
471 name=name,
472 shape=shape_object,
473 dtype=dtype))
474 else:
475 initializing_from_value = not callable(initializer)
476 # Same logic as get_variable
477 variable_dtype = dtype.base_dtype
478 if initializing_from_value:
479 if shape is not None:
480 raise ValueError("If initializer is a constant, do not specify shape.")
481 initial_value = initializer
482 else:
483 # Instantiate initializer if provided initializer is a type object.
484 if isinstance(initializer, type(init_ops.Initializer)):
485 initializer = initializer(dtype=dtype)
486 shape_list = None if shape is None else shape_object.as_list()
487 if "partition_info" in tf_inspect.getargspec(initializer).args:
488 initial_value = functools.partial(initializer,
489 shape_list,
490 dtype=dtype,
491 partition_info=partition_info)
492 else:
493 initial_value = functools.partial(initializer,
494 shape_list,
495 dtype=dtype)
497 return variable_v1.VariableV1(
498 initial_value=initial_value,
499 name=name,
500 dtype=variable_dtype,
501 use_resource=True,
502 **kwargs)
505def add_variable(trackable,
506 name,
507 shape=None,
508 dtype=dtypes.float32,
509 initializer=None,
510 trainable=True):
511 """Add a variable to a Trackable with no scope influence."""
512 return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access
513 name=name,
514 shape=shape,
515 dtype=dtype,
516 initializer=initializer,
517 getter=_default_getter,
518 trainable=trainable)
521def object_metadata(save_path):
522 """Retrieves information about the objects in a checkpoint.
524 Example usage:
526 ```python
527 object_graph = tf.contrib.checkpoint.object_metadata(
528 tf.train.latest_checkpoint(checkpoint_directory))
529 ckpt_variable_names = set()
530 for node in object_graph.nodes:
531 for attribute in node.attributes:
532 ckpt_variable_names.add(attribute.full_name)
533 ```
535 Args:
536 save_path: The path to the checkpoint, as returned by `save` or
537 `tf.train.latest_checkpoint`.
539 Returns:
540 A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer.
541 Raises:
542 ValueError: If an object graph was not found in the checkpoint.
543 """
544 reader = py_checkpoint_reader.NewCheckpointReader(save_path)
545 try:
546 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
547 except errors_impl.NotFoundError:
548 raise ValueError(
549 f"The specified checkpoint \"{save_path}\" does not appear to be "
550 "object-based (saved with TF2) since it is missing the key "
551 f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the "
552 "TF1 name-based saver and does not contain an object dependency graph.")
553 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
554 object_graph_proto.ParseFromString(object_graph_string)
555 return object_graph_proto
558def list_objects(root_trackable):
559 """Traverse the object graph and list all accessible objects.
561 Looks for `Trackable` objects which are dependencies of
562 `root_trackable`. Includes slot variables only if the variable they are
563 slotting for and the optimizer are dependencies of `root_trackable`
564 (i.e. if they would be saved with a checkpoint).
566 Args:
567 root_trackable: A `Trackable` object whose dependencies should be flattened.
569 Returns:
570 A flat list of objects.
571 """
572 return util.list_objects(graph_view_lib.ObjectGraphView(root_trackable))
575def gather_initializers(root_trackable):
576 """Traverse the object graph and find initialization ops.
578 Looks for `Trackable` objects which are dependencies of
579 `root_trackable` and which have an `initializer` property. Includes
580 initializers for slot variables only if the variable they are slotting for and
581 the optimizer are dependencies of `root_trackable` (i.e. if they would be
582 saved with a checkpoint).
584 Args:
585 root_trackable: A `Trackable` object to gather initializers for.
587 Returns:
588 A list of initialization ops.
589 """
590 trackable_objects = list_objects(root_trackable)
591 return [
592 c.initializer
593 for c in trackable_objects
594 if hasattr(c, "initializer") and c.initializer is not None
595 ]
598@tf_contextlib.contextmanager
599def capture_dependencies(template):
600 """Capture variables created within this scope as `Template` dependencies.
602 Requires that `template.variable_scope` is active.
604 This scope is intended as a compatibility measure, allowing a trackable
605 object to add dependencies on variables created in a block of code which is
606 not aware of object-based saving (and instead uses variable names
607 heavily). This is how `Template` objects add dependencies on variables and
608 sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly.
610 Args:
611 template: The `Template` object to register dependencies with.
613 Yields:
614 None (when used as a context manager).
615 """
616 name_prefix = template.variable_scope.name
618 def _trackable_custom_creator(next_creator,
619 name,
620 initial_value,
621 trackable_parent=None,
622 **kwargs):
623 """A variable creation hook which adds Trackable dependencies.
625 Set for example during a `Template`'s first wrapped function
626 execution. Ensures that (a) `template` depends on any trackable
627 objects using their own `capture_dependencies` scope inside this scope which
628 create variables, and (b) that any variables not in a more deeply nested
629 scope are added as dependencies directly.
631 The `trackable_parent` argument is passed between custom creators but
632 ignored when the variable object itself is created. This argument indicates
633 (if not `None`) that a more deeply nested scope has already added the
634 variable as a dependency, and that parent scopes should add a dependency on
635 that object rather than on the variable directly.
637 Args:
638 next_creator: See `variable_scope.variable_creator_scope`; the next
639 creator in the chain.
640 name: The (full, scope-influenced) name of the variable. The `name_prefix`
641 itself is stripped for the purposes of object-based dependency tracking,
642 but scopes opened within this scope are respected.
643 initial_value: See `variable_scope.variable_creator_scope`. Taken
644 explicitly so the argument can be re-named and used with
645 `Trackable._add_variable_with_custom_getter`.
646 trackable_parent: If not None, a more deeply nested trackable object and
647 its name prefix which were passed to `capture_dependencies` to add a
648 dependency on (rather than depending on the variable directly).
649 **kwargs: Passed through to the next creator.
651 Returns:
652 The output of `next_creator`: the fetched/created variable object.
653 """
655 def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
656 inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
657 # we don't want to propagate.
658 return next_creator(initial_value=initializer, name=name, **inner_kwargs)
660 if name is not None and name.startswith(name_prefix):
661 scope_stripped_name = name[len(name_prefix) + 1:]
662 if not trackable_parent:
663 return template._add_variable_with_custom_getter( # pylint: disable=protected-access
664 initializer=initial_value,
665 name=scope_stripped_name,
666 getter=_call_next_creator_renaming_initializer,
667 # Disable error checking for Trackable. Exceptions are instead
668 # raised if necessary when the object-based saver tries to
669 # save/restore the object.
670 overwrite=True,
671 trackable_parent=(template, name_prefix),
672 **kwargs)
673 else:
674 parent_object, parent_name_prefix = trackable_parent
675 template._track_trackable( # pylint: disable=protected-access
676 parent_object,
677 name=parent_name_prefix[len(name_prefix) + 1:],
678 overwrite=True)
679 return next_creator(
680 name=name,
681 initial_value=initial_value,
682 trackable_parent=(template, name_prefix),
683 **kwargs)
685 with variable_scope.variable_creator_scope(_trackable_custom_creator):
686 yield
689class _LoadStatus:
690 """Abstract base for load status callbacks."""
692 @abc.abstractmethod
693 def assert_consumed(self):
694 """Raises an exception unless a non-trivial restoration has completed."""
695 pass
697 @abc.abstractmethod
698 def assert_existing_objects_matched(self):
699 """Raises an exception unless existing Python objects have been matched."""
700 pass
702 @abc.abstractmethod
703 def assert_nontrivial_match(self):
704 """Raises an exception if only the root object matched."""
705 pass
707 @abc.abstractmethod
708 def run_restore_ops(self, session=None):
709 """Runs restore ops from the checkpoint. Requires a valid checkpoint."""
710 pass
712 @abc.abstractmethod
713 def initialize_or_restore(self, session=None):
714 """Runs restore ops from the checkpoint, or initializes variables."""
715 pass
717 def expect_partial(self):
718 """Silence warnings about incomplete checkpoint restores."""
719 return self
722@tf_export("__internal__.tracking.streaming_restore", v1=[])
723def streaming_restore(status, session=None):
724 """When graph building, runs restore ops as soon as they come in.
726 Args:
727 status: A _LoadStatus objects from an object-based saver's restore().
728 Streaming restore from name-based checkpoints is not currently supported.
729 session: A session to run new restore ops in.
730 """
731 if context.executing_eagerly():
732 # Streaming restore is the default/only behavior when executing eagerly.
733 return
734 if session is None:
735 session = get_session()
736 if isinstance(status, NameBasedSaverStatus):
737 raise NotImplementedError(
738 "Streaming restore not supported from name-based checkpoints when "
739 "graph building. File a feature request if this limitation bothers "
740 "you. As a workaround, consider either using tf.train.Checkpoint to "
741 "load name-based checkpoints or enabling eager execution.")
742 status.run_restore_ops(session=session)
743 # pylint: disable=protected-access
744 status._checkpoint.new_restore_ops_callback = (
745 lambda ops: session.run(ops, feed_dict=status._feed_dict))
746 # pylint: enable=protected-access
749def _objects_with_attributes(full_list):
750 """Filters out objects with no direct variable dependencies for assertions."""
751 return [
752 o for o in full_list
753 if saveable_object_util.saveable_objects_from_trackable(o)
754 ]
757class CheckpointLoadStatus(_LoadStatus):
758 """Checks the status of checkpoint loading and manages restore ops.
760 Returned from `Saver.restore`. Since `restore` may defer the loading of values
761 in the checkpoint which don't yet have corresponding Python objects,
762 `CheckpointLoadStatus` provides a callback to verify that checkpoint loading
763 is complete (`assert_consumed`).
765 When graph building, `restore` does not run restore ops itself since their
766 creation may be deferred. The `run_restore_ops` method must be called once all
767 Python objects with values to restore have been created and added to the
768 dependency graph (this does not necessarily have to be the whole checkpoint;
769 calling `run_restore_ops` while `assert_consumed` fails is supported and will
770 partially restore the checkpoint).
772 See `Saver.restore` for usage examples.
773 """
775 def __init__(self, checkpoint, feed_dict, graph_view):
776 self._checkpoint = checkpoint
777 self._feed_dict = feed_dict
778 self._object_graph_view = graph_view
779 # Keep a reference to the root, since object_graph_view might only have a
780 # weakref.
781 self._root = graph_view.root
783 def assert_consumed(self):
784 """Asserts that all objects in the checkpoint have been created/matched.
786 Returns:
787 `self` for chaining.
788 Raises:
789 AssertionError: If there are any Python objects in the dependency graph
790 which have not been restored from this checkpoint or a later `restore`,
791 or if there are any checkpointed values which have not been matched to
792 Python objects.
793 """
794 pretty_printer = ObjectGraphProtoPrettyPrinter(
795 self._checkpoint.object_graph_proto)
796 self.assert_existing_objects_matched()
797 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
798 if not node.attributes:
799 # Only raise exceptions for the nodes with attributes themselves. Either
800 # they're ultimately not important, or they have a child with an
801 # attribute.
802 continue
803 trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
804 if trackable is None:
805 raise AssertionError(
806 "Unresolved object in checkpoint "
807 f"{pretty_printer.node_names[node_id]}: {node}")
808 if self._checkpoint.slot_restorations:
809 # Sanity check; this collection should be clear if everything has been
810 # restored.
811 raise AssertionError(
812 f"Unresolved slot restorations: {self._checkpoint.slot_restorations}")
813 if self._checkpoint.unused_attributes:
814 unused_attribute_messages = []
815 for node_id, attribute in self._checkpoint.unused_attributes.items():
816 obj = self._checkpoint.object_by_proto_id[node_id]
817 unused_attribute_messages.append(
818 f"{pretty_printer.node_names[node_id]} ({obj}): {attribute}")
819 joined_attribute_messages = "\n".join(unused_attribute_messages)
820 raise AssertionError(
821 "Unused attributes in these objects (the attributes exist in the "
822 f"checkpoint but were not restored):\n{joined_attribute_messages}")
823 return self
825 def assert_existing_objects_matched(self):
826 """Asserts that trackable Python objects have been matched.
828 Note that this is a weaker assertion than `assert_consumed`. It will only
829 fail for existing Python objects which are (transitive) dependencies of the
830 root object and which do not have an entry in the checkpoint.
832 It will not fail, for example, if a `tf.keras.Layer` object has not yet been
833 built and so has not created any `tf.Variable` objects.
835 Returns:
836 `self` for chaining.
838 Raises:
839 AssertionError: If a Python object exists in the transitive dependencies
840 of the root object but does not have a value in the checkpoint.
841 """
842 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
843 trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
844 if (trackable is not None and
845 trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access
846 raise AssertionError(
847 f"Object {node} not assigned a value from checkpoint.")
848 for trackable_object in util.list_objects(self._object_graph_view):
849 # Remove data structures that do not contain any variables from
850 # restoration checks.
851 if (isinstance(trackable_object,
852 data_structures.TrackableDataStructure) and
853 not trackable_object._trackable_children( # pylint: disable=protected-access
854 save_type=base.SaveType.CHECKPOINT)):
855 continue
856 self._checkpoint.all_python_objects.add(trackable_object)
857 unused_python_objects = (
858 object_identity.ObjectIdentitySet(
859 _objects_with_attributes(
860 self._checkpoint.all_python_objects)) -
861 object_identity.ObjectIdentitySet(
862 self._checkpoint.object_by_proto_id.values()))
863 if unused_python_objects:
864 num_unused_python_objects = len(list(unused_python_objects))
865 # Display max number of 10 variables in error message.
866 num_variables_to_show = min(10, num_unused_python_objects)
867 raise AssertionError(
868 f"Found {num_unused_python_objects} Python objects that were "
869 "not bound to checkpointed values, likely due to changes in the "
870 f"Python program. Showing {num_variables_to_show} of "
871 f"{num_unused_python_objects} unmatched objects: "
872 f"{list(unused_python_objects)[:num_variables_to_show]}")
873 return self
875 def assert_nontrivial_match(self):
876 """Raises an exception if only the root object matched."""
877 for trackable_object in util.list_objects(self._object_graph_view):
878 self._checkpoint.all_python_objects.add(trackable_object)
879 if len(self._checkpoint.object_by_proto_id) <= 1:
880 unused_python_objects = (
881 object_identity.ObjectIdentitySet(
882 _objects_with_attributes(self._checkpoint.all_python_objects)) -
883 object_identity.ObjectIdentitySet(
884 self._checkpoint.object_by_proto_id.values()))
885 if unused_python_objects:
886 raise AssertionError(
887 "Nothing except the root object matched a checkpointed value. "
888 "Typically this means that the checkpoint does not match the "
889 "Python program. The following objects have no matching "
890 f"checkpointed value: {list(unused_python_objects)}")
891 else:
892 raise AssertionError(
893 "Nothing to load. No dependencies have been added to "
894 f"{self._object_graph_view.root} yet.")
895 return self
897 def run_restore_ops(self, session=None):
898 """Run operations to restore objects in the dependency graph."""
899 if context.executing_eagerly():
900 return # Run eagerly
901 if session is None:
902 session = get_session()
903 session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
905 def initialize_or_restore(self, session=None):
906 """Run operations to initialize or restore objects in the dependency graph.
908 Any objects in the dependency graph which have initializers but are not in
909 the checkpoint will have those initializers run, unless those variables are
910 being restored by a later call to `tf.train.Checkpoint.restore()`.
912 This method has a sibling in `InitializationOnlyStatus` which instead
913 initializes variables. That type is returned if no checkpoint is specified
914 in `Saver.restore`.
916 Args:
917 session: The session to run init/restore ops in. If `None`, uses the
918 default session.
919 """
920 if context.executing_eagerly():
921 return # Initialization and restoration ops are run eagerly
922 if session is None:
923 session = get_session()
924 all_objects = util.list_objects(self._object_graph_view)
925 already_initialized_objects = object_identity.ObjectIdentitySet(
926 self._checkpoint.object_by_proto_id.values())
927 initializers_for_non_restored_variables = [
928 c.initializer for c in all_objects
929 if hasattr(c, "initializer")
930 and c not in already_initialized_objects
931 and (getattr(c, "_update_uid", self._checkpoint.restore_uid - 1)
932 < self._checkpoint.restore_uid)
933 ]
934 self.run_restore_ops(session=session)
935 session.run(initializers_for_non_restored_variables)
937 def expect_partial(self):
938 """Silence warnings about incomplete checkpoint restores."""
939 self._checkpoint.expect_partial = True
940 return self
943class InitializationOnlyStatus(_LoadStatus):
944 """Returned from `Saver.restore` when no checkpoint has been specified.
946 Objects of this type have the same `assert_consumed` method as
947 `CheckpointLoadStatus`, but it always fails. However,
948 `initialize_or_restore` works on objects of both types, and will
949 initialize variables in `InitializationOnlyStatus` objects or restore them
950 otherwise.
951 """
953 def __init__(self, object_graph_view, restore_uid):
954 self._restore_uid = restore_uid
955 self._object_graph_view = object_graph_view
956 # Keep a reference to the root, since graph_view might only have a weakref.
957 self._root = object_graph_view.root
959 def assert_consumed(self):
960 """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
961 raise AssertionError(
962 "No checkpoint specified (save_path=None); nothing is being restored.")
964 def assert_existing_objects_matched(self):
965 """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
966 raise AssertionError(
967 "No checkpoint specified (save_path=None); nothing is being restored.")
969 def assert_nontrivial_match(self):
970 """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
971 raise AssertionError(
972 "No checkpoint specified (save_path=None); nothing is being restored.")
974 def run_restore_ops(self, session=None):
975 """For consistency with `CheckpointLoadStatus`.
977 Use `initialize_or_restore` for initializing if no checkpoint was passed
978 to `Saver.restore` and restoring otherwise.
980 Args:
981 session: Not used.
982 """
983 raise AssertionError(
984 "No checkpoint specified, so no restore ops are available "
985 "(save_path=None to Saver.restore).")
987 def initialize_or_restore(self, session=None):
988 """Runs initialization ops for variables.
990 Objects which would be saved by `Saver.save` will be initialized, unless
991 those variables are being restored by a later call to
992 `tf.train.Checkpoint.restore()`.
994 This method does nothing when executing eagerly (initializers get run
995 eagerly).
997 Args:
998 session: The session to run initialization ops in. If `None`, uses the
999 default session.
1000 """
1001 if context.executing_eagerly():
1002 return # run eagerly
1003 if session is None:
1004 session = get_session()
1005 trackable_objects = util.list_objects(self._object_graph_view)
1006 initializers = [
1007 c.initializer for c in trackable_objects
1008 if hasattr(c, "initializer") and c.initializer is not None
1009 and (getattr(c, "_update_uid", self._restore_uid - 1)
1010 < self._restore_uid)
1011 ]
1012 session.run(initializers)
1015_DEPRECATED_RESTORE_INSTRUCTIONS = (
1016 "Restoring a name-based tf.train.Saver checkpoint using the object-based "
1017 "restore API. This mode uses global names to match variables, and so is "
1018 "somewhat fragile. It also adds new restore ops to the graph each time it "
1019 "is called when graph building. Prefer re-encoding training checkpoints in "
1020 "the object-based format: run save() on the object-based saver (the same "
1021 "one this message is coming from) and use that checkpoint in the future.")
1024class NameBasedSaverStatus(_LoadStatus):
1025 """Status for loading a name-based training checkpoint."""
1027 # Ideally this deprecation decorator would be on the class, but that
1028 # interferes with isinstance checks.
1029 @deprecation.deprecated(
1030 date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
1031 def __init__(self, checkpoint, object_graph_view):
1032 self._checkpoint = checkpoint
1033 self._object_graph_view = object_graph_view
1034 self._optionally_restored = []
1035 # Keep a reference to the root, since graph_view might only have a weakref.
1036 self._root = object_graph_view.root
1038 def add_to_optionally_restored(self, var):
1039 """Add a variable to the list of optionally restored variables.
1041 There are situations where certain variables should be ignored in assertions
1042 such as assert_existing_objects_matched(). One example is that of a
1043 checkpoint saved with train.Saver(), and restored with train.Checkpoint():
1044 it is possible for the train.Saver() checkpoint to be missing the internal
1045 `save_counter` variable, which we want to ignore on restore.
1047 Args:
1048 var: The variable to treat as optionally restored.
1049 """
1050 self._optionally_restored.append(var)
1052 def assert_consumed(self):
1053 """Raises an exception if any variables are unmatched."""
1054 unused_attributes = list(self._checkpoint.unused_attributes.items())
1055 unused_attributes = [
1056 a for a in unused_attributes
1057 if all(a[0] is not x for x in self._optionally_restored)
1058 ]
1059 if unused_attributes:
1060 unused_attribute_string = "".join(
1061 f"\n {obj}: {attributes}" for obj, attributes in unused_attributes)
1062 raise AssertionError(
1063 "Some objects had attributes which were not restored: "
1064 f"{unused_attribute_string}")
1065 for trackable in util.list_objects(self._object_graph_view):
1066 # pylint: disable=protected-access
1067 trackable._maybe_initialize_trackable()
1068 if trackable._update_uid < self._checkpoint.restore_uid:
1069 raise AssertionError(f"Object not restored: {trackable}")
1070 # pylint: enable=protected-access
1071 return self
1073 def assert_existing_objects_matched(self):
1074 """Raises an exception if currently created objects are unmatched."""
1075 # For name-based checkpoints there's no object information in the
1076 # checkpoint, so there's no distinction between
1077 # assert_existing_objects_matched and assert_consumed (and both are less
1078 # useful since we don't touch Python objects or Python state).
1079 return self.assert_consumed()
1081 def assert_nontrivial_match(self):
1082 """Raises an exception if currently created objects are unmatched."""
1083 # For name-based checkpoints there's no object information in the
1084 # checkpoint, so there's no distinction between
1085 # assert_nontrivial_match and assert_consumed (and both are less
1086 # useful since we don't touch Python objects or Python state).
1087 return self.assert_consumed()
1089 def _gather_saveable_objects(self):
1090 """Walk the object graph, using global names for SaveableObjects."""
1091 objects = util.list_objects(self._object_graph_view)
1092 saveable_objects = []
1093 for trackable in objects:
1094 # pylint: disable=protected-access
1095 trackable._maybe_initialize_trackable()
1096 if trackable._update_uid < self._checkpoint.restore_uid:
1097 trackable._update_uid = self._checkpoint.restore_uid
1098 else:
1099 continue
1100 # pylint: enable=protected-access
1101 saveable_objects.extend(
1102 self._checkpoint.globally_named_object_attributes(trackable))
1103 return saveable_objects
1105 def run_restore_ops(self, session=None):
1106 """Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`."""
1107 if context.executing_eagerly():
1108 return # Nothing to do, variables are restored on creation.
1109 if session is None:
1110 session = get_session()
1111 with ops.device("/cpu:0"):
1112 saveables = self._gather_saveable_objects()
1113 v1_saver_lib.Saver(saveables).restore(
1114 sess=session, save_path=self._checkpoint.save_path)
1116 def initialize_or_restore(self, session=None):
1117 """Alias for `run_restore_ops`."""
1118 self.run_restore_ops(session=session)
1121class _SessionWithFeedDictAdditions(session_lib.SessionInterface):
1122 """Pretends to be a session, inserts extra feeds on run()."""
1124 def __init__(self, session, feed_additions):
1125 self._wrapped_session = session
1126 self._feed_additions = feed_additions
1128 def run(self, fetches, feed_dict=None, **kwargs):
1129 if feed_dict is None:
1130 feed_dict = {}
1131 else:
1132 feed_dict = feed_dict.copy()
1133 feed_dict.update(self._feed_additions)
1134 return self._wrapped_session.run(
1135 fetches=fetches, feed_dict=feed_dict, **kwargs)
1138class TrackableSaver:
1139 """Saves and restores a `Trackable` object and its dependencies.
1141 See `Trackable` for details of dependency management. `Saver` wraps
1142 `tf.compat.v1.train.Saver` for saving, including extra information about the
1143 graph of
1144 dependencies between Python objects. When restoring, it uses this information
1145 about the save-time dependency graph to more robustly match objects with their
1146 checkpointed values. When executing eagerly, it supports restoring variables
1147 on object creation (see `Saver.restore`).
1149 Values in a checkpoint are mapped to `Trackable` Python objects
1150 (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the
1151 checkpoint was written. To avoid breaking existing checkpoints when modifying
1152 a class, dependency names (the names of attributes to which `Trackable`
1153 objects are assigned) may not change. These names are local to objects, in
1154 contrast to the `Variable.name`-based save/restore from
1155 `tf.compat.v1.train.Saver`, and
1156 so allow additional program transformations.
1157 """
1159 def __init__(self, graph_view):
1160 """Configure saving.
1162 Args:
1163 graph_view: An `ObjectGraphView` object containing a description of the
1164 object graph to save.
1165 """
1166 self._graph_view = graph_view
1168 # The following attributes are used when graph building.
1170 # self._cache: A more generic cache used to cache the serialized tensors and
1171 # TrackableObjectGraph proto attributes.
1172 # self._saveables_cache: A dictionary mapping `Trackable` objects ->
1173 # attribute names -> SaveableObjects, used to avoid re-creating
1174 # SaveableObjects when graph building.
1175 if context.executing_eagerly():
1176 self._cache = None
1177 self._saveables_cache = None
1178 else:
1179 self._cache = object_identity.ObjectIdentityWeakKeyDictionary()
1180 self._saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
1182 # The file prefix placeholder is created lazily when graph building (and not
1183 # at all when executing eagerly) to avoid creating ops in the constructor
1184 # (when they may never be necessary).
1185 self._file_prefix_placeholder = None
1187 # Op caching for save
1188 self._object_graph_feed_tensor = None
1189 self._last_save_object_graph = None
1190 self._file_prefix_feed_tensor = None
1191 self._cached_save_operation = None
1193 # Op caching for restore, shared between _CheckpointRestoreCoordinators
1194 self._restore_op_cache = {}
1196 # Object map used for checkpoint. This attribute is to be overridden by a
1197 # Checkpoint subclass, e.g., AsyncCheckpoint, to replace the trackable
1198 # objects for checkpoint saving.
1199 self._object_map = None
1201 def _gather_serialized_tensors(self, object_graph_tensor=None):
1202 """Gathers tensors to save to ckpt and includes the object graph proto."""
1203 serialized_tensors, feed_additions, registered_savers, graph_proto = (
1204 save_util.serialize_graph_view(self._graph_view,
1205 self._object_map,
1206 cache=self._cache))
1208 if self._saveables_cache is not None:
1209 # Store saveables cache for restoration purposes.
1210 self._saveables_cache = (
1211 saveable_object_util.serialized_tensors_to_saveable_cache(
1212 serialized_tensors))
1214 if object_graph_tensor is None:
1215 with ops.device("/cpu:0"):
1216 object_graph_tensor = constant_op.constant(
1217 graph_proto.SerializeToString(), dtype=dtypes.string)
1218 else:
1219 feed_additions.update(
1220 {object_graph_tensor: graph_proto.SerializeToString()})
1221 assert base.OBJECT_GRAPH_PROTO_KEY not in serialized_tensors.get(None, {})
1222 serialized_tensors.setdefault(None, {})[base.OBJECT_GRAPH_PROTO_KEY] = (
1223 object_graph_tensor)
1224 return serialized_tensors, feed_additions, registered_savers, graph_proto
1226 def _save_cached_when_graph_building(self, file_prefix, object_graph_tensor,
1227 options):
1228 """Create or retrieve save ops.
1230 Args:
1231 file_prefix: The prefix for saved checkpoint files.
1232 object_graph_tensor: A `Tensor` to which the current object graph will be
1233 fed.
1234 options: `CheckpointOptions` object.
1236 Returns:
1237 A two-element tuple with a filename tensor and a feed_dict of tensors to
1238 feed when running it (if graph building). The feed dict contains the
1239 current object graph and any Python state to be saved in the
1240 checkpoint. When executing eagerly only the first argument is meaningful.
1241 """
1242 serialized_tensors, feed_additions, registered_savers, graph_proto = (
1243 self._gather_serialized_tensors(object_graph_tensor))
1245 if (self._last_save_object_graph != graph_proto
1246 # When executing eagerly, we need to re-create SaveableObjects each
1247 # time save() is called so they pick up new Tensors passed to their
1248 # constructors. That means the Saver needs to be copied with a new
1249 # var_list.
1250 or context.executing_eagerly() or ops.inside_function()):
1251 saver = functional_saver.MultiDeviceSaver(serialized_tensors,
1252 registered_savers)
1253 save_op = saver.save(file_prefix, options=options)
1254 with ops.device("/cpu:0"):
1255 with ops.control_dependencies([save_op]):
1256 self._cached_save_operation = array_ops.identity(file_prefix)
1257 self._last_save_object_graph = graph_proto
1258 return self._cached_save_operation, feed_additions
1260 def save(self,
1261 file_prefix,
1262 checkpoint_number=None,
1263 session=None,
1264 options=None):
1265 """Save a training checkpoint.
1267 The saved checkpoint includes variables created by this object and any
1268 Trackable objects it depends on at the time `Saver.save()` is called.
1270 Args:
1271 file_prefix: A prefix to use for the checkpoint filenames
1272 (/path/to/directory/and_a_prefix). Names are generated based on this
1273 prefix and `checkpoint_number`, if provided.
1274 checkpoint_number: An integer variable or Tensor, used to number
1275 checkpoints. Typically this value is saved along with other variables in
1276 training checkpoints, which will happen automatically if it was created
1277 by `root_trackable` or one of its dependencies (via
1278 `Trackable._add_variable`).
1279 session: The session to evaluate variables in. Ignored when executing
1280 eagerly. If not provided when graph building, the default session is
1281 used.
1282 options: Optional `tf.train.CheckpointOptions` object.
1284 Returns:
1285 The full path to the checkpoint.
1287 Raises:
1288 RuntimeError: if called in V1 Graph mode without a default session.
1289 """
1290 options = options or checkpoint_options.CheckpointOptions()
1291 feed_dict = {}
1292 use_session = (not context.executing_eagerly() and
1293 not ops.inside_function())
1294 if checkpoint_number:
1295 file_prefix = "%s-%d" % (file_prefix, checkpoint_number)
1296 if use_session:
1297 if self._object_graph_feed_tensor is None:
1298 with ops.device("/cpu:0"):
1299 self._object_graph_feed_tensor = constant_op.constant(
1300 "", dtype=dtypes.string)
1301 self._file_prefix_feed_tensor = constant_op.constant(
1302 "", dtype=dtypes.string)
1303 object_graph_tensor = self._object_graph_feed_tensor
1304 file_prefix_tensor = self._file_prefix_feed_tensor
1305 feed_dict[file_prefix_tensor] = file_prefix
1306 else:
1307 with ops.device("/cpu:0"):
1308 file_prefix_tensor = ops.convert_to_tensor(
1309 file_prefix, dtype=dtypes.string)
1310 object_graph_tensor = None
1312 if not tensor_util.is_tensor(file_prefix):
1313 file_io.recursive_create_dir(os.path.dirname(file_prefix))
1315 save_path, new_feed_additions = self._save_cached_when_graph_building(
1316 file_prefix_tensor, object_graph_tensor, options)
1318 if new_feed_additions:
1319 feed_dict.update(new_feed_additions)
1320 if not use_session:
1321 session = None
1322 elif session is None:
1323 session = get_session()
1325 if session:
1326 return session.run(save_path, feed_dict=feed_dict)
1327 elif use_session:
1328 raise RuntimeError(f"Unable to save checkpoint to \"{file_prefix}\" "
1329 "in graph mode without a default session. Please use "
1330 "`with tf.Session():` to create a session.")
1331 else:
1332 return save_path
1334 def restore(self, save_path, options=None):
1335 """Restore a training checkpoint.
1337 Restores `root_trackable` and any objects that it tracks
1338 (transitive). Either assigns values immediately if variables to restore have
1339 been created already, or defers restoration until the variables are
1340 created. Dependencies added to the `root_trackable` passed to the
1341 constructor after this call will be matched if they have a corresponding
1342 object in the checkpoint.
1344 When building a graph, restorations are added to the graph but not run.
1346 ```python
1347 saver = Saver(root)
1348 saver.restore(path)
1349 ```
1351 To ensure that loading is complete and no more deferred restorations will
1352 take place, you can use the `assert_consumed()` method of the status object
1353 returned by the `restore` call.
1355 The assert will raise an exception unless every object was matched and all
1356 checkpointed values have a matching variable object.
1358 ```python
1359 saver = Saver(root)
1360 saver.restore(path).assert_consumed()
1361 ```
1363 When graph building, `assert_consumed()` indicates that all of the restore
1364 ops which will be created for this checkpoint have been created. They can be
1365 run via the `run_restore_ops()` function of the status object:
1367 ```python
1368 saver.restore(path).assert_consumed().run_restore_ops()
1369 ```
1371 If the checkpoint has not been consumed completely, then the list of restore
1372 ops will grow as more objects are added to the dependency graph.
1374 Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
1375 method. There is no deferred loading, and names are used to match
1376 variables. No restore ops are created/run until `run_restore_ops()` or
1377 `initialize_or_restore()` are called on the returned status object, even
1378 when executing eagerly. Re-encode name-based checkpoints using this
1379 object-based `Saver.save` as soon as possible.
1381 Args:
1382 save_path: The path to the checkpoint, as returned by `save` or
1383 `tf.train.latest_checkpoint`. If None (as when there is no latest
1384 checkpoint for `tf.train.latest_checkpoint` to return), returns an
1385 object which may run initializers for objects in the dependency graph.
1386 If the checkpoint was written by the name-based
1387 `tf.compat.v1.train.Saver`, names are used to match variables.
1388 options: Optional `tf.train.CheckpointOptions` object.
1390 Returns:
1391 A load status object, which can be used to make assertions about the
1392 status of checkpoint restoration and run initialization/restore ops
1393 (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if
1394 `save_path` is `None`).
1396 If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus`
1397 object is returned which runs restore ops from a name-based saver.
1399 Raises:
1400 RuntimeError: When a checkpoint file saved by async checkpoint is not
1401 available upon restore().
1402 """
1403 options = options or checkpoint_options.CheckpointOptions()
1404 if save_path is None:
1405 return InitializationOnlyStatus(self._graph_view, ops.uid())
1407 # Wait until the ongoing checkpoint to finish.
1408 # TODO(chienchunh): Allow to load the file while other checkpoint events
1409 # are still ongiing. Need to add timeout mechanism along
1410 # with conditional variables to notify when the checkpoint
1411 # file is ready.
1412 global _ASYNC_CHECKPOINT_THREAD
1413 if _ASYNC_CHECKPOINT_THREAD is not None:
1414 _ASYNC_CHECKPOINT_THREAD.join()
1415 reader = py_checkpoint_reader.NewCheckpointReader(save_path)
1416 graph_building = not context.executing_eagerly()
1417 if graph_building:
1418 dtype_map = None
1419 else:
1420 dtype_map = reader.get_variable_to_dtype_map()
1421 try:
1422 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
1423 except errors_impl.NotFoundError:
1424 # The object graph proto does not exist in this checkpoint. Try the
1425 # name-based compatibility mode.
1426 restore_coordinator = _NameBasedRestoreCoordinator(
1427 save_path=save_path,
1428 dtype_map=dtype_map)
1429 if not graph_building:
1430 for existing_trackable in util.list_objects(self._graph_view):
1431 # pylint: disable=protected-access
1432 existing_trackable._maybe_initialize_trackable()
1433 existing_trackable._name_based_restores.add(restore_coordinator)
1434 existing_trackable._name_based_attribute_restore(restore_coordinator)
1435 # pylint: enable=protected-access
1436 return NameBasedSaverStatus(
1437 restore_coordinator,
1438 object_graph_view=self._graph_view)
1440 if graph_building:
1441 if self._file_prefix_placeholder is None:
1442 with ops.device("/cpu:0"):
1443 self._file_prefix_placeholder = constant_op.constant("model")
1444 file_prefix_tensor = self._file_prefix_placeholder
1445 file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
1446 else:
1447 with ops.device("/cpu:0"):
1448 file_prefix_tensor = constant_op.constant(save_path)
1449 file_prefix_feed_dict = None
1450 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
1451 object_graph_proto.ParseFromString(object_graph_string)
1452 checkpoint = _CheckpointRestoreCoordinator(
1453 object_graph_proto=object_graph_proto,
1454 save_path=save_path,
1455 save_path_tensor=file_prefix_tensor,
1456 reader=reader,
1457 restore_op_cache=self._restore_op_cache,
1458 graph_view=self._graph_view,
1459 options=options,
1460 saveables_cache=self._saveables_cache)
1461 restore_lib.CheckpointPosition(
1462 checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root,
1463 reader)
1465 # Attached dependencies are not attached to the root, so should be restored
1466 # separately.
1467 if self._graph_view.attached_dependencies:
1468 for ref in self._graph_view.attached_dependencies:
1469 if ref.name == "root":
1470 # Root dependency is automatically added to attached dependencies --
1471 # this can be ignored since it maps back to the root object.
1472 continue
1473 proto_id = None
1474 # Find proto ID of attached dependency (if it is in the proto).
1475 for proto_ref in object_graph_proto.nodes[0].children:
1476 if proto_ref.local_name == ref.name:
1477 proto_id = proto_ref.node_id
1478 break
1480 if proto_id in checkpoint.object_by_proto_id:
1481 # Object has already been restored. This can happen when there's an
1482 # indirect connection from the attached object to the root.
1483 continue
1485 if proto_id is None:
1486 # Could not find attached dependency in proto.
1487 continue
1489 restore_lib.CheckpointPosition(
1490 checkpoint=checkpoint,
1491 proto_id=proto_id).restore(ref.ref, reader)
1493 load_status = CheckpointLoadStatus(
1494 checkpoint,
1495 graph_view=self._graph_view,
1496 feed_dict=file_prefix_feed_dict)
1497 return load_status
1500def frozen_saver(root_trackable):
1501 """Creates a static `tf.compat.v1.train.Saver` from a trackable object.
1503 The returned `Saver` saves object-based checkpoints, but these checkpoints
1504 will no longer reflect structural changes to the object graph, only changes to
1505 the values of `Variable`s added as dependencies of the root object before
1506 `freeze` was called.
1508 `restore` works on the returned `Saver`, but requires that the object graph of
1509 the checkpoint being loaded exactly matches the object graph when `freeze` was
1510 called. This is in contrast the object-based restore performed by
1511 `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's
1512 object graph and the current Python object graph.
1514 Args:
1515 root_trackable: A trackable object to save.
1517 Returns:
1518 A saver which saves object-based checkpoints for the object graph frozen at
1519 the time `frozen_saver` was called.
1520 """
1521 named_saveable_objects, registered_savers = (
1522 save_util_v1.frozen_saveables_and_savers(
1523 graph_view_lib.ObjectGraphView(root_trackable)))
1524 return functional_saver.MultiDeviceSaver.from_saveables(
1525 named_saveable_objects, registered_savers)
1528def _assert_trackable(obj, name):
1529 if not isinstance(
1530 obj, (base.Trackable, def_function.Function)):
1531 raise ValueError(
1532 f"`Checkpoint` was expecting {name} to be a trackable object (an "
1533 f"object derived from `Trackable`), got {obj}. If you believe this "
1534 "object should be trackable (i.e. it is part of the "
1535 "TensorFlow Python API and manages state), please open an issue.")
1538def _update_checkpoint_state_internal(file_path):
1539 """Update internal checkpoint state."""
1540 checkpoint_management.update_checkpoint_state_internal(
1541 save_dir=os.path.dirname(file_path),
1542 model_checkpoint_path=file_path,
1543 all_model_checkpoint_paths=[file_path],
1544 save_relative_paths=True)
1547def _convert_file_name_tensor_to_string(tensor):
1548 """Convert file name tensor to string."""
1549 output = tensor
1550 if tensor_util.is_tf_type(output):
1551 # Convert to numpy if not `tf.function` building.
1552 if context.executing_eagerly():
1553 output = compat.as_str(output.numpy())
1554 else:
1555 # Graph + Session, so we already session.ran it.
1556 output = compat.as_str(output)
1557 return output
1560def _copy_single_tensor(tensor):
1561 """Copies a single Tensor / SaveSpec onto the CPU device."""
1562 device = tensor.device
1563 if isinstance(tensor, saveable_object_lib.SaveSpec):
1564 # Pin the device according to the tensor's device location to
1565 # avoid unnecessary data copies when reading the variables. This is
1566 # aligned with the behavior in MultiDeviceSaver.save().
1567 with ops.device(device):
1568 tensor = tensor.tensor
1570 if tensor is not None:
1571 with ops.device(saveable_object_util.set_cpu0(device)):
1572 tensor = array_ops.identity(tensor) # pylint: disable=protected-access
1573 return tensor
1576# Mentions graph building / Sessions. The v2 version is below.
1577@tf_export(v1=["train.Checkpoint"])
1578class CheckpointV1(autotrackable.AutoTrackable):
1579 """Groups trackable objects, saving and restoring them.
1581 `Checkpoint`'s constructor accepts keyword arguments whose values are types
1582 that contain trackable state, such as `tf.compat.v1.train.Optimizer`
1583 implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
1584 `tf.keras.Model` implementations. It saves these values with a checkpoint, and
1585 maintains a `save_counter` for numbering checkpoints.
1587 Example usage when graph building:
1589 ```python
1590 import tensorflow as tf
1591 import os
1593 checkpoint_directory = "/tmp/training_checkpoints"
1594 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
1596 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
1597 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
1598 train_op = optimizer.minimize( ... )
1599 status.assert_consumed() # Optional sanity checks.
1600 with tf.compat.v1.Session() as session:
1601 # Use the Session to restore variables, or initialize them if
1602 # tf.train.latest_checkpoint returned None.
1603 status.initialize_or_restore(session)
1604 for _ in range(num_training_steps):
1605 session.run(train_op)
1606 checkpoint.save(file_prefix=checkpoint_prefix)
1607 ```
1609 Example usage with eager execution enabled:
1611 ```python
1612 import tensorflow as tf
1613 import os
1615 tf.compat.v1.enable_eager_execution()
1617 checkpoint_directory = "/tmp/training_checkpoints"
1618 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
1620 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
1621 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
1622 for _ in range(num_training_steps):
1623 optimizer.minimize( ... ) # Variables will be restored on creation.
1624 status.assert_consumed() # Optional sanity checks.
1625 checkpoint.save(file_prefix=checkpoint_prefix)
1626 ```
1628 `Checkpoint.save` and `Checkpoint.restore` write and read object-based
1629 checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads
1630 `variable.name` based checkpoints. Object-based checkpointing saves a graph of
1631 dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s,
1632 etc.) with named edges, and this graph is used to match variables when
1633 restoring a checkpoint. It can be more robust to changes in the Python
1634 program, and helps to support restore-on-create for variables when executing
1635 eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new
1636 code.
1638 `Checkpoint` objects have dependencies on the objects passed as keyword
1639 arguments to their constructors, and each dependency is given a name that is
1640 identical to the name of the keyword argument for which it was created.
1641 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
1642 dependencies on their variables (e.g. "kernel" and "bias" for
1643 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
1644 dependencies easy in user-defined classes, since `Model` hooks into attribute
1645 assignment. For example:
1647 ```python
1648 class Regress(tf.keras.Model):
1650 def __init__(self):
1651 super().__init__()
1652 self.input_transform = tf.keras.layers.Dense(10)
1653 # ...
1655 def call(self, inputs):
1656 x = self.input_transform(inputs)
1657 # ...
1658 ```
1660 This `Model` has a dependency named "input_transform" on its `Dense` layer,
1661 which in turn depends on its variables. As a result, saving an instance of
1662 `Regress` using `tf.train.Checkpoint` will also save all the variables created
1663 by the `Dense` layer.
1665 When variables are assigned to multiple workers, each worker writes its own
1666 section of the checkpoint. These sections are then merged/re-indexed to behave
1667 as a single checkpoint. This avoids copying all variables to one worker, but
1668 does require that all workers see a common filesystem.
1670 While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the
1671 same format, note that the root of the resulting checkpoint is the object the
1672 save method is attached to. This means saving a `tf.keras.Model` using
1673 `save_weights` and loading into a `tf.train.Checkpoint` with a `Model`
1674 attached (or vice versa) will not match the `Model`'s variables. See the
1675 [guide to training
1676 checkpoints](https://www.tensorflow.org/guide/checkpoint) for
1677 details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for
1678 training checkpoints.
1680 Attributes:
1681 save_counter: Incremented when `save()` is called. Used to number
1682 checkpoints.
1683 """
1685 def __init__(self, **kwargs):
1686 """Group objects into a training checkpoint.
1688 Args:
1689 **kwargs: Keyword arguments are set as attributes of this object, and are
1690 saved with the checkpoint. Values must be trackable objects.
1692 Raises:
1693 ValueError: If objects in `kwargs` are not trackable.
1694 """
1695 super().__init__()
1696 global _END_TIME_OF_LAST_WRITE
1697 with _END_TIME_OF_LAST_WRITE_LOCK:
1698 if _END_TIME_OF_LAST_WRITE is None:
1699 _END_TIME_OF_LAST_WRITE = time.time()
1701 for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
1702 setattr(self, k, v)
1703 if not isinstance(
1704 getattr(self, k), (base.Trackable, def_function.Function)):
1705 raise ValueError(
1706 "`Checkpoint` was expecting a trackable object (an object "
1707 f"derived from `Trackable`), got {v}. If you believe this "
1708 "object should be trackable (i.e. it is part of the "
1709 "TensorFlow Python API and manages state), please open an issue.")
1710 self._save_counter = None # Created lazily for restore-on-create.
1711 self._save_assign_op = None
1712 self._saver = TrackableSaver(graph_view_lib.ObjectGraphView(self))
1714 def _maybe_create_save_counter(self):
1715 """Create a save counter if it does not yet exist."""
1716 if self._save_counter is None:
1717 # Initialized to 0 and incremented before saving.
1718 with ops.device("/cpu:0"):
1719 # add_variable creates a dependency named "save_counter"; NoDependency
1720 # prevents creating a second dependency named "_save_counter".
1721 self._save_counter = data_structures.NoDependency(
1722 add_variable(
1723 self,
1724 name="save_counter",
1725 initializer=0,
1726 dtype=dtypes.int64,
1727 trainable=False))
1729 def write(self, file_prefix, session=None):
1730 """Writes a training checkpoint.
1732 The checkpoint includes variables created by this object and any
1733 trackable objects it depends on at the time `Checkpoint.write()` is
1734 called.
1736 `write` does not number checkpoints, increment `save_counter`, or update the
1737 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for
1738 use by higher level checkpoint management utilities. `save` provides a very
1739 basic implementation of these features.
1741 Args:
1742 file_prefix: A prefix to use for the checkpoint filenames
1743 (/path/to/directory/and_a_prefix).
1744 session: The session to evaluate variables in. Ignored when executing
1745 eagerly. If not provided when graph building, the default session is
1746 used.
1748 Returns:
1749 The full path to the checkpoint (i.e. `file_prefix`).
1750 """
1751 return self._write(file_prefix, session)
1753 def _write(self, file_prefix, session=None, write_done_callback=None):
1754 """Writes a training checkpoint.
1756 The checkpoint includes variables created by this object and any
1757 trackable objects it depends on at the time `Checkpoint.write()` is
1758 called.
1760 `write` does not number checkpoints, increment `save_counter`, or update the
1761 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for
1762 use by higher level checkpoint management utilities. `save` provides a very
1763 basic implementation of these features.
1765 Args:
1766 file_prefix: A prefix to use for the checkpoint filenames
1767 (/path/to/directory/and_a_prefix).
1768 session: The session to evaluate variables in. Ignored when executing
1769 eagerly. If not provided when graph building, the default session is
1770 used.
1771 write_done_callback: Optional callback function to be executed once
1772 the underlying checkpoint saving is finished. Example usage includes
1773 updating the checkpoint internal state.
1775 Returns:
1776 The full path to the checkpoint (i.e. `file_prefix`).
1777 """
1778 start_time = time.time()
1779 output = self._saver.save(file_prefix=file_prefix, session=session)
1780 end_time = time.time()
1782 metrics.AddCheckpointWriteDuration(
1783 api_label=_CHECKPOINT_V1,
1784 microseconds=_get_duration_microseconds(start_time, end_time))
1786 global _END_TIME_OF_LAST_WRITE
1787 with _END_TIME_OF_LAST_WRITE_LOCK:
1788 metrics.AddTrainingTimeSaved(
1789 api_label=_CHECKPOINT_V1,
1790 microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE,
1791 end_time))
1793 if checkpoint_context.in_preemption_save_context():
1794 _preemption_checkpoint_saved_time_usecs.get_cell().increase_by(
1795 _get_duration_microseconds(_END_TIME_OF_LAST_WRITE, end_time)
1796 )
1798 _END_TIME_OF_LAST_WRITE = end_time
1800 if tensor_util.is_tf_type(output):
1801 # Convert to numpy if not `tf.function` building.
1802 if context.executing_eagerly():
1803 output = compat.as_str(output.numpy())
1804 else:
1805 # Graph + Session, so we already session.ran it.
1806 output = compat.as_str(output)
1808 if write_done_callback:
1809 write_done_callback(output)
1811 metrics.RecordCheckpointSize(
1812 api_label=_CHECKPOINT_V1, filesize=_get_checkpoint_size(output))
1813 return output
1815 @property
1816 def save_counter(self):
1817 """An integer variable which starts at zero and is incremented on save.
1819 Used to number checkpoints.
1821 Returns:
1822 The save counter variable.
1823 """
1824 self._maybe_create_save_counter()
1825 return self._save_counter
1827 def save(self, file_prefix, session=None):
1828 """Saves a training checkpoint and provides basic checkpoint management.
1830 The saved checkpoint includes variables created by this object and any
1831 trackable objects it depends on at the time `Checkpoint.save()` is
1832 called.
1834 `save` is a basic convenience wrapper around the `write` method,
1835 sequentially numbering checkpoints using `save_counter` and updating the
1836 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
1837 management, for example garbage collection and custom numbering, may be
1838 provided by other utilities which also wrap `write`
1839 (`tf.train.CheckpointManager` for example).
1841 Args:
1842 file_prefix: A prefix to use for the checkpoint filenames
1843 (/path/to/directory/and_a_prefix). Names are generated based on this
1844 prefix and `Checkpoint.save_counter`.
1845 session: The session to evaluate variables in. Ignored when executing
1846 eagerly. If not provided when graph building, the default session is
1847 used.
1849 Returns:
1850 The full path to the checkpoint.
1851 """
1852 graph_building = not context.executing_eagerly()
1853 if graph_building:
1854 if ops.inside_function():
1855 raise NotImplementedError(
1856 "Calling tf.train.Checkpoint.save() from a function is not "
1857 "supported, as save() modifies saving metadata in ways not "
1858 "supported by TensorFlow Operations. Consider using "
1859 "tf.train.Checkpoint.write(), a lower-level API which does not "
1860 "update metadata. tf.train.latest_checkpoint and related APIs will "
1861 "not see this checkpoint.")
1862 if session is None:
1863 session = get_session()
1864 if self._save_counter is None:
1865 # When graph building, if this is a new save counter variable then it
1866 # needs to be initialized before assign_add. This is only an issue if
1867 # restore() has not been called first.
1868 session.run(self.save_counter.initializer)
1869 if not graph_building or self._save_assign_op is None:
1870 with ops.colocate_with(self.save_counter):
1871 assign_op = self.save_counter.assign_add(1, read_value=True)
1872 if graph_building:
1873 self._save_assign_op = data_structures.NoDependency(assign_op)
1874 if graph_building:
1875 checkpoint_number = session.run(self._save_assign_op)
1876 else:
1877 checkpoint_number = assign_op.numpy()
1878 file_path = self.write(
1879 "%s-%d" % (file_prefix, checkpoint_number), session=session)
1880 checkpoint_management.update_checkpoint_state_internal(
1881 save_dir=os.path.dirname(file_prefix),
1882 model_checkpoint_path=file_path,
1883 all_model_checkpoint_paths=[file_path],
1884 save_relative_paths=True)
1885 return file_path
1887 def restore(self, save_path):
1888 """Restore a training checkpoint.
1890 Restores this `Checkpoint` and any objects it depends on.
1892 When executing eagerly, either assigns values immediately if variables to
1893 restore have been created already, or defers restoration until the variables
1894 are created. Dependencies added after this call will be matched if they have
1895 a corresponding object in the checkpoint (the restore request will queue in
1896 any trackable object waiting for the expected dependency to be added).
1898 When graph building, restoration ops are added to the graph but not run
1899 immediately.
1901 ```python
1902 checkpoint = tf.train.Checkpoint( ... )
1903 checkpoint.restore(path)
1904 ```
1906 To ensure that loading is complete and no more deferred restorations will
1907 take place, you can use the `assert_consumed()` method of the status object
1908 returned by `restore`.
1909 The assert will raise an exception if any Python objects in the dependency
1910 graph were not found in the checkpoint, or if any checkpointed values do not
1911 have a matching Python object:
1913 ```python
1914 checkpoint = tf.train.Checkpoint( ... )
1915 checkpoint.restore(path).assert_consumed()
1916 ```
1918 When graph building, `assert_consumed()` indicates that all of the restore
1919 ops that will be created for this checkpoint have been created. They can be
1920 run via the `run_restore_ops()` method of the status object:
1922 ```python
1923 checkpoint.restore(path).assert_consumed().run_restore_ops()
1924 ```
1926 If the checkpoint has not been consumed completely, then the list of restore
1927 ops will grow as more objects are added to the dependency graph.
1929 To check that all variables in the Python object have restored values from
1930 checkpoint, use `assert_existing_objects_matched()`. This assertion is
1931 useful when called after the variables in your graph have been created.
1933 Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
1934 method. Names are used to match variables. No restore ops are created/run
1935 until `run_restore_ops()` or `initialize_or_restore()` are called on the
1936 returned status object when graph building, but there is restore-on-creation
1937 when executing eagerly. Re-encode name-based checkpoints using
1938 `tf.train.Checkpoint.save` as soon as possible.
1940 Args:
1941 save_path: The path to the checkpoint, as returned by `save` or
1942 `tf.train.latest_checkpoint`. If None (as when there is no latest
1943 checkpoint for `tf.train.latest_checkpoint` to return), returns an
1944 object which may run initializers for objects in the dependency graph.
1945 If the checkpoint was written by the name-based
1946 `tf.compat.v1.train.Saver`, names are used to match variables.
1948 Returns:
1949 A load status object, which can be used to make assertions about the
1950 status of a checkpoint restoration and run initialization/restore ops.
1952 The returned status object has the following methods:
1954 * `assert_consumed()`:
1955 Raises an exception if any variables are unmatched: either
1956 checkpointed values which don't have a matching Python object or
1957 Python objects in the dependency graph with no values in the
1958 checkpoint. This method returns the status object, and so may be
1959 chained with `initialize_or_restore` or `run_restore_ops`.
1961 * `assert_existing_objects_matched()`:
1962 Raises an exception if any existing Python objects in the dependency
1963 graph are unmatched. Unlike `assert_consumed`, this assertion will
1964 pass if values in the checkpoint have no corresponding Python
1965 objects. For example a `tf.keras.Layer` object which has not yet been
1966 built, and so has not created any variables, will pass this assertion
1967 but will fail `assert_consumed`. Useful when loading part of a larger
1968 checkpoint into a new Python program, e.g. a training checkpoint with
1969 a `tf.compat.v1.train.Optimizer` was saved but only the state required
1970 for inference is being loaded. This method returns the status object,
1971 and so may be chained with `initialize_or_restore` or
1972 `run_restore_ops`.
1974 * `assert_nontrivial_match()`: Asserts that something aside from the root
1975 object was matched. This is a very weak assertion, but is useful for
1976 sanity checking in library code where objects may exist in the
1977 checkpoint which haven't been created in Python and some Python
1978 objects may not have a checkpointed value.
1980 * `expect_partial()`: Silence warnings about incomplete checkpoint
1981 restores. Warnings are otherwise printed for unused parts of the
1982 checkpoint file or object when the `Checkpoint` object is deleted
1983 (often at program shutdown).
1985 * `initialize_or_restore(session=None)`:
1986 When graph building, runs variable initializers if `save_path` is
1987 `None`, but otherwise runs restore operations. If no `session` is
1988 explicitly specified, the default session is used. No effect when
1989 executing eagerly (variables are initialized or restored eagerly).
1991 * `run_restore_ops(session=None)`:
1992 When graph building, runs restore operations. If no `session` is
1993 explicitly specified, the default session is used. No effect when
1994 executing eagerly (restore operations are run eagerly). May only be
1995 called when `save_path` is not `None`.
1996 """
1997 start_time = time.time()
1998 status = self._saver.restore(save_path=save_path)
1999 # Create the save counter now so it gets initialized with other variables
2000 # when graph building. Creating it earlier would lead to errors when using,
2001 # say, train.Saver() to save the model before initializing it.
2002 self._maybe_create_save_counter()
2003 if isinstance(status, NameBasedSaverStatus):
2004 status.add_to_optionally_restored(self.save_counter)
2006 metrics.AddCheckpointReadDuration(
2007 api_label=_CHECKPOINT_V1,
2008 microseconds=_get_duration_microseconds(start_time, time.time()))
2009 return status
2012@tf_export("train.Checkpoint", v1=[])
2013class Checkpoint(autotrackable.AutoTrackable):
2014 """Manages saving/restoring trackable values to disk.
2016 TensorFlow objects may contain trackable state, such as `tf.Variable`s,
2017 `tf.keras.optimizers.Optimizer` implementations, `tf.data.Dataset` iterators,
2018 `tf.keras.Layer` implementations, or `tf.keras.Model` implementations.
2019 These are called **trackable objects**.
2021 A `Checkpoint` object can be constructed to save either a single or group of
2022 trackable objects to a checkpoint file. It maintains a `save_counter` for
2023 numbering checkpoints.
2025 Example:
2027 ```python
2028 model = tf.keras.Model(...)
2029 checkpoint = tf.train.Checkpoint(model)
2031 # Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
2032 # checkpoint.save is called, the save counter is increased.
2033 save_path = checkpoint.save('/tmp/training_checkpoints')
2035 # Restore the checkpointed values to the `model` object.
2036 checkpoint.restore(save_path)
2037 ```
2039 Example 2:
2041 ```python
2042 import tensorflow as tf
2043 import os
2045 checkpoint_directory = "/tmp/training_checkpoints"
2046 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
2048 # Create a Checkpoint that will manage two objects with trackable state,
2049 # one we name "optimizer" and the other we name "model".
2050 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
2051 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
2052 for _ in range(num_training_steps):
2053 optimizer.minimize( ... ) # Variables will be restored on creation.
2054 status.assert_consumed() # Optional sanity checks.
2055 checkpoint.save(file_prefix=checkpoint_prefix)
2056 ```
2058 `Checkpoint.save()` and `Checkpoint.restore()` write and read object-based
2059 checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which
2060 writes and
2061 reads `variable.name` based checkpoints. Object-based checkpointing saves a
2062 graph of dependencies between Python objects (`Layer`s, `Optimizer`s,
2063 `Variable`s, etc.) with named edges, and this graph is used to match variables
2064 when restoring a checkpoint. It can be more robust to changes in the Python
2065 program, and helps to support restore-on-create for variables.
2067 `Checkpoint` objects have dependencies on the objects passed as keyword
2068 arguments to their constructors, and each dependency is given a name that is
2069 identical to the name of the keyword argument for which it was created.
2070 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add
2071 dependencies on their own variables (e.g. "kernel" and "bias" for
2072 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing
2073 dependencies easy in user-defined classes, since `Model` hooks into attribute
2074 assignment. For example:
2076 ```python
2077 class Regress(tf.keras.Model):
2079 def __init__(self):
2080 super().__init__()
2081 self.input_transform = tf.keras.layers.Dense(10)
2082 # ...
2084 def call(self, inputs):
2085 x = self.input_transform(inputs)
2086 # ...
2087 ```
2089 This `Model` has a dependency named "input_transform" on its `Dense` layer,
2090 which in turn depends on its variables. As a result, saving an instance of
2091 `Regress` using `tf.train.Checkpoint` will also save all the variables created
2092 by the `Dense` layer.
2094 When variables are assigned to multiple workers, each worker writes its own
2095 section of the checkpoint. These sections are then merged/re-indexed to behave
2096 as a single checkpoint. This avoids copying all variables to one worker, but
2097 does require that all workers see a common filesystem.
2099 This function differs slightly from the Keras Model `save_weights` function.
2100 `tf.keras.Model.save_weights` creates a checkpoint file with the name
2101 specified in `filepath`, while `tf.train.Checkpoint` numbers the checkpoints,
2102 using `filepath` as the prefix for the checkpoint file names. Aside from this,
2103 `model.save_weights()` and `tf.train.Checkpoint(model).save()` are equivalent.
2105 See the [guide to training
2106 checkpoints](https://www.tensorflow.org/guide/checkpoint) for
2107 details.
2109 Attributes:
2110 save_counter: Incremented when `save()` is called. Used to number
2111 checkpoints.
2112 """
2114 def __init__(self, root=None, **kwargs):
2115 """Creates a training checkpoint for a single or group of objects.
2117 Args:
2118 root: The root object to checkpoint. `root` may be a trackable object or
2119 `WeakRef` of a trackable object.
2120 **kwargs: Keyword arguments are set as attributes of this object, and are
2121 saved with the checkpoint. All `kwargs` must be trackable objects, or a
2122 nested structure of trackable objects (`list`, `dict`, or `tuple`).
2124 Raises:
2125 ValueError: If `root` or the objects in `kwargs` are not trackable. A
2126 `ValueError` is also raised if the `root` object tracks different
2127 objects from the ones listed in attributes in kwargs (e.g.
2128 `root.child = A` and `tf.train.Checkpoint(root, child=B)` are
2129 incompatible).
2131 """
2132 super().__init__()
2133 global _END_TIME_OF_LAST_WRITE
2134 with _END_TIME_OF_LAST_WRITE_LOCK:
2135 if _END_TIME_OF_LAST_WRITE is None:
2136 _END_TIME_OF_LAST_WRITE = time.time()
2138 # Store a reference to root and kwargs if we need to instantiate an
2139 # AsyncCheckpointer later.
2140 self._root = root
2141 self._kwargs = kwargs
2142 self._delete_tracking("_kwargs")
2144 # Don't instantiate the AsyncCheckpointer unless required.
2145 self._async_checkpointer_impl = None
2147 # Store checkpoint options during the save/write calls so that subsequent
2148 # read/restore calls are done properly. This is only populated when
2149 # async read/write is enabled.
2150 self._checkpoint_options = None
2152 attached_dependencies = None
2153 self._save_counter = None # Created lazily for restore-on-create.
2154 self._save_assign_op = None
2156 if root:
2157 trackable_root = root() if isinstance(root, weakref.ref) else root
2158 _assert_trackable(trackable_root, "root")
2159 attached_dependencies = []
2161 # All keyword arguments (including root itself) are set as children
2162 # of root.
2163 kwargs["root"] = root
2164 trackable_root._maybe_initialize_trackable()
2166 self._save_counter = data_structures.NoDependency(
2167 trackable_root._lookup_dependency("save_counter"))
2169 for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
2170 setattr(self, k, v)
2172 # Call getattr instead of directly using v because setattr converts
2173 # v to a Trackable data structure when v is a list/dict/tuple.
2174 converted_v = getattr(self, k)
2175 if isinstance(converted_v, weakref.ref):
2176 converted_v = converted_v()
2177 _assert_trackable(converted_v, k)
2179 if root:
2180 # Make sure that root doesn't already have dependencies with these names
2181 child = trackable_root._lookup_dependency(k)
2182 if child is None:
2183 attached_dependencies.append(
2184 base.WeakTrackableReference(k, converted_v))
2185 elif child != converted_v:
2186 raise ValueError(
2187 f"Cannot create a Checkpoint with keyword argument {k} if "
2188 f"root.{k} already exists.")
2190 self._saver = TrackableSaver(
2191 graph_view_lib.ObjectGraphView(
2192 root if root else self,
2193 attached_dependencies=attached_dependencies))
2194 self._attached_dependencies = data_structures.NoDependency(
2195 attached_dependencies)
2197 def _maybe_create_save_counter(self):
2198 """Create a save counter if it does not yet exist."""
2199 if self._save_counter is None:
2200 # Initialized to 0 and incremented before saving.
2201 with ops.device("/cpu:0"):
2202 # add_variable creates a dependency named "save_counter"; NoDependency
2203 # prevents creating a second dependency named "_save_counter".
2204 self._save_counter = data_structures.NoDependency(
2205 add_variable(
2206 self,
2207 name="save_counter",
2208 initializer=0,
2209 dtype=dtypes.int64,
2210 trainable=False))
2211 if self._attached_dependencies is not None:
2212 self._attached_dependencies.append(
2213 # Store a stronge reference to the `save_counter`, so that if the
2214 # `Checkpoint` object is deleted, the `save_counter` does not get
2215 # deleted immediately. (The LoadStatus object needs to indirectly
2216 # reference the counter through the ObjectGraphView).
2217 base.TrackableReference("save_counter", self._save_counter))
2218 # When loading a checkpoint, the save counter is created after
2219 # the checkpoint has been loaded, so it must be handled in a deferred
2220 # manner.
2221 if isinstance(self.root, weakref.ref):
2222 root = self.root()
2223 else:
2224 root = self.root
2225 restore = root._deferred_dependencies.pop("save_counter", ()) # pylint: disable=protected-access
2226 if restore:
2227 restore[0].restore(self._save_counter)
2229 def write(self, file_prefix, options=None):
2230 """Writes a training checkpoint.
2232 The checkpoint includes variables created by this object and any
2233 trackable objects it depends on at the time `Checkpoint.write()` is
2234 called.
2236 `write` does not number checkpoints, increment `save_counter`, or update the
2237 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for
2238 use by higher level checkpoint management utilities. `save` provides a very
2239 basic implementation of these features.
2241 Checkpoints written with `write` must be read with `read`.
2243 Example usage:
2245 ```
2246 step = tf.Variable(0, name="step")
2247 checkpoint = tf.Checkpoint(step=step)
2248 checkpoint.write("/tmp/ckpt")
2250 # Later, read the checkpoint with read()
2251 checkpoint.read("/tmp/ckpt")
2253 # You can also pass options to write() and read(). For example this
2254 # runs the IO ops on the localhost:
2255 options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
2256 checkpoint.write("/tmp/ckpt", options=options)
2258 # Later, read the checkpoint with read()
2259 checkpoint.read("/tmp/ckpt", options=options)
2260 ```
2262 Args:
2263 file_prefix: A prefix to use for the checkpoint filenames
2264 (/path/to/directory/and_a_prefix).
2265 options: Optional `tf.train.CheckpointOptions` object.
2267 Returns:
2268 The full path to the checkpoint (i.e. `file_prefix`).
2269 """
2270 if isinstance(file_prefix, os.PathLike):
2271 file_prefix = os.fspath(file_prefix)
2272 return self._write(file_prefix, options)
2274 def _async_checkpointer(self):
2275 """Returns an instantiated AsyncCheckpointHelper."""
2276 if self._async_checkpointer_impl is None:
2277 self._async_checkpointer_impl = (
2278 async_checkpoint_helper.AsyncCheckpointHelper(
2279 Checkpoint,
2280 **self._kwargs))
2282 return self._async_checkpointer_impl
2284 def _write(self, file_prefix, options=None, write_done_callback=None):
2285 """Internal method that implements Checkpoint.write().
2287 Args:
2288 file_prefix: A prefix to use for the checkpoint filenames
2289 (/path/to/directory/and_a_prefix).
2290 options: Optional `tf.train.CheckpointOptions` object.
2291 write_done_callback: Optional callback function to be executed once
2292 the underlying checkpoint saving is finished. Example usage includes
2293 updating the checkpoint internal state.
2295 Returns:
2296 The full path to the checkpoint (i.e. `file_prefix`).
2297 """
2298 # Triggers TF2 async checkpoint handling if:
2299 # 1. async checkpoint is enabled in CheckpointOptions
2300 # 2. running in eager mode
2301 if options and options.experimental_enable_async_checkpoint:
2302 self._checkpoint_options = options
2303 if checkpoint_context.in_preemption_save_context():
2304 # Make sure all in-progress writes have completed before saving the
2305 # final preemption checkpoint.
2306 if self._async_checkpointer_impl is not None:
2307 self._async_checkpointer_impl.sync()
2308 # Additional work done will not be saved in a future checkpoint, so
2309 # we use regular sync checkpoint to avoid overhead of dispatching
2310 # checkpoint write to a new thread.
2311 logging.warning(
2312 "Switching to regular sync checkpoint for preemption checkpoint."
2313 )
2314 elif context.executing_eagerly():
2315 return self._async_checkpointer()._write( # pylint: disable=protected-access
2316 file_prefix, options, write_done_callback)
2317 else:
2318 logging.warning(
2319 "Saving async checkpoint in graph mode is currently not supported;"
2320 " switching to regular sync checkpoint instead.")
2322 start_time = time.time()
2323 options = options or checkpoint_options.CheckpointOptions()
2324 output = self._saver.save(file_prefix=file_prefix, options=options)
2325 output = _convert_file_name_tensor_to_string(output)
2327 if write_done_callback:
2328 write_done_callback(output)
2330 # Ensure save operations have completed when running in eager runtime.
2331 if context.executing_eagerly():
2332 context.async_wait()
2334 end_time = time.time()
2336 if not checkpoint_context.in_async_metrics_context():
2337 # This records the time checkpoint._write() blocks on the main thread.
2338 metrics.AddCheckpointWriteDuration(
2339 api_label=_CHECKPOINT_V2,
2340 microseconds=_get_duration_microseconds(start_time, end_time),
2341 )
2343 global _END_TIME_OF_LAST_WRITE
2344 with _END_TIME_OF_LAST_WRITE_LOCK:
2345 if not checkpoint_context.in_async_metrics_context():
2346 metrics.AddTrainingTimeSaved(
2347 api_label=_CHECKPOINT_V2,
2348 microseconds=_get_duration_microseconds(
2349 _END_TIME_OF_LAST_WRITE, end_time)
2350 )
2351 if checkpoint_context.in_preemption_save_context():
2352 _preemption_checkpoint_saved_time_usecs.get_cell().increase_by(
2353 _get_duration_microseconds(_END_TIME_OF_LAST_WRITE, end_time)
2354 )
2355 _END_TIME_OF_LAST_WRITE = end_time
2357 metrics.RecordCheckpointSize(
2358 api_label=_CHECKPOINT_V2, filesize=_get_checkpoint_size(output)
2359 )
2360 return output
2362 @property
2363 def save_counter(self):
2364 """An integer variable which starts at zero and is incremented on save.
2366 Used to number checkpoints.
2368 Returns:
2369 The save counter variable.
2370 """
2371 self._maybe_create_save_counter()
2372 return self._save_counter
2374 def sync(self):
2375 """Wait for any outstanding save or restore operations."""
2376 # Subclasses of Checkpoint may not have `_async_checkpointer_impl` so use
2377 # `getattr` for safer check.
2378 if getattr(self, "_async_checkpointer_impl", None) is not None:
2379 self._async_checkpointer_impl.sync()
2381 def save(self, file_prefix, options=None):
2382 # pylint:disable=line-too-long
2383 """Saves a training checkpoint and provides basic checkpoint management.
2385 The saved checkpoint includes variables created by this object and any
2386 trackable objects it depends on at the time `Checkpoint.save()` is
2387 called.
2389 `save` is a basic convenience wrapper around the `write` method,
2390 sequentially numbering checkpoints using `save_counter` and updating the
2391 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
2392 management, for example garbage collection and custom numbering, may be
2393 provided by other utilities which also wrap `write` and `read`.
2394 (`tf.train.CheckpointManager` for example).
2396 ```
2397 step = tf.Variable(0, name="step")
2398 checkpoint = tf.train.Checkpoint(step=step)
2399 checkpoint.save("/tmp/ckpt")
2401 # Later, read the checkpoint with restore()
2402 checkpoint.restore("/tmp/ckpt-1")
2404 # You can also pass options to save() and restore(). For example this
2405 # runs the IO ops on the localhost:
2406 options = tf.train.CheckpointOptions(experimental_io_device="/job:localhost")
2407 checkpoint.save("/tmp/ckpt", options=options)
2409 # Later, read the checkpoint with restore()
2410 checkpoint.restore("/tmp/ckpt-1", options=options)
2411 ```
2413 Args:
2414 file_prefix: A prefix to use for the checkpoint filenames
2415 (/path/to/directory/and_a_prefix). Names are generated based on this
2416 prefix and `Checkpoint.save_counter`.
2417 options: Optional `tf.train.CheckpointOptions` object.
2419 Returns:
2420 The full path to the checkpoint.
2421 """
2422 # Triggers TF2 async checkpoint handling if:
2423 # 1. async checkpoint is enabled in CheckpointOptions
2424 # 2. running in eager mode
2425 if options and options.experimental_enable_async_checkpoint:
2426 self._checkpoint_options = options
2427 if checkpoint_context.in_preemption_save_context():
2428 # Make sure all in-progress writes have completed before saving the
2429 # final preemption checkpoint.
2430 if self._async_checkpointer_impl is not None:
2431 self._async_checkpointer_impl.sync()
2432 # Additional work done will not be saved in a future checkpoint, so
2433 # we use regular sync checkpoint to avoid overhead of dispatching
2434 # checkpoint write to a new thread.
2435 logging.warning(
2436 "Switching to regular sync checkpoint for preemption checkpoint."
2437 )
2438 elif context.executing_eagerly():
2439 return self._async_checkpointer().save(file_prefix, options)
2440 else:
2441 logging.warning(
2442 "Saving async checkpoint in graph mode is currently not supported;"
2443 " switching to regular sync checkpoint instead.")
2445 if isinstance(file_prefix, os.PathLike):
2446 file_prefix = os.fspath(file_prefix)
2447 # pylint:enable=line-too-long
2448 options = options or checkpoint_options.CheckpointOptions()
2449 graph_building = not context.executing_eagerly()
2450 if graph_building:
2451 if ops.inside_function():
2452 raise NotImplementedError(
2453 "Calling tf.train.Checkpoint.save() from a function is not "
2454 "supported, as save() modifies saving metadata in ways not "
2455 "supported by TensorFlow Operations. Consider using "
2456 "tf.train.Checkpoint.write(), a lower-level API which does not "
2457 "update metadata. tf.train.latest_checkpoint and related APIs will "
2458 "not see this checkpoint.")
2459 session = get_session()
2460 if self._save_counter is None:
2461 # When graph building, if this is a new save counter variable then it
2462 # needs to be initialized before assign_add. This is only an issue if
2463 # restore() has not been called first.
2464 session.run(self.save_counter.initializer)
2466 if not graph_building or self._save_assign_op is None:
2467 with ops.colocate_with(self.save_counter):
2468 assign_op = self.save_counter.assign_add(1, read_value=True)
2469 if graph_building:
2470 self._save_assign_op = data_structures.NoDependency(assign_op)
2472 if graph_building:
2473 checkpoint_number = session.run(self._save_assign_op)
2474 else:
2475 checkpoint_number = assign_op.numpy()
2477 return self._write(
2478 "%s-%d" % (file_prefix, checkpoint_number),
2479 options=options,
2480 write_done_callback=_update_checkpoint_state_internal)
2482 def read(self, save_path, options=None):
2483 """Reads a training checkpoint written with `write`.
2485 Reads this `Checkpoint` and any objects it depends on.
2487 This method is just like `restore()` but does not expect the `save_counter`
2488 variable in the checkpoint. It only restores the objects that the checkpoint
2489 already depends on.
2491 The method is primarily intended for use by higher level checkpoint
2492 management utilities that use `write()` instead of `save()` and have their
2493 own mechanisms to number and track checkpoints.
2495 Example usage:
2497 ```python
2498 # Create a checkpoint with write()
2499 ckpt = tf.train.Checkpoint(v=tf.Variable(1.))
2500 path = ckpt.write('/tmp/my_checkpoint')
2502 # Later, load the checkpoint with read()
2503 # With restore() assert_consumed() would have failed.
2504 checkpoint.read(path).assert_consumed()
2506 # You can also pass options to read(). For example this
2507 # runs the IO ops on the localhost:
2508 options = tf.train.CheckpointOptions(
2509 experimental_io_device="/job:localhost")
2510 checkpoint.read(path, options=options)
2511 ```
2513 Args:
2514 save_path: The path to the checkpoint as returned by `write`.
2515 options: Optional `tf.train.CheckpointOptions` object.
2517 Returns:
2518 A load status object, which can be used to make assertions about the
2519 status of a checkpoint restoration. See `restore` for details.
2520 """
2521 if options and options.experimental_enable_async_checkpoint:
2522 self._checkpoint_options = options
2523 # Triggers TF2 async checkpoint handling if:
2524 # 1. async checkpoint is enabled in CheckpointOptions
2525 # 2. there's a preceeding async save/write
2526 # 3. running in eager mode
2527 if (self._checkpoint_options and
2528 self._checkpoint_options.experimental_enable_async_checkpoint):
2529 if context.executing_eagerly():
2530 return self._async_checkpointer().read(save_path, options)
2531 else:
2532 logging.warning(
2533 "Saving async checkpoint in graph mode is currently not supported;"
2534 " switching to regular sync checkpoint instead.")
2536 start_time = time.time()
2537 if isinstance(save_path, os.PathLike):
2538 save_path = os.fspath(save_path)
2539 options = options or checkpoint_options.CheckpointOptions()
2540 result = self._saver.restore(save_path=save_path, options=options)
2541 metrics.AddCheckpointReadDuration(
2542 api_label=_CHECKPOINT_V2,
2543 microseconds=_get_duration_microseconds(start_time, time.time()))
2544 return result
2546 def restore(self, save_path, options=None):
2547 """Restores a training checkpoint.
2549 Restores this `Checkpoint` and any objects it depends on.
2551 This method is intended to be used to load checkpoints created by `save()`.
2552 For checkpoints created by `write()` use the `read()` method which does not
2553 expect the `save_counter` variable added by `save()`.
2555 `restore()` either assigns values immediately if variables to restore have
2556 been created already, or defers restoration until the variables are
2557 created. Dependencies added after this call will be matched if they have a
2558 corresponding object in the checkpoint (the restore request will queue in
2559 any trackable object waiting for the expected dependency to be added).
2561 ```python
2562 checkpoint = tf.train.Checkpoint( ... )
2563 checkpoint.restore(path)
2565 # You can additionally pass options to restore():
2566 options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
2567 checkpoint.restore(path, options=options)
2568 ```
2570 To ensure that loading is complete and no more deferred restorations will
2571 take place, use the `assert_consumed()` method of the status object returned
2572 by `restore()`:
2574 ```python
2575 checkpoint.restore(path, options=options).assert_consumed()
2576 ```
2578 The assert will raise an error if any Python objects in the dependency graph
2579 were not found in the checkpoint, or if any checkpointed values do not have
2580 a matching Python object.
2582 Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be
2583 loaded using this method. Names are used to match variables. Re-encode
2584 name-based checkpoints using `tf.train.Checkpoint.save` as soon as possible.
2586 **Loading from SavedModel checkpoints**
2588 To load values from a SavedModel, just pass the SavedModel directory
2589 to checkpoint.restore:
2591 ```python
2592 model = tf.keras.Model(...)
2593 tf.saved_model.save(model, path) # or model.save(path, save_format='tf')
2595 checkpoint = tf.train.Checkpoint(model)
2596 checkpoint.restore(path).expect_partial()
2597 ```
2599 This example calls `expect_partial()` on the loaded status, since
2600 SavedModels saved from Keras often generates extra keys in the checkpoint.
2601 Otherwise, the program prints a lot of warnings about unused keys at exit
2602 time.
2604 Args:
2605 save_path: The path to the checkpoint, as returned by `save` or
2606 `tf.train.latest_checkpoint`. If the checkpoint was written by the
2607 name-based `tf.compat.v1.train.Saver`, names are used to match
2608 variables. This path may also be a SavedModel directory.
2609 options: Optional `tf.train.CheckpointOptions` object.
2611 Returns:
2612 A load status object, which can be used to make assertions about the
2613 status of a checkpoint restoration.
2615 The returned status object has the following methods:
2617 * `assert_consumed()`:
2618 Raises an exception if any variables are unmatched: either
2619 checkpointed values which don't have a matching Python object or
2620 Python objects in the dependency graph with no values in the
2621 checkpoint. This method returns the status object, and so may be
2622 chained with other assertions.
2624 * `assert_existing_objects_matched()`:
2625 Raises an exception if any existing Python objects in the dependency
2626 graph are unmatched. Unlike `assert_consumed`, this assertion will
2627 pass if values in the checkpoint have no corresponding Python
2628 objects. For example a `tf.keras.Layer` object which has not yet been
2629 built, and so has not created any variables, will pass this assertion
2630 but fail `assert_consumed`. Useful when loading part of a larger
2631 checkpoint into a new Python program, e.g. a training checkpoint with
2632 a `tf.compat.v1.train.Optimizer` was saved but only the state required
2633 for
2634 inference is being loaded. This method returns the status object, and
2635 so may be chained with other assertions.
2637 * `assert_nontrivial_match()`: Asserts that something aside from the root
2638 object was matched. This is a very weak assertion, but is useful for
2639 sanity checking in library code where objects may exist in the
2640 checkpoint which haven't been created in Python and some Python
2641 objects may not have a checkpointed value.
2643 * `expect_partial()`: Silence warnings about incomplete checkpoint
2644 restores. Warnings are otherwise printed for unused parts of the
2645 checkpoint file or object when the `Checkpoint` object is deleted
2646 (often at program shutdown).
2648 Raises:
2649 NotFoundError: if the a checkpoint or SavedModel cannot be found at
2650 `save_path`.
2651 """
2652 if options and options.experimental_enable_async_checkpoint:
2653 self._checkpoint_options = options
2654 # Triggers TF2 async checkpoint handling if:
2655 # 1. async checkpoint is enabled in CheckpointOptions
2656 # 2. there's a preceeding async save/write
2657 # 3. running in eager mode
2658 if (self._checkpoint_options and
2659 self._checkpoint_options.experimental_enable_async_checkpoint):
2660 if context.executing_eagerly():
2661 return self._async_checkpointer().restore(save_path, options)
2662 else:
2663 logging.warning(
2664 "Saving async checkpoint in graph mode is currently not supported;"
2665 " switching to regular sync checkpoint instead.")
2667 orig_save_path = save_path
2668 if isinstance(save_path, os.PathLike):
2669 save_path = os.fspath(save_path)
2671 if save_path is not None and gfile.IsDirectory(save_path) and (
2672 (gfile.Exists(path_helpers.get_saved_model_pb_path(save_path)) or
2673 gfile.Exists(path_helpers.get_saved_model_pbtxt_path(save_path)))):
2674 save_path = path_helpers.get_variables_path(save_path)
2676 try:
2677 status = self.read(save_path, options=options)
2678 if context.executing_eagerly():
2679 context.async_wait() # Ensure restore operations have completed.
2680 except errors_impl.NotFoundError as e:
2681 raise errors_impl.NotFoundError(
2682 None, None,
2683 f"Error when restoring from checkpoint or SavedModel at "
2684 f"{orig_save_path}: {e.message}"
2685 f"\nPlease double-check that the path is correct. You may be missing "
2686 "the checkpoint suffix (e.g. the '-1' in 'path/to/ckpt-1').")
2687 # Create the save counter now so it gets initialized with other variables
2688 # when graph building. Creating it earlier would lead to errors when using,
2689 # say, train.Saver() to save the model before initializing it.
2690 self._maybe_create_save_counter()
2691 if isinstance(status, NameBasedSaverStatus):
2692 status.add_to_optionally_restored(self.save_counter)
2693 return status
2696_preemption_checkpoint_saved_time_usecs = monitoring.Counter(
2697 "/tensorflow/api/distribution_strategy/preemption_checkpoint_saved_time_usecs",
2698 "Training time saved by PreemptionCheckpointHandler (us).")