Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/save_restore.py: 37%

59 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains functionaility for Checkpoint/SavedModel in DTensor.""" 

16 

17import collections 

18from typing import Dict, List, Union 

19 

20from tensorflow.dtensor.python import api 

21from tensorflow.dtensor.python import d_variable 

22from tensorflow.dtensor.python import gen_dtensor_ops 

23from tensorflow.dtensor.python import layout as layout_lib 

24from tensorflow.dtensor.python import mesh_util 

25from tensorflow.python.eager import context 

26from tensorflow.python.framework import errors_impl 

27from tensorflow.python.framework import ops 

28from tensorflow.python.ops import io_ops 

29from tensorflow.python.ops import variables as tf_variables 

30from tensorflow.python.util.tf_export import tf_export 

31 

32 

33@tf_export('experimental.dtensor.sharded_save', v1=[]) 

34def sharded_save( 

35 mesh: layout_lib.Mesh, 

36 file_prefix: Union[str, ops.Tensor], 

37 tensor_names: Union[List[str], ops.Tensor], 

38 shape_and_slices: Union[List[str], ops.Tensor], 

39 tensors: List[Union[ops.Tensor, tf_variables.Variable]], 

40): 

41 """Saves given named tensor slices in a sharded, multi-client safe fashion. 

42 

43 The method makes sure the checkpoint directory state is correct in a sharded 

44 mutli-client saving. Namely, we place a barrier after SaveV2 to make sure 

45 every client has done writing the files. And another one after 

46 MergeV2Checkpoints to make sure all Metadata is properly merged. 

47 

48 Upon existing, the checkpoint is completed and the all directory operations 

49 are done. 

50 

51 Args: 

52 mesh: The Mesh that contains the Tensors to save. 

53 file_prefix: The prefix of checkpoint. 

54 tensor_names: a list of tensor names used in save op. 

55 shape_and_slices: a list of shape and slice specification used in save op. 

56 The only supported value is "" as we don't support distributed saving with 

57 slices yet. 

58 tensors: a list of tensors used in save op. The order should match 

59 tensor_names. 

60 

61 Returns: 

62 A MergeV2Checkpoints op that merged all Metadata. 

63 """ 

64 with ops.device(api.device_name()): 

65 io_ops.save_v2(file_prefix, tensor_names, shape_and_slices, tensors) 

66 

67 # Make sure all clients have written the files 

68 mesh_util.barrier(mesh.host_mesh(), 'SaveV2') # pylint: disable=protected-access 

69 

70 with api.default_mesh(mesh.host_mesh()): 

71 merge_op = io_ops.MergeV2Checkpoints( 

72 checkpoint_prefixes=[file_prefix], 

73 destination_prefix=file_prefix, 

74 delete_old_dirs=True) 

75 

76 # Make sure first device in first host has finished merge. 

77 mesh_util.barrier(mesh.host_mesh(), 'MergeV2Checkpoints') 

78 

79 return merge_op 

80 

81 

82@tf_export('experimental.dtensor.enable_save_as_bf16', v1=[]) 

83def enable_save_as_bf16(variables: List[tf_variables.Variable]): 

84 """Allows float32 DVariables to be checkpointed and restored as bfloat16. 

85 

86 The method only affects the DVariable part inside the model and leaves 

87 non-DTensor Variables/Tensors untouched. 

88 

89 Args: 

90 variables: A list of tf.Variable to be enabled with bfloat16 save/restore. 

91 Only has effect on DTensor Variables as they go through d_variables with 

92 DTensor Specific logis. 

93 """ 

94 for v in variables: 

95 if isinstance(v, d_variable.DVariable): 

96 v.save_as_bf16 = True 

97 

98 

99@tf_export('experimental.dtensor.name_based_restore', v1=[]) 

100def name_based_restore( 

101 mesh: layout_lib.Mesh, 

102 checkpoint_prefix: str, 

103 name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]], 

104): 

105 """Restores from checkpoint_prefix to name based DTensors. 

106 

107 It is required to have already-initialized DTensor variables that have same 

108 shape/dtype for the tensors being restored. 

109 

110 Also, we currently only support a named based restore on a single mesh. 

111 

112 Args: 

113 mesh: The single mesh that all Tensors would be restored to. 

114 checkpoint_prefix : The prefix of checkpoint to be restored. 

115 name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The 

116 DTensor shape/dtype must match the tensors being saved/restored for now. 

117 

118 Returns: 

119 A dictionary of name to its restored DTensor value. 

