Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/save_util_v1.py: 18%

133 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Extracts tensors for checkpointing while updating a TrackableObjectGraph. 

16 

17This is labelled "v1" because the methods here use SaveableObject, which will 

18soon be deprecated. 

19""" 

20 

21import collections 

22 

23from tensorflow.core.protobuf import trackable_object_graph_pb2 

24from tensorflow.python.checkpoint import saveable_compat 

25from tensorflow.python.checkpoint import util 

26from tensorflow.python.framework import constant_op 

27from tensorflow.python.framework import dtypes 

28from tensorflow.python.framework import ops 

29from tensorflow.python.saved_model import registration 

30from tensorflow.python.trackable import base 

31from tensorflow.python.trackable import python_state 

32from tensorflow.python.trackable import trackable_utils 

33from tensorflow.python.training.saving import saveable_object as saveable_object_lib 

34from tensorflow.python.training.saving import saveable_object_util 

35from tensorflow.python.util import object_identity 

36 

37# Factory and related info used to build a SaveableObject that saves a Trackable 

38# to checkpoint. 

39_CheckpointFactoryData = collections.namedtuple( 

40 "_CheckpointFactoryData", ["factory", "name", "checkpoint_key"]) 

41 

42 

43def get_checkpoint_factories_and_keys(object_names, object_map=None): 

44 """Gets a map of saveable factories and corresponding checkpoint keys. 

45 

46 Args: 

47 object_names: a dictionary that maps `Trackable` objects to auto-generated 

48 string names. 

49 object_map: a dictionary mapping `Trackable` to copied `Trackable` objects. 

50 The copied objects are generated from `Trackable. 

51 _export_to_saved_model_graph()` which copies the object into another 

52 graph. Generally only resource objects (e.g. Variables, Tables) will be 

53 in this map. 

54 

55 Returns: 

56 A tuple of ( 

57 Dictionary mapping trackable -> list of _CheckpointFactoryData, 

58 Dictionary mapping registered saver name -> {object name -> trackable}) 

