Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/tpu_values.py: 32%
260 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing TPU distributed values.
17Note that the tests are in values_test.py .
19"""
21from tensorflow.python.distribute import packed_distributed_variable as packed
22from tensorflow.python.distribute import tpu_replicated_variable
23from tensorflow.python.distribute import tpu_util
24from tensorflow.python.distribute import values
25from tensorflow.python.distribute import values_util
26from tensorflow.python.eager import context
27from tensorflow.python.eager import tape
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import gen_resource_variable_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import variable_scope
34_scatter_error_msg = ("{op_name} is only supported for distributed "
35 "variable (variable created within certain "
36 "`tf.distribute.Strategy` scope) with NONE "
37 " aggregation, got: {aggregation}.")
40class TPUVariableMixin(object):
41 """Mixin for TPU variables."""
43 def __init__(self, *args, **kwargs):
44 super(TPUVariableMixin, self).__init__(*args, **kwargs)
46 # Handle ID is needed for `get_replicated_var_handle` to cache the variables
47 # correctly since in eager mode different variables can have the same name.
48 if ops.executing_eagerly_outside_functions():
49 self._handle_id = self._common_name + "_" + str(id(self._primary))
50 else:
51 self._handle_id = self._common_name
53 def __getattr__(self, name):
54 if tpu_util.enclosing_tpu_context() is None:
55 return super(TPUVariableMixin, self).__getattr__(name)
56 else:
57 raise AttributeError(
58 f"`TPUVariableMixin.{name}` not accessible within a TPU context.")
60 def get(self):
61 if tpu_util.enclosing_tpu_context() is None:
62 return super(TPUVariableMixin, self).get()
63 else:
64 raise NotImplementedError(
65 "`TPUVariableMixin.get()` is not supported within a TPU context.")
67 def _get_as_operand(self):
68 return self.read_value()
70 @property
71 def handle(self):
72 """The handle by which this variable can be accessed."""
73 # If we're in a tpu.rewrite(), return the replicated handle.
74 tpu_context = tpu_util.enclosing_tpu_context()
75 if tpu_context is None or context.executing_eagerly():
76 var = self._get_on_device_or_primary()
77 if isinstance(var, packed.PackedVarAndDevice):
78 return var.on_device_handle()
79 else:
80 return var.handle
81 else:
82 is_packed = self._packed_var is not None
83 val = self._values
84 if is_packed:
85 val = [self._packed_var]
87 return tpu_context.get_replicated_var_handle(self._common_name,
88 self._handle_id, val,
89 self._is_mirrored(),
90 is_packed)
92 @property
93 def device(self):
94 return self.handle.device
96 def _read_variable_op(self):
97 """Reads the value of this variable."""
98 if self.trainable:
99 tape.variable_accessed(self)
101 handle = self.handle
102 if getattr(handle, "is_packed", False):
103 # Add a device scope for a packed variable handle.
104 with ops.device(self._get_on_device_or_primary().device):
105 return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
106 else:
107 return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
109 def read_value(self):
110 if tpu_util.enclosing_tpu_context() is None:
111 return super(TPUVariableMixin, self).read_value()
112 else:
113 return self._read_variable_op()
115 def value(self):
116 if tpu_util.enclosing_tpu_context() is None:
117 return super(TPUVariableMixin, self).value()
118 else:
119 return self._read_variable_op()
121 def _as_graph_element(self):
122 if tpu_util.enclosing_tpu_context() is None:
123 return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access
124 else:
125 return None
127 @property
128 def op(self):
129 if values_util.is_saving_non_distributed():
130 return self._primary.op
131 return values.DistributedVarOp(self._primary.op.name,
132 self._primary.op.graph,
133 self._primary.op.traceback,
134 self._primary.op.type)
136 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
137 """Converts a variable to a tensor."""
138 # pylint: disable=protected-access
139 if tpu_util.enclosing_tpu_context() is None:
140 return super(TPUVariableMixin, self)._dense_var_to_tensor(
141 dtype=dtype, name=name, as_ref=as_ref)
142 # pylint: enable=protected-access
143 elif dtype is not None and dtype != self.dtype:
144 return math_ops.cast(self.read_value(), dtype)
145 else:
146 return self.handle if as_ref else self.read_value()
149class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
150 """DistributedVariable subclass for TPUStrategy."""
152 def assign_sub(self, value, use_locking=False, name=None, read_value=True):
153 if values_util.is_saving_non_distributed():
154 return self._primary.assign_sub(value, use_locking, name, read_value)
155 return self._policy.assign_sub(
156 self, value, use_locking=use_locking, name=name, read_value=read_value)
158 def assign_add(self, value, use_locking=False, name=None, read_value=True):
159 if values_util.is_saving_non_distributed():
160 return self._primary.assign_add(value, use_locking, name, read_value)
161 return self._policy.assign_add(
162 self, value, use_locking=use_locking, name=name, read_value=read_value)
164 def assign(self, value, use_locking=False, name=None, read_value=True):
165 if values_util.is_saving_non_distributed():
166 return self._primary.assign(value, use_locking, name, read_value)
167 return self._policy.assign(
168 self, value, use_locking=use_locking, name=name, read_value=read_value)
170 def scatter_sub(self, sparse_delta, use_locking=False, name=None):
171 if values_util.is_saving_non_distributed():
172 return self._primary.scatter_sub(sparse_delta, use_locking, name)
173 return self._policy.scatter_sub(
174 self, sparse_delta, use_locking=use_locking, name=name)
176 def scatter_add(self, sparse_delta, use_locking=False, name=None):
177 if values_util.is_saving_non_distributed():
178 return self._primary.scatter_add(sparse_delta, use_locking, name)
179 return self._policy.scatter_add(
180 self, sparse_delta, use_locking=use_locking, name=name)
182 def scatter_mul(self, sparse_delta, use_locking=False, name=None):
183 if values_util.is_saving_non_distributed():
184 return self._primary.scatter_mul(sparse_delta, use_locking, name)
185 return self._policy.scatter_mul(
186 self, sparse_delta, use_locking=use_locking, name=name)
188 def scatter_div(self, sparse_delta, use_locking=False, name=None):
189 if values_util.is_saving_non_distributed():
190 return self._primary.scatter_div(sparse_delta, use_locking, name)
191 return self._policy.scatter_div(
192 self, sparse_delta, use_locking=use_locking, name=name)
194 def scatter_min(self, sparse_delta, use_locking=False, name=None):
195 if values_util.is_saving_non_distributed():
196 return self._primary.scatter_min(sparse_delta, use_locking, name)
197 return self._policy.scatter_min(
198 self, sparse_delta, use_locking=use_locking, name=name)
200 def scatter_max(self, sparse_delta, use_locking=False, name=None):
201 if values_util.is_saving_non_distributed():
202 return self._primary.scatter_max(sparse_delta, use_locking, name)
203 return self._policy.scatter_max(
204 self, sparse_delta, use_locking=use_locking, name=name)
206 def scatter_update(self, sparse_delta, use_locking=False, name=None):
207 if values_util.is_saving_non_distributed():
208 return self._primary.scatter_update(sparse_delta, use_locking, name)
209 return self._policy.scatter_update(
210 self, sparse_delta, use_locking=use_locking, name=name)
213class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
214 """Holds a map from replica to TPU variables whose values are kept in sync."""
216 def _is_replicated_or_sharded_to_logical_cores(self):
217 """Returns whether each of the underlying variables is replicated or sharded to logical cores.
219 If True, the handles of the underlying variables are not available outside a
220 TPU context.
221 """
222 return isinstance(self._primary,
223 tpu_replicated_variable.TPUReplicatedVariable)
225 @property
226 def device(self):
227 if (self._is_replicated_or_sharded_to_logical_cores() and
228 tpu_util.enclosing_tpu_context() is None):
229 return self._primary.device
230 return super(TPUMirroredVariable, self).device
232 def assign_sub(self, value, use_locking=False, name=None, read_value=True):
233 tpu_context = tpu_util.enclosing_tpu_context()
234 if (self._is_replicated_or_sharded_to_logical_cores() and
235 tpu_context is None):
236 assign_sub_fn = lambda v, *a, **ka: v.assign_sub(*a, **ka)
237 return self._update(
238 update_fn=assign_sub_fn,
239 value=value,
240 use_locking=use_locking,
241 name=name,
242 read_value=read_value)
244 if (tpu_context and
245 self.aggregation == variable_scope.VariableAggregation.NONE):
246 return tpu_util.make_raw_assign_fn(
247 gen_resource_variable_ops.assign_sub_variable_op)(
248 self,
249 value=value,
250 use_locking=use_locking,
251 name=name,
252 read_value=read_value)
253 return assign_sub(
254 self, value, use_locking=use_locking, name=name, read_value=read_value)
256 def assign_add(self, value, use_locking=False, name=None, read_value=True):
257 tpu_context = tpu_util.enclosing_tpu_context()
258 if (self._is_replicated_or_sharded_to_logical_cores() and
259 tpu_context is None):
260 assign_add_fn = lambda v, *a, **ka: v.assign_add(*a, **ka)
261 return self._update(
262 update_fn=assign_add_fn,
263 value=value,
264 use_locking=use_locking,
265 name=name,
266 read_value=read_value)
268 if (tpu_context and
269 self.aggregation == variable_scope.VariableAggregation.NONE):
270 return tpu_util.make_raw_assign_fn(
271 gen_resource_variable_ops.assign_add_variable_op)(
272 self,
273 value=value,
274 use_locking=use_locking,
275 name=name,
276 read_value=read_value)
277 return assign_add(
278 self, value, use_locking=use_locking, name=name, read_value=read_value)
280 def assign(self, value, use_locking=False, name=None, read_value=True):
281 tpu_context = tpu_util.enclosing_tpu_context()
282 if (self._is_replicated_or_sharded_to_logical_cores() and
283 tpu_context is None):
284 assign_fn = lambda v, *a, **ka: v.assign(*a, **ka)
285 return self._update(
286 update_fn=assign_fn,
287 value=value,
288 use_locking=use_locking,
289 name=name,
290 read_value=read_value)
292 if (tpu_util.enclosing_tpu_context() and
293 self.aggregation == variable_scope.VariableAggregation.NONE):
294 return tpu_util.make_raw_assign_fn(
295 gen_resource_variable_ops.assign_variable_op)(
296 self,
297 value=value,
298 use_locking=use_locking,
299 name=name,
300 read_value=read_value)
301 return assign(
302 self, value, use_locking=use_locking, name=name, read_value=read_value)
304 def scatter_sub(self, *args, **kwargs):
305 if values_util.is_saving_non_distributed():
306 return self._primary.scatter_sub(*args, **kwargs)
307 raise NotImplementedError
309 def scatter_add(self, *args, **kwargs):
310 if values_util.is_saving_non_distributed():
311 return self._primary.scatter_add(*args, **kwargs)
312 raise NotImplementedError
314 def scatter_max(self, *args, **kwargs):
315 if values_util.is_saving_non_distributed():
316 return self._primary.scatter_max(*args, **kwargs)
317 raise NotImplementedError
319 def scatter_min(self, *args, **kwargs):
320 if values_util.is_saving_non_distributed():
321 return self._primary.scatter_min(*args, **kwargs)
322 raise NotImplementedError
324 def scatter_mul(self, *args, **kwargs):
325 if values_util.is_saving_non_distributed():
326 return self._primary.scatter_mul(*args, **kwargs)
327 raise NotImplementedError
329 def scatter_div(self, *args, **kwargs):
330 if values_util.is_saving_non_distributed():
331 return self._primary.scatter_div(*args, **kwargs)
332 raise NotImplementedError
334 def scatter_update(self, *args, **kwargs):
335 if values_util.is_saving_non_distributed():
336 return self._primary.scatter_update(*args, **kwargs)
337 raise NotImplementedError
340class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
341 """Holds a map from replica to variables whose values are reduced on save."""
343 def assign_sub(self, *args, **kwargs):
344 if tpu_util.enclosing_tpu_context() is None:
345 return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
346 else:
347 return tpu_util.make_raw_assign_fn(
348 gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
349 **kwargs)
351 def assign_add(self, *args, **kwargs):
352 if tpu_util.enclosing_tpu_context() is None:
353 return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
354 else:
355 return tpu_util.make_raw_assign_fn(
356 gen_resource_variable_ops.assign_add_variable_op)(self, *args,
357 **kwargs)
359 def assign(self, *args, **kwargs):
360 if tpu_util.enclosing_tpu_context() is None:
361 return values.SyncOnReadVariable.assign(self, *args, **kwargs)
362 else:
363 return tpu_util.make_raw_assign_fn(
364 gen_resource_variable_ops.assign_variable_op)(self, *args, **kwargs)
367# Common method between OnWrite and Mirrored variables.
368def assign_sub(var, value, use_locking=False, name=None, read_value=True):
369 assign_sub_fn = tpu_util.make_raw_assign_fn(
370 gen_resource_variable_ops.assign_sub_variable_op)
371 return var._update( # pylint: disable=protected-access
372 update_fn=assign_sub_fn,
373 value=value,
374 use_locking=use_locking,
375 name=name,
376 read_value=read_value)
379def assign_add(var, value, use_locking=False, name=None, read_value=True):
380 assign_add_fn = tpu_util.make_raw_assign_fn(
381 gen_resource_variable_ops.assign_add_variable_op)
382 return var._update( # pylint: disable=protected-access
383 update_fn=assign_add_fn,
384 value=value,
385 use_locking=use_locking,
386 name=name,
387 read_value=read_value)
390def assign(var, value, use_locking=False, name=None, read_value=True):
391 assign_fn = tpu_util.make_raw_assign_fn(
392 gen_resource_variable_ops.assign_variable_op)
393 return var._update( # pylint: disable=protected-access
394 update_fn=assign_fn,
395 value=value,
396 use_locking=use_locking,
397 name=name,
398 read_value=read_value)
401class TPUOnWritePolicy(values.OnWritePolicy):
402 """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
404 This policy is created when `synchronization` is set to
405 `tf.VariableSynchronization.AUTO` or `tf.VariableSynchronization.ON_WRITE`.
406 """
408 def assign_sub(self,
409 var,
410 value,
411 use_locking=False,
412 name=None,
413 read_value=True):
414 if (tpu_util.enclosing_tpu_context() and
415 var.aggregation == variable_scope.VariableAggregation.NONE):
416 return tpu_util.make_raw_assign_fn(
417 gen_resource_variable_ops.assign_sub_variable_op)(
418 var,
419 value=value,
420 use_locking=use_locking,
421 name=name,
422 read_value=read_value)
423 return assign_sub(
424 var, value, use_locking=use_locking, name=name, read_value=read_value)
426 def assign_add(self,
427 var,
428 value,
429 use_locking=False,
430 name=None,
431 read_value=True):
432 if (tpu_util.enclosing_tpu_context() and
433 var.aggregation == variable_scope.VariableAggregation.NONE):
434 return tpu_util.make_raw_assign_fn(
435 gen_resource_variable_ops.assign_add_variable_op)(
436 var,
437 value=value,
438 use_locking=use_locking,
439 name=name,
440 read_value=read_value)
441 return assign_add(
442 var, value, use_locking=use_locking, name=name, read_value=read_value)
444 def assign(self, var, value, use_locking=False, name=None, read_value=True):
445 if (tpu_util.enclosing_tpu_context() and
446 var.aggregation == variable_scope.VariableAggregation.NONE):
447 return tpu_util.make_raw_assign_fn(
448 gen_resource_variable_ops.assign_variable_op)(
449 var,
450 value=value,
451 use_locking=use_locking,
452 name=name,
453 read_value=read_value)
454 return assign(
455 var, value, use_locking=use_locking, name=name, read_value=read_value)
457 def _scatter_xxx(self,
458 raw_scater_xxx_fn,
459 op_name,
460 var,
461 sparse_delta,
462 use_locking=False,
463 name=None):
464 scater_xxx_fn = tpu_util.make_raw_scatter_xxx_fn(raw_scater_xxx_fn)
465 if tpu_util.enclosing_tpu_context():
466 if self._aggregation != variable_scope.VariableAggregation.NONE:
467 raise NotImplementedError(
468 _scatter_error_msg.format(
469 op_name=op_name, aggregation=self._aggregation))
470 return scater_xxx_fn(
471 var, sparse_delta=sparse_delta, use_locking=use_locking, name=name)
472 else:
473 return var._update( # pylint: disable=protected-access
474 update_fn=scater_xxx_fn,
475 value=sparse_delta,
476 use_locking=use_locking,
477 name=name)
479 def scatter_sub(self, var, sparse_delta, use_locking=False, name=None):
480 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_sub,
481 "scatter_sub", var, sparse_delta, use_locking,
482 name)
484 def scatter_add(self, var, sparse_delta, use_locking=False, name=None):
485 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_add,
486 "scatter_add", var, sparse_delta, use_locking,
487 name)
489 def scatter_max(self, var, sparse_delta, use_locking=False, name=None):
490 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_max,
491 "scatter_max", var, sparse_delta, use_locking,
492 name)
494 def scatter_min(self, var, sparse_delta, use_locking=False, name=None):
495 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_min,
496 "scatter_min", var, sparse_delta, use_locking,
497 name)
499 def scatter_mul(self, var, sparse_delta, use_locking=False, name=None):
500 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_mul,
501 "scatter_mul", var, sparse_delta, use_locking,
502 name)
504 def scatter_div(self, var, sparse_delta, use_locking=False, name=None):
505 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_div,
506 "scatter_div", var, sparse_delta, use_locking,
507 name)
509 def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
510 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_update,
511 "scatter_update", var, sparse_delta, use_locking,
512 name)
515class TPUOnReadPolicy(values.OnReadPolicy):
516 """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
518 This policy is created when `synchronization` is set to
519 `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
520 values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
521 `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
522 scope.
523 """
525 def assign_sub(self, var, *args, **kwargs):
526 if tpu_util.enclosing_tpu_context() is None:
527 return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs)
528 else:
529 return tpu_util.make_raw_assign_fn(
530 gen_resource_variable_ops.assign_sub_variable_op)(var, *args,
531 **kwargs)
533 def assign_add(self, var, *args, **kwargs):
534 if tpu_util.enclosing_tpu_context() is None:
535 return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs)
536 else:
537 return tpu_util.make_raw_assign_fn(
538 gen_resource_variable_ops.assign_add_variable_op)(var, *args,
539 **kwargs)
541 def assign(self, var, *args, **kwargs):
542 if tpu_util.enclosing_tpu_context() is None:
543 return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs)
544 else:
545 return tpu_util.make_raw_assign_fn(
546 gen_resource_variable_ops.assign_variable_op)(var, *args, **kwargs)
548 def scatter_sub(self, *args, **kwargs):
549 raise NotImplementedError
551 def scatter_add(self, *args, **kwargs):
552 raise NotImplementedError
554 def scatter_max(self, *args, **kwargs):
555 raise NotImplementedError
557 def scatter_min(self, *args, **kwargs):
558 raise NotImplementedError
560 def scatter_mul(self, *args, **kwargs):
561 raise NotImplementedError
563 def scatter_div(self, *args, **kwargs):
564 raise NotImplementedError
566 def scatter_update(self, *args, **kwargs):
567 raise NotImplementedError