120 """ 

121 if not context.executing_eagerly(): 

122 raise ValueError('name based restore must run eagerly.') 

123 

124 ordered_name_tensor_dict = name_tensor_dict 

125 if not isinstance(name_tensor_dict, collections.OrderedDict): 

126 ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict) 

127 

128 # Make sure that all tensors are on CPU mesh for now. 

129 # This might not be a hard limitation in the future. 

130 for name, tensor in ordered_name_tensor_dict.items(): 

131 try: 

132 if api.fetch_layout(tensor).mesh.device_type().upper() != 'CPU': 

133 raise ValueError( 

134 'Restoring a non CPU Tensor is not supported currently. Offending ' 

135 'tensor name : {tensor_name}'.format(tensor_name=name)) 

136 except errors_impl.OpError as op_error: 

137 raise ValueError( 

138 'Saving/Restoring tensor must be a DTensor') from op_error 

139 

140 # Now that we have all tensors on CPU mesh, do a DTensorRestoreV2. 

141 checkpoint_prefix = api.pack( 

142 [checkpoint_prefix] * mesh.num_local_devices(), 

143 layout_lib.Layout.replicated(mesh.host_mesh(), rank=0)) 

144 # Explicitly pack to mesh to avoid implicit small constant extraction, which 

145 # does not work larger restores that has lots of names. 

146 tensor_names = api.pack( 

147 [list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(), 

148 layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) 

149 shape_and_slices = api.pack( 

150 [[''] * len(ordered_name_tensor_dict)] * mesh.num_local_devices(), 

151 layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) 

152 # A list of TensorShape representing all shapes for the input tensors. 

153 input_shapes = [tensor.shape for tensor in ordered_name_tensor_dict.values()] 

154 input_layouts = [ 

155 api.fetch_layout(tensor).to_string() 

156 for tensor in ordered_name_tensor_dict.values() 

157 ] 

158 

159 with ops.device(api.device_name()): 

160 restored_cpu_tensors = gen_dtensor_ops.d_tensor_restore_v2( 

161 prefix=checkpoint_prefix, 

162 tensor_names=tensor_names, 

163 shape_and_slices=shape_and_slices, 

164 input_shapes=input_shapes, 

165 input_layouts=input_layouts, 

166 dtypes=[tensor.dtype for tensor in ordered_name_tensor_dict.values()]) 

167 

168 return collections.OrderedDict( 

169 zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors)) 

170 

171 

172@tf_export('experimental.dtensor.name_based_save', v1=[]) 

173def name_based_save(mesh: layout_lib.Mesh, checkpoint_prefix: Union[str, 

174 ops.Tensor], 

175 name_tensor_dict: Dict[str, Union[ops.Tensor, 

176 tf_variables.Variable]]): 

177 """Saves name based Tensor into a Checkpoint. 

178 

179 The function prepares the input dictionary to the format of a `sharded_save`, 

180 so that it can take advantage of DTensor SPMD based distributed save. 

181 

182 Same as restore, the function only supports saving on the single mesh. 

183 

184 Args: 

185 mesh: The single mesh that all Tensors would be restored to. 

186 checkpoint_prefix : The prefix of checkpoint to be restored. 

187 name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The 

188 DTensor shape/dtype must match the tensors being saved/restored for now. 

189 """ 

190 if not context.executing_eagerly(): 

191 raise ValueError('name based save must run eagerly.') 

192 

193 ordered_name_tensor_dict = name_tensor_dict 

194 if not isinstance(name_tensor_dict, collections.OrderedDict): 

195 ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict) 

196 

197 # Current _dtensor_device() in api.py is the correct way of specifying 

198 # DTensor device singletons. The API itself will be eventually be moved to 

199 # a public API and provides global singleton in DTensor context. 

200 # For now, we just use the current `internal` API and aim at migrating in 

201 # one shot later. 

202 # TODO(hthu): Provide _dtensor_device() singleton as a public API. 

203 # pylint: disable=protected-access 

204 checkpoint_prefix = api.pack([checkpoint_prefix] * mesh.num_local_devices(), 

205 layout_lib.Layout.replicated( 

206 mesh.host_mesh(), rank=0)) 

207 tensor_names = api.pack( 

208 [list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(), 

209 layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) 

210 

211 sharded_save( 

212 mesh, 

213 file_prefix=checkpoint_prefix, 

214 tensor_names=tensor_names, 

215 shape_and_slices=[''] * len(ordered_name_tensor_dict), 

216 tensors=list(ordered_name_tensor_dict.values()))