Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/save_util.py: 21%
136 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Extracts tensors for checkpointing while updating a TrackableObjectGraph.
17The tensors are extracted from `Trackable._serialize_to_tensors`.
18"""
19import collections
21from typing import Any, Callable, List, Optional, Tuple, Mapping, Union, Dict
23from tensorflow.core.protobuf import trackable_object_graph_pb2
24from tensorflow.python.checkpoint import graph_view as graph_view_lib
25from tensorflow.python.checkpoint import save_util_v1
26from tensorflow.python.checkpoint import saveable_compat
27from tensorflow.python.checkpoint import util
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.saved_model import registration
32from tensorflow.python.trackable import base
33from tensorflow.python.trackable import python_state
34from tensorflow.python.trackable import trackable_utils
35from tensorflow.python.training.saving import saveable_object as saveable_object_lib
36from tensorflow.python.training.saving import saveable_object_util
37from tensorflow.python.types import core
38from tensorflow.python.util import object_identity
40# Attributes for each Trackable in the checkpointed object graph.
41_TrackableData = collections.namedtuple("_TrackableData", [
42 # A trackable in the root Trackable object graph.
43 "trackable",
44 # The index at which the Trackable appears in TrackableObjectGraph.nodes.
45 "node_id",
46 # The BFS-generated path from the root object / used to generate readable
47 # checkpoint keys.
48 "object_name",
49 # A list of ObjectReference for each child connected to this Trackable.
50 "children_proto",
51 # A list of SlotVariableReference to save to the object (only valid for
52 # Optimizer objects).
53 "slot_variable_proto",
54 # The object to save to checkpoint. Usually this is the same as `trackable`,
55 # but can differ when the the caller wants to specify a different object to
56 # save. For example, when saving checkpoints asynchronously, variables are
57 # copied to the CPU. `object_to_save` is set as the copied variable.
58 "object_to_save",
59 ])
62def _split_trackables(
63 trackable_data: List[_TrackableData]
64) -> Tuple[List[_TrackableData], List[_TrackableData],
65 Dict[str, List[_TrackableData]]]:
66 """Splits Trackables into 3 categories (tensor/pystate/registered)."""
67 tensor_trackables = []
68 pystate_trackables = []
69 registered_trackables = collections.defaultdict(list)
71 for td in trackable_data:
72 saver_name = registration.get_registered_saver_name(td.object_to_save)
73 if isinstance(td.object_to_save, python_state.PythonState):
74 pystate_trackables.append(td)
75 elif saver_name:
76 registered_trackables[saver_name].append(td)
77 else:
78 tensor_trackables.append(td)
80 return tensor_trackables, pystate_trackables, registered_trackables
83def _gather_trackable_data(
84 graph_view: graph_view_lib.ObjectGraphView,
85 object_map: Mapping[base.Trackable, base.Trackable]
86) -> Tuple[List[_TrackableData], Dict[base.Trackable, int]]:
87 """Returns a list of generated TrackableData based on the ObjectGraphView."""
88 trackable_objects, node_paths = graph_view.breadth_first_traversal()
89 object_names = object_identity.ObjectIdentityDictionary()
90 for obj, path in node_paths.items():
91 object_names[obj] = trackable_utils.object_path_to_string(path)
92 node_ids = object_identity.ObjectIdentityDictionary()
93 for node_id, node in enumerate(trackable_objects):
94 node_ids[node] = node_id
95 slot_variables = util.serialize_slot_variables(
96 trackable_objects=trackable_objects,
97 node_ids=node_ids,
98 object_names=object_names)
99 trackable_data = []
100 for trackable in trackable_objects:
101 children_proto = []
102 for child in graph_view.list_children(trackable):
103 children_proto.append(
104 trackable_object_graph_pb2.TrackableObjectGraph.TrackableObject
105 .ObjectReference(node_id=node_ids[child.ref],
106 local_name=child.name))
108 trackable_data.append(_TrackableData(
109 trackable,
110 node_id=node_ids[trackable],
111 object_name=object_names[trackable],
112 children_proto=children_proto,
113 slot_variable_proto=slot_variables.get(trackable, []),
114 object_to_save=util.get_mapped_trackable(trackable, object_map)))
115 return trackable_data, node_ids
118def _fill_object_graph_proto(
119 trackable_data: List[_TrackableData]
120) -> trackable_object_graph_pb2.TrackableObjectGraph:
121 """Name non-slot `Trackable`s and add them to `object_graph_proto`."""
122 object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph()
123 for checkpoint_id, td in enumerate(trackable_data):
124 assert td.node_id == checkpoint_id
125 object_graph_proto.nodes.add(
126 slot_variables=td.slot_variable_proto,
127 children=td.children_proto)
128 return object_graph_proto
131def _get_and_write_tensors_to_serialize(
132 tensor_trackables: List[_TrackableData],
133 node_ids: Dict[base.Trackable, int],
134 call_with_mapped_captures: Union[Callable[..., Any], None],
135 cache: Union[Dict[base.Trackable, any], None],
136 object_graph_proto: trackable_object_graph_pb2.TrackableObjectGraph
137) -> Dict[base.Trackable, Any]:
138 """Creates dictionary of tensors to checkpoint, and updates the proto."""
139 # Maps trackable to the a dictionary of tensors, which maps
140 # checkpoint key (-> slice_spec) -> tensor.
141 serialized_tensors = object_identity.ObjectIdentityDictionary()
143 for td in tensor_trackables:
144 if cache is not None and td.object_to_save in cache:
145 trackable, tensor_dict, object_proto = cache[td.object_to_save]
146 serialized_tensors[trackable] = tensor_dict
147 object_graph_proto.nodes[td.node_id].attributes.MergeFrom(object_proto)
148 continue
150 legacy_name = saveable_compat.get_saveable_name(td.object_to_save) or ""
152 if (not saveable_object_util.trackable_has_serialize_to_tensor(
153 td.object_to_save) or
154 legacy_name):
155 # Use the legacy code path for objects that are using SaveableObjects
156 # or the compat saveable name decorator.
157 trackable, tensor_dict = _get_tensors_from_legacy_saveable(
158 td, node_ids, call_with_mapped_captures, object_graph_proto)
159 else:
160 tensor_dict = _get_tensors_from_trackable(
161 td, call_with_mapped_captures, object_graph_proto)
162 trackable = td.object_to_save
163 serialized_tensors[trackable] = tensor_dict
165 if cache is not None and td.object_to_save not in cache:
166 cache[td.object_to_save] = (
167 trackable, tensor_dict,
168 object_graph_proto.nodes[td.node_id].attributes)
170 return serialized_tensors
173def _get_tensors_from_legacy_saveable(
174 trackable_data: _TrackableData,
175 node_ids: Dict[base.Trackable, int],
176 call_with_mapped_captures: Callable[..., Any],
177 object_graph_proto: trackable_object_graph_pb2.TrackableObjectGraph
178) -> Tuple[base.Trackable, Dict[str, Any]]:
179 """Gets tensors to serialize from a Trackable with legacy SaveableObjects."""
180 # Call `save_util_v1` methods to create legacy SaveableObjects and update the
181 # proto.
182 object_names = object_identity.ObjectIdentityDictionary()
183 object_names[trackable_data.trackable] = trackable_data.object_name
184 object_map = object_identity.ObjectIdentityDictionary()
185 object_map[trackable_data.trackable] = trackable_data.object_to_save
187 checkpoint_factory_map, _ = save_util_v1.get_checkpoint_factories_and_keys(
188 object_names, object_map)
189 named_saveable_objects, _ = (
190 save_util_v1.generate_saveable_objects(
191 checkpoint_factory_map,
192 object_graph_proto,
193 node_ids,
194 object_map,
195 call_with_mapped_captures,
196 saveables_cache=None))
197 trackable = (
198 saveable_object_util.SaveableCompatibilityConverter(
199 trackable_data.object_to_save, named_saveable_objects))
200 return trackable, trackable._serialize_to_tensors() # pylint: disable=protected-access
203def _get_tensors_from_trackable(
204 trackable_data: _TrackableData,
205 call_with_mapped_captures: Union[Callable[..., Any], None],
206 object_graph_proto: trackable_object_graph_pb2.TrackableObjectGraph
207) -> Dict[str, Any]:
208 """Gets tensors to serialize from a Trackable."""
209 trackable = trackable_data.object_to_save
210 save_fn = trackable._serialize_to_tensors # pylint: disable=protected-access
212 if (call_with_mapped_captures and
213 isinstance(save_fn, core.ConcreteFunction)):
214 ret_tensor_dict = call_with_mapped_captures(save_fn, [])
215 else:
216 ret_tensor_dict = save_fn()
218 # Create checkpoint keys for each entry in the returned tensor dict, and
219 # write each entry to the object proto.
220 tensor_dict = {}
221 for tensor_name, maybe_tensor in ret_tensor_dict.items():
222 local_name = trackable_utils.escape_local_name(tensor_name)
223 checkpoint_key = trackable_utils.checkpoint_key(trackable_data.object_name,
224 local_name)
225 tensor_dict[checkpoint_key] = maybe_tensor
227 # TODO(b/261786493): Delete this when DCheckpoint is removed.
228 if isinstance(maybe_tensor, saveable_object_lib.SaveSpec):
229 maybe_tensor.name = checkpoint_key
230 maybe_tensor.slice_spec = ""
232 if object_graph_proto is not None:
233 object_graph_proto.nodes[trackable_data.node_id].attributes.add(
234 name=local_name,
235 checkpoint_key=checkpoint_key,
236 full_name=util.get_full_name(trackable))
238 return tensor_dict
241def _get_and_write_pystate_feed_additions(
242 pystate_trackables: List[_TrackableData],
243 cache: Union[Dict[base.Trackable, Any], None],
244 object_graph_proto=None
245) -> Tuple[Dict[base.Trackable, Any], Dict[base.Trackable, Any]]:
246 """Gets feed additions needed for checkpointing Python State."""
247 serialized_tensors = object_identity.ObjectIdentityDictionary()
248 # Maps tensor placeholders to python values.
249 feed_additions = {}
251 for td in pystate_trackables:
252 trackable = td.object_to_save
253 checkpoint_key = trackable_utils.checkpoint_key(td.object_name,
254 python_state.PYTHON_STATE)
255 if trackable in cache:
256 save_string = cache[td.object_to_save][python_state.PYTHON_STATE]
257 else:
258 with ops.device("/cpu:0"):
259 save_string = constant_op.constant("", dtype=dtypes.string)
260 cache[trackable] = {python_state.PYTHON_STATE: save_string}
262 with ops.init_scope():
263 value = trackable.serialize()
264 feed_additions[save_string] = value
265 serialized_tensors[trackable] = {checkpoint_key: save_string}
267 object_graph_proto.nodes[td.node_id].attributes.add(
268 name=python_state.PYTHON_STATE,
269 checkpoint_key=checkpoint_key,
270 full_name=util.get_full_name(trackable))
272 return serialized_tensors, feed_additions
275def _get_and_write_registered_savers(
276 registered_trackables: Dict[str, List[_TrackableData]],
277 object_graph_proto: trackable_object_graph_pb2.TrackableObjectGraph
278) -> Dict[str, Dict[str, base.Trackable]]:
279 """Generates dictionary of registered savers and updates the proto."""
280 registered_savers = collections.defaultdict(dict)
281 for saver_name, trackables in registered_trackables.items():
282 for td in trackables:
283 registered_savers[saver_name][td.object_name] = td.object_to_save
285 object_proto = object_graph_proto.nodes[td.node_id]
286 object_proto.registered_saver.name = saver_name
287 object_proto.registered_saver.object_name = td.object_name
289 return registered_savers
292def serialize_graph_view(
293 graph_view: graph_view_lib.ObjectGraphView,
294 object_map: Optional[Mapping[base.Trackable, base.Trackable]] = None,
295 call_with_mapped_captures: Optional[Callable[..., Any]] = None,
296 cache: Optional[Dict[base.Trackable, Any]] = None) -> ...:
297 """Gathers serialization objects, and creates a TrackableObjectGraph proto."""
298 # There are 3 types of checkpoint serialization types supported:
299 # 1. Trackables that override `Trackable._serialize_to_tensor()`.
300 # 2. PythonState: A special type of Trackable that serializes a Python string.
301 # 3. Registered Trackable Savers: For objects that need to define advanced
302 # checkpointing operations not supported by (1) or (2).
303 trackable_data, node_ids = _gather_trackable_data(graph_view, object_map)
304 tensor_trackables, pystate_trackables, registered_trackables = (
305 _split_trackables(trackable_data))
307 object_graph_proto = _fill_object_graph_proto(trackable_data)
309 serialized_tensors = _get_and_write_tensors_to_serialize(
310 tensor_trackables,
311 node_ids,
312 call_with_mapped_captures,
313 cache,
314 object_graph_proto)
315 registered_savers = _get_and_write_registered_savers(
316 registered_trackables, object_graph_proto)
318 # PythonState trackables must be treated differently depending on if the
319 # checkpoint is being saved in TF1 graph mode (`cache` exists) or
320 # eager mode (`cache` is None).
321 if cache is None:
322 # When the tensor cache is None, get the serialized tensors directly.
323 feed_additions = None
324 serialized_tensors.update(_get_and_write_tensors_to_serialize(
325 pystate_trackables,
326 node_ids,
327 call_with_mapped_captures,
328 cache,
329 object_graph_proto))
330 else:
331 # Python state is not automatically updated within a TF session so these
332 # values must be passed to sess.run(feed_additions=...).
333 new_serialized_tensors, feed_additions = (
334 _get_and_write_pystate_feed_additions(pystate_trackables,
335 cache,
336 object_graph_proto))
337 serialized_tensors.update(new_serialized_tensors)
339 # Gather all trackables that have checkpoint values or descendants with
340 # checkpoint values, and add that info to the proto.
341 util.add_checkpoint_values_check(object_graph_proto)
342 return (serialized_tensors, feed_additions, registered_savers,
343 object_graph_proto)