1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the loss scaling optimizer class."""
16
17from tensorflow.python.distribute import collective_all_reduce_strategy
18from tensorflow.python.distribute import distribute_lib
19from tensorflow.python.distribute import mirrored_strategy
20from tensorflow.python.distribute import one_device_strategy
21from tensorflow.python.distribute import tpu_strategy
22from tensorflow.python.eager import backprop
23from tensorflow.python.eager import context
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import indexed_slices
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import smart_cond
28from tensorflow.python.framework import tensor_conversion
29from tensorflow.python.keras import backend
30from tensorflow.python.keras import optimizers
31from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module
32from tensorflow.python.keras.optimizer_v2 import optimizer_v2
33from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils
34from tensorflow.python.ops import cond
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import variable_v1
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import tf_logging
40from tensorflow.python.trackable import base as trackable
41from tensorflow.python.trackable import base_delegate
42from tensorflow.python.training.experimental import loss_scale as loss_scale_module
43from tensorflow.python.training.experimental import mixed_precision
44from tensorflow.python.util import nest
45from tensorflow.python.util.tf_export import keras_export
46
47
48class _UnwrapPreventer(object):
49 """Wrapper that DistributionStrategy will not unwrap.
50
51 Typically, DistributionStrategy will unwrap values when going from a cross-
52 replica context to a replica context via `call_for_each_replica`. This class
53 is a wrapper that DistributionStrategy will not unwrap, so it can be used to
54 prevent it from unwrapping a value.
55
56 TODO(reedwm): Find/implement a better way of preventing values from being
57 unwrapped by DistributionStrategy
58 """
59
60 __slots__ = ['value']
61
62 def __init__(self, value):
63 self.value = value
64
65
66def _is_all_finite(grads):
67 """Returns a scalar boolean tensor indicating if all gradients are finite."""
68 def raw_values(g):
69 return g.values if isinstance(g, indexed_slices.IndexedSlices) else g
70
71 is_finite_per_grad = [
72 math_ops.reduce_all(math_ops.is_finite(raw_values(g)))
73 for g in grads
74 if g is not None
75 ]
76 return math_ops.reduce_all(is_finite_per_grad)
77
78
79def _op_in_graph_mode(tensor):
80 """Returns the tensor's op in graph mode, or the tensor in eager mode.
81
82 This is useful because sometimes an op is needed in graph mode instead of a
83 tensor. In eager mode, there are no ops.
84
85 Args:
86 tensor: A tensor.
87
88 Returns:
89 The tensor's op in graph mode. The tensor in eager mode.
90 """
91 if context.executing_eagerly():
92 return tensor
93 return tensor.op
94
95
96def _assign_if_finite(var, value):
97 """Assigns a value to a variable if the value is finite."""
98 return cond.cond(
99 math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)),
100 control_flow_ops.no_op)
101
102
103class _DynamicLossScaleState(trackable.Trackable):
104 """The state of a dynamic loss scale."""
105
106 def __init__(self,
107 initial_loss_scale,
108 growth_steps,
109 multiplier):
110 """Creates the dynamic loss scale."""
111 super(_DynamicLossScaleState, self).__init__()
112 self._initial_loss_scale = float(initial_loss_scale)
113 self._growth_steps = int(growth_steps)
114 self._multiplier = float(multiplier)
115
116 self._weights = {}
117 self._current_loss_scale = self._add_weight(
118 name='current_loss_scale',
119 dtype=dtypes.float32,
120 initial_value=self._initial_loss_scale)
121 # The number of consecutive steps with finite gradients since the last
122 # nonfinite gradient or change in loss scale. The name is 'good_steps' for
123 # backwards compatibility with older checkpoints.
124 self._counter = self._add_weight(
125 name='good_steps', dtype=dtypes.int64, initial_value=0)
126
127 def _add_weight(self, name, initial_value, dtype=None):
128 """Adds a weight to this loss scale.
129
130 Args:
131 name: Variable name.
132 initial_value: The variable's initial value.
133 dtype: The type of the variable.
134
135 Returns:
136 A variable.
137
138 Raises:
139 RuntimeError: If a weight with `name` has already been added.
140 """
141 variable = variable_v1.VariableV1(
142 initial_value=initial_value,
143 name=name,
144 dtype=dtype,
145 trainable=False,
146 use_resource=True,
147 synchronization=variables.VariableSynchronization.AUTO,
148 # Set aggregation to NONE, as loss scaling variables should never be
149 # aggregated.
150 aggregation=variables.VariableAggregation.NONE)
151 if context.executing_eagerly():
152 graph_key = None
153 else:
154 graph = ops.get_default_graph()
155 graph_key = graph._graph_key # pylint: disable=protected-access
156
157 key = (name, graph_key)
158 self._weights[key] = variable
159 self._handle_deferred_dependencies(name=name, trackable=variable)
160 backend.track_variable(variable)
161 return variable
162
163 def _trackable_children(self,
164 save_type=trackable.SaveType.CHECKPOINT,
165 **kwargs):
166 """From Trackable. Gather graph-specific weights to save."""
167 if context.executing_eagerly():
168 graph_key = None
169 else:
170 graph = ops.get_default_graph()
171 graph_key = graph._graph_key # pylint: disable=protected-access
172 weights = {}
173 for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]):
174 if g == graph_key:
175 weights[name] = v
176 weights.update(
177 super(_DynamicLossScaleState,
178 self)._trackable_children(save_type, **kwargs))
179 return weights
180
181 def _lookup_dependency(self, name):
182 """From Trackable. Find a weight in the current graph."""
183 unconditional = super(_DynamicLossScaleState, self)._lookup_dependency(name)
184 if unconditional is not None:
185 return unconditional
186 if context.executing_eagerly():
187 graph_key = None
188 else:
189 graph = ops.get_default_graph()
190 graph_key = graph._graph_key # pylint: disable=protected-access
191 return self._weights.get((name, graph_key), None)
192
193 @property
194 def initial_loss_scale(self):
195 return self._initial_loss_scale
196
197 @property
198 def growth_steps(self):
199 return self._growth_steps
200
201 @property
202 def multiplier(self):
203 return self._multiplier
204
205 @property
206 def current_loss_scale(self):
207 """Returns the current loss scale as a float32 `tf.Variable`."""
208 return self._current_loss_scale
209
210 @property
211 def counter(self):
212 """Returns the counter as a float32 `tf.Variable`."""
213 return self._counter
214
215 def __call__(self):
216 """Returns the current loss scale as a scalar `float32` tensor."""
217 return tensor_conversion.convert_to_tensor_v2_with_dispatch(
218 self._current_loss_scale
219 )
220
221 def update(self, grads):
222 """Updates the value of the loss scale.
223
224 Args:
225 grads: A nested structure of unscaled gradients, each which is an
226 all-reduced gradient of the loss with respect to a weight.
227
228 Returns:
229 update_op: In eager mode, None. In graph mode, an op to update the loss
230 scale.
231 should_apply_gradients: Either a bool or a scalar boolean tensor. If
232 False, the caller should skip applying `grads` to the variables this
233 step.
234 """
235 grads = nest.flatten(grads)
236 if distribute_lib.has_strategy(
237 ) and distribute_lib.in_cross_replica_context():
238 distribution = distribute_lib.get_strategy()
239 is_finite_per_replica = distribution.extended.call_for_each_replica(
240 _is_all_finite, args=(grads,))
241 # Each replica computed the same `is_finite` value, since `grads` is
242 # all-reduced across replicas. Arbitrarily take `is_finite` from the first
243 # replica.
244 is_finite = (
245 distribution.experimental_local_results(is_finite_per_replica)[0])
246 else:
247 is_finite = _is_all_finite(grads)
248
249 def update_if_finite_grads():
250 """Update assuming the gradients are finite."""
251
252 def incr_loss_scale():
253 new_loss_scale = self.current_loss_scale * self.multiplier
254 return control_flow_ops.group(
255 _assign_if_finite(self.current_loss_scale, new_loss_scale),
256 self.counter.assign(0))
257
258 return cond.cond(
259 self.counter + 1 >= self.growth_steps,
260 incr_loss_scale,
261 lambda: _op_in_graph_mode(self.counter.assign_add(1)))
262
263 def update_if_not_finite_grads():
264 """Update assuming the gradients are nonfinite."""
265
266 new_loss_scale = math_ops.maximum(
267 self.current_loss_scale / self.multiplier, 1)
268 return control_flow_ops.group(
269 self.counter.assign(0),
270 self.current_loss_scale.assign(new_loss_scale))
271
272 update_op = cond.cond(is_finite, update_if_finite_grads,
273 update_if_not_finite_grads)
274 should_apply_gradients = is_finite
275 return update_op, should_apply_gradients
276
277
278# See LossScaleOptimizer docstring for why this is so big
279_DEFAULT_INITIAL_SCALE = 2 ** 15
280_DEFAULT_GROWTH_STEPS = 2000
281
282
283# pylint: disable=g-classes-have-attributes
284@keras_export('keras.mixed_precision.LossScaleOptimizer')
285class LossScaleOptimizer(base_delegate.DelegatingTrackableMixin,
286 optimizer_v2.OptimizerV2):
287 """An optimizer that applies loss scaling to prevent numeric underflow.
288
289 Loss scaling is a technique to prevent numeric underflow in intermediate
290 gradients when float16 is used. To prevent underflow, the loss is multiplied
291 (or "scaled") by a certain factor called the "loss scale", which causes
292 intermediate gradients to be scaled by the loss scale as well. The final
293 gradients are divided (or "unscaled") by the loss scale to bring them back to
294 their original value.
295
296 `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it.
297 By default, the loss scale is dynamically updated over time so you do not have
298 to choose the loss scale. The `minimize` method automatically scales the loss,
299 unscales the gradients, and updates the loss scale so all you have to do is
300 wrap your optimizer with a `LossScaleOptimizer` if you use `minimize`. For
301 example:
302
303 >>> opt = tf.keras.optimizers.SGD(0.25)
304 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
305 >>> var = tf.Variable(1.)
306 >>> loss_fn = lambda: var ** 2
307 >>> # 'minimize' applies loss scaling and updates the loss sale.
308 >>> opt.minimize(loss_fn, var_list=var)
309 >>> var.numpy()
310 0.5
311
312 If a `tf.GradientTape` is used to compute gradients instead of `minimize`, you
313 must scale the loss and gradients manually. This can be done with the
314 `LossScaleOptimizer.get_scaled_loss` and
315 `LossScaleOptimizer.get_unscaled_gradients` methods. For example:
316
317 >>> with tf.GradientTape() as tape:
318 ... loss = loss_fn()
319 ... scaled_loss = opt.get_scaled_loss(loss)
320 >>> scaled_grad = tape.gradient(scaled_loss, var)
321 >>> (grad,) = opt.get_unscaled_gradients([scaled_grad])
322 >>> opt.apply_gradients([(grad, var)]) # Loss scale is updated here
323 >>> var.numpy()
324 0.25
325
326 Warning: If you forget to call `get_scaled_loss` or `get_unscaled_gradients`
327 (or both) when using a `tf.GradientTape`, the model will likely converge to a
328 worse quality. Please make sure you call each function exactly once.
329
330 When mixed precision with float16 is used, there is typically no risk of
331 underflow affecting model quality if loss scaling is properly used. See
332 [the mixed precision guide](
333 https://www.tensorflow.org/guide/keras/mixed_precision) for more information
334 on how to use mixed precision.
335
336 Args:
337 inner_optimizer: The `tf.keras.optimizers.Optimizer` instance to wrap.
338 dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to
339 True. If True, the loss scale will be dynamically updated over time using
340 an algorithm that keeps the loss scale at approximately its optimal value.
341 If False, a single fixed loss scale is used and `initial_scale` must be
342 specified, which is used as the loss scale. Recommended to keep as True,
343 as choosing a fixed loss scale can be tricky. Currently, there is a small
344 performance overhead to dynamic loss scaling compared to fixed loss
345 scaling.
346 initial_scale: The initial loss scale. If `dynamic` is True, this defaults
347 to `2 ** 15`. If `dynamic` is False, this must be specified and acts as
348 the sole loss scale, as the loss scale does not change over time. When
349 dynamic loss scaling is used, is better for this to be a very high number,
350 because a loss scale that is too high gets lowered far more quickly than a
351 loss scale that is too low gets raised.
352 dynamic_growth_steps: With dynamic loss scaling, every
353 `dynamic_growth_steps` steps with finite gradients, the loss scale is
354 doubled. Defaults to 2000. If a nonfinite gradient is encountered, the
355 count is reset back to zero, gradients are skipped that step, and the loss
356 scale is halved. The count can be queried with
357 `LossScaleOptimizer.dynamic_counter`. This argument can only be specified
358 if `dynamic` is True.
359
360 `LossScaleOptimizer` will occasionally skip applying gradients to the
361 variables, in which case the trainable variables will not change that step.
362 This is done because the dynamic loss scale will sometimes be raised too
363 high, causing overflow in the gradients. Typically, the first 2 to 15 steps of
364 the model are skipped as the initial loss scale is very high, but afterwards
365 steps will only be skipped on average 0.05% of the time (the fraction of steps
366 skipped is `1 / dynamic_growth_steps`).
367
368 `LossScaleOptimizer` delegates all public `Optimizer` methods to the inner
369 optimizer. Additionally, in methods `minimize` and `get_gradients`, it scales
370 the loss and unscales the gradients. In methods `minimize` and
371 `apply_gradients`, it additionally updates the loss scale and skips applying
372 gradients if any gradient has a nonfinite value.
373
374 ### Hyperparameters
375
376 Hyperparameters can be accessed and set on the LossScaleOptimizer, which will
377 be delegated to the wrapped optimizer.
378
379 >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5)
380 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
381 >>> opt.beta_1 # Equivalent to `opt.inner_optimizer.beta_1`
382 0.8
383 >>> opt.beta_1 = 0.7 # Equivalent to `opt.inner_optimizer.beta_1 = 0.7`
384 >>> opt.beta_1
385 0.7
386 >>> opt.inner_optimizer.beta_1
387 0.7
388
389 However, accessing or setting non-hyperparameters is not delegated to the
390 LossScaleOptimizer. In an Adam optimizer, `beta_1` is a hyperparameter but
391 `epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on
392 `beta_1`.
393
394 >>> opt.inner_optimizer.epsilon
395 1e-5
396 >>> opt.epsilon
397 Traceback (most recent call last):
398 ...
399 AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon'
400 >>> opt.epsilon = 1e-4 # This does NOT set epsilon on `opt.inner_optimizer`
401 >>> opt.inner_optimizer.epsilon
402 >>> 1e-5
403
404 In the above example, despite epsilon being set on the LossScaleOptimizer, the
405 old epsilon value will still be used when training as epsilon was not set on
406 the inner optimizer.
407 """
408
409 _HAS_AGGREGATE_GRAD = True
410
411 def __init__(self, inner_optimizer, dynamic=True, initial_scale=None,
412 dynamic_growth_steps=None):
413 if not isinstance(inner_optimizer, optimizer_v2.OptimizerV2):
414 raise TypeError('"inner_optimizer" must be an instance of OptimizerV2, '
415 'but got: %s' % inner_optimizer)
416 if not isinstance(dynamic, bool):
417 # Catch errors if a user incorrectly passes a string or float to the
418 # second argument argument, as this is commonly done for
419 # LossScaleOptimizerV1.
420 raise TypeError('"dynamic" argument to LossScaleOptimizer.__init__ must '
421 'be a bool, but got: %r' % (dynamic,))
422 if isinstance(inner_optimizer, LossScaleOptimizer):
423 raise TypeError('LossScaleOptimizer cannot wrap another '
424 'LossScaleOptimizer, but got: %s' % (inner_optimizer,))
425 self._raise_if_strategy_unsupported()
426 if getattr(inner_optimizer, '_is_wrapped_by_loss_scale_optimizer', False):
427 # TODO(reedwm): Maybe support this. The difficulty is that LSO has the
428 # same checkpoint format as the inner optimizer, so multiple LSOs wrapping
429 # the same optimizer causes the checkpointing logic to become confused.
430 raise ValueError('"inner_optimizer" is already wrapped by a '
431 'LossScaleOptimizer. An optimizer can only be wrapped '
432 'by a single LossScaleOptimizer')
433 self._optimizer = inner_optimizer
434 self._optimizer._is_wrapped_by_loss_scale_optimizer = True
435
436 # We don't call super().__init__, since we do not want to call OptimizerV2's
437 # constructor.
438 base_delegate.DelegatingTrackableMixin.__init__(self, self._optimizer)
439
440 if dynamic:
441 if initial_scale is None:
442 initial_scale = _DEFAULT_INITIAL_SCALE
443 if dynamic_growth_steps is None:
444 dynamic_growth_steps = _DEFAULT_GROWTH_STEPS
445 self._loss_scale = _DynamicLossScaleState(
446 initial_scale, dynamic_growth_steps, multiplier=2)
447 self._track_trackable(self._loss_scale, 'loss_scale')
448 else:
449 if initial_scale is None:
450 raise ValueError('"initial_scale" must be specified if "dynamic" is '
451 'False')
452 self._loss_scale = float(initial_scale)
453 if dynamic_growth_steps is not None:
454 raise ValueError('"dynamic_growth_steps" must be None if "dynamic" '
455 'is False, but got: %s' % (dynamic_growth_steps,))
456
457 # To support restoring TensorFlow 2.2 checkpoints.
458 self._track_trackable(FakeOptimizerForRestoration(self._optimizer),
459 'base_optimizer')
460
461 @property
462 def dynamic(self):
463 """Bool indicating whether dynamic loss scaling is used."""
464 return isinstance(self._loss_scale, _DynamicLossScaleState)
465
466 @property
467 def loss_scale(self):
468 """The current loss scale as a float32 scalar tensor."""
469 if isinstance(self._loss_scale, _DynamicLossScaleState):
470 return tensor_conversion.convert_to_tensor_v2_with_dispatch(
471 self._loss_scale.current_loss_scale
472 )
473 else:
474 return tensor_conversion.convert_to_tensor_v2_with_dispatch(
475 self._loss_scale
476 )
477
478 @property
479 def dynamic_counter(self):
480 """The number of steps since the loss scale was last increased or decreased.
481
482 This is None if `LossScaleOptimizer.dynamic` is False.
483
484 The counter is incremented every step. Once it reaches
485 `LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be doubled
486 and the counter will be reset back to zero. If nonfinite gradients are
487 encountered, the loss scale will be halved and the counter will be reset
488 back to zero.
489 """
490 if isinstance(self._loss_scale, _DynamicLossScaleState):
491 return self._loss_scale.counter
492 else:
493 return None
494
495 @property
496 def initial_scale(self):
497 """The initial loss scale.
498
499 If `LossScaleOptimizer.dynamic` is False, this is the same number as
500 `LossScaleOptimizer.loss_scale`, as the loss scale never changes.
501 """
502 if isinstance(self._loss_scale, _DynamicLossScaleState):
503 return self._loss_scale.initial_loss_scale
504 else:
505 return self._loss_scale
506
507 @property
508 def dynamic_growth_steps(self):
509 """The number of steps it takes to increase the loss scale.
510
511 This is None if `LossScaleOptimizer.dynamic` is False.
512
513 Every `dynamic_growth_steps` consecutive steps with finite gradients, the
514 loss scale is increased.
515 """
516 if isinstance(self._loss_scale, _DynamicLossScaleState):
517 return self._loss_scale.growth_steps
518 else:
519 return None
520
521 @property
522 def inner_optimizer(self):
523 """The optimizer that this LossScaleOptimizer is wrapping."""
524 return self._optimizer
525
526 def get_scaled_loss(self, loss):
527 """Scales the loss by the loss scale.
528
529 This method is only needed if you compute gradients manually, e.g. with
530 `tf.GradientTape`. In that case, call this method to scale the loss before
531 passing the loss to `tf.GradientTape`. If you use
532 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
533 scaling is automatically applied and this method is unneeded.
534
535 If this method is called, `get_unscaled_gradients` should also be called.
536 See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for
537 an example.
538
539 Args:
540 loss: The loss, which will be multiplied by the loss scale. Can either be
541 a tensor or a callable returning a tensor.
542
543 Returns:
544 `loss` multiplied by `LossScaleOptimizer.loss_scale`.
545 """
546 if callable(loss):
547 def new_loss():
548 loss_val = loss()
549 return loss_val * math_ops.cast(self.loss_scale, loss_val.dtype)
550 return new_loss
551 else:
552 return loss * math_ops.cast(self.loss_scale, loss.dtype)
553
554 def get_unscaled_gradients(self, grads):
555 """Unscales the gradients by the loss scale.
556
557 This method is only needed if you compute gradients manually, e.g. with
558 `tf.GradientTape`. In that case, call this method to unscale the gradients
559 after computing them with `tf.GradientTape`. If you use
560 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
561 scaling is automatically applied and this method is unneeded.
562
563 If this method is called, `get_scaled_loss` should also be called. See
564 the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an
565 example.
566
567 Args:
568 grads: A list of tensors, each which will be divided by the loss scale.
569 Can have None values, which are ignored.
570
571 Returns:
572 A new list the same size as `grads`, where every non-None value in `grads`
573 is divided by `LossScaleOptimizer.loss_scale`.
574 """
575 loss_scale_reciprocal = 1. / self.loss_scale
576 return [
577 _multiply_gradient(g, loss_scale_reciprocal) if g is not None else None
578 for g in grads
579 ]
580
581 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
582 tape = backprop.GradientTape() if tape is None else tape
583 with tape:
584 loss = self.get_scaled_loss(loss)
585 grads_and_vars = self._optimizer._compute_gradients( # pylint: disable=protected-access
586 loss,
587 var_list,
588 grad_loss,
589 tape=tape)
590 grads = [g for g, _ in grads_and_vars]
591 weights = [v for _, v in grads_and_vars]
592 unscaled_grads = self.get_unscaled_gradients(grads)
593 return list(zip(unscaled_grads, weights))
594
595 def get_gradients(self, loss, params):
596 loss = self.get_scaled_loss(loss)
597 grads = self._optimizer.get_gradients(loss, params)
598 return self.get_unscaled_gradients(grads)
599
600 def _create_all_weights(self, var_list):
601 self._optimizer._create_all_weights(var_list) # pylint: disable=protected-access
602
603 def apply_gradients(self,
604 grads_and_vars,
605 name=None,
606 experimental_aggregate_gradients=True):
607 if distribute_lib.in_cross_replica_context():
608 raise ValueError('apply_gradients() must be called in a replica context.')
609 # We check for the strategy here despite already checking in the constructor
610 # as frequently the optimizer is created outside the strategy's scope.
611 self._raise_if_strategy_unsupported()
612
613 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
614 if experimental_aggregate_gradients:
615 # We must aggregate the gradients here instead of in
616 # self.optimizer.apply_gradients, so that any NaN or Inf gradients are
617 # propogated to each replica. If any replica has a NaN or Inf gradient,
618 # they must all have a NaN or Inf gradient so that they all skip the step.
619 # pylint: disable=protected-access
620 grads_and_vars = self._optimizer._transform_unaggregated_gradients(
621 grads_and_vars)
622 grads_and_vars = self._optimizer._aggregate_gradients(grads_and_vars)
623 # pylint: enable=protected-access
624
625 grads_and_vars = tuple(grads_and_vars)
626 grads = [g for g, _ in grads_and_vars]
627 # We do not want DistributionStrategy to unwrap any MirroredVariables in
628 # grads_and_vars, because even in a replica context, the wrapped
629 # optimizer expects mirrored variables. So we wrap the variables with an
630 # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the
631 # MirroredVariables.
632 wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars])
633
634 def do_not_apply_fn():
635 # Normally self._optimizer.iterations is incremented in
636 # self._optimizer.apply_gradients(). Since that is not called in this
637 # branch, we increment it here instead.
638 return self._optimizer.iterations.assign_add(1, read_value=False)
639
640 def _if_should_apply_grads(grads):
641 if isinstance(self._loss_scale, _DynamicLossScaleState):
642 return self._loss_scale.update(grads)
643 else:
644 return (control_flow_ops.no_op(), True)
645
646 if optimizer_utils.strategy_supports_no_merge_call():
647 loss_scale_update_op, should_apply_grads = _if_should_apply_grads(grads)
648 def apply_fn():
649 return self._apply_gradients(grads, wrapped_vars, name)
650
651 maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
652 do_not_apply_fn)
653 return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
654
655 else:
656
657 def _apply_gradients_cross_replica(distribution, grads, wrapped_vars,
658 name):
659 loss_scale_update_op, should_apply_grads = _if_should_apply_grads(grads)
660
661 def apply_fn():
662 return distribution.extended.call_for_each_replica(
663 self._apply_gradients,
664 args=(grads, wrapped_vars, name))
665
666 # Note: We must call this cond() in a cross-replica context.
667 # DistributionStrategy does not support having a cond in a replica
668 # context with a branch that calls `merge_call`, and
669 # self._optimizer.apply_gradients calls `merge_call`.
670 maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
671 do_not_apply_fn)
672 return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
673 return distribute_lib.get_replica_context().merge_call(
674 _apply_gradients_cross_replica,
675 args=(grads, wrapped_vars, name))
676
677 def _apply_gradients(self, grads, wrapped_vars, name):
678 # Pass experimental_aggregate_gradients=False since LossScaleOptimizer
679 # already aggregated the gradients.
680 # TODO(reedwm): This will raise a fairly cryptic error message if
681 # self._optimizer.apply_gradients does not take
682 # experimental_aggregate_gradients.
683 return self._optimizer.apply_gradients(
684 list(zip(grads, wrapped_vars.value)), name,
685 experimental_aggregate_gradients=False)
686
687 def get_config(self):
688 serialized_optimizer = optimizers.serialize(self._optimizer)
689 return {
690 'inner_optimizer': serialized_optimizer,
691 'dynamic': self.dynamic,
692 'initial_scale': self.initial_scale,
693 'dynamic_growth_steps': self.dynamic_growth_steps,
694 }
695
696 @classmethod
697 def from_config(cls, config, custom_objects=None):
698 config = config.copy() # Make a copy, since we mutate config
699 if 'loss_scale' in config:
700 # If loss_scale is in config, we assume we are deserializing a
701 # LossScaleOptimizer from TF 2.3 or below. We convert the config so it
702 # can be deserialized in the current LossScaleOptimizer.
703 loss_scale = keras_loss_scale_module.deserialize(
704 config.pop('loss_scale'))
705 if isinstance(loss_scale, loss_scale_module.FixedLossScale):
706 config['dynamic'] = False
707 config['initial_scale'] = loss_scale._loss_scale_value # pylint: disable=protected-access
708 elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
709 config['dynamic'] = True
710 config['initial_scale'] = loss_scale.initial_loss_scale
711 config['dynamic_growth_steps'] = loss_scale.increment_period
712 if loss_scale.multiplier != 2:
713 raise ValueError('Cannot deserialize LossScaleOptimizer with a '
714 'DynamicLossScale whose multiplier is not 2. Got '
715 'DynamicLossScale: %s' % (loss_scale,))
716 else:
717 raise ValueError(
718 'Serialized LossScaleOptimizers with a LossScale that is neither a '
719 'FixedLossScale nor a DynamicLossScale can no longer be '
720 'deserialized')
721 config['inner_optimizer'] = config.pop('optimizer')
722 config['inner_optimizer'] = optimizers.deserialize(
723 config['inner_optimizer'], custom_objects=custom_objects)
724 return cls(**config)
725
726 def _raise_if_strategy_unsupported(self):
727 if not strategy_supports_loss_scaling():
728 strategy = distribute_lib.get_strategy()
729 if isinstance(strategy,
730 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
731 tpu_strategy.TPUStrategyV2)):
732 raise ValueError(
733 'Loss scaling is not supported with TPUStrategy. Loss scaling is '
734 'unnecessary with TPUs, since they support bfloat16 instead of '
735 'float16 and bfloat16 does not require loss scaling. You should '
736 'remove the use of the LossScaleOptimizer when TPUs are used.')
737 else:
738 raise ValueError('Loss scaling is not supported with the '
739 'tf.distribute.Strategy: %s. Try using a different '
740 'Strategy, e.g. a MirroredStrategy' %
741 strategy.__class__.__name__)
742
743 # Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer
744 # below.
745
746 @property
747 def iterations(self):
748 return self._optimizer.iterations
749
750 @iterations.setter
751 def iterations(self, variable):
752 self._optimizer.iterations = variable
753
754 def get_slot_names(self):
755 return self._optimizer.get_slot_names()
756
757 def variables(self):
758 return self._optimizer.variables()
759
760 @property
761 def weights(self):
762 return self._optimizer.weights
763
764 def get_weights(self):
765 return self._optimizer.get_weights()
766
767 def set_weights(self, weights):
768 return self._optimizer.set_weights(weights)
769
770 @property
771 def clipnorm(self):
772 return self._optimizer.clipnorm
773
774 @clipnorm.setter
775 def clipnorm(self, val):
776 self._optimizer.clipnorm = val
777
778 @property
779 def global_clipnorm(self):
780 return self._optimizer.global_clipnorm
781
782 @global_clipnorm.setter
783 def global_clipnorm(self, val):
784 self._optimizer.global_clipnorm = val
785
786 @property
787 def clipvalue(self):
788 return self._optimizer.clipvalue
789
790 @clipvalue.setter
791 def clipvalue(self, val):
792 self._optimizer.clipvalue = val
793
794 def _aggregate_gradients(self, grads_and_vars):
795 return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access
796
797 def _restore_slot_variable(self, slot_name, variable, slot_variable):
798 return self._optimizer._restore_slot_variable(slot_name, variable, # pylint: disable=protected-access
799 slot_variable)
800
801 def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
802 variable):
803 return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access
804 slot_variable_position, slot_name, variable)
805
806 def get_slot(self, var, slot_name):
807 return self._optimizer.get_slot(var, slot_name)
808
809 def add_slot(self, var, slot_name, initializer='zeros'):
810 return self._optimizer.add_slot(var, slot_name, initializer)
811
812 def __getattribute__(self, name):
813 try:
814 return object.__getattribute__(self, name)
815 except AttributeError as e:
816 if name == '_optimizer' or name == '_hyper':
817 # Avoid infinite recursion
818 raise e
819
820 # Delegate hyperparameter accesses to inner optimizer.
821 if name == 'lr':
822 name = 'learning_rate'
823 if name in self._optimizer._hyper:
824 return self._optimizer._get_hyper(name)
825 raise e
826
827 def __dir__(self):
828 result = set(super(LossScaleOptimizer, self).__dir__())
829 if '_optimizer' in result:
830 result |= self._optimizer._hyper.keys()
831 if 'learning_rate' in self._optimizer._hyper.keys():
832 result.add('lr')
833 return list(result)
834
835 def __setattr__(self, name, value):
836 if name == 'lr':
837 name = 'learning_rate'
838 # Delegate setting hyperparameter to inner optimizer if the attribute does
839 # not exist on the LossScaleOptimizer
840 try:
841 # We cannot check for the 'iterations' attribute as it cannot be set after
842 # it is accessed.
843 if name != 'iterations':
844 object.__getattribute__(self, name)
845 has_attribute = True
846 except AttributeError:
847 has_attribute = False
848 if (name != '_optimizer' and name in self._optimizer._hyper
849 and not has_attribute):
850 self._optimizer._set_hyper(name, value)
851 else:
852 super(LossScaleOptimizer, self).__setattr__(name, value)
853
854 # Explicitly delegate learning_rate. Normally hyperparameters are delegated in
855 # __getattribute__, but if a hyperparameter is not in self._optimizer._hyper
856 # (e.g. because self._optimizer itself wraps another optimizer), then it won't
857 # be delegated. Since learning_rate is a very commonly accessed
858 # hyperparameter, we delegate it here.
859 @property
860 def learning_rate(self):
861 return self._optimizer.learning_rate
862
863 @learning_rate.setter
864 def learning_rate(self, value):
865 self._optimizer.learning_rate = value
866
867 @property
868 def lr(self):
869 return self._optimizer.learning_rate
870
871 @lr.setter
872 def lr(self, value):
873 self._optimizer.lr = value
874
875 # We do not override some OptimizerV2 methods. For each, we describe why we do
876 # not delegate them to self._optimizer:
877 # * get_updates: get_updates() calls get_gradients(). Since we override
878 # get_gradients(), we cannot delegate get_updates() to self._optimizer,
879 # otherwise the overridden get_gradients() method would not be called.
880 # Luckily, get_updates() does not access any OptimizerV2 fields, so
881 # inheriting the OptimizerV2 version works fine.
882 # * minimize: We don't delegate for a similar as get_updates(): it calls
883 # both self._compute_gradients() and self.apply_gradients(), and both need
884 # to have the LossScaleOptimizer version called.
885
886 # TODO(reedwm): Maybe throw an error if mixed precision is used without this
887 # optimizer being used.
888
889
890@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
891class LossScaleOptimizerV1(LossScaleOptimizer):
892 """An deprecated optimizer that applies loss scaling.
893
894 Warning: This class is deprecated and will be removed in a future version of
895 TensorFlow. Please use the non-experimental class
896 `tf.keras.mixed_precision.LossScaleOptimizer` instead.
897
898 This class is identical to the non-experimental
899 `keras.mixed_precision.LossScaleOptimizer` except its constructor takes
900 different arguments. For this class (the experimental version), the
901 constructor takes a `loss_scale` argument. For the non-experimental class,
902 the constructor encodes the loss scaling information in multiple arguments.
903 Note that unlike this class, the non-experimental class does not accept a
904 `tf.compat.v1.mixed_precision.LossScale`, which is deprecated.
905
906 If you currently use this class, you should switch to the non-experimental
907 `tf.keras.mixed_precision.LossScaleOptimizer` instead. We show several
908 examples of converting the use of the experimental class to the equivalent
909 non-experimental class.
910
911 >>> # In all of the examples below, `opt1` and `opt2` are identical
912 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
913 ... tf.keras.optimizers.SGD(), loss_scale='dynamic')
914 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
915 ... tf.keras.optimizers.SGD())
916 >>> assert opt1.get_config() == opt2.get_config()
917
918 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
919 ... tf.keras.optimizers.SGD(), loss_scale=123)
920 >>> # dynamic=False indicates to use fixed loss scaling. initial_scale=123
921 >>> # refers to the initial loss scale, which is the single fixed loss scale
922 >>> # when dynamic=False.
923 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
924 ... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=123)
925 >>> assert opt1.get_config() == opt2.get_config()
926
927 >>> loss_scale = tf.compat.v1.mixed_precision.experimental.DynamicLossScale(
928 ... initial_loss_scale=2048, increment_period=500)
929 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
930 ... tf.keras.optimizers.SGD(), loss_scale=loss_scale)
931 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
932 ... tf.keras.optimizers.SGD(), initial_scale=2048,
933 ... dynamic_growth_steps=500)
934 >>> assert opt1.get_config() == opt2.get_config()
935
936 Make sure to also switch from this class to the non-experimental class in
937 isinstance checks, if you have any. If you do not do this, your model may run
938 into hard-to-debug issues, as the experimental `LossScaleOptimizer` subclasses
939 the non-experimental `LossScaleOptimizer`, but not vice versa. It is safe to
940 switch isinstance checks to the non-experimental `LossScaleOptimizer` even
941 before using the non-experimental `LossScaleOptimizer`.
942
943 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
944 ... tf.keras.optimizers.SGD(), loss_scale='dynamic')
945 >>> # The experimental class subclasses the non-experimental class
946 >>> isinstance(opt1, tf.keras.mixed_precision.LossScaleOptimizer)
947 True
948 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
949 ... tf.keras.optimizers.SGD())
950 >>> # The non-experimental class does NOT subclass the experimental class.
951 >>> isinstance(opt2, tf.keras.mixed_precision.experimental.LossScaleOptimizer)
952 False
953
954 Args:
955 optimizer: The Optimizer instance to wrap.
956 loss_scale: The loss scale to scale the loss and gradients. This can
957 either be an int/float to use a fixed loss scale, the string "dynamic"
958 to use dynamic loss scaling, or an instance of a LossScale. The string
959 "dynamic" equivalent to passing `DynamicLossScale()`, and passing an
960 int/float is equivalent to passing a FixedLossScale with the given loss
961 scale. If a DynamicLossScale is passed, DynamicLossScale.multiplier must
962 be 2 (the default).
963 """
964
965 def __init__(self, optimizer, loss_scale):
966 warn_msg_prefix = (
967 'tf.keras.mixed_precision.experimental.LossScaleOptimizer is '
968 'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer '
969 'instead. ')
970
971 if isinstance(loss_scale, dict):
972 loss_scale = keras_loss_scale_module.deserialize(loss_scale)
973
974 if isinstance(loss_scale, (int, float)):
975 tf_logging.warning(
976 warn_msg_prefix + 'For example:\n'
977 ' opt = tf.keras.mixed_precision.LossScaleOptimizer('
978 'opt, dynamic=False, initial_scale={})'.format(loss_scale))
979 super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
980 initial_scale=loss_scale)
981 elif isinstance(loss_scale, loss_scale_module.FixedLossScale):
982 ls_val = loss_scale._loss_scale_value # pylint: disable=protected-access
983 tf_logging.warning(
984 warn_msg_prefix + 'For example:\n'
985 ' opt = tf.keras.mixed_precision.LossScaleOptimizer('
986 'opt, dynamic=False, initial_scale={})'.format(ls_val))
987 super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
988 initial_scale=ls_val)
989 elif loss_scale == 'dynamic':
990 tf_logging.warning(
991 warn_msg_prefix + 'For example:\n'
992 ' opt = tf.keras.mixed_precision.LossScaleOptimizer('
993 'opt)')
994 super(LossScaleOptimizerV1, self).__init__(optimizer)
995 elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
996 kwargs = {}
997 extra_arguments = ''
998 if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE:
999 kwargs['initial_scale'] = loss_scale.initial_loss_scale
1000 extra_arguments += (', initial_scale=%s' %
1001 loss_scale.initial_loss_scale)
1002 if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS:
1003 kwargs['dynamic_growth_steps'] = loss_scale.increment_period
1004 extra_arguments += (', dynamic_growth_steps=%s' %
1005 loss_scale.increment_period)
1006 if loss_scale.multiplier != 2:
1007 raise ValueError('When passing a DynamicLossScale to "loss_scale", '
1008 'DynamicLossScale.multiplier must be 2. Got: %s'
1009 % (loss_scale,))
1010 tf_logging.warning(
1011 warn_msg_prefix +
1012 'Note that the non-experimental LossScaleOptimizer does not take a '
1013 'DynamicLossScale but instead takes the dynamic configuration '
1014 'directly in the constructor. For example:\n'
1015 ' opt = tf.keras.mixed_precision.LossScaleOptimizer('
1016 'opt{})\n'.format(extra_arguments))
1017 super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs)
1018 elif isinstance(loss_scale, loss_scale_module.LossScale):
1019 raise TypeError('Passing a LossScale that is not a FixedLossScale or a '
1020 'DynamicLossScale is no longer supported. Got: {}'
1021 .format(loss_scale))
1022 else:
1023 raise ValueError('Invalid value passed to loss_scale. loss_scale '
1024 'must be the string "dynamic" (recommended), an int, '
1025 'a float, a FixedLossScale, or a DynamicLossScale. Got '
1026 'value: {}'.format(loss_scale))
1027
1028 @classmethod
1029 def from_config(cls, config, custom_objects=None):
1030 config = config.copy() # Make a copy, since we mutate config
1031
1032 # If loss_scale is in config, we assume we are deserializing a
1033 # LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are
1034 # deserializing a LossScaleOptimizer from TF 2.4 or above.
1035 if 'loss_scale' in config:
1036 config['loss_scale'] = keras_loss_scale_module.deserialize(
1037 config['loss_scale'])
1038 if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale)
1039 and config['loss_scale'].multiplier != 2):
1040 raise ValueError('Cannot deserialize LossScaleOptimizer with a '
1041 'DynamicLossScale whose multiplier is not 2. Got '
1042 'DynamicLossScale: %s' % (config['loss_scale'],))
1043 config['optimizer'] = optimizers.deserialize(
1044 config['optimizer'], custom_objects=custom_objects)
1045 return cls(**config)
1046
1047 # We convert the config, as generated by LossScaleOptimizer.get_config, to a
1048 # version that can be passed to LossScaleOptimizerV1.__init__
1049 if config['dynamic']:
1050 config['loss_scale'] = loss_scale_module.DynamicLossScale(
1051 config['initial_scale'], config['dynamic_growth_steps'], multiplier=2)
1052 else:
1053 config['loss_scale'] = loss_scale_module.FixedLossScale(
1054 config['initial_scale'])
1055
1056 del config['dynamic']
1057 del config['initial_scale']
1058 del config['dynamic_growth_steps']
1059 config['optimizer'] = optimizers.deserialize(
1060 config.pop('inner_optimizer'), custom_objects=custom_objects)
1061 return cls(**config)
1062
1063
1064class FakeOptimizerForRestoration(trackable.Trackable):
1065 """A fake optimizer used to support restoring TensorFlow 2.2 checkpoints.
1066
1067 The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class
1068 exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow.
1069
1070 In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the
1071 following in LossScaleOptimizer.__init__
1072
1073 ```
1074 self._track_trackable(self._optimizer, 'base_optimizer')
1075 ```
1076
1077 This means a dependency from the LossScaleOptimizer to the wrapped optimizer
1078 would be stored in the checkpoint. However now, the checkpoint format with a
1079 LossScaleOptimizer is the same as the format without a LossScaleOptimizer,
1080 except the loss scale is also stored. This means there is no dependency from
1081 the LossScaleOptimizer to the wrapped optimizer. Instead, the
1082 LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's
1083 perspective, by overriding all Trackable methods and delegating them to the
1084 wrapped optimizer.
1085
1086 To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency
1087 on this class instead of the inner optimizer. When restored, this class will
1088 instead restore the slot variables of the inner optimizer. Since this class
1089 has no variables, it does not affect the checkpoint when saved.
1090 """
1091
1092 def __init__(self, optimizer):
1093 self._optimizer = optimizer
1094
1095 def get_slot_names(self):
1096 return self._optimizer.get_slot_names()
1097
1098 def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
1099 variable):
1100 return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access
1101 slot_variable_position, slot_name, variable)
1102
1103
1104mixed_precision.register_loss_scale_wrapper(optimizer_v2.OptimizerV2,
1105 LossScaleOptimizerV1)
1106
1107
1108def _multiply_gradient(gradient, scale):
1109 """Multiply a (possibly sparse) gradient by the given scale factor."""
1110 scale = math_ops.cast(scale, gradient.dtype)
1111 if isinstance(gradient, indexed_slices.IndexedSlices):
1112 return indexed_slices.IndexedSlices(
1113 gradient.values * scale,
1114 gradient.indices,
1115 dense_shape=gradient.dense_shape)
1116 else:
1117 return gradient * scale
1118
1119
1120def strategy_supports_loss_scaling():
1121 """Returns True if the current Strategy supports loss scaling."""
1122 if not distribute_lib.has_strategy():
1123 return True
1124 strategy = distribute_lib.get_strategy()
1125 # Strategies are supported if either there is only one replica or if variables
1126 # are replicated per device. Otherwise, the current model.fit() implementation
1127 # and most custom training loops incorrectly unscale the gradients. Currently,
1128 # gradients are unscaled once per compute replica, but they should be unscaled
1129 # once per variable replica. When there is one variable replica for each
1130 # compute replica, this works fine, but otherwise issues will occur.
1131 # TODO(reedwm): Support all strategies.
1132 return isinstance(strategy, (
1133 collective_all_reduce_strategy.CollectiveAllReduceStrategy,
1134 collective_all_reduce_strategy.CollectiveAllReduceStrategyV1,
1135 one_device_strategy.OneDeviceStrategy,
1136 one_device_strategy.OneDeviceStrategyV1,
1137 mirrored_strategy.MirroredStrategy,
1138 mirrored_strategy.MirroredStrategyV1,
1139 ))