Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/d_variable.py: 33%
88 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""DTensor variable and saveable."""
17import contextlib
18import functools
20from tensorflow.dtensor.python import api
21from tensorflow.dtensor.python import layout as layout_lib
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import resource_variable_ops
29from tensorflow.python.trackable import base as trackable
30from tensorflow.python.training.saving import saveable_object
31from tensorflow.python.util.tf_export import tf_export
34class DSaveSpec(saveable_object.SaveSpec):
35 """DTensor SaveSpec that additionaly captures global_shape and layout."""
37 def __init__(self,
38 tensor,
39 slice_spec,
40 name,
41 global_shape,
42 layout,
43 dtype=None,
44 device=None):
45 super().__init__(
46 tensor=tensor,
47 slice_spec=slice_spec,
48 name=name,
49 dtype=dtype,
50 device=device)
51 self.global_shape = global_shape
52 self.layout = layout
55class _DVariableSaveable(saveable_object.SaveableObject):
56 """Class for defining how to save/restore DTensor variable."""
58 def __init__(self, dvariable, name):
59 with ops.device(dvariable.device):
60 original_layout = api.fetch_layout(dvariable)
61 # Record original layout to allow restore.
62 self._original_layout = original_layout
63 self._dvariable = dvariable
65 def pack(tensors, layout):
66 with ops.device(dvariable.device):
67 return api.pack(tensors, layout)
69 host_layout = layout_lib.Layout(original_layout.sharding_specs,
70 original_layout.mesh.host_mesh())
72 def get_host_dtensor():
73 # Copy to host mesh if needed.
74 if original_layout.mesh.device_type().upper() != 'CPU':
75 # Prefer pack and unpack in eager mode because it supports sharded
76 # layouts.
77 if context.executing_eagerly():
78 host_dtensor = api.pack(
79 api.unpack(dvariable.read_value()), host_layout)
80 else:
81 host_dtensor = api.copy_to_mesh(dvariable.read_value(), host_layout)
82 else:
83 host_dtensor = dvariable.read_value()
84 return (math_ops.cast(host_dtensor, dtypes.bfloat16)
85 if self.should_cast(host_dtensor) else host_dtensor)
87 num_local_devices = original_layout.mesh.num_local_devices()
88 super(_DVariableSaveable, self).__init__(
89 None,
90 [
91 DSaveSpec(
92 tensor=get_host_dtensor,
93 slice_spec=pack([''] * num_local_devices,
94 layout_lib.Layout.replicated(
95 original_layout.mesh.host_mesh(), rank=0)),
96 name=pack([name] * num_local_devices,
97 layout_lib.Layout.replicated(
98 original_layout.mesh.host_mesh(), rank=0)),
99 global_shape=dvariable.shape,
100 # Layout is attached as attribute, no need to put it as a
101 # Tensor on DTensorDevice.
102 layout=host_layout.to_string(),
103 dtype=dtypes.bfloat16
104 if self.should_cast(dvariable) else dvariable.dtype,
105 device=dvariable.device)
106 ],
107 name)
109 def should_cast(self, v):
110 """Returns True if v has float32 dtype and is intructed to save as bf16.
112 Args:
113 v : The variable that determines whether to cast.
115 Returns:
116 True if current savable DVariable is instructed to save as bfloat16 and
117 the variable has dtype float32.
118 """
119 return self._dvariable.save_as_bf16 and v.dtype == dtypes.float32
121 def restore(self, restored_tensors, restored_shapes):
122 """Restores the same value into all variables."""
123 tensor, = restored_tensors
125 @def_function.function
126 def _restore(t):
127 with ops.device(self._dvariable.device):
128 return api.copy_to_mesh(t, self._original_layout)
130 # This assign establishes connections from restored tensor and tensors
131 # being restored to -- so that restore in SPMD can backtrack the DVariable
132 # and its layout, given that we're using tf.function style restore.
133 # Note that the restored dvaraible is on CPU no matter what as the restoreV2
134 # op must run on CPU.
135 # TODO(b/159035705): Allow restore for Tensor objects as well?
136 # Restore the dvariable back to original layout.
137 if self._original_layout.mesh.device_type().upper() != 'CPU':
138 tensor = _restore(tensor)
139 return self._dvariable.assign(
140 math_ops.cast(tensor, dtype=self._dvariable.dtype) if self._dvariable
141 .save_as_bf16 else tensor)
144@tf_export('experimental.dtensor.DVariable', v1=[])
145class DVariable(resource_variable_ops.ResourceVariable):
146 """A replacement for tf.Variable which follows initial value placement.
148 The class also handles restore/save operations in DTensor. Note that,
149 DVariable may fall back to normal tf.Variable at this moment if
150 `initial_value` is not a DTensor.
151 """
153 def __init__(self, initial_value, *args, dtype=None, **kwargs):
154 """Overrides tf.Variable to fix VarHandleOp placements."""
155 # Variables by default use the current device scope for placement. This
156 # wrapper has them follow the initial value's placement instead (which will
157 # be the DTensor device if the initial value has a layout).
159 # Pop layout from kwargs since keras make_variable may pass a 'layout'
160 # keyword argument. We need to pop it because we are passing kwargs to
161 # super class constructor.
162 layout = kwargs.pop('layout', None)
163 shape = kwargs.get('shape', None)
165 if callable(initial_value):
166 unwrapped = initial_value
167 if issubclass(type(initial_value), functools.partial):
168 unwrapped = initial_value.func
170 # If wrapped is a CheckpointInitialValueCallable, this means that
171 # we are creating a Variable during a checkpoint restore.
172 # Thus the restore will happen now through this callable
173 # and we will create the DVariable with the restored dtensor.
174 if issubclass(type(unwrapped), trackable.CheckpointInitialValueCallable):
175 if not shape or not layout:
176 raise ValueError('Expected shape and layout to be not None.')
178 # CheckpointInitialValueCallable will call an eager tf.RestoreV2,
179 # which does not have any shape information or layout information
180 # attached. Thus we will do two things to have them correctly specified:
181 #
182 # The default layout scope allows us to correctly specify the output
183 # layout of the tf.RestoreV2 that will be called
184 #
185 # Passing shard_info with the correct shape allows the tf.RestoreV2
186 # ShapeInference to extract the shape.
187 initial_value = api.call_with_layout(
188 initial_value,
189 layout,
190 shard_info=trackable.ShardInfo(
191 shape=shape, offset=[0] * len(shape)))
192 else:
193 initial_value = initial_value()
195 # When the initial value came from a Checkpoint restoration, fetch tensor.
196 if isinstance(initial_value, trackable.CheckpointInitialValue):
197 initial_value = initial_value.wrapped_value
199 initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
200 variable_device = initial_value.device
201 self._save_as_bf16 = False
202 # TODO(b/159035705): The following code enables variable creation inside
203 # a tf.function. However, it requires a global dtensor device.
204 # if not variable_device and not tf.executing_eagerly():
205 # try:
206 # initial_value.op.get_attr("_layout")
207 # except ValueError:
208 # pass
209 # else:
210 # # The initial value is a DTensor, but because the DTensor device is
211 # # only active during eager execution at the moment we need to
212 # # translate that into a placement for the eager VarHandleOp.
213 # variable_device = _dtensor_device().name
214 with ops.device(variable_device):
215 # If initial tensor assigned to DVariable is DTensor, record the layout of
216 # the resource so that this can be queried.
217 self.layout = None
218 if context.executing_eagerly():
219 try:
220 self.layout = api.fetch_layout(initial_value)
221 except (errors.InvalidArgumentError, errors.NotFoundError):
222 # For Non-DTensor tensors, fetch layout results in expected
223 # InvalidArgument or NotFoundError depending on whether the API
224 # is called within DTensor device scope or not.
225 self.layout = None
226 pass
227 mesh = self.layout.mesh if self.layout else None
228 with api.default_mesh(mesh) if mesh else contextlib.nullcontext():
229 super(DVariable, self).__init__(
230 initial_value, *args, dtype=dtype, **kwargs)
232 @property
233 def save_as_bf16(self):
234 return self._save_as_bf16
236 @save_as_bf16.setter
237 def save_as_bf16(self, save_as_bf16):
238 """Enables saving float32 as bfloat16."""
239 self._save_as_bf16 = save_as_bf16 and self.dtype == dtypes.float32
241 def _gather_saveables_for_checkpoint(self):
242 return {
243 trackable.VARIABLE_VALUE_KEY:
244 functools.partial(_DVariableSaveable, self)
245 }