Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_utils.py: 27%
199 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class implementing utilities used by tf.distribute.Strategy."""
17from collections import abc
18import contextlib
19import threading
21from tensorflow.python.distribute import distribute_lib
22from tensorflow.python.distribute import tpu_values as tpu_values_lib
23from tensorflow.python.distribute import values as values_lib
24from tensorflow.python.distribute.reduce_util import ReduceOp
25from tensorflow.python.eager import context
26from tensorflow.python.eager import record
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import resource_variable_ops
33from tensorflow.python.ops import variable_scope as vs
34from tensorflow.python.ops.losses import losses_impl
35from tensorflow.python.util import nest
36from tensorflow.python.util.tf_export import tf_export
39@tf_export(v1=["distribute.get_loss_reduction"])
40def get_loss_reduction():
41 """`tf.distribute.ReduceOp` corresponding to the last loss reduction.
43 This is used to decide whether loss should be scaled in optimizer (used only
44 for estimator + v1 optimizer use case).
46 Returns:
47 `tf.distribute.ReduceOp` corresponding to the last loss reduction for
48 estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise.
49 """
50 if not distribute_lib.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access
51 # If we are not in Estimator context then return 'SUM'. We do not need to
52 # scale loss in the optimizer.
53 return ReduceOp.SUM
54 last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
55 if (last_reduction == losses_impl.Reduction.SUM or
56 last_reduction == "sum"): # Check for tf.keras.losses.Reduction.SUM
57 return ReduceOp.SUM
58 return ReduceOp.MEAN
61def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
62 """Makes a nest per-replica into a nest of PerReplica/Mirrored values.
64 Args:
65 values: Values to regroup
66 wrap_class: Class that `values` be wrapped in.
67 always_wrap: Always wrap the `values` in `wrap_class` even if the values
68 are the same except for DistributeVariable.
69 Returns:
70 Wrapped `values`.
71 """
72 v0 = values[0]
74 if isinstance(v0, list):
75 for v in values[1:]:
76 assert isinstance(v, list)
77 assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" %
78 (len(v), len(v0), v, v0))
79 return [
80 regroup(tuple(v[i] for v in values), wrap_class, always_wrap)
81 for i in range(len(v0))
82 ]
84 if isinstance(v0, tuple):
85 for v in values[1:]:
86 assert isinstance(v, tuple)
87 assert len(v) == len(v0), ("Values to regroup had different lengths: "
88 f"len(v) == {len(v)}, len(v0) == {len(v0)}, "
89 f"v: {v}, v0: {v0}")
90 regrouped_tuple = tuple(
91 regroup(tuple(v[i] for v in values), wrap_class, always_wrap)
92 for i in range(len(v0)))
93 if hasattr(v0, "_fields"):
94 # This tuple is in fact a namedtuple! Create a new namedtuple instance
95 # and initialize it with the regrouped values:
96 assert hasattr(v0, "_make")
97 return v0._make(regrouped_tuple)
98 else:
99 return regrouped_tuple
101 if isinstance(v0, abc.Mapping):
102 v0keys = v0.keys()
103 for v in values[1:]:
104 assert isinstance(v, abc.Mapping), ("v[0]: %r v[i]: %r" % (v0, v))
105 assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" %
106 (set(v0keys), set(v.keys())))
107 # Use the actual type in case it is a class inherited from a dict.
108 return type(v0)({
109 key: regroup(tuple(v[key] for v in values),
110 wrap_class, always_wrap)
111 for key in v0keys
112 })
114 # If exactly the same object across all devices, return it unwrapped.
115 same_id = True
116 for v in values[1:]:
117 if v is not v0:
118 same_id = False
119 break
120 # Consider three cases where same_id is true:
121 # * If v0 is a DistributedVariable (a MirroredVariable or
122 # SyncOnReadVariable, and same_id means it is the same across all
123 # devices), we want to return it. We check DistributedVariable
124 # specifically since it can look like it has a
125 # _distributed_container member since its members do.
126 if same_id and isinstance(v0, values_lib.DistributedVariable):
127 return v0
128 # * If v0 is a member of a distributed variable, in which case
129 # value_container(v0) is not v0 itself, we want to
130 # return the DistributedVariable that contains it using the
131 # _distributed_container logic below. This case can trigger
132 # same_id when there is only one device.
133 # * In any other situation, same_id means we return v0 unless `always_wrap` is
134 # true.
135 if same_id and not always_wrap and value_container(v0) is v0:
136 return v0
138 # Detect the case where each device has a parallel component of the
139 # same MirroredVariable (or SyncOnReadVariable). In this case we
140 # want to return the containing MirroredVariable, after a bunch of
141 # sanity checking. In particular, each component should have the
142 # same container, and the devices of the variables should match the
143 # keys of the per-replica dictionary. For _UnreadVariables, use the wrap_class
144 # path, which calls tf.identity on them.
145 if (not isinstance(v0, resource_variable_ops._UnreadVariable) and # pylint: disable=protected-access
146 value_container(v0) is not v0):
147 # pylint: disable=protected-access
148 assert not isinstance(v0, values_lib.MirroredVariable), (
149 "ids = %s, values = %s" % ([id(v) for v in values], values))
150 distributed_container = value_container(v0)
151 assert distributed_container is not None
152 for v in values[1:]:
153 assert distributed_container is value_container(v)
154 return distributed_container
155 # pylint: enable=protected-access
157 return wrap_class(values)
160def select_replica(replica_id, structured):
161 """Specialize a nest of regular & per-replica values for one replica."""
163 def _get(x):
164 # `DistributedValues` would be sliced according to replica unless it is a
165 # `DistributedVariable` because `DistributedVariable` can be handled
166 # directly in the replica context.
167 if (isinstance(x, values_lib.DistributedVariable) or
168 not isinstance(x, values_lib.DistributedValues)):
169 return x
170 else:
171 return x.values[replica_id]
173 return nest.map_structure(_get, structured)
176def select_replica_mirrored(replica_id, structured):
177 """Specialize a nest of regular & mirrored values for one replica."""
178 assert_mirrored(structured)
179 return select_replica(replica_id, structured)
182def assert_mirrored(structured):
183 """Raises if the structured is not composed of mirrored or regular values."""
185 def _assert_mirrored(x):
186 if isinstance(x, values_lib.DistributedValues) and not is_mirrored(x):
187 raise TypeError(
188 "Expected value to be mirrored across replicas: %s in %s." %
189 (x, structured))
191 nest.map_structure(_assert_mirrored, structured)
194def update_regroup(extended, updates, group):
195 """Regroup for an update, with dependencies to ensure all updates execute."""
196 if not group:
197 regrouped = regroup(updates, values_lib.Mirrored)
198 return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access
200 def _make_grouped_mirrored(values):
201 """Convert per-replica list `values` into Mirrored type with grouping."""
202 if len(values) == 1:
203 return values_lib.Mirrored(values)
205 # Make sure we run all updates. Without this, something like
206 # session.run(extended.update(...)) may only update one replica.
207 g = control_flow_ops.group(values)
209 # If values is just ops, the grouping is enough. Everything in values
210 # should have the same type, since we expect every replica to be performing
211 # the same computation.
212 if not all(tensor_util.is_tf_type(v) for v in values):
213 return g
215 # Otherwise we need tensors with the same values as `values`, but
216 # that have a dependency on `g`.
217 with_dep = []
218 for v in values:
219 with ops.device(v.device), ops.control_dependencies([g]):
220 with_dep.append(array_ops.identity(v))
222 return values_lib.Mirrored(with_dep)
224 return regroup(updates, _make_grouped_mirrored)
227def value_container(val):
228 """Returns the container that this per-replica `value` belongs to.
230 Args:
231 val: A value returned by `call_for_each_replica()` or a variable created in
232 `scope()`.
234 Returns:
235 A container that `value` belongs to.
236 If value does not belong to any container (including the case of
237 container having been destroyed), returns the value itself.
238 """
239 # DistributedVariable has _distributed_container defined but we don't want to
240 # return it.
241 container = None
242 if not isinstance(val, values_lib.DistributedVariable):
243 if hasattr(val, "_distributed_container"):
244 container = val._distributed_container() # pylint: disable=protected-access
245 elif (isinstance(val, composite_tensor.CompositeTensor) and
246 hasattr(val, "handle") and
247 hasattr(val.handle, "_distributed_container")):
248 # For ResourceVariables, the _distributed_container attribute
249 # is added to their handle tensors.
250 container = val.handle._distributed_container() # pylint: disable=protected-access
251 return container if container is not None else val
254def is_distributed_variable(v):
255 """Determine if a variable is ds variable or TPU mirrored variable."""
256 return getattr(v, "is_distributed_variable", False)
259def is_distributed_table(v):
260 """Determine if an object is a DistributedTable."""
261 return getattr(v, "is_distributed_table", False)
264def _validate_colocate_extended(v, extended):
265 variable_strategy = v._distribute_strategy # pylint: disable=protected-access
266 if variable_strategy.extended is not extended:
267 raise ValueError(
268 "`colocate_vars_with` must only be passed a variable created in this "
269 "tf.distribute.Strategy.scope(), not %s created in scope: %s" %
270 (v, variable_strategy))
273def validate_colocate_distributed_variable(v, extended):
274 if not isinstance(v, values_lib.DistributedVariable):
275 raise ValueError(
276 "`colocate_vars_with` must only be passed a variable created in this "
277 "tf.distribute.Strategy.scope(), not: %r" % (v,))
278 _validate_colocate_extended(v, extended)
281def validate_colocate(v, extended):
282 if not hasattr(v, "_distribute_strategy"):
283 raise ValueError(
284 "`colocate_vars_with` must only be passed a variable created in this "
285 "tf.distribute.Strategy.scope(), not: %r" % (v,))
286 _validate_colocate_extended(v, extended)
289# Variable creation function for sync strategies.
290def _validate_synchronization(kwargs):
291 """Validate that given synchronization value is valid."""
292 synchronization = kwargs.get("synchronization",
293 vs.VariableSynchronization.AUTO)
294 if synchronization == vs.VariableSynchronization.NONE:
295 raise ValueError(
296 "`NONE` variable synchronization mode is not supported with "
297 "tf.distribute strategy. Please change the `synchronization` for "
298 "variable: " + str(kwargs["name"]))
299 if synchronization not in (vs.VariableSynchronization.ON_READ,
300 vs.VariableSynchronization.ON_WRITE,
301 vs.VariableSynchronization.AUTO):
302 raise ValueError(
303 "Invalid variable synchronization mode: %s for variable: %s" %
304 (synchronization, kwargs["name"]))
305 if synchronization == vs.VariableSynchronization.AUTO:
306 return vs.VariableSynchronization.ON_WRITE
307 return synchronization
310def _validate_aggregation(kwargs):
311 aggregation = kwargs.get("aggregation", vs.VariableAggregation.NONE)
313 if aggregation not in (vs.VariableAggregation.NONE,
314 vs.VariableAggregation.SUM,
315 vs.VariableAggregation.MEAN,
316 vs.VariableAggregation.ONLY_FIRST_REPLICA):
317 raise ValueError("Invalid variable aggregation mode: %s for variable: %s" %
318 (aggregation, kwargs["name"]))
319 return aggregation
322def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
323 policy_mapping, **kwargs):
324 """Create distributed variables with given synchronization and aggregation."""
325 # Figure out what collections this variable should be added to.
326 # We'll add the MirroredVariable to those collections instead.
327 var_collections = kwargs.pop("collections", None)
328 if var_collections is None:
329 var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
330 kwargs["collections"] = []
332 synchronization = _validate_synchronization(kwargs)
333 # Update synchronization in kwargs in case it's AUTO, which is converted to
334 # ON_WRITE.
335 kwargs["synchronization"] = synchronization
336 aggregation = _validate_aggregation(kwargs)
337 use_var_policy = getattr(strategy.extended, "_use_var_policy", False)
339 # Ignore user-specified caching device, not needed for mirrored variables.
340 kwargs.pop("caching_device", None)
342 # TODO(josh11b,apassos): It would be better if variable initialization
343 # was never recorded on the tape instead of having to do this manually
344 # here.
345 with record.stop_recording():
346 value_list = real_mirrored_creator(**kwargs)
347 # MirroredVariable is recreated during saved_model loading, and its
348 # component variables (value_list) will have None initializer. We
349 # set their initializers to no_op so that consumer like
350 # `global_variables_initializer` wouldn't complain, as it groups all
351 # variables' initializers thus all variables have to have initializers.
352 for v in value_list:
353 # pylint:disable=protected-access
354 if hasattr(v, "_initializer_op") and v._initializer_op is None:
355 v._initializer_op = control_flow_ops.no_op()
356 # pylint:enable=protected-access
357 if use_var_policy:
358 var_policy_cls = policy_mapping.get(synchronization)
359 var_policy = var_policy_cls(aggregation=aggregation)
360 var_cls = class_mapping.get("VariableClass")
361 result = var_cls(strategy, value_list, aggregation, var_policy=var_policy)
362 else:
363 var_cls = class_mapping.get(synchronization)
364 result = var_cls(strategy, value_list, aggregation)
366 # Add the wrapped variable to the requested collections.
367 # The handling of eager mode and the global step matches
368 # ResourceVariable._init_from_args().
369 if not context.executing_eagerly():
370 g = ops.get_default_graph()
371 # If "trainable" is True, next_creator() will add the member variables
372 # to the TRAINABLE_VARIABLES collection, so we manually remove
373 # them and replace with the MirroredVariable. We can't set
374 # "trainable" to False for next_creator() since that causes functions
375 # like implicit_gradients to skip those variables.
376 if kwargs.get("trainable", True):
377 var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
378 l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
379 for value in value_list:
380 for i, trainable_variable in enumerate(l):
381 if value is trainable_variable:
382 del l[i]
383 break
385 g.add_to_collections(var_collections, result)
386 elif ops.GraphKeys.GLOBAL_STEP in var_collections:
387 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
389 return result
392# Utility functions
393# Return True if the Value is Mirrored or the Variable is replicated and kept in
394# sync.
395def is_mirrored(val):
396 return (getattr(val, "_is_mirrored", lambda: False))()
399def is_sync_on_read(val):
400 return not is_mirrored(val)
403class CachingScopeLocal(threading.local):
404 """Class for maintaining thread local state for caching scope."""
406 def __init__(self):
407 super(CachingScopeLocal, self).__init__()
408 self.new_cache_scope_count = 0
409 self.cache_scope_exited_count = 0
411 def enter_scope(self):
412 self.new_cache_scope_count += 1
414 def exit_scope(self):
415 self.cache_scope_exited_count += 1
417 def in_caching_scope(self):
418 return self.new_cache_scope_count > self.cache_scope_exited_count
421caching_scope_local = CachingScopeLocal()
424@contextlib.contextmanager
425def cache_variable_reads():
426 """Scope for caching variable reads for AggregatingVariable.
428 The variable reads for AggregatingVariable inside this scope are cached. i.e.
429 the first read of variable reads the value from possibly remote handle, but
430 subsequent reads are returned using local cached value.
432 For example:
433 strategy = ParameterServerStrategy...
434 with strategy.scope():
435 # Variable v is of AggregatingVariable type with actual variable residing
436 # on PS.
437 v = tf.Variable(1.0)
439 with distribute_utils.cache_variable_reads():
440 v.read_value() # Reads value 1.0
441 v.assign(constant_op.constant(5.0)) # v changes to 5.0
442 t1 = v.read_value()
443 t2 = v.read_value() # Both t1 & t2 return cached value 1.0 from local CPU.
445 Notes about cache_variable_reads scope:
446 1. Nesting of scope cache_variable_reads() is not supported
447 2. And when caching scope is enabled, the thread enabling the cache and
448 mirrored_run._MirroredReplicaThread threads spawned from it will have
449 caching enabled.
451 Yields:
452 A context for caching variables.
453 """
455 try:
456 if caching_scope_local.in_caching_scope():
457 # There is nested cache scope, which is not supported.
458 raise ValueError("cache_variable_reads scope cannot be nested")
459 caching_scope_local.enter_scope()
460 yield
461 finally:
462 caching_scope_local.exit_scope()
465# The following mapping indicates the policy that you must use for a given
466# variable `synchronization` and `aggregation` pair.
467# OnWritePolicy is used for:
468# (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
469# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
470# OnReadPolicy is used for:
471# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
472VARIABLE_POLICY_MAPPING = {
473 vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy,
474 vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy,
475}
477VARIABLE_CLASS_MAPPING = {
478 "VariableClass": values_lib.DistributedVariable,
479 vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable,
480 vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable,
481}
483TPU_VARIABLE_POLICY_MAPPING = {
484 vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUOnWritePolicy,
485 vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUOnReadPolicy,
486}
488TPU_VARIABLE_CLASS_MAPPING = {
489 "VariableClass": tpu_values_lib.TPUDistributedVariable,
490 vs.VariableSynchronization.ON_WRITE: tpu_values_lib.TPUMirroredVariable,
491 vs.VariableSynchronization.ON_READ: tpu_values_lib.TPUSyncOnReadVariable,
492}