Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/save_restore.py: 37%
59 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains functionaility for Checkpoint/SavedModel in DTensor."""
17import collections
18from typing import Dict, List, Union
20from tensorflow.dtensor.python import api
21from tensorflow.dtensor.python import d_variable
22from tensorflow.dtensor.python import gen_dtensor_ops
23from tensorflow.dtensor.python import layout as layout_lib
24from tensorflow.dtensor.python import mesh_util
25from tensorflow.python.eager import context
26from tensorflow.python.framework import errors_impl
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import io_ops
29from tensorflow.python.ops import variables as tf_variables
30from tensorflow.python.util.tf_export import tf_export
33@tf_export('experimental.dtensor.sharded_save', v1=[])
34def sharded_save(
35 mesh: layout_lib.Mesh,
36 file_prefix: Union[str, ops.Tensor],
37 tensor_names: Union[List[str], ops.Tensor],
38 shape_and_slices: Union[List[str], ops.Tensor],
39 tensors: List[Union[ops.Tensor, tf_variables.Variable]],
40):
41 """Saves given named tensor slices in a sharded, multi-client safe fashion.
43 The method makes sure the checkpoint directory state is correct in a sharded
44 mutli-client saving. Namely, we place a barrier after SaveV2 to make sure
45 every client has done writing the files. And another one after
46 MergeV2Checkpoints to make sure all Metadata is properly merged.
48 Upon existing, the checkpoint is completed and the all directory operations
49 are done.
51 Args:
52 mesh: The Mesh that contains the Tensors to save.
53 file_prefix: The prefix of checkpoint.
54 tensor_names: a list of tensor names used in save op.
55 shape_and_slices: a list of shape and slice specification used in save op.
56 The only supported value is "" as we don't support distributed saving with
57 slices yet.
58 tensors: a list of tensors used in save op. The order should match
59 tensor_names.
61 Returns:
62 A MergeV2Checkpoints op that merged all Metadata.
63 """
64 with ops.device(api.device_name()):
65 io_ops.save_v2(file_prefix, tensor_names, shape_and_slices, tensors)
67 # Make sure all clients have written the files
68 mesh_util.barrier(mesh.host_mesh(), 'SaveV2') # pylint: disable=protected-access
70 with api.default_mesh(mesh.host_mesh()):
71 merge_op = io_ops.MergeV2Checkpoints(
72 checkpoint_prefixes=[file_prefix],
73 destination_prefix=file_prefix,
74 delete_old_dirs=True)
76 # Make sure first device in first host has finished merge.
77 mesh_util.barrier(mesh.host_mesh(), 'MergeV2Checkpoints')
79 return merge_op
82@tf_export('experimental.dtensor.enable_save_as_bf16', v1=[])
83def enable_save_as_bf16(variables: List[tf_variables.Variable]):
84 """Allows float32 DVariables to be checkpointed and restored as bfloat16.
86 The method only affects the DVariable part inside the model and leaves
87 non-DTensor Variables/Tensors untouched.
89 Args:
90 variables: A list of tf.Variable to be enabled with bfloat16 save/restore.
91 Only has effect on DTensor Variables as they go through d_variables with
92 DTensor Specific logis.
93 """
94 for v in variables:
95 if isinstance(v, d_variable.DVariable):
96 v.save_as_bf16 = True
99@tf_export('experimental.dtensor.name_based_restore', v1=[])
100def name_based_restore(
101 mesh: layout_lib.Mesh,
102 checkpoint_prefix: str,
103 name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]],
104):
105 """Restores from checkpoint_prefix to name based DTensors.
107 It is required to have already-initialized DTensor variables that have same
108 shape/dtype for the tensors being restored.
110 Also, we currently only support a named based restore on a single mesh.
112 Args:
113 mesh: The single mesh that all Tensors would be restored to.
114 checkpoint_prefix : The prefix of checkpoint to be restored.
115 name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The
116 DTensor shape/dtype must match the tensors being saved/restored for now.
118 Returns:
119 A dictionary of name to its restored DTensor value.
120 """
121 if not context.executing_eagerly():
122 raise ValueError('name based restore must run eagerly.')
124 ordered_name_tensor_dict = name_tensor_dict
125 if not isinstance(name_tensor_dict, collections.OrderedDict):
126 ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict)
128 # Make sure that all tensors are on CPU mesh for now.
129 # This might not be a hard limitation in the future.
130 for name, tensor in ordered_name_tensor_dict.items():
131 try:
132 if api.fetch_layout(tensor).mesh.device_type().upper() != 'CPU':
133 raise ValueError(
134 'Restoring a non CPU Tensor is not supported currently. Offending '
135 'tensor name : {tensor_name}'.format(tensor_name=name))
136 except errors_impl.OpError as op_error:
137 raise ValueError(
138 'Saving/Restoring tensor must be a DTensor') from op_error
140 # Now that we have all tensors on CPU mesh, do a DTensorRestoreV2.
141 checkpoint_prefix = api.pack(
142 [checkpoint_prefix] * mesh.num_local_devices(),
143 layout_lib.Layout.replicated(mesh.host_mesh(), rank=0))
144 # Explicitly pack to mesh to avoid implicit small constant extraction, which
145 # does not work larger restores that has lots of names.
146 tensor_names = api.pack(
147 [list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(),
148 layout_lib.Layout.replicated(mesh.host_mesh(), rank=1))
149 shape_and_slices = api.pack(
150 [[''] * len(ordered_name_tensor_dict)] * mesh.num_local_devices(),
151 layout_lib.Layout.replicated(mesh.host_mesh(), rank=1))
152 # A list of TensorShape representing all shapes for the input tensors.
153 input_shapes = [tensor.shape for tensor in ordered_name_tensor_dict.values()]
154 input_layouts = [
155 api.fetch_layout(tensor).to_string()
156 for tensor in ordered_name_tensor_dict.values()
157 ]
159 with ops.device(api.device_name()):
160 restored_cpu_tensors = gen_dtensor_ops.d_tensor_restore_v2(
161 prefix=checkpoint_prefix,
162 tensor_names=tensor_names,
163 shape_and_slices=shape_and_slices,
164 input_shapes=input_shapes,
165 input_layouts=input_layouts,
166 dtypes=[tensor.dtype for tensor in ordered_name_tensor_dict.values()])
168 return collections.OrderedDict(
169 zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors))
172@tf_export('experimental.dtensor.name_based_save', v1=[])
173def name_based_save(mesh: layout_lib.Mesh, checkpoint_prefix: Union[str,
174 ops.Tensor],
175 name_tensor_dict: Dict[str, Union[ops.Tensor,
176 tf_variables.Variable]]):
177 """Saves name based Tensor into a Checkpoint.
179 The function prepares the input dictionary to the format of a `sharded_save`,
180 so that it can take advantage of DTensor SPMD based distributed save.
182 Same as restore, the function only supports saving on the single mesh.
184 Args:
185 mesh: The single mesh that all Tensors would be restored to.
186 checkpoint_prefix : The prefix of checkpoint to be restored.
187 name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The
188 DTensor shape/dtype must match the tensors being saved/restored for now.
189 """
190 if not context.executing_eagerly():
191 raise ValueError('name based save must run eagerly.')
193 ordered_name_tensor_dict = name_tensor_dict
194 if not isinstance(name_tensor_dict, collections.OrderedDict):
195 ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict)
197 # Current _dtensor_device() in api.py is the correct way of specifying
198 # DTensor device singletons. The API itself will be eventually be moved to
199 # a public API and provides global singleton in DTensor context.
200 # For now, we just use the current `internal` API and aim at migrating in
201 # one shot later.
202 # TODO(hthu): Provide _dtensor_device() singleton as a public API.
203 # pylint: disable=protected-access
204 checkpoint_prefix = api.pack([checkpoint_prefix] * mesh.num_local_devices(),
205 layout_lib.Layout.replicated(
206 mesh.host_mesh(), rank=0))
207 tensor_names = api.pack(
208 [list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(),
209 layout_lib.Layout.replicated(mesh.host_mesh(), rank=1))
211 sharded_save(
212 mesh,
213 file_prefix=checkpoint_prefix,
214 tensor_names=tensor_names,
215 shape_and_slices=[''] * len(ordered_name_tensor_dict),
216 tensors=list(ordered_name_tensor_dict.values()))