Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/d_checkpoint.py: 22%
195 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""DTensor Checkpoint.
17Note that this module contains deprecated functionality, and the DTensor related
18checkpoint has been integrated with tf.train.Checkpoint. It can be used out of
19the box to save and restore dtensors.
20"""
22from typing import Dict, List, Optional
23import weakref
25from tensorflow.core.protobuf import trackable_object_graph_pb2
27from tensorflow.dtensor.python import api
28from tensorflow.dtensor.python import d_variable
29from tensorflow.dtensor.python import gen_dtensor_ops
30from tensorflow.dtensor.python import layout
31from tensorflow.dtensor.python import save_restore
32from tensorflow.python.checkpoint import checkpoint as util
33from tensorflow.python.checkpoint import checkpoint_options
34from tensorflow.python.checkpoint import graph_view as graph_view_lib
35from tensorflow.python.checkpoint import restore as restore_lib
36from tensorflow.python.eager import context
37from tensorflow.python.framework import constant_op
38from tensorflow.python.framework import errors_impl
39from tensorflow.python.framework import ops
40from tensorflow.python.ops import array_ops
41from tensorflow.python.trackable import base
42from tensorflow.python.trackable import data_structures
43from tensorflow.python.training import py_checkpoint_reader
44from tensorflow.python.training.saving import saveable_object
45from tensorflow.python.training.saving import saveable_object_util
46from tensorflow.python.util import deprecation
47from tensorflow.python.util import nest
48from tensorflow.python.util.tf_export import tf_export
51class _DSaver: # pylint: disable=protected-access
52 """A single device saver that places tensors on DTensor Device."""
54 def __init__(self, mesh: layout.Mesh,
55 saveable_objects: List[saveable_object.SaveableObject]):
56 self._saveable_objects = saveable_objects
57 self._mesh = mesh
59 def save(
60 self,
61 file_prefix: str,
62 options: Optional[checkpoint_options.CheckpointOptions] = None
63 ) -> Optional[ops.Operation]:
64 """Saves the saveable objects to a checkpoint with `file_prefix`.
66 Also query the generated shards from the distributed DTensor SaveV2 ops and
67 do a MergeV2 on those. Each op here is backed by a global_barrier to avoid
68 racing from multiple clients.
70 Args:
71 file_prefix: A string or scalar string Tensor containing the prefix to
72 save under.
73 options: Optional `CheckpointOptions` object. This is unused in DTensor.
75 Returns:
76 An `Operation`, or None when executing eagerly.
77 """
78 if options is not None and options.experimental_io_device is not None:
79 raise ValueError(
80 "Specified experimental_io_device in DTensor checkpoint is not supported."
81 )
82 del options
83 tensor_names = []
84 tensors = []
85 tensor_slices = []
86 for saveable in self._saveable_objects:
87 for spec in saveable.specs:
88 tensor = spec.tensor
89 # A tensor value of `None` indicates that this SaveableObject gets
90 # recorded in the object graph, but that no value is saved in the
91 # checkpoint.
92 if tensor is not None:
93 if api.device_name() != spec.device:
94 # Some small tensors are placed on CPU0 from save manager and
95 # broadcasted to DTensor mesh, e,g., SaveCounter.
96 tensor = api.pack([tensor] *
97 self._mesh.host_mesh().num_local_devices(),
98 layout.Layout.replicated(
99 self._mesh.host_mesh(),
100 rank=tensor.shape.rank))
101 tensor_names.append(spec.name)
102 tensors.append(tensor)
103 tensor_slices.append(spec.slice_spec)
104 return save_restore.sharded_save(self._mesh, file_prefix, tensor_names,
105 tensor_slices, tensors)
107 def restore(
108 self,
109 file_prefix: str,
110 options: Optional[checkpoint_options.CheckpointOptions] = None
111 ) -> Dict[str, ops.Operation]:
112 """Restore the saveable objects from a checkpoint with `file_prefix`.
114 Args:
115 file_prefix: A string or scalar string Tensor containing the prefix for
116 files to read from.
117 options: Optional `CheckpointOptions` object. This is unused in DTensor.
119 Returns:
120 A dictionary mapping from SaveableObject names to restore operations.
121 """
122 if options is not None and options.experimental_io_device is not None:
123 raise ValueError(
124 "Specified experimental_io_device in DTensor checkpoint is not "
125 "supported.")
126 del options
127 restore_specs = []
128 tensor_structure = []
129 for saveable in self._saveable_objects:
130 saveable_tensor_structure = []
131 tensor_structure.append(saveable_tensor_structure)
132 # DTensor change 1 : Gather shapes and layout from original saveable
133 # specs.
134 # Note that this relies on the fact that the variables are already
135 # initialized -- which isn't the behavior we want eventually.
136 # TODO(b/159035705): Handle the variable initialization in restore.
137 for spec in saveable.specs:
138 saveable_tensor_structure.append(spec.name)
139 if isinstance(spec, d_variable.DSaveSpec):
140 restore_specs.append((spec.name, spec.slice_spec, spec.dtype,
141 spec.layout, spec.global_shape))
142 # Fall back to replicated layouts for non-DTensor saves that constructs
143 # normal SaveSpec.
144 elif isinstance(spec, saveable_object.SaveSpec):
145 restore_specs.append(
146 (spec.name, spec.slice_spec, spec.dtype,
147 layout.Layout.replicated(self._mesh.host_mesh(),
148 spec.tensor.shape.rank).to_string(),
149 spec.tensor.shape.as_list()))
150 tensor_names, tensor_slices, tensor_dtypes, layouts, global_shapes = zip(
151 *restore_specs)
152 with ops.device(api.device_name()):
153 # DTensor change 2 : Run on customized DTensor RestoreV2 op rather than
154 # stock TF io_ops.RestoreV2.
155 restored_tensors = gen_dtensor_ops.d_tensor_restore_v2(
156 prefix=file_prefix,
157 tensor_names=tensor_names,
158 shape_and_slices=tensor_slices,
159 input_shapes=global_shapes,
160 input_layouts=layouts,
161 dtypes=tensor_dtypes)
162 structured_restored_tensors = nest.pack_sequence_as(tensor_structure,
163 restored_tensors)
164 restore_ops = {}
165 for saveable, restored_tensors in zip(self._saveable_objects,
166 structured_restored_tensors):
167 restore_ops[saveable.name] = saveable.restore(
168 restored_tensors, restored_shapes=None)
169 return restore_ops
172class _DCheckpointRestoreCoordinator(util._CheckpointRestoreCoordinator): # pylint: disable=protected-access
173 """Holds the status of an object-based checkpoint load."""
175 def __init__(self, mesh: layout.Mesh, **kwargs):
176 super().__init__(**kwargs)
177 self._mesh = mesh
179 def restore_saveables(self,
180 tensor_saveables: Dict[str,
181 saveable_object.SaveableObject],
182 python_positions: List[restore_lib.CheckpointPosition],
183 registered_savers: Optional[Dict[str, Dict[
184 str, base.Trackable]]] = None,
185 reader: py_checkpoint_reader.NewCheckpointReader = None
186 ) -> Optional[List[ops.Operation]]:
187 """Run or build restore operations for SaveableObjects.
189 Args:
190 tensor_saveables: `SaveableObject`s which correspond to Tensors.
191 python_positions: `CheckpointPosition`s which correspond to `PythonState`
192 Trackables bound to the checkpoint.
193 registered_savers: a dict mapping saver names-> object name -> Trackable.
194 This argument is not implemented for DTensorCheckpoint.
195 reader: A CheckpointReader. Creates one lazily if None.
197 Returns:
198 When graph building, a list of restore operations, either cached or newly
199 created, to restore `tensor_saveables`.
200 """
201 del registered_savers
203 restore_ops = []
204 # Eagerly run restorations for Python state.
205 if python_positions:
206 # Lazily create the NewCheckpointReader, since this requires file access
207 # and we may not have any Python saveables.
208 if reader is None:
209 reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string)
210 for position in python_positions:
211 key = position.object_proto.attributes[0].checkpoint_key
212 position.trackable.deserialize(reader.get_tensor(key))
214 # If we have new SaveableObjects, extract and cache restore ops.
215 if tensor_saveables:
216 validated_saveables = saveable_object_util.validate_and_slice_inputs(
217 tensor_saveables)
218 validated_names = set(saveable.name for saveable in validated_saveables)
219 if set(tensor_saveables.keys()) != validated_names:
220 raise AssertionError(
221 ("Saveable keys changed when validating. Got back %s, was "
222 "expecting %s") % (tensor_saveables.keys(), validated_names))
223 # DTensor change: Use _DSaver that does restore on DTensor with
224 # customized DTensorRestoreV2 op.
225 new_restore_ops = _DSaver(self._mesh, validated_saveables).restore(
226 self.save_path_tensor, self.options)
227 if not context.executing_eagerly():
228 for name, restore_op in sorted(new_restore_ops.items()):
229 restore_ops.append(restore_op)
230 assert name not in self.restore_ops_by_name
231 self.restore_ops_by_name[name] = restore_op
232 return restore_ops
235class DTrackableSaver(util.TrackableSaver):
236 """A DTensor trackable saver that uses _SingleDeviceSaver."""
238 def __init__(self, mesh: layout.Mesh, graph_view):
239 super(DTrackableSaver, self).__init__(graph_view)
240 self._mesh = mesh
242 def _gather_saveables(self, object_graph_tensor=None):
243 # Since the base Checkpoint class does not return SaveableObjects, re-use
244 # the saveables cache or generate new Saveables.
245 (serialized_tensors, feed_additions, registered_savers,
246 graph_proto) = self._gather_serialized_tensors(object_graph_tensor)
248 saveables_dict = self._saveables_cache
249 if saveables_dict is None:
250 # Get and remove object graph tensor from `serialized_tensors`, because
251 # the function `serialized_tensors_to_saveable_cache` isn't equipped
252 # to handle it.
253 object_graph_tensor = serialized_tensors.pop(
254 None)[base.OBJECT_GRAPH_PROTO_KEY]
255 saveables_dict = (
256 saveable_object_util.serialized_tensors_to_saveable_cache(
257 serialized_tensors))
258 named_saveable_objects = []
259 for saveable_by_name in saveables_dict.values():
260 for saveables in saveable_by_name.values():
261 named_saveable_objects.extend(saveables)
262 named_saveable_objects.append(
263 base.NoRestoreSaveable(
264 tensor=object_graph_tensor,
265 name=base.OBJECT_GRAPH_PROTO_KEY))
266 return (named_saveable_objects, graph_proto, feed_additions,
267 registered_savers)
269 def _save_cached_when_graph_building(self,
270 file_prefix,
271 object_graph_tensor,
272 options,
273 update_ckpt_state=False):
274 """Create or retrieve save ops, overrides parents's private method.
276 Args:
277 file_prefix: The prefix for saved checkpoint files.
278 object_graph_tensor: A `Tensor` to which the current object graph will be
279 fed.
280 options: `CheckpointOptions` object.
281 update_ckpt_state: Optional bool flag. Indiciate whether the internal
282 checkpoint state needs to be updated. This is used for async checkpoint,
283 which DTrackableSaver currently does not support.
284 TODO(chienchunh): Implement async checkpoint for DTrackableSaver.
286 Returns:
287 A two-element tuple with a filename tensor and a feed_dict of tensors to
288 feed when running it (if graph building). The feed dict contains the
289 current object graph and any Python state to be saved in the
290 checkpoint. When executing eagerly only the first argument is meaningful.
291 """
292 (named_saveable_objects, graph_proto, feed_additions,
293 unused_registered_savers) = self._gather_saveables(
294 object_graph_tensor=object_graph_tensor)
295 if (self._last_save_object_graph != graph_proto
296 # When executing eagerly, we need to re-create SaveableObjects each time
297 # save() is called so they pick up new Tensors passed to their
298 # constructors. That means the Saver needs to be copied with a new
299 # var_list.
300 or context.executing_eagerly() or ops.inside_function()):
301 # This is needed to avoid MultiDeviceSaver creating unnecessary MergeV2
302 # ops in DTensor. It is an issue when saving TPU Variables on host CPU
303 # mesh given our limited expressiveness in API and hard-coded logic in
304 # broadcasting -- for a small constant Tensor with no extra information,
305 # we place it on the first registered mesh(A.K.A. default mesh).
306 saver = _DSaver(self._mesh, named_saveable_objects)
307 save_op = saver.save(file_prefix, options=options)
308 with ops.device("/cpu:0"):
309 with ops.control_dependencies([save_op]):
310 self._cached_save_operation = array_ops.identity(file_prefix)
311 self._last_save_object_graph = graph_proto
312 return self._cached_save_operation, feed_additions
314 # TODO(b/180466245): Use proper mesh placement semantic.
315 def restore(self, save_path, options=None):
316 """Restore a training checkpoint with host mesh placement."""
317 options = options or checkpoint_options.CheckpointOptions()
318 if save_path is None:
319 return util.InitializationOnlyStatus(self._graph_view, ops.uid())
320 reader = py_checkpoint_reader.NewCheckpointReader(save_path)
321 graph_building = not context.executing_eagerly()
322 if graph_building:
323 dtype_map = None
324 else:
325 dtype_map = reader.get_variable_to_dtype_map()
326 try:
327 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
328 except errors_impl.NotFoundError:
329 # The object graph proto does not exist in this checkpoint. Try the
330 # name-based compatibility mode.
331 restore_coordinator = util._NameBasedRestoreCoordinator( # pylint: disable=protected-access
332 save_path=save_path,
333 dtype_map=dtype_map)
334 if not graph_building:
335 for existing_trackable in self._graph_view.list_objects():
336 # pylint: disable=protected-access
337 existing_trackable._maybe_initialize_trackable()
338 existing_trackable._name_based_restores.add(restore_coordinator)
339 existing_trackable._name_based_attribute_restore(restore_coordinator)
340 # pylint: enable=protected-access
341 return util.NameBasedSaverStatus(
342 restore_coordinator, graph_view=self._graph_view)
344 if graph_building:
345 if self._file_prefix_placeholder is None:
346 # DTensor change: provide a hint for mesh broadcasting to put the input
347 # onto the host mesh.
348 self._file_prefix_placeholder = api.pack(
349 [constant_op.constant("model")] * self._mesh.num_local_devices(),
350 layout.Layout.replicated(self._mesh.host_mesh(), rank=0))
351 file_prefix_tensor = self._file_prefix_placeholder
352 file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
353 else:
354 # DTensor change: provide a hint for mesh broadcasting to put the input
355 # onto the host mesh.
356 file_prefix_tensor = api.pack(
357 [constant_op.constant(save_path)] * self._mesh.num_local_devices(),
358 layout.Layout.replicated(self._mesh.host_mesh(), rank=0))
359 file_prefix_feed_dict = None
360 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
361 object_graph_proto.ParseFromString(object_graph_string)
362 # DTensor Change: Hook the proper DSaver in restore.
363 checkpoint = _DCheckpointRestoreCoordinator(
364 mesh=self._mesh,
365 object_graph_proto=object_graph_proto,
366 save_path=save_path,
367 save_path_tensor=file_prefix_tensor,
368 reader=reader,
369 restore_op_cache=self._restore_op_cache,
370 graph_view=self._graph_view,
371 options=options,
372 saveables_cache=self._saveables_cache)
373 restore_lib.CheckpointPosition(
374 checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
376 # Attached dependencies are not attached to the root, so should be restored
377 # separately.
378 if self._graph_view.attached_dependencies:
379 for ref in self._graph_view.attached_dependencies:
380 if ref.name == "root":
381 # Root dependency is automatically added to attached dependencies --
382 # this can be ignored since it maps back to the root object.
383 continue
384 proto_id = None
385 # Find proto ID of attached dependency (if it is in the proto).
386 for proto_ref in object_graph_proto.nodes[0].children:
387 if proto_ref.local_name == ref.name:
388 proto_id = proto_ref.node_id
389 break
391 if proto_id in checkpoint.object_by_proto_id:
392 # Object has already been restored. This can happen when there's an
393 # indirect connection from the attached object to the root.
394 continue
396 restore_lib.CheckpointPosition(
397 checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref)
399 load_status = util.CheckpointLoadStatus(
400 checkpoint,
401 graph_view=self._graph_view,
402 feed_dict=file_prefix_feed_dict)
403 return load_status
406@deprecation.deprecated(
407 date=None,
408 instructions="Please use tf.train.Checkpoint instead of DTensorCheckpoint. "
409 "DTensor is integrated with tf.train.Checkpoint and it can be "
410 "used out of the box to save and restore dtensors.")
411@tf_export("experimental.dtensor.DTensorCheckpoint", v1=[])
412class DTensorCheckpoint(util.Checkpoint):
413 """Manages saving/restoring trackable values to disk, for DTensor."""
415 def __init__(self, mesh: layout.Mesh, root=None, **kwargs):
416 super(DTensorCheckpoint, self).__init__(root=root, **kwargs)
417 self._mesh = mesh
419 saver_root = self
420 attached_dependencies = None
421 self._save_counter = None # Created lazily for restore-on-create.
422 self._save_assign_op = None
424 if root:
425 util._assert_trackable(root, "root")
426 saver_root = root
427 attached_dependencies = []
429 # All keyword arguments (including root itself) are set as children
430 # of root.
431 kwargs["root"] = root
432 root._maybe_initialize_trackable()
434 self._save_counter = data_structures.NoDependency(
435 root._lookup_dependency("save_counter"))
436 self._root = data_structures.NoDependency(root)
438 for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
439 setattr(self, k, v)
441 # Call getattr instead of directly using v because setattr converts
442 # v to a Trackable data structure when v is a list/dict/tuple.
443 converted_v = getattr(self, k)
444 util._assert_trackable(converted_v, k)
446 if root:
447 # Make sure that root doesn't already have dependencies with these names
448 attached_dependencies = attached_dependencies or []
449 child = root._lookup_dependency(k)
450 if child is None:
451 attached_dependencies.append(base.TrackableReference(k, converted_v))
452 elif child != converted_v:
453 raise ValueError(
454 "Cannot create a Checkpoint with keyword argument {name} if "
455 "root.{name} already exists.".format(name=k))
456 # DTensor Change:
457 # Override the parents saver with DTrackableSaver with _SingleDeviceSaver.
458 self._saver = DTrackableSaver(
459 mesh,
460 graph_view_lib.ObjectGraphView(
461 weakref.ref(saver_root),
462 attached_dependencies=attached_dependencies))