Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/restore.py: 17%
310 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Logic for restoring checkpointed values for Trackables."""
17import collections
19from tensorflow.python.checkpoint import checkpoint_view
20from tensorflow.python.checkpoint import functional_saver
21from tensorflow.python.checkpoint import save_util_v1
22from tensorflow.python.checkpoint import saveable_compat
23from tensorflow.python.eager import context
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import gen_io_ops as io_ops
27from tensorflow.python.ops import io_ops
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.saved_model import registration
30from tensorflow.python.trackable import base
31from tensorflow.python.trackable import constants
32from tensorflow.python.trackable import python_state
33from tensorflow.python.trackable import trackable_utils
34from tensorflow.python.training.saving import saveable_object_util
35from tensorflow.python.util import object_identity
38class CheckpointPosition(object):
39 """Indicates a position within a `_CheckpointRestoreCoordinator`."""
41 __slots__ = ["_checkpoint", "_proto_id", "skip_restore"]
43 def __init__(self, checkpoint, proto_id):
44 """Specify an object within a checkpoint.
46 Args:
47 checkpoint: A _CheckpointRestoreCoordinator object.
48 proto_id: The index of this object in TrackableObjectGraph.nodes.
49 """
50 self._checkpoint = checkpoint
51 self._proto_id = proto_id
52 # This may be set to True if the registered saver cannot be used with this
53 # object.
54 self.skip_restore = False
56 def restore(self, trackable, reader=None):
57 """Restore this value into `trackable`."""
58 with ops.init_scope():
59 if self.bind_object(trackable):
60 # This object's correspondence with a checkpointed object is new, so
61 # process deferred restorations for it and its dependencies.
62 restore_ops = self._restore_descendants(reader)
63 if restore_ops:
64 self._checkpoint.new_restore_ops(restore_ops)
66 def bind_object(self, trackable):
67 """Set a checkpoint<->object correspondence.
69 Args:
70 trackable: The object to record a correspondence for.
72 Returns:
73 True if this is a new assignment, False if this object has already been
74 mapped to a checkpointed `Object` proto.
75 Raises:
76 AssertionError: If another object is already bound to the `Object` proto.
77 """
78 checkpoint = self.checkpoint
79 checkpoint.all_python_objects.add(trackable)
80 current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
81 checkpoint.matched_proto_ids.add(self._proto_id)
82 if current_assignment is None:
83 checkpoint.object_by_proto_id[self._proto_id] = trackable
84 return True # New assignment
85 else:
86 # The object was already mapped for this checkpoint load, which means
87 # we don't need to do anything besides check that the mapping is
88 # consistent (if the dependency DAG is not a tree then there are
89 # multiple paths to the same object).
90 if current_assignment is not trackable:
91 logging.warning(
92 "Inconsistent references when loading the checkpoint into this "
93 "object graph. For example, in the saved checkpoint object, "
94 "`model.layer.weight` and `model.layer_copy.weight` reference the "
95 "same variable, while in the current object these are two different"
96 " variables. The referenced variables are:"
97 f"({current_assignment} and {trackable}).")
98 return False # Not a new assignment
100 def is_simple_variable(self):
101 """Determine whether this value is restorable with a Tensor initializer."""
102 attributes = self.object_proto.attributes
103 return (len(attributes) == 1 and
104 attributes[0].name == constants.VARIABLE_VALUE_KEY and
105 not self.object_proto.children)
107 def value_tensors(self, shape_and_slices=None):
108 """Create value `Tensor`s for this object's attributes.
110 Does not require that the Python object has been created. Used for
111 restore-on-create when executing eagerly.
113 Args:
114 shape_and_slices: A dict mapping from object attribute names to a shape
115 and slice string that will be passed to a RestoreV2 op. If the dict is
116 None or if an object attribute is not in the dict, the full tensor will
117 be restored.
119 Returns:
120 A dictionary mapping from object attribute names to `Tensor`s.
121 """
122 value_tensors = {}
123 for serialized_tensor in self.object_proto.attributes:
124 checkpoint_key = serialized_tensor.checkpoint_key
125 dtype = self._checkpoint.dtype_map[checkpoint_key]
126 base_type = dtype.base_dtype
127 io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
128 with ops.init_scope():
129 with ops.device(io_device):
130 # Run the restore itself on the io_device(CPU or specified).
131 if (shape_and_slices is not None and
132 serialized_tensor.name in shape_and_slices):
133 shape_and_slice = shape_and_slices[serialized_tensor.name]
134 else:
135 shape_and_slice = ""
136 value, = io_ops.restore_v2(
137 prefix=self._checkpoint.save_path_tensor,
138 tensor_names=[checkpoint_key],
139 shape_and_slices=[shape_and_slice],
140 dtypes=[base_type],
141 name="%s_checkpoint_read" % (serialized_tensor.name,))
142 # Copy the value to the current device if necessary.
143 value_tensors[serialized_tensor.name] = array_ops.identity(value)
144 return value_tensors
146 def gather_ops_or_named_saveables(self):
147 """Looks up or creates SaveableObjects which don't have cached ops.
149 Returns:
150 A tuple of (
151 existing_restore_ops: list,
152 named_saveables: dict,
153 python_positions: list,
154 registered_savers: dict)
155 """
157 recorded_registered_saver = self.get_registered_saver_name()
158 if not (self.object_proto.attributes or recorded_registered_saver):
159 return [], {}, [], {}
161 existing_restore_ops = []
162 named_saveables = {}
163 python_positions = []
164 registered_savers = collections.defaultdict(dict)
166 saveable_factories = saveable_object_util.saveable_objects_from_trackable(
167 self.trackable)
168 saver_name = registration.get_registered_saver_name(self.trackable)
170 if recorded_registered_saver:
171 if not self.skip_restore:
172 name = self.object_proto.registered_saver.object_name
173 registered_savers[recorded_registered_saver][name] = self.trackable
174 # Else: Skip restoration of this Trackable. This skip only happens if the
175 # registered saver has enabled `option_restore`. Otherwise, an error would
176 # have been raised at `self.get_registered_saver_name()`.
177 elif saver_name:
178 # In this case, the checkpoint has a recorded serialized tensor but no
179 # registered saver, while the Trackable loading the checkpoint has
180 # migrated to the registered checkpoint functionality (TPUEmbedding is an
181 # example of this).
183 # Set the Trackable's object name to the first checkpoint key that is
184 # stored in checkpoint. If there is a use case that requires the other
185 # keys, then we can take another look at this.
186 registered_savers[saver_name] = {
187 self.object_proto.attributes[0].checkpoint_key: self.trackable
188 }
189 elif isinstance(self.trackable, python_state.PythonState):
190 python_positions.append(self)
191 elif saveable_factories.keys() == {
192 trackable_utils.SERIALIZE_TO_TENSORS_NAME
193 }:
194 existing_restore_ops, named_saveables = (
195 self._create_serialize_to_tensor_saveable(saveable_factories))
196 elif saveable_factories:
197 existing_restore_ops, named_saveables = (
198 self._create_saveables_by_attribute_name(saveable_factories))
199 else:
200 # If no registered savers were found, then it means that one or more
201 # serialized tensors were never used.
202 for serialized_tensor in self.object_proto.attributes:
203 self._checkpoint.unused_attributes.setdefault(
204 self._proto_id, []).append(serialized_tensor.name)
205 return (existing_restore_ops, named_saveables, python_positions,
206 registered_savers)
208 def _create_serialize_to_tensor_saveable(self, saveable_factories):
209 """Creates a saveable using the _serialize_to_tensor method."""
210 # Extract the saveable name from the checkpoint key. This will be used as
211 # the cache key or the name to pass to the saveable factory.
212 suffix = saveable_compat.get_saveable_name(self.trackable) or ""
213 saveable_name = _extract_saveable_name(
214 self.object_proto.attributes[0].checkpoint_key) + suffix
216 # Try to find the cached saveable (only in graph mode).
217 if not context.executing_eagerly():
218 existing_op = self._checkpoint.restore_ops_by_name.get(
219 saveable_name, None)
220 if existing_op is not None:
221 return [existing_op], {}
223 saveables_cache = self._checkpoint.saveables_cache.setdefault(
224 self.trackable, {})
225 if saveable_name in saveables_cache:
226 return [], {saveable_name: saveables_cache[saveable_name]}
228 saveable = saveable_factories[trackable_utils.SERIALIZE_TO_TENSORS_NAME](
229 name=saveable_name)
230 if not context.executing_eagerly():
231 saveables_cache[saveable_name] = saveable
232 return [], {saveable_name: saveable}
234 def _create_saveables_by_attribute_name(self, saveable_factories):
235 """Creates or caches SaveableObjects by matching the attribute names.
237 The attribute name keys in the `saveable_factories` is used to find the
238 corresponding attribute in the object proto. Attributes contain checkpoint
239 keys which are passed to the factory function to generate the
240 SaveableObject.
242 Args:
243 saveable_factories: a dict mapping attribute name to a callable factory
244 function that produces a SaveableObject.
246 Returns:
247 A tuple of (
248 existing_restore_ops: list,
249 named_saveables: dict)
250 """
251 # Name saveables based on the name this object had when it was checkpointed.
252 named_saveables = {}
253 existing_restore_ops = []
255 # Forward compatibility code: when loading a future checkpoint, there may
256 # be multiple SerializedTensors mapped to a single saveable.
257 created_compat_names = set()
259 for serialized_tensor in self.object_proto.attributes:
260 if context.executing_eagerly():
261 existing_op = None
262 else:
263 existing_op = self._checkpoint.restore_ops_by_name.get(
264 serialized_tensor.checkpoint_key, None)
265 if existing_op is not None:
266 existing_restore_ops.append(existing_op)
267 continue
269 if any(serialized_tensor.name.startswith(name)
270 for name in created_compat_names):
271 continue # Saveable has already been created for this tensor.
273 # Only if we don't have cached ops for this SaveableObject, we'll see if
274 # the SaveableObject itself has been cached. If not, we'll make it, and
275 # either way we'll extract new ops from it (or if it has Python state to
276 # restore, we'll run that).
277 saveables_cache = self._checkpoint.saveables_cache
278 if saveables_cache is None:
279 # No SaveableObject caching when executing eagerly.
280 saveable = None
281 else:
282 # If we've already created and cached a SaveableObject for this
283 # attribute, we can re-use it to avoid re-creating some ops when graph
284 # building.
285 saveable_list = saveables_cache.get(self.trackable,
286 {}).get(serialized_tensor.name,
287 (None,))
288 if len(saveable_list) == 1:
289 # Almost every attribute will have exactly one SaveableObject.
290 saveable, = saveable_list
291 else:
292 # Don't use cached SaveableObjects for partitioned variables, which is
293 # the only case where we'd have a list of SaveableObjects. Op caching
294 # will catch them.
295 saveable = None
296 if saveable is not None:
297 # The name of this attribute has changed, so we need to re-generate
298 # the SaveableObject.
299 if serialized_tensor.checkpoint_key not in saveable.name:
300 saveable = None
301 del saveables_cache[self.trackable]
302 if saveable is None:
303 # If there was no cached SaveableObject, create one.
304 # Use the name to check if the Python object has the same attribute.
305 saveable = _get_saveable_from_factory(saveable_factories,
306 serialized_tensor,
307 created_compat_names)
308 if saveable is None:
309 # Purposefully does not throw an exception if attributes have been
310 # added or deleted. Stores unused attributes so an exception can be
311 # raised if the user decides to check that everything in the
312 # checkpoint was loaded.
313 self._checkpoint.unused_attributes.setdefault(
314 self._proto_id, []).append(serialized_tensor.name)
315 continue
316 if saveables_cache is not None:
317 saveables_cache.setdefault(self.trackable,
318 {})[serialized_tensor.name] = [saveable]
319 named_saveables[serialized_tensor.checkpoint_key] = saveable
321 return existing_restore_ops, named_saveables
323 def restore_ops(self, reader=None):
324 """Create or fetch restore ops for this object's attributes.
326 Requires that the `Trackable` Python object has been bound to an object
327 ID in the checkpoint.
329 Args:
330 reader: A `CheckpointReader`. If None, a new instance will be created.
332 Returns:
333 A list of operations when graph building, or an empty list when executing
334 eagerly.
335 """
336 if self._has_registered_saver():
337 raise ValueError("Unable to run individual checkpoint restore for objects"
338 " with registered savers.")
339 (restore_ops, tensor_saveables, python_positions,
340 _) = self.gather_ops_or_named_saveables()
341 restore_ops.extend(
342 self._checkpoint.restore_saveables(
343 tensor_saveables, python_positions, reader=reader))
344 return restore_ops
346 @property
347 def checkpoint(self):
348 return self._checkpoint
350 @property
351 def trackable(self):
352 return self._checkpoint.object_by_proto_id[self._proto_id]
354 @property
355 def object_proto(self):
356 return self._checkpoint.object_graph_proto.nodes[self._proto_id]
358 @property
359 def proto_id(self):
360 return self._proto_id
362 @property
363 def restore_uid(self):
364 return self._checkpoint.restore_uid
366 def __repr__(self):
367 return repr(self.object_proto)
369 def value_shape(self):
370 """The shape of the VARIABLE_VALUE tensor.
372 Returns:
373 If found a TensorShape object, otherwise None.
374 """
375 for serialized_tensor in self.object_proto.attributes:
376 if serialized_tensor.name == constants.VARIABLE_VALUE_KEY:
377 return self._checkpoint.shape_map[serialized_tensor.checkpoint_key]
378 return None
380 def _has_registered_saver(self):
381 return bool(self.object_proto.registered_saver.name)
383 def get_registered_saver_name(self):
384 """Returns the registered saver name defined in the Checkpoint."""
385 if self._has_registered_saver():
386 saver_name = self.object_proto.registered_saver.name
387 try:
388 registration.validate_restore_function(self.trackable, saver_name)
389 except ValueError as e:
390 if registration.get_strict_predicate_restore(saver_name):
391 raise e
392 self.skip_restore = True
393 return saver_name
394 return None
396 def create_slot_variable_position(self, optimizer_object, variable,
397 slot_variable_id, slot_name):
398 """Generates CheckpointPosition for a slot variable.
400 Args:
401 optimizer_object: Optimizer that owns the slot variable.
402 variable: Variable associated with the slot variable.
403 slot_variable_id: ID of the slot variable.
404 slot_name: Name of the slot variable.
406 Returns:
407 If there is a slot variable in the `optimizer_object` that has not been
408 bound to the checkpoint, this function returns a tuple of (
409 new `CheckpointPosition` for the slot variable,
410 the slot variable itself).
411 """
412 slot_variable_position = CheckpointPosition(
413 checkpoint=self.checkpoint, proto_id=slot_variable_id)
414 # pylint: disable=protected-access
415 slot_variable = optimizer_object._create_or_restore_slot_variable(
416 slot_variable_position=slot_variable_position,
417 variable=variable,
418 slot_name=slot_name)
419 # pylint: enable=protected-access
420 if (slot_variable is not None and
421 slot_variable_position.bind_object(slot_variable)):
422 return slot_variable_position, slot_variable
423 else:
424 return None, None
426 def create_child_position(self, node_id):
427 return CheckpointPosition(checkpoint=self.checkpoint, proto_id=node_id)
429 def _restore_descendants(self, reader=None):
430 """Restore the bound Trackable and dependencies (may be deferred)."""
431 # Attempt a breadth-first traversal, since presumably the user has more
432 # control over shorter paths. If we don't have all of the dependencies at
433 # this point, the end result is not breadth-first (since other deferred
434 # traversals will happen later).
436 # You may be wondering why elements in the `visit_queue` are tuples that
437 # contains both CheckpointPositions and their Trackable. The reason is that
438 # Optimizers will not keep a strong reference to slot vars for
439 # ShardedVariables. The slot variable must be kept in memory until the
440 # restore saveables have been created.
441 visit_queue = collections.deque([(self, self.trackable)])
442 restore_ops = []
443 tensor_saveables = {}
444 python_positions = []
445 registered_savers = collections.defaultdict(dict)
446 while visit_queue:
447 current_position, _ = visit_queue.popleft()
449 # Restore using the ops defined in a Saveable or registered function.
450 (new_restore_ops, new_tensor_saveables, new_python_positions,
451 new_registered_savers) = current_position._single_restore() # pylint: disable=protected-access
452 restore_ops.extend(new_restore_ops)
453 tensor_saveables.update(new_tensor_saveables)
454 python_positions.extend(new_python_positions)
455 for saver_name, trackable_map in new_registered_savers.items():
456 registered_savers[saver_name].update(trackable_map)
458 # Pass the restoration to the dependencies.
459 _queue_children_for_restoration(current_position, visit_queue)
460 _queue_slot_variables(current_position, visit_queue)
462 restore_ops.extend(
463 current_position.checkpoint.restore_saveables(
464 tensor_saveables,
465 python_positions,
466 registered_savers,
467 reader=reader))
468 return restore_ops
470 def _single_restore(self):
471 """Restores the trackable."""
472 trackable = self.trackable
473 trackable._maybe_initialize_trackable() # pylint: disable=protected-access
474 checkpoint = self.checkpoint
475 # If the UID of this restore is lower than our current update UID, we don't
476 # need to actually restore the object.
477 if checkpoint.restore_uid > trackable._update_uid: # pylint: disable=protected-access
478 restore_ops, tensor_saveables, python_positions, registered_savers = (
479 self.gather_ops_or_named_saveables())
480 trackable._update_uid = checkpoint.restore_uid # pylint: disable=protected-access
481 else:
482 restore_ops = ()
483 tensor_saveables = {}
484 python_positions = ()
485 registered_savers = {}
486 return restore_ops, tensor_saveables, python_positions, registered_savers
489def restore_nodes(save_path, nodes_to_restore):
490 """Restores nodes from a dict.
492 Requires that the `Trackable` Python object has been bound to an object
493 ID in the checkpoint.
495 Args:
496 save_path: a string represents path to the checkpoint.
497 nodes_to_restore: a dict maps `node_id` to `trackable` to be restored.
498 """
499 if save_path is None:
500 raise ValueError("save_path cannot be empty.")
501 if not isinstance(nodes_to_restore, dict):
502 raise ValueError(
503 "Expecting a dictionary of node_id to Trackable for nodes_to_restore.")
505 ckpt_view = checkpoint_view.CheckpointView(save_path)
506 ckpt_view_descendants = ckpt_view.descendants()
507 for node_id, trackable in nodes_to_restore.items():
508 # node_id does not have a corresponding Checkpoint value.
509 if (node_id not in ckpt_view_descendants or
510 ckpt_view._object_graph_proto.nodes[ # pylint: disable=protected-access
511 node_id] is None):
512 raise ValueError(
513 f"The expected node_id: {node_id} to Trackable {trackable} to "
514 "restore does not exist in the checkpoint.")
515 # Trackable mapped to node_id to restore is empty.
516 if trackable is None or not isinstance(trackable, base.Trackable):
517 raise ValueError(
518 f"Expecting a valid Trackable to node_id: {node_id} but got "
519 f"trackable: {trackable}."
520 )
522 serialized_tensors = object_identity.ObjectIdentityDictionary()
523 for node_id, current_trackable in nodes_to_restore.items():
524 ckpt_contains_serialized_tensors = ckpt_view._object_graph_proto.nodes[ # pylint: disable=protected-access
525 node_id].attributes
526 node = ckpt_view._object_graph_proto.nodes[node_id] # pylint: disable=protected-access
527 trackable_has_serialize_to_tensor = saveable_object_util.trackable_has_serialize_to_tensor(
528 current_trackable)
529 if not trackable_has_serialize_to_tensor:
530 if not node.attributes:
531 if saveable_object_util.saveable_objects_from_trackable(
532 current_trackable):
533 raise ValueError(
534 f"Trackable {current_trackable} expects checkpointed values but "
535 "checkpoint does not contain serialized tensors for node_id: "
536 f"{node_id}.")
537 else:
538 continue
539 object_names = object_identity.ObjectIdentityDictionary()
540 object_names[current_trackable] = trackable_utils.extract_object_name(
541 node.attributes[0].checkpoint_key)
542 checkpoint_factory_map, _ = save_util_v1.get_checkpoint_factories_and_keys(
543 object_names, None)
544 saveable_objects = save_util_v1.generate_saveable_objects(
545 checkpoint_factory_map)[0]
546 if len(node.attributes) != len(saveable_objects):
547 raise ValueError("Size for saveable_objects for Trackable: "
548 f"{len(saveable_objects)} did not match the size for "
549 "serialized_tensors for checkpoint: "
550 f"{len(node.attributes)}.")
551 current_trackable = saveable_object_util.SaveableCompatibilityConverter(
552 current_trackable, saveable_objects)
554 serialized_tensors[
555 current_trackable] = current_trackable._serialize_to_tensors() # pylint: disable=protected-access
556 trackable_expects_ckpted_value = bool(serialized_tensors[current_trackable])
558 if trackable_expects_ckpted_value and not ckpt_contains_serialized_tensors:
559 raise ValueError(
560 f"Trackable {current_trackable} expects checkpointed values but "
561 "checkpoint does not contain serialized tensors for node_id: "
562 f"{node_id}.")
564 if not trackable_expects_ckpted_value and ckpt_contains_serialized_tensors:
565 raise ValueError(
566 f"Trackable {current_trackable} does not expect checkpointed "
567 "values but checkpoint contains serialized tensors: "
568 f"{ckpt_contains_serialized_tensors} for node_id: {node_id}.")
570 if len(node.attributes) != len(serialized_tensors[current_trackable]):
571 raise ValueError("Size for serialized_tensors for Trackable: "
572 f"{len(serialized_tensors[current_trackable])} did not "
573 "match size for serialized_tensors for checkpoint: "
574 f"{len(node.attributes)}.")
576 if not trackable_has_serialize_to_tensor:
577 functional_saver.MultiDeviceSaver(serialized_tensors).restore(save_path)
578 else:
579 # Converts attribute.name to attribute.checkpoint_key since that's what
580 # restore method is expecting. i.e., converts "a" to "/.ATTRIBUTES/a".
581 serialized_tensors_renamed = object_identity.ObjectIdentityDictionary()
582 serialized_tensors_renamed[current_trackable] = {}
583 for attribute in node.attributes:
584 name = attribute.name
585 checkpoint_key = attribute.checkpoint_key
586 serialized_tensors_renamed[current_trackable][
587 checkpoint_key] = serialized_tensors[current_trackable][name]
588 functional_saver.MultiDeviceSaver(serialized_tensors_renamed).restore(
589 save_path)
592def _queue_children_for_restoration(checkpoint_position, visit_queue):
593 """Queues the restoration of trackable's children or defers them."""
594 # pylint: disable=protected-access
595 trackable = checkpoint_position.trackable
596 for child in checkpoint_position.object_proto.children:
597 # trackable._lookup_dependency can be expensive so first check if this node
598 # already has an object correspondence. If so we skip this node.
599 correspondence = checkpoint_position.checkpoint.object_by_proto_id.get(
600 child.node_id, None
601 )
602 if correspondence is not None:
603 continue
604 child_position = checkpoint_position.create_child_position(child.node_id)
605 local_object = trackable._lookup_dependency(child.local_name)
606 child_proto = child_position.object_proto
607 if local_object is None:
608 # We don't yet have a dependency registered with this name. Save it
609 # in case we do.
610 if child_proto.HasField("has_checkpoint_values"):
611 has_value = child_proto.has_checkpoint_values.value
612 else:
613 # If the field is not set, do a simple check to see if the dependency
614 # has children and/or checkpointed values.
615 has_value = bool(
616 child_proto.children or child_proto.attributes or
617 child_proto.slot_variables or
618 child_proto.HasField("registered_saver"))
619 if has_value:
620 trackable._deferred_dependencies.setdefault(child.local_name,
621 []).append(child_position)
622 else:
623 if child_position.bind_object(trackable=local_object):
624 # This object's correspondence is new, so dependencies need to be
625 # visited. Delay doing it so that we get a breadth-first dependency
626 # resolution order (shallowest paths first). The caller is responsible
627 # for emptying visit_queue.
628 visit_queue.append((child_position, local_object))
631_DeferredSlotVariableRestoration = collections.namedtuple(
632 "_DeferredSlotVariableRestoration", [
633 "original_variable",
634 "slot_variable_id",
635 "slot_name",
636 ])
639def _queue_slot_variables(checkpoint_position, visit_queue):
640 """Queues slot variables for restoration."""
641 trackable = checkpoint_position.trackable
642 checkpoint = checkpoint_position.checkpoint
643 for deferred_slot_restoration in (checkpoint.deferred_slot_restorations.pop(
644 checkpoint_position.proto_id, ())):
645 slot_variable_position, slot_variable = (
646 checkpoint_position.create_slot_variable_position(
647 trackable, deferred_slot_restoration.original_variable,
648 deferred_slot_restoration.slot_variable_id,
649 deferred_slot_restoration.slot_name))
650 if slot_variable_position is not None:
651 visit_queue.append((slot_variable_position, slot_variable))
652 for slot_restoration in checkpoint.slot_restorations.pop(
653 checkpoint_position.proto_id, ()):
654 optimizer_object = checkpoint.object_by_proto_id.get(
655 slot_restoration.optimizer_id, None)
656 if optimizer_object is None:
657 # The optimizer has not yet been created or tracked. Record in the
658 # checkpoint that the slot variables need to be restored when it is.
659 checkpoint.deferred_slot_restorations.setdefault(
660 slot_restoration.optimizer_id, []).append(
661 _DeferredSlotVariableRestoration(
662 original_variable=trackable,
663 slot_variable_id=slot_restoration.slot_variable_id,
664 slot_name=slot_restoration.slot_name))
666 # `optimizer_object` can be a `Checkpoint` when user only needs the
667 # attributes the optimizer holds, such as `iterations`. In those cases,
668 # it would not have the optimizer's `_create_or_restore_slot_variable`
669 # method.
670 elif hasattr(optimizer_object, "_create_or_restore_slot_variable"):
671 slot_variable_position, slot_variable = (
672 checkpoint_position.create_slot_variable_position(
673 optimizer_object, trackable, slot_restoration.slot_variable_id,
674 slot_restoration.slot_name))
675 if slot_variable_position is not None:
676 visit_queue.append((slot_variable_position, slot_variable))
679def _extract_saveable_name(checkpoint_key):
680 # Substring the checkpoint key to the end of the "{...}.ATTRIBUTES/"
681 search_key = trackable_utils.OBJECT_ATTRIBUTES_NAME + "/"
682 return checkpoint_key[:checkpoint_key.index(search_key) + len(search_key)]
685def _get_saveable_from_factory(saveable_factories, serialized_tensor,
686 created_compat_names):
687 """Returns the saveable generated from the factory method."""
688 matched_factory = None
690 # The `expected_factory_name` is used to find the right saveable factory,
691 # while the `factory_input_name` is the value that is passed to the factory
692 # method to instantiate the SaveableObject.
693 expected_factory_name = serialized_tensor.name
694 factory_input_name = serialized_tensor.checkpoint_key
696 # Case 1: the name already exactly matches a key in saveable_factories.
697 if expected_factory_name in saveable_factories:
698 matched_factory = saveable_factories[expected_factory_name]
700 # Case 2: (Forward compat) The serialized name is composed of
701 # "factory_name" + "SUFFIX". Get the matching factory name.
702 if matched_factory is None:
704 for factory_name, factory in saveable_factories.items():
705 if expected_factory_name.startswith(factory_name):
706 if matched_factory is not None:
707 # This condition is met in the extreme edge case where the object
708 # returns two saveable factories with similar names. This is very
709 # unlikely because there zero objects inside TensorFlow that use
710 # more than one saveable factory.
711 raise ValueError("Forward compatibility load error: Unable to load "
712 "checkpoint saved in future version of TensorFlow. "
713 "Please update your version of TensorFlow to the "
714 "version in which the checkpoint was saved.")
716 matched_factory = factory
717 factory_input_name = _extract_saveable_name(
718 serialized_tensor.checkpoint_key) + factory_name
719 created_compat_names.add(factory_name)
721 if callable(matched_factory):
722 return matched_factory(name=factory_input_name)
723 return matched_factory