Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/save_util_v1.py: 18%
133 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Extracts tensors for checkpointing while updating a TrackableObjectGraph.
17This is labelled "v1" because the methods here use SaveableObject, which will
18soon be deprecated.
19"""
21import collections
23from tensorflow.core.protobuf import trackable_object_graph_pb2
24from tensorflow.python.checkpoint import saveable_compat
25from tensorflow.python.checkpoint import util
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.saved_model import registration
30from tensorflow.python.trackable import base
31from tensorflow.python.trackable import python_state
32from tensorflow.python.trackable import trackable_utils
33from tensorflow.python.training.saving import saveable_object as saveable_object_lib
34from tensorflow.python.training.saving import saveable_object_util
35from tensorflow.python.util import object_identity
37# Factory and related info used to build a SaveableObject that saves a Trackable
38# to checkpoint.
39_CheckpointFactoryData = collections.namedtuple(
40 "_CheckpointFactoryData", ["factory", "name", "checkpoint_key"])
43def get_checkpoint_factories_and_keys(object_names, object_map=None):
44 """Gets a map of saveable factories and corresponding checkpoint keys.
46 Args:
47 object_names: a dictionary that maps `Trackable` objects to auto-generated
48 string names.
49 object_map: a dictionary mapping `Trackable` to copied `Trackable` objects.
50 The copied objects are generated from `Trackable.
51 _export_to_saved_model_graph()` which copies the object into another
52 graph. Generally only resource objects (e.g. Variables, Tables) will be
53 in this map.
55 Returns:
56 A tuple of (
57 Dictionary mapping trackable -> list of _CheckpointFactoryData,
58 Dictionary mapping registered saver name -> {object name -> trackable})
59 """
60 checkpoint_factory_map = object_identity.ObjectIdentityDictionary()
61 unmapped_registered_savers = collections.defaultdict(dict)
62 for trackable, object_name in object_names.items():
63 # object_to_save is only used to retrieve the saving functionality. For keys
64 # and other data, use the original `trackable`.
65 object_to_save = util.get_mapped_trackable(trackable, object_map)
67 saver_name = registration.get_registered_saver_name(object_to_save)
68 if saver_name:
69 # Add the original trackable instead of `object_to_save` to the returned
70 # dict because the original is needed for writing the object proto.
71 unmapped_registered_savers[saver_name][object_name] = trackable
72 else:
73 checkpoint_factory_map[trackable] = []
74 for name, saveable_factory in (
75 saveable_object_util.saveable_objects_from_trackable(
76 object_to_save).items()): # pylint: disable=protected-access
77 # Retrieve the legacy saveable name (for compatibility purposes during
78 # SaveableObject deprecation)
80 key_suffix = saveable_compat.get_saveable_name(object_to_save) or name
81 checkpoint_key = trackable_utils.checkpoint_key(object_name, key_suffix)
83 if not saveable_compat.force_checkpoint_conversion_enabled():
84 # Make sure the set the name as the legacy saveable name if there
85 # is one (only when checkpoint conversion is diabled)
86 name = key_suffix
88 checkpoint_factory_map[trackable].append(
89 _CheckpointFactoryData(
90 factory=saveable_factory,
91 name=name,
92 checkpoint_key=checkpoint_key))
93 return checkpoint_factory_map, unmapped_registered_savers
96def _add_attributes_to_object_graph(trackable_objects, object_graph_proto,
97 node_ids, object_names, object_map,
98 call_with_mapped_captures, saveables_cache):
99 """Create saveables/savers and corresponding protos in the object graph."""
100 # The loop below creates TrackableObject protos in the TrackableObjectGraph,
101 # which are filled in the `_add_attributes_to_object_graph_for_*` methods.
102 for checkpoint_id, (trackable, unused_object_proto) in enumerate(
103 zip(trackable_objects, object_graph_proto.nodes)):
104 assert node_ids[trackable] == checkpoint_id
106 checkpoint_factory_map, unmapped_registered_savers = (
107 get_checkpoint_factories_and_keys(object_names, object_map))
109 # Add attributes, which describe what values are saved in checkpoint for
110 # this trackable.
111 registered_savers = _add_attributes_to_object_graph_for_registered_savers(
112 unmapped_registered_savers, object_graph_proto, node_ids, object_map)
113 named_saveable_objects, feed_additions = (
114 generate_saveable_objects(checkpoint_factory_map, object_graph_proto,
115 node_ids, object_map, call_with_mapped_captures,
116 saveables_cache))
117 return named_saveable_objects, feed_additions, registered_savers
120def _add_attributes_to_object_graph_for_registered_savers(
121 unmapped_registered_savers, object_graph_proto, node_ids, object_map):
122 """Fills the object graph proto with data about the registered savers."""
123 registered_savers = collections.defaultdict(dict)
124 for saver_name, trackables in unmapped_registered_savers.items():
125 for object_name, trackable in trackables.items():
126 object_proto = object_graph_proto.nodes[node_ids[trackable]]
127 object_proto.registered_saver.name = saver_name
128 object_proto.registered_saver.object_name = object_name
130 object_to_save = util.get_mapped_trackable(trackable, object_map)
131 registered_savers[saver_name][object_name] = object_to_save
132 return registered_savers
135def generate_saveable_objects(checkpoint_factory_map,
136 object_graph_proto=None,
137 node_ids=None,
138 object_map=None,
139 call_with_mapped_captures=None,
140 saveables_cache=None):
141 """Create SaveableObjects and corresponding SerializedTensor protos."""
142 named_saveable_objects = []
143 if saveables_cache is None:
144 # No SaveableObject caching. Either we're executing eagerly, or building a
145 # static save which is specialized to the current Python state.
146 feed_additions = None
147 else:
148 # If we are caching SaveableObjects, we need to build up a feed_dict with
149 # functions computing volatile Python state to be saved with the
150 # checkpoint.
151 feed_additions = {}
152 for trackable, factory_data_list in checkpoint_factory_map.items():
153 fill_object_proto = object_graph_proto is not None and node_ids is not None
154 if fill_object_proto:
155 object_proto = object_graph_proto.nodes[node_ids[trackable]]
156 object_to_save = util.get_mapped_trackable(trackable, object_map)
157 if saveables_cache is not None:
158 cached_attributes = saveables_cache.setdefault(object_to_save, {})
159 else:
160 cached_attributes = None
162 for factory_data in factory_data_list:
163 name = factory_data.name
164 key = factory_data.checkpoint_key
165 saveable_factory = factory_data.factory
167 # See if we can skip saving this checkpoint key.
168 saveables = cached_attributes.get(name) if cached_attributes else None
169 if saveables is not None:
170 for saveable in saveables:
171 if key not in saveable.name:
172 # The checkpoint key for this SaveableObject is different. We
173 # need to re-create it.
174 saveables = None
175 del cached_attributes[name]
176 break
178 if saveables is None:
179 if callable(saveable_factory):
180 maybe_saveable = saveable_object_util.create_saveable_object(
181 name, key, saveable_factory, call_with_mapped_captures)
182 else:
183 maybe_saveable = saveable_factory
184 if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
185 saveables = (maybe_saveable,)
186 else:
187 saveables = tuple(
188 saveable_object_util.saveable_objects_for_op(
189 op=maybe_saveable, name=key))
190 for saveable in saveables:
191 if key not in saveable.name:
192 raise AssertionError(
193 f"The object {trackable} produced a SaveableObject with name "
194 f"'{saveable.name}' for attribute '{name}'. Expected a name"
195 f" containing '{key}'.")
196 if cached_attributes is not None:
197 cached_attributes[name] = saveables
199 if isinstance(object_to_save, python_state.PythonState):
200 assert len(saveables) == 1
201 saveable = saveables[0]
203 if feed_additions is None:
204 assert saveables_cache is None
205 # If we're not caching saveables, then we're either executing
206 # eagerly or building a static save/restore (e.g. for a
207 # SavedModel). In either case, we should embed the current Python
208 # state in the graph rather than relying on a feed dict.
209 saveables = (saveable.freeze(),)
210 else:
211 feed_additions.update(saveable.feed_dict_additions())
212 named_saveable_objects.extend(saveables)
214 # Update the object proto.
215 # For updated Trackables that override serialize_to_tensors, add an
216 # attribute for each tensor that is serialized.
217 # For Trackables that have SaveableObjects or a legacy saveable name,
218 # add a single attribute to the proto.
219 if not fill_object_proto:
220 continue
221 if (isinstance(saveables[0], saveable_object_util.TrackableSaveable) and
222 (saveable_compat.force_checkpoint_conversion_enabled() or
223 saveable_compat.get_saveable_name(object_to_save) is None)):
224 for local_name, local_key in (
225 saveables[0].get_proto_names_and_checkpoint_keys()):
226 object_proto.attributes.add(
227 name=local_name,
228 checkpoint_key=local_key,
229 full_name=util.get_full_name(object_to_save))
230 else:
231 object_proto.attributes.add(
232 name=name,
233 checkpoint_key=key,
234 full_name=util.get_full_name(object_to_save))
236 return named_saveable_objects, feed_additions
239def _fill_object_graph_proto(graph_view,
240 trackable_objects,
241 node_ids,
242 slot_variables):
243 """Name non-slot `Trackable`s and add them to `object_graph_proto`."""
244 object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph()
245 for checkpoint_id, trackable in enumerate(trackable_objects):
246 assert node_ids[trackable] == checkpoint_id
247 object_proto = object_graph_proto.nodes.add(
248 slot_variables=slot_variables.get(trackable, ())
249 )
250 for child in graph_view.list_children(trackable):
251 object_proto.children.add(
252 node_id=node_ids[child.ref],
253 local_name=child.name)
254 return object_graph_proto
257def serialize_gathered_objects(graph_view,
258 object_map=None,
259 call_with_mapped_captures=None,
260 saveables_cache=None):
261 """Create SaveableObjects and protos for gathered objects."""
262 trackable_objects, node_paths = graph_view.breadth_first_traversal()
263 object_names = object_identity.ObjectIdentityDictionary()
264 for obj, path in node_paths.items():
265 object_names[obj] = trackable_utils.object_path_to_string(path)
266 node_ids = object_identity.ObjectIdentityDictionary()
267 for node_id, node in enumerate(trackable_objects):
268 node_ids[node] = node_id
269 slot_variables = util.serialize_slot_variables(
270 trackable_objects=trackable_objects,
271 node_ids=node_ids,
272 object_names=object_names)
273 object_graph_proto = _fill_object_graph_proto(
274 graph_view=graph_view,
275 trackable_objects=trackable_objects,
276 node_ids=node_ids,
277 slot_variables=slot_variables)
278 named_saveable_objects, feed_additions, registered_savers = (
279 _add_attributes_to_object_graph(
280 trackable_objects=trackable_objects,
281 object_graph_proto=object_graph_proto,
282 node_ids=node_ids,
283 object_names=object_names,
284 object_map=object_map,
285 call_with_mapped_captures=call_with_mapped_captures,
286 saveables_cache=saveables_cache))
287 # Gather all trackables that have checkpoint values or descendants with
288 # checkpoint values, and add that info to the proto.
289 util.add_checkpoint_values_check(object_graph_proto)
290 return (named_saveable_objects, object_graph_proto, feed_additions,
291 registered_savers)
294def serialize_object_graph_with_registered_savers(graph_view, saveables_cache):
295 """Determine checkpoint keys for variables and build a serialized graph."""
296 return serialize_gathered_objects(graph_view, saveables_cache=saveables_cache)
299def frozen_saveables_and_savers(graph_view,
300 object_map=None,
301 to_graph=None,
302 call_with_mapped_captures=None,
303 saveables_cache=None):
304 """Generates SaveableObjects and registered savers in the frozen graph."""
305 if to_graph:
306 target_context = to_graph.as_default
307 else:
308 target_context = ops.NullContextmanager
309 with target_context():
310 named_saveable_objects, graph_proto, _, registered_savers = (
311 serialize_gathered_objects(graph_view, object_map,
312 call_with_mapped_captures, saveables_cache))
313 with ops.device("/cpu:0"):
314 object_graph_tensor = constant_op.constant(
315 graph_proto.SerializeToString(), dtype=dtypes.string)
316 named_saveable_objects.append(
317 base.NoRestoreSaveable(
318 tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY))
319 return named_saveable_objects, registered_savers