59 """ 

60 checkpoint_factory_map = object_identity.ObjectIdentityDictionary() 

61 unmapped_registered_savers = collections.defaultdict(dict) 

62 for trackable, object_name in object_names.items(): 

63 # object_to_save is only used to retrieve the saving functionality. For keys 

64 # and other data, use the original `trackable`. 

65 object_to_save = util.get_mapped_trackable(trackable, object_map) 

66 

67 saver_name = registration.get_registered_saver_name(object_to_save) 

68 if saver_name: 

69 # Add the original trackable instead of `object_to_save` to the returned 

70 # dict because the original is needed for writing the object proto. 

71 unmapped_registered_savers[saver_name][object_name] = trackable 

72 else: 

73 checkpoint_factory_map[trackable] = [] 

74 for name, saveable_factory in ( 

75 saveable_object_util.saveable_objects_from_trackable( 

76 object_to_save).items()): # pylint: disable=protected-access 

77 # Retrieve the legacy saveable name (for compatibility purposes during 

78 # SaveableObject deprecation) 

79 

80 key_suffix = saveable_compat.get_saveable_name(object_to_save) or name 

81 checkpoint_key = trackable_utils.checkpoint_key(object_name, key_suffix) 

82 

83 if not saveable_compat.force_checkpoint_conversion_enabled(): 

84 # Make sure the set the name as the legacy saveable name if there 

85 # is one (only when checkpoint conversion is diabled) 

86 name = key_suffix 

87 

88 checkpoint_factory_map[trackable].append( 

89 _CheckpointFactoryData( 

90 factory=saveable_factory, 

91 name=name, 

92 checkpoint_key=checkpoint_key)) 

93 return checkpoint_factory_map, unmapped_registered_savers 

94 

95 

96def _add_attributes_to_object_graph(trackable_objects, object_graph_proto, 

97 node_ids, object_names, object_map, 

98 call_with_mapped_captures, saveables_cache): 

99 """Create saveables/savers and corresponding protos in the object graph.""" 

100 # The loop below creates TrackableObject protos in the TrackableObjectGraph, 

101 # which are filled in the `_add_attributes_to_object_graph_for_*` methods. 

102 for checkpoint_id, (trackable, unused_object_proto) in enumerate( 

103 zip(trackable_objects, object_graph_proto.nodes)): 

104 assert node_ids[trackable] == checkpoint_id 

105 

106 checkpoint_factory_map, unmapped_registered_savers = ( 

107 get_checkpoint_factories_and_keys(object_names, object_map)) 

108 

109 # Add attributes, which describe what values are saved in checkpoint for 

110 # this trackable. 

111 registered_savers = _add_attributes_to_object_graph_for_registered_savers( 

112 unmapped_registered_savers, object_graph_proto, node_ids, object_map) 

113 named_saveable_objects, feed_additions = ( 

114 generate_saveable_objects(checkpoint_factory_map, object_graph_proto, 

115 node_ids, object_map, call_with_mapped_captures, 

116 saveables_cache)) 

117 return named_saveable_objects, feed_additions, registered_savers 

118 

119 

120def _add_attributes_to_object_graph_for_registered_savers( 

121 unmapped_registered_savers, object_graph_proto, node_ids, object_map): 

122 """Fills the object graph proto with data about the registered savers.""" 

123 registered_savers = collections.defaultdict(dict) 

124 for saver_name, trackables in unmapped_registered_savers.items(): 

125 for object_name, trackable in trackables.items(): 

126 object_proto = object_graph_proto.nodes[node_ids[trackable]] 

127 object_proto.registered_saver.name = saver_name 

128 object_proto.registered_saver.object_name = object_name 

129 

130 object_to_save = util.get_mapped_trackable(trackable, object_map) 

131 registered_savers[saver_name][object_name] = object_to_save 

132 return registered_savers 

133 

134 

135def generate_saveable_objects(checkpoint_factory_map, 

136 object_graph_proto=None, 

137 node_ids=None, 

138 object_map=None, 

139 call_with_mapped_captures=None, 

140 saveables_cache=None): 

141 """Create SaveableObjects and corresponding SerializedTensor protos.""" 

142 named_saveable_objects = [] 

143 if saveables_cache is None: 

144 # No SaveableObject caching. Either we're executing eagerly, or building a 

145 # static save which is specialized to the current Python state. 

146 feed_additions = None 

147 else: 

148 # If we are caching SaveableObjects, we need to build up a feed_dict with 

149 # functions computing volatile Python state to be saved with the 

150 # checkpoint. 

151 feed_additions = {} 

152 for trackable, factory_data_list in checkpoint_factory_map.items(): 

153 fill_object_proto = object_graph_proto is not None and node_ids is not None 

154 if fill_object_proto: 

155 object_proto = object_graph_proto.nodes[node_ids[trackable]] 

156 object_to_save = util.get_mapped_trackable(trackable, object_map) 

157 if saveables_cache is not None: 

158 cached_attributes = saveables_cache.setdefault(object_to_save, {}) 

159 else: 

160 cached_attributes = None 

161 

162 for factory_data in factory_data_list: 

163 name = factory_data.name 

164 key = factory_data.checkpoint_key 

165 saveable_factory = factory_data.factory 

166 

167 # See if we can skip saving this checkpoint key. 

168 saveables = cached_attributes.get(name) if cached_attributes else None 

169 if saveables is not None: 

170 for saveable in saveables: 

171 if key not in saveable.name: 

172 # The checkpoint key for this SaveableObject is different. We 

173 # need to re-create it. 

174 saveables = None 

175 del cached_attributes[name] 

176 break 

177 

178 if saveables is None: 

179 if callable(saveable_factory): 

180 maybe_saveable = saveable_object_util.create_saveable_object( 

181 name, key, saveable_factory, call_with_mapped_captures) 

182 else: 

183 maybe_saveable = saveable_factory 

184 if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): 

185 saveables = (maybe_saveable,) 

186 else: 

187 saveables = tuple( 

188 saveable_object_util.saveable_objects_for_op( 

189 op=maybe_saveable, name=key)) 

190 for saveable in saveables: 

191 if key not in saveable.name: 

192 raise AssertionError( 

193 f"The object {trackable} produced a SaveableObject with name " 

194 f"'{saveable.name}' for attribute '{name}'. Expected a name" 

195 f" containing '{key}'.") 

196 if cached_attributes is not None: 

197 cached_attributes[name] = saveables 

198 

199 if isinstance(object_to_save, python_state.PythonState): 

200 assert len(saveables) == 1 

201 saveable = saveables[0] 

202 

203 if feed_additions is None: 

204 assert saveables_cache is None 

205 # If we're not caching saveables, then we're either executing 

206 # eagerly or building a static save/restore (e.g. for a 

207 # SavedModel). In either case, we should embed the current Python 

208 # state in the graph rather than relying on a feed dict. 

209 saveables = (saveable.freeze(),) 

210 else: 

211 feed_additions.update(saveable.feed_dict_additions()) 

212 named_saveable_objects.extend(saveables) 

213 

214 # Update the object proto. 

215 # For updated Trackables that override serialize_to_tensors, add an 

216 # attribute for each tensor that is serialized. 

217 # For Trackables that have SaveableObjects or a legacy saveable name, 

218 # add a single attribute to the proto. 

219 if not fill_object_proto: 

220 continue 

221 if (isinstance(saveables[0], saveable_object_util.TrackableSaveable) and 

222 (saveable_compat.force_checkpoint_conversion_enabled() or 

223 saveable_compat.get_saveable_name(object_to_save) is None)): 

224 for local_name, local_key in ( 

225 saveables[0].get_proto_names_and_checkpoint_keys()): 

226 object_proto.attributes.add( 

227 name=local_name, 

228 checkpoint_key=local_key, 

229 full_name=util.get_full_name(object_to_save)) 

230 else: 

231 object_proto.attributes.add( 

232 name=name, 

233 checkpoint_key=key, 

234 full_name=util.get_full_name(object_to_save)) 

235 

236 return named_saveable_objects, feed_additions 

237 

238 

239def _fill_object_graph_proto(graph_view, 

240 trackable_objects, 

241 node_ids, 

242 slot_variables): 

243 """Name non-slot `Trackable`s and add them to `object_graph_proto`.""" 

244 object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph() 

245 for checkpoint_id, trackable in enumerate(trackable_objects): 

246 assert node_ids[trackable] == checkpoint_id 

247 object_proto = object_graph_proto.nodes.add( 

248 slot_variables=slot_variables.get(trackable, ()) 

249 ) 

250 for child in graph_view.list_children(trackable): 

251 object_proto.children.add( 

252 node_id=node_ids[child.ref], 

253 local_name=child.name) 

254 return object_graph_proto 

255 

256 

257def serialize_gathered_objects(graph_view, 

258 object_map=None, 

259 call_with_mapped_captures=None, 

260 saveables_cache=None): 

261 """Create SaveableObjects and protos for gathered objects.""" 

262 trackable_objects, node_paths = graph_view.breadth_first_traversal() 

263 object_names = object_identity.ObjectIdentityDictionary() 

264 for obj, path in node_paths.items(): 

265 object_names[obj] = trackable_utils.object_path_to_string(path) 

266 node_ids = object_identity.ObjectIdentityDictionary() 

267 for node_id, node in enumerate(trackable_objects): 

268 node_ids[node] = node_id 

269 slot_variables = util.serialize_slot_variables( 

270 trackable_objects=trackable_objects, 

271 node_ids=node_ids, 

272 object_names=object_names) 

273 object_graph_proto = _fill_object_graph_proto( 

274 graph_view=graph_view, 

275 trackable_objects=trackable_objects, 

276 node_ids=node_ids, 

277 slot_variables=slot_variables) 

278 named_saveable_objects, feed_additions, registered_savers = ( 

279 _add_attributes_to_object_graph( 

280 trackable_objects=trackable_objects, 

281 object_graph_proto=object_graph_proto, 

282 node_ids=node_ids, 

283 object_names=object_names, 

284 object_map=object_map, 

285 call_with_mapped_captures=call_with_mapped_captures, 

286 saveables_cache=saveables_cache)) 

287 # Gather all trackables that have checkpoint values or descendants with 

288 # checkpoint values, and add that info to the proto. 

289 util.add_checkpoint_values_check(object_graph_proto) 

290 return (named_saveable_objects, object_graph_proto, feed_additions, 

291 registered_savers) 

292 

293 

294def serialize_object_graph_with_registered_savers(graph_view, saveables_cache): 

295 """Determine checkpoint keys for variables and build a serialized graph.""" 

296 return serialize_gathered_objects(graph_view, saveables_cache=saveables_cache) 

297 

298 

299def frozen_saveables_and_savers(graph_view, 

300 object_map=None, 

301 to_graph=None, 

302 call_with_mapped_captures=None, 

303 saveables_cache=None): 

304 """Generates SaveableObjects and registered savers in the frozen graph.""" 

305 if to_graph: 

306 target_context = to_graph.as_default 

307 else: 

308 target_context = ops.NullContextmanager 

309 with target_context(): 

310 named_saveable_objects, graph_proto, _, registered_savers = ( 

311 serialize_gathered_objects(graph_view, object_map, 

312 call_with_mapped_captures, saveables_cache)) 

313 with ops.device("/cpu:0"): 

314 object_graph_tensor = constant_op.constant( 

315 graph_proto.SerializeToString(), dtype=dtypes.string) 

316 named_saveable_objects.append( 

317 base.NoRestoreSaveable( 

318 tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY)) 

319 return named_saveable_objects, registered_savers