Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/experimental/loss_scale.py: 38%
157 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains LossScale classes."""
16import abc
18from tensorflow.python.distribute import distribute_lib
19from tensorflow.python.distribute import reduce_util
20from tensorflow.python.eager import context
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import indexed_slices
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import cond
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import variable_v1
28from tensorflow.python.ops import variables
29from tensorflow.python.trackable import base as trackable
30from tensorflow.python.util import deprecation
31from tensorflow.python.util import nest
32from tensorflow.python.util.tf_export import tf_export
35@deprecation.deprecated_endpoints('mixed_precision.experimental.LossScale',
36 'train.experimental.LossScale')
37@tf_export(
38 v1=[
39 'mixed_precision.LossScale',
40 'mixed_precision.experimental.LossScale',
41 'train.experimental.LossScale'
42 ])
43class LossScale(trackable.Trackable, metaclass=abc.ABCMeta):
44 """Base class for all TF1 loss scales.
46 This is an abstract base class, so you cannot instantiate it directly.
47 Instead, use one of its concrete subclasses:
48 * `tf.compat.v1.mixed_precision.DynamicLossScale`
49 * `tf.compat.v1.mixed_precision.FixedLossScale`
51 Loss scaling is a process that multiplies the loss by a multiplier called the
52 loss scale, and divides each gradient by the same multiplier. The pseudocode
53 for this process is:
55 ```
56 loss = ...
57 loss *= loss_scale
58 grads = gradients(loss, vars)
59 grads /= loss_scale
60 ```
62 Mathematically, loss scaling has no effect, but can help avoid numerical
63 underflow in intermediate gradients when float16 tensors are used for mixed
64 precision training. By multiplying the loss, each intermediate gradient will
65 have the same multiplier applied.
67 Instances of this class represent a loss scale. Calling instances of this
68 class returns the loss scale as a scalar float32 tensor, while method
69 `update()` updates the loss scale depending on the values of the gradients.
70 Optimizers use instances of this class to scale loss and gradients.
72 In most functions that accept a LossScale, you can also pass an int (such as
73 8) to create a `FixedLossScale` or the string `"dynamic"` to create a dynamic
74 loss scale.
75 """
77 def __init__(self):
78 """Initializes the loss scale class."""
79 self._weights = {}
81 @abc.abstractmethod
82 def __call__(self):
83 """Returns the current loss scale as a scalar `float32` tensor."""
84 pass
86 @abc.abstractmethod
87 def update(self, grads):
88 """Updates the value of the loss scale.
90 The loss scale will be potentially updated, based on the value of `grads`.
91 The tensor returned by calling this class is only updated when this function
92 is evaluated.
94 In eager mode, this directly updates the loss scale, so that calling
95 `__call__` will return the newly updated loss scale. In graph mode,
96 this returns an op that, when evaluated, updates the loss scale.
98 This function also returns a `should_apply_gradients` bool. If False,
99 gradients should not be applied to the variables that step, as nonfinite
100 gradients were found, and the loss scale has been be updated to reduce the
101 chance of finding nonfinite gradients in the next step. Some loss scale
102 classes will always return True, as they cannot adjust themselves in
103 response to nonfinite gradients.
105 When a DistributionStrategy is used, this function may only be called in a
106 cross-replica context.
108 Args:
109 grads: A nested structure of unscaled gradients, each which is the
110 gradient of the loss with respect to a weight. The gradients should have
111 already been divided by the loss scale being before passed to this
112 function. 'None' gradients are accepted, and are ignored.
114 Returns:
115 update_op: In eager mode, None. In graph mode, an op to update the loss
116 scale.
117 should_apply_gradients: Either a bool or a scalar boolean tensor. If
118 False, the caller should skip applying `grads` to the variables this
119 step.
120 """
121 pass
123 def _add_weight(self, name, initial_value, dtype=None):
124 """Adds a weight to this loss scale.
126 Args:
127 name: Variable name.
128 initial_value: The variable's initial value.
129 dtype: The type of the variable.
131 Returns:
132 A variable.
134 Raises:
135 RuntimeError: If a weight with `name` has already been added.
136 """
137 variable = variable_v1.VariableV1(
138 initial_value=initial_value,
139 name=name,
140 dtype=dtype,
141 trainable=False,
142 use_resource=True,
143 synchronization=variables.VariableSynchronization.AUTO,
144 # Set aggregation to NONE, as loss scaling variables should never be
145 # aggregated.
146 aggregation=variables.VariableAggregation.NONE)
147 if context.executing_eagerly():
148 graph_key = None
149 else:
150 graph = ops.get_default_graph()
151 graph_key = graph._graph_key # pylint: disable=protected-access
153 key = (name, graph_key)
154 if self._weights.get(key, None) is not None:
155 raise RuntimeError('Duplicate variables detected. {}'.format(key))
156 self._weights[key] = variable
157 self._handle_deferred_dependencies(name=name, trackable=variable)
158 return variable
160 def _trackable_children(self,
161 save_type=trackable.SaveType.CHECKPOINT,
162 **kwargs):
163 """From Trackable. Gather graph-specific weights to save."""
164 if context.executing_eagerly():
165 graph_key = None
166 else:
167 graph = ops.get_default_graph()
168 graph_key = graph._graph_key # pylint: disable=protected-access
169 weights = {}
170 for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]):
171 if g == graph_key:
172 weights[name] = v
173 weights.update(
174 super(LossScale, self)._trackable_children(save_type, **kwargs))
175 return weights
177 def _lookup_dependency(self, name):
178 """From Trackable. Find a weight in the current graph."""
179 unconditional = super(LossScale, self)._lookup_dependency(name)
180 if unconditional is not None:
181 return unconditional
182 if context.executing_eagerly():
183 graph_key = None
184 else:
185 graph = ops.get_default_graph()
186 graph_key = graph._graph_key # pylint: disable=protected-access
187 return self._weights.get((name, graph_key), None)
189 @abc.abstractmethod
190 def get_config(self):
191 """Returns the config of this loss scale."""
192 pass
194 @classmethod
195 def from_config(cls, config):
196 """Creates the LossScale from its config."""
197 return cls(**config)
200@deprecation.deprecated_endpoints('mixed_precision.experimental.FixedLossScale',
201 'train.experimental.FixedLossScale')
202@tf_export(
203 v1=[
204 'mixed_precision.FixedLossScale',
205 'mixed_precision.experimental.FixedLossScale',
206 'train.experimental.FixedLossScale'
207 ])
208class FixedLossScale(LossScale):
209 """Loss scale with a fixed value.
211 The loss scale is not updated for the lifetime of instances of this class.
212 A given instance of this class always returns the same number when called.
213 """
215 @deprecation.deprecated(
216 None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. '
217 'LossScaleOptimizer now has all the functionality of '
218 'FixedLossScale')
219 def __init__(self, loss_scale_value):
220 """Creates the fixed loss scale.
222 Args:
223 loss_scale_value: A Python float. Its ideal value varies depending on
224 models to run. Choosing a too small loss_scale might affect model
225 quality; a too big loss_scale might cause inf or nan. There is no single
226 right loss_scale to apply. There is no harm choosing a relatively big
227 number as long as no nan or inf is encountered in training.
229 Raises:
230 ValueError: If loss_scale_value is less than 1.
231 """
232 super(FixedLossScale, self).__init__()
233 if not isinstance(loss_scale_value, (int, float)):
234 raise ValueError('loss_scale_value must be a Python int or float.')
235 if loss_scale_value < 1:
236 raise ValueError('loss_scale_value must be at least 1.')
237 # It's important we do not create tensors in the constructor, as such
238 # tensors might be on a different device or tf.function vs when the tensor
239 # is used. This would hurt performance. Therefore, we do not create a tensor
240 # from loss_scale_value, but instead leave it as a Python float.
241 # TODO(reedwm): Also do not create tensors in the DynamicLossScale
242 # constructor.
243 self._loss_scale_value = float(loss_scale_value)
245 def __call__(self):
246 return ops.convert_to_tensor(self._loss_scale_value)
248 def update(self, grads):
249 del grads
250 return control_flow_ops.no_op(), True
252 def __repr__(self):
253 return 'FixedLossScale(%s)' % self._loss_scale_value
255 def get_config(self):
256 return {'loss_scale_value': self._loss_scale_value}
259def _is_all_finite(grads):
260 """Returns a scalar boolean tensor indicating if all gradients are finite."""
261 def raw_values(g):
262 return g.values if isinstance(g, indexed_slices.IndexedSlices) else g
264 is_finite_per_grad = [
265 math_ops.reduce_all(math_ops.is_finite(raw_values(g)))
266 for g in grads
267 if g is not None
268 ]
269 return math_ops.reduce_all(is_finite_per_grad)
272def _op_in_graph_mode(tensor):
273 """Returns the tensor's op in graph mode, or the tensor in eager mode.
275 This is useful because sometimes an op is needed in graph mode instead of a
276 tensor. In eager mode, there are no ops.
278 Args:
279 tensor: A tensor.
281 Returns:
282 The tensor's op in graph mode. The tensor in eager mode.
283 """
284 if context.executing_eagerly():
285 return tensor
286 return tensor.op
289def _assign_if_finite(var, value):
290 """Assigns a value to a variable if the value is finite."""
291 return cond.cond(
292 math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)),
293 control_flow_ops.no_op)
296@deprecation.deprecated_endpoints(
297 'mixed_precision.experimental.DynamicLossScale',
298 'train.experimental.DynamicLossScale')
299@tf_export(
300 v1=[
301 'mixed_precision.DynamicLossScale',
302 'mixed_precision.experimental.DynamicLossScale',
303 'train.experimental.DynamicLossScale'
304 ])
305class DynamicLossScale(LossScale):
306 """Loss scale that dynamically adjusts itself.
308 Dynamic loss scaling works by adjusting the loss scale as training progresses.
309 The goal is to keep the loss scale as high as possible without overflowing the
310 gradients. As long as the gradients do not overflow, raising the loss scale
311 never hurts.
313 The algorithm starts by setting the loss scale to an initial value. Every N
314 steps that the gradients are finite, the loss scale is increased by some
315 factor. However, if a NaN or Inf gradient is found, the gradients for that
316 step are not applied, and the loss scale is decreased by the factor. This
317 process tends to keep the loss scale as high as possible without gradients
318 overflowing.
319 """
321 @deprecation.deprecated(
322 None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. '
323 'LossScaleOptimizer now has all the functionality of '
324 'DynamicLossScale')
325 def __init__(self,
326 initial_loss_scale=2 ** 15, # See docstring for why this is big.
327 increment_period=2000,
328 multiplier=2.):
329 """Creates the dynamic loss scale.
331 Args:
332 initial_loss_scale: A Python float. The loss scale to use at the
333 beginning. It's better to start this at a very high number, because a
334 loss scale that is too high gets lowered far more quickly than a loss
335 scale that is too low gets raised. The default is 2 ** 15, which is
336 approximately half the maximum float16 value.
337 increment_period: Increases loss scale every `increment_period`
338 consecutive steps that finite gradients are encountered. If a nonfinite
339 gradient is encountered, the count is reset back to zero.
340 multiplier: The multiplier to use when increasing or decreasing the loss
341 scale.
342 """
343 super(DynamicLossScale, self).__init__()
344 self._initial_loss_scale = float(initial_loss_scale)
345 self._increment_period = int(increment_period)
346 self._multiplier = float(multiplier)
348 self._current_loss_scale = self._add_weight(
349 name='current_loss_scale',
350 dtype=dtypes.float32,
351 initial_value=self._initial_loss_scale)
352 # The number of consecutive steps with finite gradients since the last
353 # nonfinite gradient or change in loss scale.
354 self._num_good_steps = self._add_weight(
355 name='good_steps', dtype=dtypes.int64, initial_value=0)
357 @property
358 def initial_loss_scale(self):
359 return self._initial_loss_scale
361 @property
362 def increment_period(self):
363 return self._increment_period
365 @property
366 def multiplier(self):
367 return self._multiplier
369 def __call__(self):
370 return ops.convert_to_tensor(self._current_loss_scale)
372 def update(self, grads):
373 """Updates loss scale based on if gradients are finite in current step."""
374 grads = nest.flatten(grads)
375 if distribute_lib.has_strategy():
376 distribution = distribute_lib.get_cross_replica_context()
378 def get_is_finite(grads):
379 is_finite = _is_all_finite(grads)
380 # We cast to float, because we cannot reduce booleans with
381 # DistributionStrategy.
382 return math_ops.cast(is_finite, dtypes.float32)
384 is_finite_float = distribution.extended.call_for_each_replica(
385 get_is_finite, args=(grads,))
386 reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM,
387 is_finite_float, axis=None)
388 is_finite = math_ops.equal(reduced_is_finite_float,
389 distribution.num_replicas_in_sync)
390 else:
391 is_finite = _is_all_finite(grads)
393 def update_if_finite_grads():
394 """Update assuming the gradients are finite."""
396 def incr_loss_scale():
397 new_loss_scale = self._current_loss_scale * self._multiplier
398 return control_flow_ops.group(
399 _assign_if_finite(self._current_loss_scale, new_loss_scale),
400 self._num_good_steps.assign(0))
402 return cond.cond(
403 self._num_good_steps + 1 >= self._increment_period,
404 incr_loss_scale, lambda: _op_in_graph_mode(
405 self._num_good_steps.assign_add(1)))
407 def update_if_not_finite_grads():
408 """Update assuming the gradients are nonfinite."""
410 new_loss_scale = math_ops.maximum(
411 self._current_loss_scale / self._multiplier, 1)
412 return control_flow_ops.group(
413 self._num_good_steps.assign(0),
414 self._current_loss_scale.assign(new_loss_scale))
416 update_op = cond.cond(is_finite, update_if_finite_grads,
417 update_if_not_finite_grads)
418 should_apply_gradients = is_finite
419 return update_op, should_apply_gradients
421 def __repr__(self):
422 if context.executing_eagerly():
423 return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, '
424 'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' %
425 (self._current_loss_scale.numpy(), self._num_good_steps.numpy(),
426 self.initial_loss_scale, self.increment_period, self.multiplier))
427 else:
428 return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, '
429 'multiplier=%s)' %
430 (self.initial_loss_scale, self.increment_period, self.multiplier))
432 def get_config(self):
433 return {
434 'initial_loss_scale': self.initial_loss_scale,
435 'increment_period': self.increment_period,
436 'multiplier': self.multiplier,
437 }
440def get(identifier):
441 """Get a loss scale object."""
442 if isinstance(identifier, (int, float)):
443 return FixedLossScale(identifier)
444 if identifier == 'dynamic':
445 return DynamicLossScale()
446 if isinstance(identifier, LossScale):
447 return identifier
448 elif identifier is None:
449 return None
450 else:
451 raise ValueError('Could not interpret loss scale identifier: %s' %
452 identifier)