Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/values_util.py: 30%
139 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions used by values.py and ps_values.py."""
17from tensorflow.python.distribute import distribute_lib
18from tensorflow.python.distribute import reduce_util
19from tensorflow.python.eager import context
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import tensor_util
22from tensorflow.python.ops import control_flow_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import variable_scope as vs
25from tensorflow.python.saved_model import save_context
26from tensorflow.python.saved_model import save_options
27from tensorflow.python.training.saving import saveable_object
30def write_object_proto(var, proto, options):
31 """Update a SavedObject proto for the caller.
33 If a DistributedVariable object supports this method, it will be called when
34 saving with a pre-built `SavedObject` proto representing the object, plus an
35 instance of `SaveOptions`. This method is then free to modify that proto
36 instance.
38 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
39 write out information about their components to the
40 `experimental_distributed_variable_components` field of a
41 `SavedVariable` (depending on the `SaveOptions` variable policy).
43 Args:
44 var: The DistributedVariable object.
45 proto: A pre-built `SavedObject` proto for this object. It is assumed this
46 will be a `SavedVariable` instance.
47 options: A `SaveOptions` instance.
48 """
49 if options.experimental_variable_policy._expand_distributed_variables( # pylint: disable=protected-access
50 ):
51 for var in var.values:
52 var_proto = (
53 proto.variable.experimental_distributed_variable_components.add())
54 var_proto.name = var.name.split(":")[0]
55 var_proto.device = var.device
58def get_on_write_saveable(var, primary_var, name):
59 """Return saveable spec for AUTO and ON_WRITE variables."""
60 # We use a callable so that we don't have to evaluate this expression
61 # in the case where we are trying to restore instead of save.
62 def tensor():
63 if context.executing_eagerly() and not primary_var.is_initialized():
64 # A SaveSpec tensor value of `None` indicates that the variable is
65 # uninitialized.
66 return None
67 strategy = var.distribute_strategy
68 return strategy.extended.read_var(var)
70 spec = saveable_object.SaveSpec(
71 tensor=tensor,
72 slice_spec="",
73 name=name,
74 dtype=var.dtype,
75 device=primary_var.device)
77 return tensor, [spec]
80def get_on_write_restore_ops(var, tensor):
81 """Return restore ops for AUTO and ON_WRITE variables."""
82 packed_var = var._packed_variable # pylint: disable=protected-access
83 if packed_var is not None:
84 return control_flow_ops.group(
85 tuple(
86 assign_on_device(d, packed_var, tensor)
87 for d in packed_var.devices))
88 return control_flow_ops.group(
89 tuple(
90 assign_on_device(v.device, v, tensor)
91 for v in var.values))
94def get_on_read_saveable(var, primary_var, name):
95 """Return saveables for ON_READ variable."""
97 # We use a callable so that we don't have to evaluate this expression
98 # in the case where we are trying to restore instead of save.
99 def tensor():
100 return var._get_cross_replica() # pylint: disable=protected-access
102 spec = saveable_object.SaveSpec(
103 tensor=tensor,
104 slice_spec="",
105 name=name,
106 dtype=var.dtype,
107 device=primary_var.device)
109 return tensor, [spec]
112def get_on_read_restore_ops(var, tensor, aggregation):
113 """Return restore ops for ON_READ variables."""
114 # To preserve the sum across save and restore, we have to divide the
115 # total across all devices when restoring a variable that was summed
116 # when saving.
117 if aggregation == vs.VariableAggregation.SUM:
118 strategy = var.distribute_strategy
119 tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
120 var.dtype)
121 return control_flow_ops.group(
122 tuple(
123 assign_on_device(v.device, v, tensor)
124 for v in var.values))
127# Utility function that indicates if you are in an UpdateContext when running
128# in a replica fn.
129def in_replica_update_context():
130 return distribute_lib.get_update_replica_id() is not None
133def on_write_assign(var, value, use_locking=False, name=None, read_value=True):
134 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
135 return var._update( # pylint: disable=protected-access
136 update_fn=assign_fn,
137 value=value,
138 use_locking=use_locking,
139 name=name,
140 read_value=read_value)
143def on_write_assign_add(var, value, use_locking=False, name=None,
144 read_value=True):
145 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
146 return var._update( # pylint: disable=protected-access
147 update_fn=assign_add_fn,
148 value=value,
149 use_locking=use_locking,
150 name=name,
151 read_value=read_value)
154def on_write_assign_sub(var, value, use_locking=False, name=None,
155 read_value=True):
156 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
157 return var._update( # pylint: disable=protected-access
158 update_fn=assign_sub_fn,
159 value=value,
160 use_locking=use_locking,
161 name=name,
162 read_value=read_value)
165def assign_on_each_device(var, assign_func, value, read_value):
166 """Update the variable on each replica with the given assign_func and value."""
167 if var._packed_variable is not None: # pylint: disable=protected-access
168 update = control_flow_ops.group(
169 tuple(
170 assign_func(d, var._packed_variable, value) for d in var._devices)) # pylint: disable=protected-access
171 else:
172 update = control_flow_ops.group(
173 tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access
174 if not read_value:
175 return update
176 with ops.control_dependencies([update] if update else []):
177 return var.read_value()
180def on_read_assign_sub_cross_replica(var, value, read_value=True):
181 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
182 if distribute_lib.in_cross_replica_context():
183 if var.aggregation == vs.VariableAggregation.SUM:
184 raise ValueError(
185 "SyncOnReadVariable does not support `assign_sub` in "
186 "cross-replica context when aggregation is set to "
187 "`tf.VariableAggregation.SUM`.")
188 return assign_on_each_device(var, assign_sub_on_device,
189 value, read_value)
192def on_read_assign_add_cross_replica(var, value, read_value=True):
193 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
194 if distribute_lib.in_cross_replica_context():
195 if var.aggregation == vs.VariableAggregation.SUM:
196 raise ValueError(
197 "SyncOnReadVariable does not support `assign_add` in "
198 "cross-replica context when aggregation is set to "
199 "`tf.VariableAggregation.SUM`.")
200 return assign_on_each_device(var, assign_add_on_device,
201 value, read_value)
204def on_read_assign_cross_replica(var, value, read_value=True):
205 """Return the value of the variable in cross replica context."""
206 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy):
207 if distribute_lib.in_cross_replica_context():
208 # To preserve the sum across save and restore, we have to divide the
209 # total across all devices when restoring a variable that was summed
210 # when saving.
211 tensor = value
212 if var.aggregation == vs.VariableAggregation.SUM:
213 strategy = var._distribute_strategy # pylint: disable=protected-access
214 tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
215 var.dtype)
216 return assign_on_each_device(var, assign_on_device, tensor,
217 read_value)
220def scatter_sub(var, sparse_delta, use_locking=False, name=None):
221 scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
222 return var._update( # pylint: disable=protected-access
223 update_fn=scatter_sub_fn,
224 value=sparse_delta,
225 use_locking=use_locking,
226 name=name)
229def scatter_add(var, sparse_delta, use_locking=False, name=None):
230 scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
231 return var._update( # pylint: disable=protected-access
232 update_fn=scatter_add_fn,
233 value=sparse_delta,
234 use_locking=use_locking,
235 name=name)
238def scatter_mul(var, sparse_delta, use_locking=False, name=None):
239 scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
240 return var._update( # pylint: disable=protected-access
241 update_fn=scatter_mul_fn,
242 value=sparse_delta,
243 use_locking=use_locking,
244 name=name)
247def scatter_div(var, sparse_delta, use_locking=False, name=None):
248 scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
249 return var._update( # pylint: disable=protected-access
250 update_fn=scatter_div_fn,
251 value=sparse_delta,
252 use_locking=use_locking,
253 name=name)
256def scatter_min(var, sparse_delta, use_locking=False, name=None):
257 scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
258 return var._update( # pylint: disable=protected-access
259 update_fn=scatter_min_fn,
260 value=sparse_delta,
261 use_locking=use_locking,
262 name=name)
265def scatter_max(var, sparse_delta, use_locking=False, name=None):
266 scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
267 return var._update( # pylint: disable=protected-access
268 update_fn=scatter_max_fn,
269 value=sparse_delta,
270 use_locking=use_locking,
271 name=name)
274def scatter_update(var, sparse_delta, use_locking=False, name=None):
275 scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
276 return var._update( # pylint: disable=protected-access
277 update_fn=scatter_update_fn,
278 value=sparse_delta,
279 use_locking=use_locking,
280 name=name)
283def get_current_replica_id_as_int():
284 """Returns the current replica ID as an integer, or `None`."""
285 replica_context = distribute_lib.get_replica_context()
286 if replica_context:
287 replica_id = replica_context._replica_id # pylint: disable=protected-access
288 if not isinstance(replica_id, int):
289 replica_id = tensor_util.constant_value(replica_id)
290 else:
291 replica_id = distribute_lib.get_update_replica_id()
292 return replica_id
295def assign_on_device(device, variable, tensor):
296 with ops.device(device):
297 return variable.assign(tensor)
300def assign_add_on_device(device, variable, tensor):
301 with ops.device(device):
302 return variable.assign_add(tensor)
305def assign_sub_on_device(device, variable, tensor):
306 with ops.device(device):
307 return variable.assign_sub(tensor)
310def assert_replica_context(strategy):
311 replica_context = distribute_lib.get_replica_context()
312 if not replica_context:
313 raise RuntimeError(
314 "Replica-local variables may only be assigned in a replica context.")
315 if replica_context.strategy is not strategy:
316 raise RuntimeError(
317 "Replica-local variables may only be assigned in a replica context.")
320def apply_aggregation(strategy, value, aggregation, destinations):
321 if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
322 return strategy.extended.broadcast_to(
323 strategy.experimental_local_results(value)[0],
324 destinations=destinations)
325 reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
326 return strategy.extended.reduce_to(reduce_op, value, destinations)
329aggregation_error_msg = (
330 "You must specify an aggregation method to update a "
331 "{variable_type} in Replica Context. You can do so by passing "
332 "an explicit value for argument `aggregation` to tf.Variable(..)."
333 "e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`"
334 "`tf.VariableAggregation` lists the possible aggregation methods."
335 "This is required because {variable_type} should always be "
336 "kept in sync. When updating them or assigning to them in a "
337 "replica context, we automatically try to aggregate the values "
338 "before updating the variable. For this aggregation, we need to "
339 "know the aggregation method. "
340 "Another alternative is to not try to update such "
341 "{variable_type} in replica context, but in cross replica "
342 "context. You can enter cross replica context by calling "
343 "`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`."
344 "Inside `merge_fn`, you can then update the {variable_type} "
345 "using `tf.distribute.StrategyExtended.update()`.")
348scatter_error_msg = ("{op_name} is only supported for mirrored "
349 "variable (variable created within certain "
350 "`tf.distribute.Strategy` scope) with NONE or "
351 "`ONLY_FIRST_REPLICA` aggregation, got: {aggregation}.")
354def is_saving_non_distributed():
355 """Returns whether we're saving a non-distributed version of the model.
357 It returns True iff we are in saving context and are saving a non-distributed
358 version of the model. That is, SaveOptions.experimental_variable_policy is
359 NONE.
361 Returns:
362 A boolean.
363 """
364 if not save_context.in_save_context():
365 return False
366 options = save_context.get_save_options()
367 return (options.experimental_variable_policy !=
368 save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)
371def mark_as_unsaveable():
372 """Marks the function as unsaveable if not inside save context."""
373 if ops.inside_function() and not save_context.in_save_context():
374 ops.get_default_graph().mark_as_unsaveable("""
375ConcreteFunction that uses distributed variables in certain way cannot be saved.
376If you're saving with
378tf.saved_model.save(..., signatures=f.get_concrete_function())
380do
382@tf.function(input_signature=...)
383def f_with_input_signature():
384 ...
386tf.saved_model.save(..., signatures=f_with_input_signature)`
388instead.""")