Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/d_variable.py: 33%

88 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""DTensor variable and saveable.""" 

16 

17import contextlib 

18import functools 

19 

20from tensorflow.dtensor.python import api 

21from tensorflow.dtensor.python import layout as layout_lib 

22from tensorflow.python.eager import context 

23from tensorflow.python.eager import def_function 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import errors 

26from tensorflow.python.framework import ops 

27from tensorflow.python.ops import math_ops 

28from tensorflow.python.ops import resource_variable_ops 

29from tensorflow.python.trackable import base as trackable 

30from tensorflow.python.training.saving import saveable_object 

31from tensorflow.python.util.tf_export import tf_export 

32 

33 

34class DSaveSpec(saveable_object.SaveSpec): 

35 """DTensor SaveSpec that additionaly captures global_shape and layout.""" 

36 

37 def __init__(self, 

38 tensor, 

39 slice_spec, 

40 name, 

41 global_shape, 

42 layout, 

43 dtype=None, 

44 device=None): 

45 super().__init__( 

46 tensor=tensor, 

47 slice_spec=slice_spec, 

48 name=name, 

49 dtype=dtype, 

50 device=device) 

51 self.global_shape = global_shape 

52 self.layout = layout 

53 

54 

55class _DVariableSaveable(saveable_object.SaveableObject): 

56 """Class for defining how to save/restore DTensor variable.""" 

57 

58 def __init__(self, dvariable, name): 

59 with ops.device(dvariable.device): 

60 original_layout = api.fetch_layout(dvariable) 

61 # Record original layout to allow restore. 

62 self._original_layout = original_layout 

63 self._dvariable = dvariable 

64 

65 def pack(tensors, layout): 

66 with ops.device(dvariable.device): 

67 return api.pack(tensors, layout) 

68 

69 host_layout = layout_lib.Layout(original_layout.sharding_specs, 

70 original_layout.mesh.host_mesh()) 

71 

72 def get_host_dtensor(): 

73 # Copy to host mesh if needed. 

74 if original_layout.mesh.device_type().upper() != 'CPU': 

75 # Prefer pack and unpack in eager mode because it supports sharded 

76 # layouts. 

77 if context.executing_eagerly(): 

78 host_dtensor = api.pack( 

79 api.unpack(dvariable.read_value()), host_layout) 

80 else: 

81 host_dtensor = api.copy_to_mesh(dvariable.read_value(), host_layout) 

82 else: 

83 host_dtensor = dvariable.read_value() 

84 return (math_ops.cast(host_dtensor, dtypes.bfloat16) 

85 if self.should_cast(host_dtensor) else host_dtensor) 

86 

87 num_local_devices = original_layout.mesh.num_local_devices() 

88 super(_DVariableSaveable, self).__init__( 

89 None, 

90 [ 

91 DSaveSpec( 

92 tensor=get_host_dtensor, 

93 slice_spec=pack([''] * num_local_devices, 

94 layout_lib.Layout.replicated( 

95 original_layout.mesh.host_mesh(), rank=0)), 

96 name=pack([name] * num_local_devices, 

97 layout_lib.Layout.replicated( 

98 original_layout.mesh.host_mesh(), rank=0)), 

99 global_shape=dvariable.shape, 

100 # Layout is attached as attribute, no need to put it as a 

101 # Tensor on DTensorDevice. 

102 layout=host_layout.to_string(), 

103 dtype=dtypes.bfloat16 

104 if self.should_cast(dvariable) else dvariable.dtype, 

105 device=dvariable.device) 

106 ], 

107 name) 

108 

109 def should_cast(self, v): 

110 """Returns True if v has float32 dtype and is intructed to save as bf16. 

111 

112 Args: 

113 v : The variable that determines whether to cast. 

114 

115 Returns: 

116 True if current savable DVariable is instructed to save as bfloat16 and 

117 the variable has dtype float32. 

118 """ 

119 return self._dvariable.save_as_bf16 and v.dtype == dtypes.float32 

120 

121 def restore(self, restored_tensors, restored_shapes): 

122 """Restores the same value into all variables.""" 

123 tensor, = restored_tensors 

124 

125 @def_function.function 

126 def _restore(t): 

127 with ops.device(self._dvariable.device): 

128 return api.copy_to_mesh(t, self._original_layout) 

129 

130 # This assign establishes connections from restored tensor and tensors 

131 # being restored to -- so that restore in SPMD can backtrack the DVariable 

132 # and its layout, given that we're using tf.function style restore. 

133 # Note that the restored dvaraible is on CPU no matter what as the restoreV2 

134 # op must run on CPU. 

135 # TODO(b/159035705): Allow restore for Tensor objects as well? 

136 # Restore the dvariable back to original layout. 

137 if self._original_layout.mesh.device_type().upper() != 'CPU': 

138 tensor = _restore(tensor) 

139 return self._dvariable.assign( 

140 math_ops.cast(tensor, dtype=self._dvariable.dtype) if self._dvariable 

141 .save_as_bf16 else tensor) 

142 

143 

144@tf_export('experimental.dtensor.DVariable', v1=[]) 

145class DVariable(resource_variable_ops.ResourceVariable): 

146 """A replacement for tf.Variable which follows initial value placement. 

