Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/checkpoint_view.py: 23%

74 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""Manages a Checkpoint View.""" 

2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15# ============================================================================== 

16import collections 

17 

18from tensorflow.core.protobuf import trackable_object_graph_pb2 

19from tensorflow.python.checkpoint import trackable_view 

20from tensorflow.python.framework import errors_impl 

21from tensorflow.python.platform import tf_logging as logging 

22from tensorflow.python.trackable import base 

23from tensorflow.python.training import py_checkpoint_reader 

24from tensorflow.python.util import object_identity 

25from tensorflow.python.util.tf_export import tf_export 

26 

27 

28@tf_export("train.CheckpointView", v1=[]) 

29class CheckpointView(object): 

30 """Gathers and serializes a checkpoint view. 

31 

32 This is for loading specific portions of a module from a 

33 checkpoint, and be able to compare two modules by matching components. 

34 

35 Example usage: 

36 

37 >>> class SimpleModule(tf.Module): 

38 ... def __init__(self, name=None): 

39 ... super().__init__(name=name) 

40 ... self.a_var = tf.Variable(5.0) 

41 ... self.b_var = tf.Variable(4.0) 

42 ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)] 

43 

44 >>> root = SimpleModule(name="root") 

45 >>> root.leaf = SimpleModule(name="leaf") 

46 >>> ckpt = tf.train.Checkpoint(root) 

47 >>> save_path = ckpt.save('/tmp/tf_ckpts') 

48 >>> checkpoint_view = tf.train.CheckpointView(save_path) 

49 

50 Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the dictionary 

51 of all children directly linked to the checkpoint root. 

52 

53 >>> for name, node_id in checkpoint_view.children(0).items(): 

54 ... print(f"- name: '{name}', node_id: {node_id}") 

55 - name: 'a_var', node_id: 1 

56 - name: 'b_var', node_id: 2 

57 - name: 'vars', node_id: 3 

58 - name: 'leaf', node_id: 4 

59 - name: 'root', node_id: 0 

60 - name: 'save_counter', node_id: 5 

61 

62 """ 

63 

64 def __init__(self, save_path): 

65 """Configure the checkpoint view. 

66 

67 Args: 

68 save_path: The path to the checkpoint. 

69 

70 Raises: 

71 ValueError: If the save_path does not lead to a TF2 checkpoint. 

72 """ 

73 

74 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 

75 try: 

76 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 

77 except errors_impl.NotFoundError as not_found_error: 

78 raise ValueError( 

79 f"The specified checkpoint \"{save_path}\" does not appear to be " 

80 "object-based (saved with TF2) since it is missing the key " 

81 f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the " 

82 "TF1 name-based saver and does not contain an object dependency graph." 

83 ) from not_found_error 

84 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 

85 object_graph_proto.ParseFromString(object_graph_string) 

86 self._object_graph_proto = object_graph_proto 

87 

88 def children(self, node_id): 

89 """Returns all child trackables attached to obj. 

90 

91 Args: 

92 node_id: Id of the node to return its children. 

93 

94 Returns: 

95 Dictionary of all children attached to the object with name to node_id. 

96 """ 

97 return { 

98 child.local_name: child.node_id 

99 for child in self._object_graph_proto.nodes[node_id].children 

100 } 

101 

102 def descendants(self): 

103 """Returns a list of trackables by node_id attached to obj.""" 

104 

105 return list(self._descendants_with_paths().keys()) 

106 

107 def _descendants_with_paths(self): 

108 """Returns a dict of descendants by node_id and paths to node. 

109 

110 The names returned by this private method are subject to change. 

