Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/util.py: 15%
81 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for extracting and writing checkpoint info`."""
17from tensorflow.core.protobuf import trackable_object_graph_pb2
18from tensorflow.python.ops import resource_variable_ops
19from tensorflow.python.ops import variables
20from tensorflow.python.trackable import trackable_utils
21from tensorflow.python.util import object_identity
24def serialize_slot_variables(trackable_objects, node_ids, object_names):
25 """Gather and name slot variables."""
26 non_slot_objects = list(trackable_objects)
27 slot_variables = object_identity.ObjectIdentityDictionary()
28 for trackable in non_slot_objects:
29 # TODO(b/110718070): Fix Keras imports.
30 # Note: dir() is used rather than hasattr() here to avoid triggering
31 # custom __getattr__ code, see b/152031870 for context.
32 if "get_slot_names" in dir(trackable):
33 slot_names = trackable.get_slot_names()
34 for slot_name in slot_names:
35 for original_variable_node_id, original_variable in enumerate(
36 non_slot_objects):
37 try:
38 slot_variable = trackable.get_slot(original_variable, slot_name)
39 except (AttributeError, KeyError):
40 slot_variable = None
41 if slot_variable is None:
42 continue
43 slot_variable._maybe_initialize_trackable() # pylint: disable=protected-access
44 if slot_variable._trackable_children(): # pylint: disable=protected-access
45 # TODO(allenl): Gather dependencies of slot variables.
46 raise NotImplementedError(
47 "Currently only variables with no dependencies can be saved as "
48 "slot variables. File a feature request if this limitation "
49 "bothers you.")
50 if slot_variable in node_ids:
51 raise NotImplementedError(
52 "A slot variable was re-used as a dependency of a Trackable "
53 f"object: {slot_variable}. This is not currently allowed. "
54 "File a feature request if this limitation bothers you.")
55 checkpoint_name = trackable_utils.slot_variable_key(
56 variable_path=object_names[original_variable],
57 optimizer_path=object_names[trackable],
58 slot_name=slot_name)
59 object_names[slot_variable] = checkpoint_name
60 slot_variable_node_id = len(trackable_objects)
61 node_ids[slot_variable] = slot_variable_node_id
62 trackable_objects.append(slot_variable)
63 slot_variable_proto = (
64 trackable_object_graph_pb2.TrackableObjectGraph.TrackableObject
65 .SlotVariableReference(
66 slot_name=slot_name,
67 original_variable_node_id=original_variable_node_id,
68 slot_variable_node_id=slot_variable_node_id))
69 slot_variables.setdefault(trackable, []).append(slot_variable_proto)
70 return slot_variables
73def get_mapped_trackable(trackable, object_map):
74 """Returns the mapped trackable if possible, otherwise returns trackable."""
75 if object_map is None:
76 return trackable
77 else:
78 return object_map.get(trackable, trackable)
81def get_full_name(var):
82 """Gets the full name of variable for name-based checkpoint compatiblity."""
83 # pylint: disable=protected-access
84 if (not (isinstance(var, variables.Variable) or
85 # Some objects do not subclass Variable but still act as one.
86 resource_variable_ops.is_resource_variable(var))):
87 return ""
89 if getattr(var, "_save_slice_info", None) is not None:
90 # Use getattr because `var._save_slice_info` may be set as `None`.
91 return var._save_slice_info.full_name
92 else:
93 return var._shared_name
94 # pylint: enable=protected-access
97def add_checkpoint_values_check(object_graph_proto):
98 """Determines which objects have checkpoint values and save this to the proto.
100 Args:
101 object_graph_proto: A `TrackableObjectGraph` proto.
102 """
103 # Trackable -> set of all trackables that depend on it (the "parents").
104 # If a trackable has checkpoint values, then all of the parents can be
105 # marked as having checkpoint values.
106 parents = {}
107 checkpointed_trackables = object_identity.ObjectIdentitySet()
109 # First pass: build dictionary of parent objects and initial set of
110 # checkpointed trackables.
111 checkpointed_trackables = set()
112 for node_id, object_proto in enumerate(object_graph_proto.nodes):
113 if (object_proto.attributes or object_proto.slot_variables or
114 object_proto.HasField("registered_saver")):
115 checkpointed_trackables.add(node_id)
116 for child_proto in object_proto.children:
117 child = child_proto.node_id
118 if child not in parents:
119 parents[child] = set()
120 parents[child].add(node_id)
122 # Second pass: add all connected parents to set of checkpointed trackables.
123 to_visit = set()
124 to_visit.update(checkpointed_trackables)
126 while to_visit:
127 trackable = to_visit.pop()
128 if trackable not in parents:
129 # Some trackables may not have parents (e.g. slot variables).
130 continue
131 current_parents = parents.pop(trackable)
132 checkpointed_trackables.update(current_parents)
133 for parent in current_parents:
134 if parent in parents:
135 to_visit.add(parent)
137 for node_id, object_proto in enumerate(object_graph_proto.nodes):
138 object_proto.has_checkpoint_values.value = bool(
139 node_id in checkpointed_trackables)
142def objects_ids_and_slot_variables_and_paths(graph_view):
143 """Traverse the object graph and list all accessible objects.
145 Looks for `Trackable` objects which are dependencies of
146 `root_trackable`. Includes slot variables only if the variable they are
147 slotting for and the optimizer are dependencies of `root_trackable`
148 (i.e. if they would be saved with a checkpoint).
150 Args:
151 graph_view: A GraphView object.
153 Returns:
154 A tuple of (trackable objects, paths from root for each object,
155 object -> node id, slot variables, object_names)
156 """
157 trackable_objects, node_paths = graph_view.breadth_first_traversal()
158 object_names = object_identity.ObjectIdentityDictionary()
159 for obj, path in node_paths.items():
160 object_names[obj] = trackable_utils.object_path_to_string(path)
161 node_ids = object_identity.ObjectIdentityDictionary()
162 for node_id, node in enumerate(trackable_objects):
163 node_ids[node] = node_id
164 slot_variables = serialize_slot_variables(
165 trackable_objects=trackable_objects,
166 node_ids=node_ids,
167 object_names=object_names)
168 return (trackable_objects, node_paths, node_ids, slot_variables, object_names)
171def list_objects(graph_view):
172 """Traverse the object graph and list all accessible objects."""
173 trackable_objects = objects_ids_and_slot_variables_and_paths(graph_view)[0]
174 return trackable_objects