147 

148 The class also handles restore/save operations in DTensor. Note that, 

149 DVariable may fall back to normal tf.Variable at this moment if 

150 `initial_value` is not a DTensor. 

151 """ 

152 

153 def __init__(self, initial_value, *args, dtype=None, **kwargs): 

154 """Overrides tf.Variable to fix VarHandleOp placements.""" 

155 # Variables by default use the current device scope for placement. This 

156 # wrapper has them follow the initial value's placement instead (which will 

157 # be the DTensor device if the initial value has a layout). 

158 

159 # Pop layout from kwargs since keras make_variable may pass a 'layout' 

160 # keyword argument. We need to pop it because we are passing kwargs to 

161 # super class constructor. 

162 layout = kwargs.pop('layout', None) 

163 shape = kwargs.get('shape', None) 

164 

165 if callable(initial_value): 

166 unwrapped = initial_value 

167 if issubclass(type(initial_value), functools.partial): 

168 unwrapped = initial_value.func 

169 

170 # If wrapped is a CheckpointInitialValueCallable, this means that 

171 # we are creating a Variable during a checkpoint restore. 

172 # Thus the restore will happen now through this callable 

173 # and we will create the DVariable with the restored dtensor. 

174 if issubclass(type(unwrapped), trackable.CheckpointInitialValueCallable): 

175 if not shape or not layout: 

176 raise ValueError('Expected shape and layout to be not None.') 

177 

178 # CheckpointInitialValueCallable will call an eager tf.RestoreV2, 

179 # which does not have any shape information or layout information 

180 # attached. Thus we will do two things to have them correctly specified: 

181 # 

182 # The default layout scope allows us to correctly specify the output 

183 # layout of the tf.RestoreV2 that will be called 

184 # 

185 # Passing shard_info with the correct shape allows the tf.RestoreV2 

186 # ShapeInference to extract the shape. 

187 initial_value = api.call_with_layout( 

188 initial_value, 

189 layout, 

190 shard_info=trackable.ShardInfo( 

191 shape=shape, offset=[0] * len(shape))) 

192 else: 

193 initial_value = initial_value() 

194 

195 # When the initial value came from a Checkpoint restoration, fetch tensor. 

196 if isinstance(initial_value, trackable.CheckpointInitialValue): 

197 initial_value = initial_value.wrapped_value 

198 

199 initial_value = ops.convert_to_tensor(initial_value, dtype=dtype) 

200 variable_device = initial_value.device 

201 self._save_as_bf16 = False 

202 # TODO(b/159035705): The following code enables variable creation inside 

203 # a tf.function. However, it requires a global dtensor device. 

204 # if not variable_device and not tf.executing_eagerly(): 

205 # try: 

206 # initial_value.op.get_attr("_layout") 

207 # except ValueError: 

208 # pass 

209 # else: 

210 # # The initial value is a DTensor, but because the DTensor device is 

211 # # only active during eager execution at the moment we need to 

212 # # translate that into a placement for the eager VarHandleOp. 

213 # variable_device = _dtensor_device().name 

214 with ops.device(variable_device): 

215 # If initial tensor assigned to DVariable is DTensor, record the layout of 

216 # the resource so that this can be queried. 

217 self.layout = None 

218 if context.executing_eagerly(): 

219 try: 

220 self.layout = api.fetch_layout(initial_value) 

221 except (errors.InvalidArgumentError, errors.NotFoundError): 

222 # For Non-DTensor tensors, fetch layout results in expected 

223 # InvalidArgument or NotFoundError depending on whether the API 

224 # is called within DTensor device scope or not. 

225 self.layout = None 

226 pass 

227 mesh = self.layout.mesh if self.layout else None 

228 with api.default_mesh(mesh) if mesh else contextlib.nullcontext(): 

229 super(DVariable, self).__init__( 

230 initial_value, *args, dtype=dtype, **kwargs) 

231 

232 @property 

233 def save_as_bf16(self): 

234 return self._save_as_bf16 

235 

236 @save_as_bf16.setter 

237 def save_as_bf16(self, save_as_bf16): 

238 """Enables saving float32 as bfloat16.""" 

239 self._save_as_bf16 = save_as_bf16 and self.dtype == dtypes.float32 

240 

241 def _gather_saveables_for_checkpoint(self): 

242 return { 

243 trackable.VARIABLE_VALUE_KEY: 

244 functools.partial(_DVariableSaveable, self) 

245 }