Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/checkpoint_view.py: 23%
74 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""Manages a Checkpoint View."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16import collections
18from tensorflow.core.protobuf import trackable_object_graph_pb2
19from tensorflow.python.checkpoint import trackable_view
20from tensorflow.python.framework import errors_impl
21from tensorflow.python.platform import tf_logging as logging
22from tensorflow.python.trackable import base
23from tensorflow.python.training import py_checkpoint_reader
24from tensorflow.python.util import object_identity
25from tensorflow.python.util.tf_export import tf_export
28@tf_export("train.CheckpointView", v1=[])
29class CheckpointView(object):
30 """Gathers and serializes a checkpoint view.
32 This is for loading specific portions of a module from a
33 checkpoint, and be able to compare two modules by matching components.
35 Example usage:
37 >>> class SimpleModule(tf.Module):
38 ... def __init__(self, name=None):
39 ... super().__init__(name=name)
40 ... self.a_var = tf.Variable(5.0)
41 ... self.b_var = tf.Variable(4.0)
42 ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
44 >>> root = SimpleModule(name="root")
45 >>> root.leaf = SimpleModule(name="leaf")
46 >>> ckpt = tf.train.Checkpoint(root)
47 >>> save_path = ckpt.save('/tmp/tf_ckpts')
48 >>> checkpoint_view = tf.train.CheckpointView(save_path)
50 Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the dictionary
51 of all children directly linked to the checkpoint root.
53 >>> for name, node_id in checkpoint_view.children(0).items():
54 ... print(f"- name: '{name}', node_id: {node_id}")
55 - name: 'a_var', node_id: 1
56 - name: 'b_var', node_id: 2
57 - name: 'vars', node_id: 3
58 - name: 'leaf', node_id: 4
59 - name: 'root', node_id: 0
60 - name: 'save_counter', node_id: 5
62 """
64 def __init__(self, save_path):
65 """Configure the checkpoint view.
67 Args:
68 save_path: The path to the checkpoint.
70 Raises:
71 ValueError: If the save_path does not lead to a TF2 checkpoint.
72 """
74 reader = py_checkpoint_reader.NewCheckpointReader(save_path)
75 try:
76 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
77 except errors_impl.NotFoundError as not_found_error:
78 raise ValueError(
79 f"The specified checkpoint \"{save_path}\" does not appear to be "
80 "object-based (saved with TF2) since it is missing the key "
81 f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the "
82 "TF1 name-based saver and does not contain an object dependency graph."
83 ) from not_found_error
84 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
85 object_graph_proto.ParseFromString(object_graph_string)
86 self._object_graph_proto = object_graph_proto
88 def children(self, node_id):
89 """Returns all child trackables attached to obj.
91 Args:
92 node_id: Id of the node to return its children.
94 Returns:
95 Dictionary of all children attached to the object with name to node_id.
96 """
97 return {
98 child.local_name: child.node_id
99 for child in self._object_graph_proto.nodes[node_id].children
100 }
102 def descendants(self):
103 """Returns a list of trackables by node_id attached to obj."""
105 return list(self._descendants_with_paths().keys())
107 def _descendants_with_paths(self):
108 """Returns a dict of descendants by node_id and paths to node.
110 The names returned by this private method are subject to change.
111 """
113 all_nodes_with_paths = {}
114 to_visit = collections.deque([0])
115 # node_id:0 will always be "root".
116 all_nodes_with_paths[0] = "root"
117 path = all_nodes_with_paths.get(0)
118 while to_visit:
119 node_id = to_visit.popleft()
120 obj = self._object_graph_proto.nodes[node_id]
121 for child in obj.children:
122 if child.node_id == 0 or child.node_id in all_nodes_with_paths.keys():
123 continue
124 path = all_nodes_with_paths.get(node_id)
125 if child.node_id not in all_nodes_with_paths.keys():
126 to_visit.append(child.node_id)
127 all_nodes_with_paths[child.node_id] = path + "." + child.local_name
128 return all_nodes_with_paths
130 def match(self, obj):
131 """Returns all matching trackables between CheckpointView and Trackable.
133 Matching trackables represents trackables with the same name and position in
134 graph.
136 Args:
137 obj: `Trackable` root.
139 Returns:
140 Dictionary containing all overlapping trackables that maps `node_id` to
141 `Trackable`.
143 Example usage:
145 >>> class SimpleModule(tf.Module):
146 ... def __init__(self, name=None):
147 ... super().__init__(name=name)
148 ... self.a_var = tf.Variable(5.0)
149 ... self.b_var = tf.Variable(4.0)
150 ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
152 >>> root = SimpleModule(name="root")
153 >>> leaf = root.leaf = SimpleModule(name="leaf")
154 >>> leaf.leaf3 = tf.Variable(6.0, name="leaf3")
155 >>> leaf.leaf4 = tf.Variable(7.0, name="leaf4")
156 >>> ckpt = tf.train.Checkpoint(root)
157 >>> save_path = ckpt.save('/tmp/tf_ckpts')
158 >>> checkpoint_view = tf.train.CheckpointView(save_path)
160 >>> root2 = SimpleModule(name="root")
161 >>> leaf2 = root2.leaf2 = SimpleModule(name="leaf2")
162 >>> leaf2.leaf3 = tf.Variable(6.0)
163 >>> leaf2.leaf4 = tf.Variable(7.0)
165 Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the
166 dictionary of all children directly linked to the checkpoint root.
168 >>> checkpoint_view_match = checkpoint_view.match(root2).items()
169 >>> for item in checkpoint_view_match:
170 ... print(item)
171 (0, ...)
172 (1, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)
173 (2, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)
174 (3, ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,
175 numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))
176 (6, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>)
177 (7, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>)
179 """
180 if not isinstance(obj, base.Trackable):
181 raise ValueError(f"Expected a Trackable, got {obj} of type {type(obj)}.")
183 overlapping_nodes = {}
184 # Root node is always matched.
185 overlapping_nodes[0] = obj
187 # Queue of tuples of node_id and trackable.
188 to_visit = collections.deque([(0, obj)])
189 visited = set()
190 view = trackable_view.TrackableView(obj)
191 while to_visit:
192 current_node_id, current_trackable = to_visit.popleft()
193 trackable_children = view.children(current_trackable)
194 for child_name, child_node_id in self.children(current_node_id).items():
195 if child_node_id in visited or child_node_id == 0:
196 continue
197 if child_name in trackable_children:
198 current_assignment = overlapping_nodes.get(child_node_id)
199 if current_assignment is None:
200 overlapping_nodes[child_node_id] = trackable_children[child_name]
201 to_visit.append((child_node_id, trackable_children[child_name]))
202 else:
203 # The object was already mapped for this checkpoint load, which
204 # means we don't need to do anything besides check that the mapping
205 # is consistent (if the dependency DAG is not a tree then there are
206 # multiple paths to the same object).
207 if current_assignment is not trackable_children[child_name]:
208 logging.warning(
209 "Inconsistent references when matching the checkpoint into "
210 "this object graph. The referenced objects are: "
211 f"({current_assignment} and "
212 f"{trackable_children[child_name]}).")
213 visited.add(current_node_id)
214 return overlapping_nodes
216 def diff(self, obj):
217 """Returns diff between CheckpointView and Trackable.
219 This method is intended to be used to compare the object stored in a
220 checkpoint vs a live model in Python. For example, if checkpoint
221 restoration fails the `assert_consumed()` or
222 `assert_existing_objects_matched()` checks, you can use this to list out
223 the objects/checkpoint nodes which were not restored.
225 Example Usage:
227 >>> class SimpleModule(tf.Module):
228 ... def __init__(self, name=None):
229 ... super().__init__(name=name)
230 ... self.a_var = tf.Variable(5.0)
231 ... self.b_var = tf.Variable(4.0)
232 ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)]
234 >>> root = SimpleModule(name="root")
235 >>> leaf = root.leaf = SimpleModule(name="leaf")
236 >>> leaf.leaf3 = tf.Variable(6.0, name="leaf3")
237 >>> leaf.leaf4 = tf.Variable(7.0, name="leaf4")
238 >>> ckpt = tf.train.Checkpoint(root)
239 >>> save_path = ckpt.save('/tmp/tf_ckpts')
240 >>> checkpoint_view = tf.train.CheckpointView(save_path)
242 >>> root2 = SimpleModule(name="root")
243 >>> leaf2 = root2.leaf2 = SimpleModule(name="leaf2")
244 >>> leaf2.leaf3 = tf.Variable(6.0)
245 >>> leaf2.leaf4 = tf.Variable(7.0)
247 Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the
248 dictionary of all children directly linked to the checkpoint root.
250 >>> checkpoint_view_diff = checkpoint_view.diff(root2)
251 >>> checkpoint_view_match = checkpoint_view_diff[0].items()
252 >>> for item in checkpoint_view_match:
253 ... print(item)
254 (0, ...)
255 (1, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>)
256 (2, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>)
257 (3, ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32,
258 numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]))
259 (6, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>)
260 (7, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>)
262 >>> only_in_checkpoint_view = checkpoint_view_diff[1]
263 >>> print(only_in_checkpoint_view)
264 [4, 5, 8, 9, 10, 11, 12, 13, 14]
266 >>> only_in_trackable = checkpoint_view_diff[2]
267 >>> print(only_in_trackable)
268 [..., <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>,
269 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>,
270 ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
271 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]),
272 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=6.0>,
273 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0>,
274 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
275 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]
277 Args:
278 obj: `Trackable` root.
280 Returns:
281 Tuple of (
282 - Overlaps: Dictionary containing all overlapping trackables that maps
283 `node_id` to `Trackable`, same as CheckpointView.match().
284 - Only in CheckpointView: List of `node_id` that only exist in
285 CheckpointView.
286 - Only in Trackable: List of `Trackable` that only exist in Trackable.
287 )
289 """
291 overlapping_nodes = self.match(obj)
292 only_in_checkpoint_view = []
293 only_in_trackable = []
294 for node_id in self.descendants():
295 if node_id not in overlapping_nodes.keys():
296 only_in_checkpoint_view.append(node_id)
297 for trackable in trackable_view.TrackableView(obj).descendants():
298 if trackable not in object_identity.ObjectIdentitySet(
299 overlapping_nodes.values()):
300 only_in_trackable.append(trackable)
301 return overlapping_nodes, only_in_checkpoint_view, only_in_trackable