Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/tpu_replicated_variable.py: 47%
150 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Variable class that is replicated to logical cores for model parallelism."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
20from collections import abc
21import contextlib
23from tensorflow.python.compiler.xla.experimental import xla_sharding
24from tensorflow.python.distribute import tpu_util
25from tensorflow.python.eager import context
26from tensorflow.python.framework import config
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_conversion_registry
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import gen_resource_variable_ops
31from tensorflow.python.ops import gen_tpu_partition_ops as tpu_partition_ops
32from tensorflow.python.ops import variable_scope
33from tensorflow.python.ops import variables as variables_lib
34from tensorflow.python.saved_model import save_context
35from tensorflow.python.trackable import base as trackable
38def _on_device_update(update_fn, var, value, **kwargs):
39 with ops.device(var.device):
40 return update_fn(var, value, **kwargs)
43class TPUReplicatedVariable(variables_lib.Variable):
44 """Container for replicated `Variables` that are treated as a single variable.
46 This class maintains a list of replicated variables that are stored on
47 separate logic TPU devices. TF2XLA bridge accesses these variables as
48 if they were a single variable.
49 """
51 def __init__(self, variables, name='TPUReplicatedVariable'):
52 """Treats `variables` as a replicated list of `tf.Variable`s.
54 Example:
56 ```
57 variables = [
58 tf.Variable(..., shape=(10, 100), dtype=tf.float32),
59 tf.Variable(..., shape=(10, 100), dtype=tf.float32),
60 tf.Variable(..., shape=(10, 100), dtype=tf.float32),
61 tf.Variable(..., shape=(10, 100), dtype=tf.float32),
62 ]
63 replicated_variable = TPUReplicatedVariable(variables)
64 assert replicated_variable.shape.as_list() == [10, 100]
65 ```
67 Args:
68 variables: A list of `ResourceVariable`s that comprise this replicated
69 variable. Variables should not be shared between different
70 `TPUReplicatedVariable` objects.
71 name: String. Name of this container. Defaults to "TPUReplicatedVariable".
72 """
73 if not isinstance(variables, abc.Sequence) or not variables or any(
74 not isinstance(v, variables_lib.Variable) for v in variables):
75 raise TypeError('Argument `variables` should be a non-empty list of '
76 f'`variables.Variable`s. Received {variables}')
78 if any(v.dtype != variables[0].dtype for v in variables):
79 raise ValueError(
80 'All elements in argument `variables` must have the same dtype. '
81 f'Received dtypes: {[v.dtype for v in variables]}')
83 if any(v.shape != variables[0].shape for v in variables):
84 raise ValueError(
85 'All elements in argument `variables` must have the same shape. '
86 f'Received shapes: {[v.shape for v in variables]}')
88 self._vars = variables
89 self._name = name
90 self._common_name = self._name.split(':')[0]
91 self._cached_value = None
93 def __iter__(self):
94 """Return an iterable for accessing the underlying sharded variables."""
95 return iter(self._vars)
97 @property
98 def name(self):
99 """The name of this object. Used for checkpointing."""
100 return self._name
102 @property
103 def dtype(self):
104 """The dtype of all `Variable`s in this object."""
105 return self._vars[0].dtype
107 @property
108 def is_initialized(self):
109 return self._vars[0].is_initialized
111 @property
112 def trainable(self):
113 return self._vars[0].trainable
115 @property
116 def device(self):
117 """The device this variable is on."""
118 return self._vars[0].device
120 @contextlib.contextmanager
121 def _handle_graph(self):
122 with self.handle.graph.as_default():
123 yield
125 @contextlib.contextmanager
126 def _assign_dependencies(self):
127 if self._cached_value is not None:
128 with ops.control_dependencies([self._cached_value]):
129 yield
130 else:
131 yield
133 @property
134 def constraint(self):
135 return self._vars[0].constraint
137 @property
138 def _in_graph_mode(self):
139 return self._vars[0]._in_graph_mode # pylint: disable=protected-access
141 @property
142 def _unique_id(self):
143 return self._vars[0]._unique_id # pylint: disable=protected-access
145 @property
146 def graph(self):
147 return self._vars[0].graph
149 @property
150 def _shared_name(self):
151 return self._common_name
153 @property
154 def synchronization(self):
155 return variable_scope.VariableSynchronization.NONE
157 @property
158 def aggregation(self):
159 return variable_scope.VariableAggregation.NONE
161 @property
162 def variables(self):
163 """The list of `Variables`."""
164 if save_context.in_save_context():
165 return [self._vars[0]]
166 return self._vars
168 def _export_to_saved_model_graph(self, object_map, tensor_map,
169 options, **kwargs):
170 """For implementing `Trackable`."""
171 first_var = self._vars[0]
172 resource_list = first_var._export_to_saved_model_graph( # pylint:disable=protected-access
173 object_map, tensor_map, options, **kwargs)
174 for v in self._vars[1:]:
175 object_map[v] = object_map[first_var]
176 tensor_map[v.handle] = tensor_map[first_var.handle]
177 resource_list.append(v.handle)
178 object_map[self] = object_map[first_var]
179 tensor_map[self] = tensor_map[first_var.handle]
180 resource_list.append(self)
181 return resource_list
183 def _gather_saveables_for_saved_model(self):
184 return {trackable.VARIABLE_VALUE_KEY: self._vars[0]}
186 @property
187 def shape(self):
188 return self._vars[0].shape
190 @property
191 def handle(self):
192 if save_context.in_save_context() or context.executing_eagerly():
193 return self._vars[0].handle
195 if tpu_util.enclosing_tpu_context() is None:
196 raise NotImplementedError('TPUReplicatedVariable.handle is not available '
197 'outside tpu context or save context')
198 else:
199 with tpu_util.outside_or_skip_tpu_context():
200 packed_var = getattr(self, '_packed_var', None)
202 # TODO(b/202047549): Enable packed variables with soft device placement
203 if packed_var is None or config.get_soft_device_placement():
204 tensor = tpu_partition_ops.tpu_partitioned_input_v2(
205 [v.handle for v in self._vars],
206 partition_dims=[], is_packed=False)
207 else:
208 tensor = tpu_partition_ops.tpu_partitioned_input_v2(
209 [packed_var.packed_handle], partition_dims=[], is_packed=True)
211 return xla_sharding.replicate(tensor)
213 def _read_variable_op(self):
214 return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
216 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
217 """Converts a variable to a tensor."""
218 # pylint: disable=protected-access
219 if tpu_util.enclosing_tpu_context() is None:
220 return self.read_value()
221 else:
222 return self._read_variable_op()
224 def read_value(self):
225 return self._vars[0].read_value()
227 def _update(self, update_fn, value, **kwargs):
228 """Converts the value to tensor and updates the variable list."""
229 input_tensor = ops.convert_to_tensor(
230 value, name='value_in_tensor', dtype=self.dtype)
232 return control_flow_ops.group(
233 *tuple(
234 _on_device_update(update_fn, v, input_tensor, **kwargs)
235 for v in self.variables))
237 def assign(self, value, use_locking=False, name=None, read_value=True):
238 if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly():
239 assign_fn = lambda var, *a, **ka: var.assign(*a, **ka)
240 return self._update(
241 assign_fn,
242 value=value,
243 use_locking=use_locking,
244 name=name,
245 read_value=read_value)
246 else:
247 return tpu_util.make_raw_assign_fn(
248 gen_resource_variable_ops.assign_variable_op)(
249 self,
250 value=value,
251 use_locking=use_locking,
252 name=name,
253 read_value=read_value)
255 def assign_sub(self, value, use_locking=False, name=None, read_value=True):
256 if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly():
257 assign_sub_fn = lambda var, *a, **ka: var.assign_sub(*a, **ka)
258 return self._update(
259 assign_sub_fn,
260 value=value,
261 use_locking=use_locking,
262 name=name,
263 read_value=read_value)
264 else:
265 return tpu_util.make_raw_assign_fn(
266 gen_resource_variable_ops.assign_sub_variable_op)(
267 self,
268 value=value,
269 use_locking=use_locking,
270 name=name,
271 read_value=read_value)
273 def assign_add(self, value, use_locking=False, name=None, read_value=True):
274 if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly():
275 assign_add_fn = lambda var, *a, **ka: var.assign_add(*a, **ka)
276 return self._update(
277 assign_add_fn,
278 value=value,
279 use_locking=use_locking,
280 name=name,
281 read_value=read_value)
282 else:
283 return tpu_util.make_raw_assign_fn(
284 gen_resource_variable_ops.assign_add_variable_op)(
285 self,
286 value=value,
287 use_locking=use_locking,
288 name=name,
289 read_value=read_value)
291 def __str__(self):
292 debug_str = ',\n'.join(
293 ' %d: %s' % (i, v) for i, v in enumerate(self._vars))
294 return '%s:{\n%s\n}' % (self.__class__.__name__, debug_str)
296 def __repr__(self):
297 debug_repr = ',\n'.join(
298 ' %d: %r' % (i, v) for i, v in enumerate(self._vars))
299 return '%s:{\n%s\n}' % (self.__class__.__name__, debug_repr)
302# Register a conversion function which reads the value of the variable,
303# allowing instances of the class to be used as tensors.
304def _tensor_conversion_tpu_replicated_var(var,
305 dtype=None,
306 name=None,
307 as_ref=False):
308 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
311tensor_conversion_registry.register_tensor_conversion_function(
312 TPUReplicatedVariable, _tensor_conversion_tpu_replicated_var)