111 """ 

112 

113 all_nodes_with_paths = {} 

114 to_visit = collections.deque([0]) 

115 # node_id:0 will always be "root". 

116 all_nodes_with_paths[0] = "root" 

117 path = all_nodes_with_paths.get(0) 

118 while to_visit: 

119 node_id = to_visit.popleft() 

120 obj = self._object_graph_proto.nodes[node_id] 

121 for child in obj.children: 

122 if child.node_id == 0 or child.node_id in all_nodes_with_paths.keys(): 

123 continue 

124 path = all_nodes_with_paths.get(node_id) 

125 if child.node_id not in all_nodes_with_paths.keys(): 

126 to_visit.append(child.node_id) 

127 all_nodes_with_paths[child.node_id] = path + "." + child.local_name 

128 return all_nodes_with_paths 

129 

130 def match(self, obj): 

131 """Returns all matching trackables between CheckpointView and Trackable. 

132 

133 Matching trackables represents trackables with the same name and position in 

134 graph. 

135 

136 Args: 

137 obj: `Trackable` root. 

138 

139 Returns: 

140 Dictionary containing all overlapping trackables that maps `node_id` to 

141 `Trackable`. 

142 

143 Example usage: 

144 

145 >>> class SimpleModule(tf.Module): 

146 ... def __init__(self, name=None): 

147 ... super().__init__(name=name) 

148 ... self.a_var = tf.Variable(5.0) 

149 ... self.b_var = tf.Variable(4.0) 

150 ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)] 

151 

152 >>> root = SimpleModule(name="root") 

153 >>> leaf = root.leaf = SimpleModule(name="leaf") 

154 >>> leaf.leaf3 = tf.Variable(6.0, name="leaf3") 

155 >>> leaf.leaf4 = tf.Variable(7.0, name="leaf4") 

156 >>> ckpt = tf.train.Checkpoint(root) 

157 >>> save_path = ckpt.save('/tmp/tf_ckpts') 

158 >>> checkpoint_view = tf.train.CheckpointView(save_path) 

159 

160 >>> root2 = SimpleModule(name="root") 

161 >>> leaf2 = root2.leaf2 = SimpleModule(name="leaf2") 

162 >>> leaf2.leaf3 = tf.Variable(6.0) 

163 >>> leaf2.leaf4 = tf.Variable(7.0) 

164 

165 Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the 

166 dictionary of all children directly linked to the checkpoint root. 

167 

168 >>> checkpoint_view_match = checkpoint_view.match(root2).items() 

169 >>> for item in checkpoint_view_match: 

170 ... print(item) 

171 (0, ...) 

172 (1, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>) 

173 (2, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>) 

174 (3, ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, 

175 numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>])) 

176 (6, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>) 

177 (7, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>) 

178 

179 """ 

180 if not isinstance(obj, base.Trackable): 

181 raise ValueError(f"Expected a Trackable, got {obj} of type {type(obj)}.") 

182 

183 overlapping_nodes = {} 

184 # Root node is always matched. 

185 overlapping_nodes[0] = obj 

186 

187 # Queue of tuples of node_id and trackable. 

188 to_visit = collections.deque([(0, obj)]) 

189 visited = set() 

190 view = trackable_view.TrackableView(obj) 

191 while to_visit: 

192 current_node_id, current_trackable = to_visit.popleft() 

193 trackable_children = view.children(current_trackable) 

194 for child_name, child_node_id in self.children(current_node_id).items(): 

195 if child_node_id in visited or child_node_id == 0: 

196 continue 

197 if child_name in trackable_children: 

198 current_assignment = overlapping_nodes.get(child_node_id) 

199 if current_assignment is None: 

200 overlapping_nodes[child_node_id] = trackable_children[child_name] 

201 to_visit.append((child_node_id, trackable_children[child_name])) 

202 else: 

203 # The object was already mapped for this checkpoint load, which 

204 # means we don't need to do anything besides check that the mapping 

205 # is consistent (if the dependency DAG is not a tree then there are 

206 # multiple paths to the same object). 

207 if current_assignment is not trackable_children[child_name]: 

208 logging.warning( 

209 "Inconsistent references when matching the checkpoint into " 

210 "this object graph. The referenced objects are: " 

211 f"({current_assignment} and " 

212 f"{trackable_children[child_name]}).") 

213 visited.add(current_node_id) 

214 return overlapping_nodes 

215 

216 def diff(self, obj): 

217 """Returns diff between CheckpointView and Trackable. 

