Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/values.py: 33%
884 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed values."""
17import copy
18from typing import Optional
19import weakref
21from tensorflow.core.protobuf import struct_pb2
22from tensorflow.python.distribute import device_util
23from tensorflow.python.distribute import distribute_lib
24from tensorflow.python.distribute import packed_distributed_variable as packed
25from tensorflow.python.distribute import reduce_util
26from tensorflow.python.distribute import values_util
27from tensorflow.python.eager import context
28from tensorflow.python.eager import record
29from tensorflow.python.framework import composite_tensor
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_conversion_registry
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.framework import type_spec
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import resource_variable_ops
39from tensorflow.python.ops import variable_scope as vs
40from tensorflow.python.ops import variables as variables_lib
41from tensorflow.python.saved_model import nested_structure_coder
42from tensorflow.python.trackable import base as trackable
43from tensorflow.python.training.saving import saveable_object
44from tensorflow.python.types import core
45from tensorflow.python.types import distribute as ds_types
46from tensorflow.python.types import trace
49def _on_write_update_replica(var, update_fn, value, **kwargs):
50 """Updates variables with ON_WRITE synchronization in replica context."""
51 if var.aggregation == vs.VariableAggregation.NONE:
52 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access
54 if not distribute_lib.get_strategy().extended._use_merge_call(): # pylint: disable=protected-access
55 # Don't allow MEAN with non float dtype, since it may cause unexpected
56 # precision loss. Python3 and NumPy automatically upcast integers to
57 # float in division, but we should always preserve the type.
58 if var.aggregation == vs.VariableAggregation.MEAN and (
59 not var.dtype.is_floating) and tensor_util.is_tf_type(value):
60 raise ValueError(
61 "Cannot update non-float variables with "
62 "tf.VariableAggregation.MEAN aggregation in replica context. "
63 "Either change the variable dtype to float or update it in "
64 "cross-replica context.")
66 aggregated_value = apply_aggregation_replica_context(
67 value, var.aggregation, var)
68 values_util.mark_as_unsaveable()
70 return distribute_lib.get_replica_context()._update( # pylint: disable=protected-access
71 var,
72 update_fn,
73 args=(aggregated_value,),
74 kwargs=kwargs,
75 group=True)
77 else:
79 def merge_fn(strategy, value, **kwargs):
80 """Aggregate values and update all variables in cross replica context."""
81 # Don't allow MEAN with non float dtype, since it may cause unexpected
82 # precision loss. Python3 and NumPy automatically upcast integers to
83 # float in division, but we should always preserve the type.
84 #
85 # Note that to be backward compatible we allow the case when the value
86 # is *always* the same on each replica. I.E. value is not a
87 # PerReplica. Refer to regroup() to see how values are grouped.
88 if var.aggregation == vs.VariableAggregation.MEAN and (
89 not var.dtype.is_floating) and isinstance(value, PerReplica):
90 raise ValueError(
91 "Cannot update non-float variables with "
92 "tf.VariableAggregation.MEAN aggregation in replica context. "
93 "Either change the variable dtype to float or update it in "
94 "cross-replica context.")
96 assert strategy == var.distribute_strategy
97 v = values_util.apply_aggregation(strategy, value, var.aggregation, var)
98 return var._update_cross_replica(update_fn, v, **kwargs) # pylint: disable=protected-access
100 return distribute_lib.get_replica_context().merge_call(
101 merge_fn, args=(value,), kwargs=kwargs)
104def apply_aggregation_replica_context(value, aggregation, destinations):
105 """Aggregate `value` to `destinations` as specified by `aggregation`."""
106 # if it is a python literal, return without aggregation
107 if isinstance(value, DistributedValues):
108 raise TypeError(
109 "Cannot use DistributedValues to update variables in replica context.")
110 if not tensor_util.is_tf_type(value):
111 return value
113 if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
114 # Switch to cross-replica context to broadcast
115 def merge_fn(strategy, value):
116 return strategy.extended.broadcast_to(
117 strategy.experimental_local_results(value)[0],
118 destinations=destinations)
120 return distribute_lib.get_replica_context().merge_call(
121 merge_fn, args=(value,))
123 else:
124 reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
125 aggregated_value = distribute_lib.get_strategy( # pylint: disable=protected-access
126 ).extended._replica_ctx_all_reduce(reduce_op, value)
127 return aggregated_value
130class DistributedValues(ds_types.DistributedValues):
131 """Base class for representing distributed values."""
133 def __init__(self, values):
134 """Should only be called by subclass __init__."""
135 self._values = tuple(values)
137 def _get(self):
138 """Returns the value for the current device or raises a ValueError."""
139 replica_id = values_util.get_current_replica_id_as_int()
140 if replica_id is None:
141 return self._get_cross_replica()
142 else:
143 return self._values[replica_id]
145 def _get_cross_replica(self):
146 raise NotImplementedError(
147 "DistributedValues._get_cross_replica should be implemented by "
148 "sub-classes which support cross-replica accesses.")
150 def _get_on_device_or_primary(self):
151 """Returns value in same replica or device if possible, else the _primary."""
152 replica_id = values_util.get_current_replica_id_as_int()
153 if replica_id is None:
154 # Try to find a value on the current device.
155 current_device = device_util.canonicalize(device_util.current())
156 for value in self._values:
157 if device_util.canonicalize(value.device) == current_device:
158 return value
159 return self._primary
160 else:
161 return self._values[replica_id]
163 @property
164 def _primary(self):
165 """Returns a representative component."""
166 return self._values[0]
168 @property
169 def _devices(self):
170 return tuple(v.device for v in self._values)
172 def __str__(self):
173 debug_str = ",\n".join(
174 " %d: %s" % (i, v) for i, v in enumerate(self._values))
175 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str)
177 def __repr__(self):
178 debug_repr = ",\n".join(
179 " %d: %r" % (i, v) for i, v in enumerate(self._values))
180 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr)
183# NOTE(josh11b,apassos): It would be great if we could inspect the values this was
184# initialized with and use that to generate the overloaded operators here.
185# Unfortunately, Python's rules for special methods don't allow this, see
186# https://docs.python.org/3/reference/datamodel.html#special-method-names
187# "if a class defines a method named __getitem__(), and x is an instance of
188# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)."
189# In particular, these special methods don't go through __getattr__, and
190# it will only use those methods if they are defined in the class, not the
191# object.
192class DistributedDelegate(DistributedValues):
193 """A map from device to values; acts as the same type as the values."""
195 def __getattr__(self, name):
196 # The '_use_resource_variables' and the attrs starts with '_self' are used
197 # for restoring the saved_model proto, and '_attribute_sentinel' is used for
198 # Layer tracking. At the point these attrs are queried, the variable has not
199 # been initialized. Thus it should not query those of the underlying
200 # components.
201 if name.startswith("_self_") or name in ("_use_resource_variables",
202 "_attribute_sentinel",
203 "_distributed_container"):
204 return super(DistributedDelegate, self).__getattr__(name)
206 # This allows copy.copy(DistributedDelegate). When copying an object,
207 # copy.copy doesn't invoke its __init__ method, instead it makes a new
208 # empty object, then copies the attributes over. copy.copy looks for
209 # attributes like "__getstate__" in case the object implements its custom
210 # copying. Since DistributedDelegate doesn't have those attributes defined,
211 # __getattr__ will be invoked, which tries to access "_values" attributes,
212 # but that doesn't exist either because this is an empty object, and again
213 # __getattr__ is invoked, leading to an infinite recursion.
214 if name == "_values":
215 raise AttributeError()
217 # TODO(priyag): This needs to be made robust against pitfalls from mix use
218 # __getattr__ and @property. See b/120402273.
219 return getattr(self._get(), name)
221 @property
222 def values(self):
223 """Returns the per replica values."""
224 return self._values
226 def _get_as_operand(self):
227 """Returns the value for operations for the current device.
229 Some implementations, e.g. `TPUMirroredVariable`, are not able to return the
230 value type within a replica context. They can, however, return a value that
231 can be used by the operations below.
232 """
233 return self._get()
235 # pylint: disable=multiple-statements
236 def __add__(self, o):
237 return self._get_as_operand() + o
239 def __radd__(self, o):
240 return o + self._get_as_operand()
242 def __sub__(self, o):
243 return self._get_as_operand() - o
245 def __rsub__(self, o):
246 return o - self._get_as_operand()
248 def __mul__(self, o):
249 return self._get_as_operand() * o
251 def __rmul__(self, o):
252 return o * self._get_as_operand()
254 def __truediv__(self, o):
255 return self._get_as_operand() / o
257 def __rtruediv__(self, o):
258 return o / self._get_as_operand()
260 def __floordiv__(self, o):
261 return self._get_as_operand() // o
263 def __rfloordiv__(self, o):
264 return o // self._get_as_operand()
266 def __mod__(self, o):
267 return self._get_as_operand() % o
269 def __rmod__(self, o):
270 return o % self._get_as_operand()
272 def __lt__(self, o):
273 return self._get_as_operand() < o
275 def __le__(self, o):
276 return self._get_as_operand() <= o
278 def __gt__(self, o):
279 return self._get_as_operand() > o
281 def __ge__(self, o):
282 return self._get_as_operand() >= o
284 def __and__(self, o):
285 return self._get_as_operand() & o
287 def __rand__(self, o):
288 return o & self._get_as_operand()
290 def __or__(self, o):
291 return self._get_as_operand() | o
293 def __ror__(self, o):
294 return o | self._get_as_operand()
296 def __xor__(self, o):
297 return self._get_as_operand() ^ o
299 def __rxor__(self, o):
300 return o ^ self._get_as_operand()
302 def __getitem__(self, o):
303 return self._get_as_operand()[o]
305 def __pow__(self, o, modulo=None):
306 return pow(self._get_as_operand(), o, modulo)
308 def __rpow__(self, o):
309 return pow(o, self._get_as_operand())
311 def __invert__(self):
312 return ~self._get_as_operand()
314 def __neg__(self):
315 return -self._get_as_operand()
317 def __abs__(self):
318 return abs(self._get_as_operand())
320 def __div__(self, o):
321 try:
322 return self._get_as_operand().__div__(o)
323 except AttributeError:
324 # See https://docs.python.org/3/library/constants.html#NotImplemented
325 return NotImplemented
327 def __rdiv__(self, o):
328 try:
329 return self._get_as_operand().__rdiv__(o)
330 except AttributeError:
331 # See https://docs.python.org/3/library/constants.html#NotImplemented
332 return NotImplemented
334 def __matmul__(self, o):
335 try:
336 return self._get_as_operand().__matmul__(o)
337 except AttributeError:
338 # See https://docs.python.org/3/library/constants.html#NotImplemented
339 return NotImplemented
341 def __rmatmul__(self, o):
342 try:
343 return self._get_as_operand().__rmatmul__(o)
344 except AttributeError:
345 # See https://docs.python.org/3/library/constants.html#NotImplemented
346 return NotImplemented
348 # TODO(josh11b): Even more operator overloads.
351class PerReplica(DistributedValues, composite_tensor.CompositeTensor,
352 ds_types.PerReplica):
353 """Holds a map from replica to unsynchronized values."""
355 @property
356 def _type_spec(self):
357 return PerReplicaSpec(
358 *(type_spec.type_spec_from_value(v) for v in self._values))
360 @property
361 def values(self):
362 """Returns the per replica values."""
363 return self._values
366def _per_replica_to_tensor(var, dtype=None, name=None, as_ref=False):
367 """Converts a `PerReplica` to a `Tensor`."""
368 del name
369 if dtype is not None and not dtype.is_compatible_with(var.dtype):
370 raise ValueError(
371 "Incompatible type conversion requested to type {!r} for variable "
372 "of type {!r}".format(dtype.name, var.dtype.name))
373 if as_ref:
374 raise NotImplementedError(
375 "PerReplica doesn't support being used as a reference.")
376 if (distribute_lib.in_cross_replica_context() or
377 not distribute_lib.has_strategy()):
378 raise ValueError("It looks like you are using a PerReplica object while "
379 "not inside a replica context, which is not supported. "
380 "Try running your op or function inside a replica context "
381 "by using `strategy.run`")
382 else:
383 replica_id = values_util.get_current_replica_id_as_int()
384 return var.values[replica_id]
386# Register a conversion function to provide a useful error message when users
387# try to use PerReplica values in the wrong contexts
388tensor_conversion_registry.register_tensor_conversion_function(
389 PerReplica, _per_replica_to_tensor)
392class PerReplicaSpec(type_spec.TypeSpec):
393 """Type specification for a `PerReplica`."""
395 __slots__ = ["_value_specs"]
397 value_type = property(lambda self: PerReplica)
399 def __init__(self, *value_specs):
400 self._value_specs = tuple(value_specs)
402 def _serialize(self):
403 return self._value_specs
405 @property
406 def _component_specs(self):
407 return self._value_specs
409 def _to_components(self, value):
410 replica_context = distribute_lib.get_replica_context()
411 if replica_context is not None and replica_context.num_replicas_in_sync > 1:
412 raise ValueError(
413 "Flattening a PerReplica to components is not supported in replica "
414 "context.")
415 return value._values # pylint: disable=protected-access
417 def _from_components(self, tensor_list):
418 return PerReplica(tensor_list)
421nested_structure_coder.register_codec(
422 nested_structure_coder.BuiltInTypeSpecCodec(
423 PerReplicaSpec, struct_pb2.TypeSpecProto.PER_REPLICA_SPEC
424 )
425)
428# Note that unlike PerReplica, Mirrored values inherit from
429# DistributedDelegate and so can be used directly in cross-replica mode.
430# TODO(tomhennigan) Should this extend CompositeTensor?
431class Mirrored(DistributedDelegate, ds_types.Mirrored):
432 """Holds a map from replica to values which are kept in sync."""
434 def _get_cross_replica(self):
435 return self._get_on_device_or_primary()
437 def _as_graph_element(self):
438 obj = self._get()
439 conv_fn = getattr(obj, "_as_graph_element", None)
440 if conv_fn and callable(conv_fn):
441 return conv_fn()
442 return obj
444 def _is_mirrored(self):
445 return True
448class DistributedVarOp(object):
449 """A class that looks like `tf.Operation`."""
451 def __init__(self, name, graph, traceback, typ):
452 self.name = name
453 self.graph = graph
454 self.traceback = traceback
455 self.type = typ
457 def __eq__(self, o):
458 if not isinstance(o, self.__class__):
459 raise NotImplementedError
460 return (self.name == o.name and self.graph == o.graph and
461 self.traceback == o.traceback and self.type == o.type)
463 def __hash__(self):
464 return hash((self.name, self.graph, tuple(self.traceback), self.type))
467# TODO(b/209081027): Remove this once Variable is a CompositeTensor.
468class DistributedVariableTraceType(trace.TraceType):
469 """TraceType of DistributedVariable objects."""
471 def __init__(self, distributed_variable):
472 self.distributed_variable = distributed_variable
473 self.components = (tuple(distributed_variable.shape.as_list()),
474 distributed_variable.dtype)
476 def is_subtype_of(self, other):
477 return self == other
479 def most_specific_common_supertype(self, others):
480 return self if all(self == other for other in others) else None
482 def placeholder_value(self, placeholder_context=None):
483 return self.distributed_variable
485 def _to_tensors(self, value):
486 return []
488 def __hash__(self) -> int:
489 return hash(self.components)
491 def __eq__(self, other) -> bool:
492 if not isinstance(other, DistributedVariableTraceType):
493 return False
495 return self.components == other.components
498class DistributedVariable(DistributedDelegate, variables_lib.Variable,
499 core.Tensor):
500 """Holds a map from replica to variables."""
502 def __init__(self, strategy, values, aggregation, var_policy=None):
503 if (aggregation == variables_lib.VariableAggregation.MEAN and
504 not values[0].dtype.is_floating):
505 raise ValueError(
506 "creating distributed tf.Variable with aggregation=MEAN and a "
507 "non-floating dtype is not supported, please use a different "
508 "aggregation or dtype")
509 self._distribute_strategy = strategy
510 self._aggregation = aggregation
511 super(DistributedVariable, self).__init__(values)
512 self._common_name = self._primary.name.split(":")[0]
513 # Use a weakref to make it easy to map from the contained values
514 # to the container without introducing a reference cycle.
515 for v in values:
516 # ResourceVariable is a CompositeTensor. Attributes added to
517 # CompositeTensors will get lost through tf.nest packing and unpacking.
518 if isinstance(v, composite_tensor.CompositeTensor) and hasattr(
519 v, "handle"):
520 v.handle._distributed_container = weakref.ref(self) # pylint: disable=protected-access
521 else:
522 v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
524 # Packed variable is used to reduce the overhead of function execution.
525 # For a DistributedVariable, only one variable handle is captured into a
526 # function graph. It's only supported in eager mode.
527 if ops.executing_eagerly_outside_functions() and getattr(
528 strategy, "_enable_packed_variable_in_eager_mode", False):
529 name = "%s/packed/" % self._common_name
530 if hasattr(values[0], "_vars"):
531 # Handle when the resource variables are "nested" underneath another
532 # layer of values, e.g., TPUReplicatedVariable, by packing all them
533 # together and pushing the packed var down a level
534 # pylint: disable=protected-access
535 packed_var = packed.PackedDistributedVariable(
536 sum((value._vars for value in values), []), name=name)
537 for value in values:
538 value._packed_var = packed_var
539 self._packed_var = None
540 # pylint: enable=protected-access
541 else:
542 self._packed_var = packed.PackedDistributedVariable(values, name=name)
543 else:
544 self._packed_var = None
546 # tf.keras keeps track of variables initialized using this attribute. When
547 # tf.keras gets the default session, it initializes all uninitialized vars.
548 # We need to make _keras_initialized a member of DistributedVariable because
549 # without this it will use `__getattr__` which will delegate to a component
550 # variable.
551 self._keras_initialized = False
552 # Typically, a `DistributedVariable`'s initializer is composed of the
553 # initializers of the components variables. However, in some cases, such as
554 # when restoring from a checkpoint, we may set the _initializer_op
555 # property on the entire `DistributedVariable`.
556 self._initializer_op = None
557 # Set a VariablePolicy which decides how we replicate/aggregate the given
558 # variable.
559 self._policy = var_policy
561 def __deepcopy__(self, memo):
562 """Perform a deepcopy of the `DistributedVariable`.
564 Unlike the deepcopy of a regular tf.Variable, this keeps the original
565 strategy and devices of the `DistributedVariable`. To avoid confusion
566 with the behavior of deepcopy on a regular `Variable` (which does
567 copy into new devices), we only allow a deepcopy of a `DistributedVariable`
568 within its originating strategy scope.
570 Args:
571 memo: The memoization object for `deepcopy`.
573 Returns:
574 A deep copy of the current `DistributedVariable`.
576 Raises:
577 RuntimeError: If trying to deepcopy into a different strategy.
578 """
579 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
580 new_values = []
582 for value in self._values:
583 with ops.device(value.device):
584 new_values.append(copy.deepcopy(value, memo))
586 copied_variable = type(self)(
587 strategy=self._distribute_strategy,
588 values=new_values,
589 aggregation=self._aggregation,
590 var_policy=copy.deepcopy(self._policy, memo))
592 memo[id(self)] = copied_variable
594 return copied_variable
596 def _use_packed_variable(self):
597 # Don't use packed variable when under a SaveContext to avoid explicit
598 # device placement on variable consuming ops.
599 return self._packed_var is not None and (
600 not values_util.is_saving_non_distributed())
602 def is_initialized(self, name=None):
603 """Identifies if all the component variables are initialized.
605 Args:
606 name: Name of the final `logical_and` op.
608 Returns:
609 The op that evaluates to True or False depending on if all the
610 component variables are initialized.
611 """
612 if values_util.is_saving_non_distributed():
613 return self._primary.is_initialized()
614 if self._use_packed_variable():
615 return self._packed_var.is_initialized()
616 result = self._primary.is_initialized()
617 # We iterate through the list of values except the last one to allow us to
618 # name the final `logical_and` op the same name that is passed by the user
619 # to the `is_initialized` op. For distributed variables, the
620 # `is_initialized` op is a `logical_and` op.
621 for v in self._values[1:-1]:
622 result = math_ops.logical_and(result, v.is_initialized())
623 result = math_ops.logical_and(
624 result, self._values[-1].is_initialized(), name=name)
625 return result
627 @property
628 def initializer(self):
629 if values_util.is_saving_non_distributed():
630 return self._primary.initializer
631 if self._initializer_op:
632 init_op = self._initializer_op
633 else:
634 # return grouped ops of all the var initializations of component values of
635 # the mirrored variable
636 init_op = control_flow_ops.group(
637 tuple(v.initializer for v in self._values))
638 return init_op
640 def initialized_value(self):
641 return self._get_on_device_or_primary().initialized_value()
643 def _is_mirrored(self):
644 return (self._policy is not None) and (self._policy._is_mirrored()) # pylint: disable=protected-access
646 @property
647 def initial_value(self):
648 return self._get_on_device_or_primary().initial_value
650 @property
651 def constraint(self):
652 return self._primary.constraint
654 @property
655 def graph(self):
656 return self._primary.graph
658 @property
659 def _shared_name(self):
660 return self._common_name
662 @property
663 def _unique_id(self):
664 return self._primary._unique_id # pylint: disable=protected-access
666 @property
667 def _graph_key(self):
668 """Lets Optimizers know which graph this variable is from."""
669 return self._primary._graph_key # pylint: disable=protected-access
671 @property
672 def name(self):
673 return self._primary.name
675 @property
676 def dtype(self):
677 return self._primary.dtype
679 @property
680 def shape(self):
681 return self._primary.shape
683 @property
684 def synchronization(self):
685 return self._primary.synchronization
687 @property
688 def aggregation(self):
689 return self._aggregation
691 @property
692 def _packed_variable(self):
693 if self._use_packed_variable():
694 return self._packed_var
695 return None
697 @property
698 def handle(self):
699 if values_util.is_saving_non_distributed():
700 return self._primary.handle
701 replica_id = values_util.get_current_replica_id_as_int()
702 if replica_id is None:
703 raise ValueError(
704 "DistributedVariable.handle is not available outside the replica "
705 "context or a `tf.distribute.Strategy.update()` call.")
706 else:
707 if self._use_packed_variable():
708 return self._packed_var.handle
709 return self._values[replica_id].handle
711 def eval(self, session=None):
712 return self._get_on_device_or_primary().eval(session)
714 @property
715 def _save_slice_info(self):
716 return self._primary._save_slice_info # pylint: disable=protected-access
718 def _get_save_slice_info(self):
719 return self._primary._get_save_slice_info() # pylint: disable=protected-access
721 def _set_save_slice_info(self, save_slice_info):
722 for v in self._values:
723 v._set_save_slice_info(save_slice_info) # pylint: disable=protected-access
725 @property
726 def device(self):
727 return self._get_on_device_or_primary().device
729 @property
730 def trainable(self):
731 return self._primary.trainable
733 @property
734 def distribute_strategy(self):
735 return self._distribute_strategy
737 def get_shape(self):
738 return self._primary.get_shape()
740 def to_proto(self, export_scope=None):
741 return self._primary.to_proto(export_scope=export_scope)
743 @property
744 def op(self):
745 if values_util.is_saving_non_distributed():
746 return self._primary.op
747 # We want cross-replica code that does some var.op.X calls
748 # to work (even if the current device isn't in self._devices), but
749 # other uses of var.op in a cross-replica context to fail.
750 if distribute_lib.in_cross_replica_context():
751 return DistributedVarOp(self._primary.op.name, self._primary.op.graph,
752 self._primary.op.traceback, self._primary.op.type)
753 return self._get().op
755 @property
756 def _in_graph_mode(self):
757 return self._primary._in_graph_mode # pylint: disable=protected-access
759 def _get_replica(self, replica_id):
760 """Returns the value on a device with the given replica_id."""
761 value = self._values[replica_id]
762 if self._use_packed_variable():
763 return self._packed_var.on_device(value.device)
764 else:
765 return value
767 def _get(self):
768 """Returns the value for the current device or raises a ValueError."""
769 if values_util.is_saving_non_distributed():
770 return self._primary
771 replica_id = values_util.get_current_replica_id_as_int()
772 if replica_id is None:
773 return self._get_cross_replica()
774 else:
775 return self._get_replica(replica_id)
777 def _get_on_device_or_primary(self):
778 """Returns value in same replica or device if possible, else the _primary."""
779 if values_util.is_saving_non_distributed():
780 return self._primary
781 replica_id = values_util.get_current_replica_id_as_int()
782 if replica_id is None:
783 # Try to find a value on the current device.
784 current_device = device_util.canonicalize(device_util.current())
785 for i, value in enumerate(self._values):
786 if device_util.canonicalize(value.device) == current_device:
787 return self._get_replica(i)
788 return self._get_replica(0)
789 else:
790 return self._get_replica(replica_id)
792 def read_value(self):
793 if values_util.is_saving_non_distributed():
794 return self._primary.read_value()
795 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
796 return array_ops.identity(self._get())
798 def value(self):
799 if values_util.is_saving_non_distributed():
800 return self._primary.value()
801 if self._policy:
802 return self._policy.value(self)
803 return self._get_on_device_or_primary().value()
805 def numpy(self):
806 if context.executing_eagerly():
807 return self.read_value().numpy()
808 else:
809 raise NotImplementedError("DistributedVariable.numpy() is only available "
810 "when eager execution is enabled.")
812 def assign_sub(self, value, use_locking=False, name=None, read_value=True):
813 if values_util.is_saving_non_distributed():
814 return self._primary.assign_sub(value, use_locking, name, read_value)
815 if self._policy:
816 return self._policy.assign_sub(
817 self,
818 value,
819 use_locking=use_locking,
820 name=name,
821 read_value=read_value)
822 return values_util.on_write_assign_sub(
823 self, value, use_locking=use_locking, name=name, read_value=read_value)
825 def assign_add(self, value, use_locking=False, name=None, read_value=True):
826 if values_util.is_saving_non_distributed():
827 return self._primary.assign_add(value, use_locking, name, read_value)
828 if self._policy:
829 return self._policy.assign_add(
830 self,
831 value,
832 use_locking=use_locking,
833 name=name,
834 read_value=read_value)
835 return values_util.on_write_assign_add(
836 self, value, use_locking=use_locking, name=name, read_value=read_value)
838 def assign(self, value, use_locking=False, name=None, read_value=True):
839 if values_util.is_saving_non_distributed():
840 return self._primary.assign(value, use_locking, name, read_value)
841 if self._policy:
842 return self._policy.assign(
843 self,
844 value,
845 use_locking=use_locking,
846 name=name,
847 read_value=read_value)
848 return values_util.on_write_assign(
849 self, value, use_locking=use_locking, name=name, read_value=read_value)
851 def scatter_sub(self, sparse_delta, use_locking=False, name=None):
852 if values_util.is_saving_non_distributed():
853 return self._primary.scatter_sub(sparse_delta, use_locking, name)
854 if self._policy:
855 return self._policy.scatter_sub(
856 self, sparse_delta, use_locking=use_locking, name=name)
857 return values_util.scatter_sub(
858 self, sparse_delta, use_locking=use_locking, name=name)
860 def scatter_add(self, sparse_delta, use_locking=False, name=None):
861 if values_util.is_saving_non_distributed():
862 return self._primary.scatter_add(sparse_delta, use_locking, name)
863 if self._policy:
864 return self._policy.scatter_add(
865 self, sparse_delta, use_locking=use_locking, name=name)
866 return values_util.scatter_add(
867 self, sparse_delta, use_locking=use_locking, name=name)
869 def scatter_mul(self, sparse_delta, use_locking=False, name=None):
870 if values_util.is_saving_non_distributed():
871 return self._primary.scatter_mul(sparse_delta, use_locking, name)
872 if self._policy:
873 return self._policy.scatter_mul(
874 self, sparse_delta, use_locking=use_locking, name=name)
875 return values_util.scatter_mul(
876 self, sparse_delta, use_locking=use_locking, name=name)
878 def scatter_div(self, sparse_delta, use_locking=False, name=None):
879 if values_util.is_saving_non_distributed():
880 return self._primary.scatter_div(sparse_delta, use_locking, name)
881 if self._policy:
882 return self._policy.scatter_div(
883 self, sparse_delta, use_locking=use_locking, name=name)
884 return values_util.scatter_div(
885 self, sparse_delta, use_locking=use_locking, name=name)
887 def scatter_min(self, sparse_delta, use_locking=False, name=None):
888 if values_util.is_saving_non_distributed():
889 return self._primary.scatter_min(sparse_delta, use_locking, name)
890 if self._policy:
891 return self._policy.scatter_min(
892 self, sparse_delta, use_locking=use_locking, name=name)
893 return values_util.scatter_min(
894 self, sparse_delta, use_locking=use_locking, name=name)
896 def scatter_max(self, sparse_delta, use_locking=False, name=None):
897 if values_util.is_saving_non_distributed():
898 return self._primary.scatter_max(sparse_delta, use_locking, name)
899 if self._policy:
900 return self._policy.scatter_max(
901 self, sparse_delta, use_locking=use_locking, name=name)
902 return values_util.scatter_max(
903 self, sparse_delta, use_locking=use_locking, name=name)
905 def scatter_update(self, sparse_delta, use_locking=False, name=None):
906 if values_util.is_saving_non_distributed():
907 return self._primary.scatter_update(sparse_delta, use_locking, name)
908 if self._policy:
909 return self._policy.scatter_update(
910 self, sparse_delta, use_locking=use_locking, name=name)
911 return values_util.scatter_update(
912 self, sparse_delta, use_locking=use_locking, name=name)
914 def __tf_tracing_type__(self, _):
915 return DistributedVariableTraceType(self)
917 def _gather_saveables_for_checkpoint(self):
918 """Overrides Trackable method.
920 This allows both name-based and object-based save and restore of
921 DistributedVariables.
923 Returns:
924 A dictionary mapping attribute names to `SaveableObject` factories.
925 """
927 def _saveable_factory(name=self._common_name):
928 return _DistributedVariableSaveable(self, self._primary, name)
930 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
932 def _as_graph_element(self):
933 if values_util.is_saving_non_distributed():
934 return self._primary._as_graph_element() # pylint: disable=protected-access
935 if self._policy:
936 return self._policy._as_graph_element(self) # pylint: disable=protected-access
938 raise NotImplementedError(
939 "DistributedVariable._as_graph_element requires a valid "
940 "VariablePolicy. Please set the policy via the `var_policy` argument "
941 "in the constructor, or override this method in sub-classes which "
942 "support cross-replica accesses.")
944 def _get_cross_replica(self):
945 if values_util.is_saving_non_distributed():
946 return self._primary
947 if self._policy:
948 return self._policy._get_cross_replica(self) # pylint: disable=protected-access
950 raise NotImplementedError(
951 "DistributedVariable._get_cross_replica requires a valid "
952 "VariablePolicy. Please set the policy via the `var_policy` argument "
953 "in the constructor, or override this method in sub-classes which "
954 "support cross-replica accesses.")
956 def _update_cross_replica(self, update_fn, value, **kwargs):
957 """Applies updates across replicas.
959 Args:
960 update_fn: A callable to pass to `strategy.extended.update` to update the
961 variable. It should has the same signature as `Variable.assign()`.
962 value: value to be passed to `update_fn`.
963 **kwargs: remaining arguments to `update_fn`.
965 Returns:
966 Updated variable or `tf.Operation`.
967 """
968 values_util.mark_as_unsaveable()
969 return self.distribute_strategy.extended.update(
970 self, update_fn, args=(value,), kwargs=kwargs, group=True)
972 def _update_replica(self, update_fn, value, **kwargs):
973 """Applies updates in one replica.
975 Args:
976 update_fn: A callable to update the variable. It should has the same
977 signature as `Variable.assign()`.
978 value: value to be passed to `update_fn`.
979 **kwargs: remaining arguments to `update_fn`.
981 Returns:
982 Updated variable or `tf.Operation`.
983 """
984 if self._policy:
985 return self._policy._update_replica(self, update_fn, value, **kwargs) # pylint: disable=protected-access
986 raise NotImplementedError(
987 "DistributedVariable._update_replica requires a valid VariablePolicy. "
988 "Please set the policy via the `var_policy` argument in the "
989 "constructor, or override this method in sub-classes which support "
990 "cross-replica accesses.")
992 def _update(self, update_fn, value, **kwargs):
993 """Applies updates depending on the context.
995 The method calls `_update_replica` in replica context,
996 `_update_cross_replica` in cross replica context, and `update_fn` in update
997 context.
999 If `read_value` is True, the method returns the updated Variable. If
1000 `read_value` is False, the method returns the update `tf.Operation`.
1002 Args:
1003 update_fn: A callable to pass to `strategy.extended.update` to update the
1004 variable. It should have the same signature as `Variable.assign()`.
1005 value: value to be passed to `update_fn`.
1006 **kwargs: keyword arguments to `update_fn`.
1008 Returns:
1009 Updated variable or `tf.Operation`.
1011 """
1012 if values_util.is_saving_non_distributed():
1013 return update_fn(self._primary, value, **kwargs)
1014 with distribute_lib.enter_or_assert_strategy(self.distribute_strategy):
1015 if distribute_lib.in_cross_replica_context():
1016 update_replica_id = distribute_lib.get_update_replica_id()
1017 if update_replica_id is not None:
1018 replica_value = self._get_replica(update_replica_id)
1019 return update_fn(replica_value, value, **kwargs)
1020 return self._update_cross_replica(update_fn, value, **kwargs)
1021 else:
1022 values_util.assert_replica_context(self.distribute_strategy)
1023 return self._update_replica(update_fn, value, **kwargs)
1025 def _should_act_as_resource_variable(self):
1026 """Pass resource_variable_ops.is_resource_variable check."""
1027 pass
1029 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1030 """Converts a variable to a tensor."""
1031 if values_util.is_saving_non_distributed():
1032 return ops.convert_to_tensor(
1033 self._primary, dtype=dtype, name=name, as_ref=as_ref)
1034 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
1035 return ops.convert_to_tensor(
1036 self._get(), dtype=dtype, name=name, as_ref=as_ref)
1038 def __tf_tensor__(self,
1039 dtype: Optional[dtypes.DType] = None,
1040 name: Optional[str] = None) -> ops.Tensor:
1041 return self._dense_var_to_tensor(dtype, name)
1043 def _export_to_saved_model_graph(self,
1044 object_map=None,
1045 tensor_map=None,
1046 options=None,
1047 **kwargs):
1048 # Initialize for self._primary first, so that obj_map[self._primary] and
1049 # resource_map[self._primary.handle] contain mapped values.
1050 resource_list = self._primary._export_to_saved_model_graph( # pylint:disable=protected-access
1051 object_map=object_map,
1052 tensor_map=tensor_map,
1053 options=options,
1054 **kwargs)
1055 for v in [v for v in self._values if v != self._primary]:
1056 if (options.experimental_variable_policy # pylint:disable=protected-access
1057 ._expand_distributed_variables()):
1058 resource_list.extend(
1059 v._export_to_saved_model_graph( # pylint:disable=protected-access
1060 object_map=object_map,
1061 tensor_map=tensor_map,
1062 options=options,
1063 **kwargs)) # pylint:disable=protected-access
1064 else:
1065 object_map[v] = object_map[self._primary]
1066 tensor_map[v.handle] = tensor_map[self._primary.handle]
1067 resource_list.append(v.handle)
1068 object_map[self] = object_map[self._primary]
1069 tensor_map[self] = tensor_map[self._primary.handle]
1070 resource_list.append(self)
1071 if self._packed_var is not None:
1072 tensor_map[self._packed_var.packed_handle] = tensor_map[
1073 self._primary.handle]
1074 resource_list.append(self._packed_var.packed_handle)
1075 return resource_list
1077 def _write_object_proto(self, proto, options):
1078 """Update a SavedObject proto for the caller.
1080 If a DistributedVariable object supports this method, it will be called when
1081 saving with a pre-built `SavedObject` proto representing the object, plus an
1082 instance of `SaveOptions`. This method is then free to modify that proto
1083 instance.
1085 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
1086 write out information about their components to the
1087 `experimental_distributed_variable_components` field of a
1088 `SavedVariable` (depending on the `SaveOptions` variable policy).
1090 Args:
1091 proto: A pre-built `SavedObject` proto for this object. It is assumed this
1092 will be a `SavedVariable` instance.
1093 options: A `SaveOptions` instance.
1094 """
1095 resource_variable_ops.write_object_proto_for_resource_variable(
1096 self, proto, options)
1097 if self._is_mirrored():
1098 values_util.write_object_proto(self, proto, options)
1100 @property
1101 def is_distributed_variable(self):
1102 return True
1104 def __tf_experimental_restore_capture__(
1105 self, concrete_function, internal_capture):
1106 graph = concrete_function.graph
1107 # Add given distributed variable to captures with given placeholder.
1108 graph.replace_capture(self, internal_capture)
1109 record.record_operation(
1110 "captured_value", [internal_capture], [self],
1111 backward_function=lambda x: [x],
1112 forward_function=lambda x: [x])
1113 return self
1116# We extend from `saveable_object.SaveableObject` instead of
1117# `saveable_object_util.ResourceVariableSaveable` since we need to read the
1118# value of ONREAD variables when saving. `SaveableObject` provides a way to
1119# specify the function to run to get the value of the variable or tensor at
1120# saving time. We can use this for both ON_READ and ON_WRITE variables.
1121# TODO(b/164586507): Consolidate ON_WRITE and ON_READ saving/restoring logic
1122# if possible.
1123class _DistributedVariableSaveable(saveable_object.SaveableObject):
1124 """Class for defining how to restore a DistributedVariable."""
1126 def __init__(self, distributed_variable, primary_variable, name):
1127 self._distributed_variable = distributed_variable
1128 if not self._distributed_variable._policy:
1129 raise ValueError(
1130 "The VariablePolicy of the argument `distributed_variable` must be "
1131 "set to create a _DistributedVariableSaveable. Please set it via "
1132 "the `var_policy` argument in the constructor of DistributedVariable."
1133 )
1134 tensor, spec = distributed_variable._policy.get_saveable(
1135 distributed_variable, primary_variable, name)
1136 super(_DistributedVariableSaveable, self).__init__(tensor, spec, name)
1138 def restore(self, restored_tensors, restored_shapes):
1139 """Restore the same value into all variables."""
1140 tensor, = restored_tensors
1141 return self._distributed_variable._policy.get_restore_ops( # pylint: disable=protected-access
1142 self._distributed_variable, tensor)
1145class _MirroredSaveable(saveable_object.SaveableObject):
1146 """Class for defining how to restore a MirroredVariable."""
1148 def __init__(self, mirrored_variable, primary_variable, name):
1149 self._mirrored_variable = mirrored_variable
1150 tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable,
1151 primary_variable, name)
1152 super(_MirroredSaveable, self).__init__(tensor, spec, name)
1154 def restore(self, restored_tensors, restored_shapes):
1155 """Restore the same value into all variables."""
1156 tensor, = restored_tensors
1157 return values_util.get_on_write_restore_ops(self._mirrored_variable, tensor)
1160class MirroredVariable(DistributedVariable, Mirrored):
1161 """Holds a map from replica to variables whose values are kept in sync."""
1163 def _is_mirrored(self):
1164 return Mirrored._is_mirrored(self) # Use correct parent class.
1166 def _update_replica(self, update_fn, value, **kwargs):
1167 return _on_write_update_replica(self, update_fn, value, **kwargs)
1169 def scatter_min(self, *args, **kwargs):
1170 if values_util.is_saving_non_distributed():
1171 return self._primary.scatter_min(*args, **kwargs)
1172 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1173 self._aggregation != vs.VariableAggregation.NONE):
1174 raise NotImplementedError(
1175 values_util.scatter_error_msg.format(
1176 op_name="scatter_min", aggregation=self._aggregation))
1177 return super(MirroredVariable, self).scatter_min(*args, **kwargs)
1179 def scatter_max(self, *args, **kwargs):
1180 if values_util.is_saving_non_distributed():
1181 return self._primary.scatter_max(*args, **kwargs)
1182 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1183 self._aggregation != vs.VariableAggregation.NONE):
1184 raise NotImplementedError(
1185 values_util.scatter_error_msg.format(
1186 op_name="scatter_max", aggregation=self._aggregation))
1187 return super(MirroredVariable, self).scatter_max(*args, **kwargs)
1189 def scatter_update(self, *args, **kwargs):
1190 if values_util.is_saving_non_distributed():
1191 return self._primary.scatter_update(*args, **kwargs)
1192 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1193 self._aggregation != vs.VariableAggregation.NONE):
1194 raise NotImplementedError(
1195 values_util.scatter_error_msg.format(
1196 op_name="scatter_update", aggregation=self._aggregation))
1197 return super(MirroredVariable, self).scatter_update(*args, **kwargs)
1199 def _get_cross_replica(self):
1200 # Return identity, to avoid directly exposing the variable to the user and
1201 # allowing it to be modified by mistake.
1202 return array_ops.identity(Mirrored._get_cross_replica(self))
1204 def _as_graph_element(self):
1205 return self._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access
1207 def _gather_saveables_for_checkpoint(self):
1208 """Overrides Trackable method.
1210 This allows both name-based and object-based save and restore of
1211 MirroredVariables.
1213 Returns:
1214 A dictionary mapping attribute names to `SaveableObject` factories.
1215 """
1217 def _saveable_factory(name=self._common_name):
1218 return _MirroredSaveable(self, self._primary, name)
1220 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
1222 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1223 """Converts a variable to a tensor."""
1224 # TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ
1225 # and ON_WRITE.
1226 # Try to avoid assignments to and other mutations of MirroredVariable
1227 # state except through a DistributionStrategy.extended.update() or any of
1228 # the `assign*` and `scatter*` calls.
1229 if as_ref:
1230 # A TF 1.x case where the variable is a boolean variable and used like:
1231 # tf.cond(v, true_fn, false_fn).
1232 raise ValueError(
1233 "You may be using variable created under distribute strategy in TF "
1234 "1.x control flows. Try explicitly converting the variable to Tensor "
1235 "using variable.read_value(), or switch to TF 2.x.")
1236 return ops.convert_to_tensor(
1237 self._get(), dtype=dtype, name=name, as_ref=as_ref)
1240class _SyncOnReadSaveable(saveable_object.SaveableObject):
1241 """Class for defining how to restore a SyncOnReadVariable."""
1243 def __init__(self, sync_on_read_variable, name):
1244 self._sync_on_read_variable = sync_on_read_variable
1245 tensor, spec = values_util.get_on_read_saveable(
1246 sync_on_read_variable, sync_on_read_variable._primary, name)
1248 super(_SyncOnReadSaveable, self).__init__(tensor, spec, name)
1250 def restore(self, restored_tensors, restored_shapes):
1251 """Restore the same value into all variables."""
1252 tensor, = restored_tensors
1253 return values_util.get_on_read_restore_ops(
1254 self._sync_on_read_variable, tensor,
1255 self._sync_on_read_variable.aggregation)
1258class SyncOnReadVariable(DistributedVariable):
1259 """Holds a map from replica to variables whose values are reduced on save."""
1261 def _update_replica(self, update_fn, value, **kwargs):
1262 return update_fn(self._get_on_device_or_primary(), value, **kwargs)
1264 def _get(self):
1265 """Returns the value of SyncOnReadVariable based on surrounding context.
1267 If called under a non-default replica-context, returns the corresponding
1268 variable on that replica.
1269 If called under default replica-context or cross-replica context, returns
1270 the synced value.
1271 """
1272 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
1273 return super(SyncOnReadVariable, self)._get()
1275 # TODO(b/154017756): Make assign behaivor in cross replica context consistent
1276 # with MirroredVariable.
1277 def assign_sub(self, value, use_locking=False, name=None, read_value=True):
1278 if values_util.is_saving_non_distributed():
1279 return self._primary.assign_sub(value, use_locking, name, read_value)
1280 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
1281 if (distribute_lib.in_cross_replica_context() and
1282 not values_util.in_replica_update_context()):
1283 values_util.mark_as_unsaveable()
1284 return values_util.on_read_assign_sub_cross_replica(
1285 self, value, read_value=read_value)
1286 else:
1287 return super(SyncOnReadVariable,
1288 self).assign_sub(value, use_locking, name, read_value)
1290 def assign_add(self, value, use_locking=False, name=None, read_value=True):
1291 if values_util.is_saving_non_distributed():
1292 return self._primary.assign_add(value, use_locking, name, read_value)
1293 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
1294 if (distribute_lib.in_cross_replica_context() and
1295 not values_util.in_replica_update_context()):
1296 values_util.mark_as_unsaveable()
1297 return values_util.on_read_assign_add_cross_replica(
1298 self, value, read_value=read_value)
1299 else:
1300 return super(SyncOnReadVariable,
1301 self).assign_add(value, use_locking, name, read_value)
1303 def assign(self, value, use_locking=False, name=None, read_value=True):
1304 if values_util.is_saving_non_distributed():
1305 return self._primary.assign(value, use_locking, name, read_value)
1306 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
1307 if (distribute_lib.in_cross_replica_context() and
1308 not values_util.in_replica_update_context()):
1309 values_util.mark_as_unsaveable()
1310 return values_util.on_read_assign_cross_replica(
1311 self, value, read_value=read_value)
1312 else:
1313 return super(SyncOnReadVariable, self).assign(value, use_locking, name,
1314 read_value)
1316 def _scatter_not_implemented(self, method):
1317 raise NotImplementedError(
1318 f"Variables with `synchronization=ON_READ` doesn't support `{method}`")
1320 def scatter_sub(self, *args, **kwargs):
1321 if values_util.is_saving_non_distributed():
1322 return self._primary.scatter_sub(*args, **kwargs)
1323 self._scatter_not_implemented("scatter_sub")
1325 def scatter_add(self, *args, **kwargs):
1326 if values_util.is_saving_non_distributed():
1327 return self._primary.scatter_add(*args, **kwargs)
1328 self._scatter_not_implemented("scatter_add")
1330 def scatter_mul(self, *args, **kwargs):
1331 if values_util.is_saving_non_distributed():
1332 return self._primary.scatter_mul(*args, **kwargs)
1333 self._scatter_not_implemented("scatter_mul")
1335 def scatter_div(self, *args, **kwargs):
1336 if values_util.is_saving_non_distributed():
1337 return self._primary.scatter_div(*args, **kwargs)
1338 self._scatter_not_implemented("scatter_div")
1340 def scatter_min(self, *args, **kwargs):
1341 if values_util.is_saving_non_distributed():
1342 return self._primary.scatter_min(*args, **kwargs)
1343 self._scatter_not_implemented("scatter_min")
1345 def scatter_max(self, *args, **kwargs):
1346 if values_util.is_saving_non_distributed():
1347 return self._primary.scatter_max(*args, **kwargs)
1348 self._scatter_not_implemented("scatter_max")
1350 def scatter_update(self, *args, **kwargs):
1351 if values_util.is_saving_non_distributed():
1352 return self._primary.scatter_update(*args, **kwargs)
1353 self._scatter_not_implemented("scatter_update")
1355 def value(self):
1356 if distribute_lib.in_variable_sync_on_read_context():
1357 raise NotImplementedError(
1358 "call `variable.value()` inside variable_sync_on_read_context is not "
1359 "supported")
1360 if values_util.is_saving_non_distributed():
1361 return self._primary.value()
1362 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
1363 if (distribute_lib.in_cross_replica_context() and
1364 not values_util.in_replica_update_context()):
1365 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1366 return self._get_replica(0).value()
1367 return self._get_cross_replica()
1368 else:
1369 # _get_on_device_or_primary() returns a Variable.
1370 return self._get_on_device_or_primary().value()
1372 def read_value(self):
1373 if distribute_lib.in_variable_sync_on_read_context():
1374 raise NotImplementedError(
1375 "call `variable.read_value()` inside variable_sync_on_read_context is"
1376 " not supported")
1377 return super().read_value()
1379 def _get_cross_replica(self):
1380 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1381 # Consider returning a tensor value here to make the return value of
1382 # _get_cross_replica consistent.
1383 return self._get_replica(0)
1384 if self._aggregation == vs.VariableAggregation.SUM:
1385 values_util.mark_as_unsaveable()
1386 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
1387 return self._distribute_strategy.reduce(
1388 reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
1389 self,
1390 axis=None)
1392 def _as_graph_element(self):
1393 if values_util.is_saving_non_distributed():
1394 return self._primary._as_graph_element() # pylint: disable=protected-access
1395 # pylint: disable=protected-access
1396 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
1397 if distribute_lib.in_cross_replica_context():
1398 return ops.convert_to_tensor(self._get_cross_replica())
1399 return self._get()._as_graph_element()
1401 def _gather_saveables_for_checkpoint(self):
1402 """Overrides Trackable method.
1404 This allows both name-based and object-based save and restore of
1405 `SyncOnReadVariable`s.
1407 Returns:
1408 A dictionary mapping attribute names to `SaveableObject` factories.
1409 """
1411 def _saveable_factory(name=self._common_name):
1412 return _SyncOnReadSaveable(self, name)
1414 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
1416 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
1417 """Converts a SyncOnReadVariable to a tensor."""
1418 if values_util.is_saving_non_distributed():
1419 return ops.convert_to_tensor(
1420 self._primary, dtype=dtype, name=name, as_ref=as_ref)
1421 with distribute_lib.enter_or_assert_strategy(self._distribute_strategy):
1422 replica_context = distribute_lib.get_replica_context()
1423 if (replica_context is not None and
1424 distribute_lib.in_variable_sync_on_read_context()):
1425 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1426 return ops.convert_to_tensor(
1427 self._get_replica(0), dtype=dtype, name=name, as_ref=as_ref)
1428 if self._aggregation == vs.VariableAggregation.SUM:
1429 values_util.mark_as_unsaveable()
1430 # pylint: disable=protected-access
1431 reduced = (
1432 replica_context.strategy.extended._replica_ctx_all_reduce(
1433 reduce_util.ReduceOp.from_variable_aggregation(
1434 self._aggregation),
1435 self._get().read_value()))
1436 return ops.convert_to_tensor(
1437 reduced, dtype=dtype, name=name, as_ref=as_ref)
1439 return ops.convert_to_tensor(
1440 self._get(), dtype=dtype, name=name, as_ref=as_ref)
1443# Register a conversion functions which reads the value of the variable,
1444# allowing instances of the class to be used as tensors.
1445# DistributedVariable
1446def _tensor_conversion_distributed_var(var,
1447 dtype=None,
1448 name=None,
1449 as_ref=False):
1450 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
1453tensor_conversion_registry.register_tensor_conversion_function(
1454 DistributedVariable, _tensor_conversion_distributed_var)
1457# MirroredVariables
1458def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
1459 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
1462tensor_conversion_registry.register_tensor_conversion_function(
1463 MirroredVariable, _tensor_conversion_mirrored)
1466# Mirrored Values
1467def _tensor_conversion_mirrored_val(value, dtype=None, name=None, as_ref=False):
1468 return ops.convert_to_tensor(
1469 value._get(), dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
1472tensor_conversion_registry.register_tensor_conversion_function(
1473 Mirrored, _tensor_conversion_mirrored_val)
1476# SyncOnReadVariables
1477def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False):
1478 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
1481tensor_conversion_registry.register_tensor_conversion_function(
1482 SyncOnReadVariable, _tensor_conversion_sync_on_read)
1485class VariablePolicy(object):
1486 """Policy defining synchronization and aggregation of a distributed variable.
1488 Given `synchronization` and `aggregation` parameters set on a `tf.Variable`
1489 during variable creation within `tf.distribute` scope, `tf.distribute` creates
1490 an appropriate policy object and assigns it to the distributed variable. All
1491 variable operations are delegated to the respective policy object.
1492 """
1494 def __init__(self, aggregation):
1495 self._aggregation = aggregation
1497 def value(self):
1498 raise NotImplementedError(
1499 "VariablePolicy.value should be overriden by sub-classes.")
1501 def _is_mirrored(self):
1502 raise NotImplementedError(
1503 "VariablePolicy._is_mirrored should be overriden by sub-classes.")
1505 def _as_graph_element(self, _):
1506 raise NotImplementedError(
1507 "VariablePolicy._as_graph_element should be overriden by sub-classes.")
1509 def _get_cross_replica(self, var):
1510 raise NotImplementedError(
1511 "VariablePolicy._get_cross_replica should be overriden by sub-classes.")
1513 def _update_replica(self, var, update_fn, value, **kwargs):
1514 raise NotImplementedError(
1515 "VariablePolicy._update_replica should be overriden by sub-classes.")
1518class OnReadPolicy(VariablePolicy):
1519 """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
1521 This policy is created when `synchronization` is set to
1522 `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
1523 values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
1524 `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
1525 scope.
1526 """
1528 def _is_mirrored(self):
1529 return False
1531 def value(self, var):
1532 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
1533 if (distribute_lib.in_cross_replica_context() and
1534 not values_util.in_replica_update_context()):
1535 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1536 return var._get_replica(0).value() # pylint: disable=protected-access
1537 return var._get_cross_replica() # pylint: disable=protected-access
1538 else:
1539 return var._get_on_device_or_primary().value() # pylint: disable=protected-access
1541 def _as_graph_element(self, var):
1542 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
1543 if distribute_lib.in_cross_replica_context():
1544 return ops.convert_to_tensor(var._get_cross_replica()) # pylint: disable=protected-access
1545 return var._get()._as_graph_element() # pylint: disable=protected-access
1547 def _get_cross_replica(self, var):
1548 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
1549 return var._get_replica(0) # pylint: disable=protected-access
1550 if self._aggregation == vs.VariableAggregation.SUM:
1551 values_util.mark_as_unsaveable()
1552 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
1553 return var.distribute_strategy.reduce(
1554 reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
1555 var,
1556 axis=None)
1558 def _update_replica(self, var, update_fn, value, **kwargs):
1559 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access
1561 def _scatter_not_implemented(self, method):
1562 raise NotImplementedError(f"ON_READ variables doesn't support `{method}` "
1563 "in cross replica context")
1565 def assign_sub(self,
1566 var,
1567 value,
1568 use_locking=False,
1569 name=None,
1570 read_value=True):
1571 """Subtracts a value from this variable."""
1572 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
1573 if (distribute_lib.in_cross_replica_context() and
1574 not values_util.in_replica_update_context()):
1575 values_util.mark_as_unsaveable()
1576 return values_util.on_read_assign_sub_cross_replica(
1577 var, value, read_value=read_value)
1578 else:
1579 return values_util.on_write_assign_sub(
1580 var,
1581 value,
1582 use_locking=use_locking,
1583 name=name,
1584 read_value=read_value)
1586 def assign_add(self,
1587 var,
1588 value,
1589 use_locking=False,
1590 name=None,
1591 read_value=True):
1592 """Adds a value to this variable."""
1593 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
1594 if (distribute_lib.in_cross_replica_context() and
1595 not values_util.in_replica_update_context()):
1596 values_util.mark_as_unsaveable()
1597 return values_util.on_read_assign_add_cross_replica(
1598 var, value, read_value=read_value)
1599 else:
1600 return values_util.on_write_assign_add(
1601 var,
1602 value,
1603 use_locking=use_locking,
1604 name=name,
1605 read_value=read_value)
1607 def assign(self, var, value, use_locking=False, name=None, read_value=True):
1608 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
1609 if (distribute_lib.in_cross_replica_context() and
1610 not values_util.in_replica_update_context()):
1611 values_util.mark_as_unsaveable()
1612 return values_util.on_read_assign_cross_replica(
1613 var, value, read_value=read_value)
1614 else:
1615 return values_util.on_write_assign(
1616 var,
1617 value,
1618 use_locking=use_locking,
1619 name=name,
1620 read_value=read_value)
1622 def scatter_sub(self, *args, **kwargs):
1623 del args, kwargs
1624 self._scatter_not_implemented("scatter_sub")
1626 def scatter_add(self, *args, **kwargs):
1627 del args, kwargs
1628 self._scatter_not_implemented("scatter_add")
1630 def scatter_mul(self, *args, **kwargs):
1631 del args, kwargs
1632 self._scatter_not_implemented("scatter_mul")
1634 def scatter_div(self, *args, **kwargs):
1635 del args, kwargs
1636 self._scatter_not_implemented("scatter_div")
1638 def scatter_min(self, *args, **kwargs):
1639 del args, kwargs
1640 self._scatter_not_implemented("scatter_min")
1642 def scatter_max(self, *args, **kwargs):
1643 del args, kwargs
1644 self._scatter_not_implemented("scatter_max")
1646 def scatter_update(self, *args, **kwargs):
1647 del args, kwargs
1648 self._scatter_not_implemented("scatter_update")
1650 def get_saveable(self, var, primary_var, name):
1651 """Create a saveable object for the given variable."""
1652 return values_util.get_on_read_saveable(var, primary_var, name)
1654 def get_restore_ops(self, var, tensor):
1655 """Restore the same value into all variables."""
1656 return values_util.get_on_read_restore_ops(var, tensor, self._aggregation)
1659class OnWritePolicy(VariablePolicy):
1660 """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
1662 This policy is created when the following `synchronization` and `aggregation`
1663 parameters are specified when creating a `tf.Variable` in `tf.distribute`
1664 scope and `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE`
1665 or `tf.VariableSynchronization.AUTO`.
1666 """
1668 def _is_mirrored(self):
1669 return True
1671 def value(self, var):
1672 return var._get_on_device_or_primary().value() # pylint: disable=protected-access
1674 def _as_graph_element(self, var):
1675 return var._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access
1677 def _get_cross_replica(self, var):
1678 # Return identity, to avoid directly exposing the variable to the user and
1679 # allowing it to be modified by mistake.
1680 return array_ops.identity(var._get_on_device_or_primary()) # pylint: disable=protected-access
1682 def _update_replica(self, var, update_fn, value, **kwargs):
1683 if var.aggregation == variables_lib.VariableAggregation.NONE:
1684 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access
1685 return _on_write_update_replica(var, update_fn, value, **kwargs)
1687 def assign(self, var, value, use_locking=False, name=None, read_value=True):
1688 return values_util.on_write_assign(
1689 var, value, use_locking=use_locking, name=name, read_value=read_value)
1691 def assign_add(self,
1692 var,
1693 value,
1694 use_locking=False,
1695 name=None,
1696 read_value=True):
1697 return values_util.on_write_assign_add(
1698 var, value, use_locking=use_locking, name=name, read_value=read_value)
1700 def assign_sub(self,
1701 var,
1702 value,
1703 use_locking=False,
1704 name=None,
1705 read_value=True):
1706 return values_util.on_write_assign_sub(
1707 var, value, use_locking=use_locking, name=name, read_value=read_value)
1709 def scatter_sub(self, var, sparse_delta, use_locking=False, name=None):
1710 return values_util.scatter_sub(
1711 var, sparse_delta, use_locking=use_locking, name=name)
1713 def scatter_add(self, var, sparse_delta, use_locking=False, name=None):
1714 return values_util.scatter_add(
1715 var, sparse_delta, use_locking=use_locking, name=name)
1717 def scatter_mul(self, var, sparse_delta, use_locking=False, name=None):
1718 return values_util.scatter_mul(
1719 var, sparse_delta, use_locking=use_locking, name=name)
1721 def scatter_div(self, var, sparse_delta, use_locking=False, name=None):
1722 return values_util.scatter_div(
1723 var, sparse_delta, use_locking=use_locking, name=name)
1725 def scatter_min(self, var, sparse_delta, use_locking=False, name=None):
1726 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1727 self._aggregation != vs.VariableAggregation.NONE):
1728 raise NotImplementedError(
1729 values_util.scatter_error_msg.format(
1730 op_name="scatter_min", aggregation=self._aggregation))
1731 return values_util.scatter_min(
1732 var, sparse_delta, use_locking=use_locking, name=name)
1734 def scatter_max(self, var, sparse_delta, use_locking=False, name=None):
1735 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1736 self._aggregation != vs.VariableAggregation.NONE):
1737 raise NotImplementedError(
1738 values_util.scatter_error_msg.format(
1739 op_name="scatter_max", aggregation=self._aggregation))
1740 return values_util.scatter_max(
1741 var, sparse_delta, use_locking=use_locking, name=name)
1743 def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
1744 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and
1745 self._aggregation != vs.VariableAggregation.NONE):
1746 raise NotImplementedError(
1747 values_util.scatter_error_msg.format(
1748 op_name="scatter_update", aggregation=self._aggregation))
1749 return values_util.scatter_update(
1750 var, sparse_delta, use_locking=use_locking, name=name)
1752 def get_saveable(self, var, primary_var, name):
1753 """Saveable ops for AUTO variables."""
1754 return values_util.get_on_write_saveable(var, primary_var, name)
1756 def get_restore_ops(self, var, tensor):
1757 return values_util.get_on_write_restore_ops(var, tensor)
1760class PerWorkerResource():
1761 """A per-worker CapturableResource class for non-ParameterServer strategy.
1763 Resources that populate `host_to_resources` should be instances of classes
1764 subclassing CapturableResource, although currently it's only used and tested
1765 for StaticHashTable with TPUStrategy.
1766 """
1768 def __init__(self, strategy, host_to_resources):
1769 distribute_lib.distribution_strategy_input_api_counter.get_cell(
1770 "PerWorkerResource", "TPUDistributedLookupTable").increase_by(1)
1771 self._strategy = strategy
1772 self._host_to_resources = host_to_resources
1774 def __getattribute__(self, name):
1775 if name not in ("__init__", "__getattribute__", "_host_to_resources",
1776 "_strategy", "local_resource"):
1777 return getattr(self.local_resource(), name)
1778 return super(PerWorkerResource, self).__getattribute__(name)
1780 def __setattr__(self, name, value):
1781 if name not in ("_strategy", "_host_to_resources"):
1782 return setattr(self.local_resource(), name, value)
1783 return super(PerWorkerResource, self).__setattr__(name, value)
1785 def local_resource(self):
1786 """Returns the resource on the local worker."""
1787 current_device = device_util.canonicalize(device_util.current())
1788 host_device = device_util.canonicalize(
1789 device_util.get_host_for_device(current_device))
1790 return self._host_to_resources.get(
1791 host_device,
1792 self._host_to_resources[next(iter(self._host_to_resources))])