218 

219 This method is intended to be used to compare the object stored in a 

220 checkpoint vs a live model in Python. For example, if checkpoint 

221 restoration fails the `assert_consumed()` or 

222 `assert_existing_objects_matched()` checks, you can use this to list out 

223 the objects/checkpoint nodes which were not restored. 

224 

225 Example Usage: 

226 

227 >>> class SimpleModule(tf.Module): 

228 ... def __init__(self, name=None): 

229 ... super().__init__(name=name) 

230 ... self.a_var = tf.Variable(5.0) 

231 ... self.b_var = tf.Variable(4.0) 

232 ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)] 

233 

234 >>> root = SimpleModule(name="root") 

235 >>> leaf = root.leaf = SimpleModule(name="leaf") 

236 >>> leaf.leaf3 = tf.Variable(6.0, name="leaf3") 

237 >>> leaf.leaf4 = tf.Variable(7.0, name="leaf4") 

238 >>> ckpt = tf.train.Checkpoint(root) 

239 >>> save_path = ckpt.save('/tmp/tf_ckpts') 

240 >>> checkpoint_view = tf.train.CheckpointView(save_path) 

241 

242 >>> root2 = SimpleModule(name="root") 

243 >>> leaf2 = root2.leaf2 = SimpleModule(name="leaf2") 

244 >>> leaf2.leaf3 = tf.Variable(6.0) 

245 >>> leaf2.leaf4 = tf.Variable(7.0) 

246 

247 Pass `node_id=0` to `tf.train.CheckpointView.children()` to get the 

248 dictionary of all children directly linked to the checkpoint root. 

249 

250 >>> checkpoint_view_diff = checkpoint_view.diff(root2) 

251 >>> checkpoint_view_match = checkpoint_view_diff[0].items() 

252 >>> for item in checkpoint_view_match: 

253 ... print(item) 

254 (0, ...) 

255 (1, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>) 

256 (2, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>) 

257 (3, ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, 

258 numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>])) 

259 (6, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>) 

260 (7, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>) 

261 

262 >>> only_in_checkpoint_view = checkpoint_view_diff[1] 

263 >>> print(only_in_checkpoint_view) 

264 [4, 5, 8, 9, 10, 11, 12, 13, 14] 

265 

266 >>> only_in_trackable = checkpoint_view_diff[2] 

267 >>> print(only_in_trackable) 

268 [..., <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>, 

269 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>, 

270 ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, 

271 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>]), 

272 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=6.0>, 

273 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0>, 

274 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, 

275 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>] 

276 

277 Args: 

278 obj: `Trackable` root. 

279 

280 Returns: 

281 Tuple of ( 

282 - Overlaps: Dictionary containing all overlapping trackables that maps 

283 `node_id` to `Trackable`, same as CheckpointView.match(). 

284 - Only in CheckpointView: List of `node_id` that only exist in 

285 CheckpointView. 

286 - Only in Trackable: List of `Trackable` that only exist in Trackable. 

287 ) 

288 

289 """ 

290 

291 overlapping_nodes = self.match(obj) 

292 only_in_checkpoint_view = [] 

293 only_in_trackable = [] 

294 for node_id in self.descendants(): 

295 if node_id not in overlapping_nodes.keys(): 

296 only_in_checkpoint_view.append(node_id) 

297 for trackable in trackable_view.TrackableView(obj).descendants(): 

298 if trackable not in object_identity.ObjectIdentitySet( 

299 overlapping_nodes.values()): 

300 only_in_trackable.append(trackable) 

301 return overlapping_nodes, only_in_checkpoint_view, only_in_trackable