Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/mixed_precision/loss_scale_optimizer.py: 30%
569 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the loss scaling optimizer class."""
17import tensorflow.compat.v2 as tf
19from keras.src import backend
20from keras.src import optimizers
21from keras.src.dtensor import utils as dtensor_utils
22from keras.src.optimizers import optimizer
23from keras.src.optimizers import utils as optimizer_utils
24from keras.src.optimizers.legacy import optimizer_v2
25from keras.src.saving import serialization_lib
27# isort: off
28from tensorflow.python.platform import tf_logging
29from tensorflow.python.util.tf_export import keras_export
32class _UnwrapPreventer:
33 """Wrapper that DistributionStrategy will not unwrap.
35 Typically, DistributionStrategy will unwrap values when going from a cross-
36 replica context to a replica context via `call_for_each_replica`. This class
37 is a wrapper that DistributionStrategy will not unwrap, so it can be used to
38 prevent it from unwrapping a value.
40 TODO(reedwm): Find/implement a better way of preventing values from being
41 unwrapped by DistributionStrategy
42 """
44 __slots__ = ["value"]
46 def __init__(self, value):
47 self.value = value
50def _is_all_finite(grads):
51 """Returns a scalar boolean tensor indicating if all gradients are
52 finite."""
54 def raw_values(g):
55 return g.values if isinstance(g, tf.IndexedSlices) else g
57 is_finite_per_grad = [
58 tf.reduce_all(tf.math.is_finite(raw_values(g)))
59 for g in grads
60 if g is not None
61 ]
62 return tf.reduce_all(is_finite_per_grad)
65def _op_in_graph_mode(tensor):
66 """Returns the tensor's op in graph mode, or the tensor in eager mode.
68 This is useful because sometimes an op is needed in graph mode instead of a
69 tensor. In eager mode, there are no ops.
71 Args:
72 tensor: A tensor.
74 Returns:
75 The tensor's op in graph mode. The tensor in eager mode.
76 """
77 if tf.executing_eagerly():
78 return tensor
79 return tensor.op
82def _assign_if_finite(var, value):
83 """Assigns a value to a variable if the value is finite."""
84 return tf.cond(
85 tf.math.is_finite(value),
86 lambda: _op_in_graph_mode(var.assign(value)),
87 tf.no_op,
88 )
91def _maybe_warn_about_scaling(
92 loss_has_been_scaled, gradients_have_been_unscaled
93):
94 """Warn if the loss or gradients hasn't been scaled or unscaled."""
95 if loss_has_been_scaled and gradients_have_been_unscaled:
96 return
98 example_code = """
99 with tf.GradientTape() as tape:
100 loss = loss_fn()
101 scaled_loss = opt.get_scaled_loss(loss)
102 scaled_grads = tape.gradient(scaled_loss, vars)
103 grads = opt.get_unscaled_gradients(scaled_grads)
104 opt.apply_gradients([(grads, var)])"""
106 if not loss_has_been_scaled and not gradients_have_been_unscaled:
107 tf_logging.warning(
108 "You forgot to call LossScaleOptimizer.get_scaled_loss() and "
109 "LossScaleOptimizer.get_unscaled_gradients() before calling "
110 "LossScaleOptimizer.apply_gradients(). This will likely result in "
111 "worse model quality, so please call them in the correct places! "
112 f"For example:{example_code}\nFor more information, see "
113 "https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision/LossScaleOptimizer" # noqa: E501
114 )
115 elif not loss_has_been_scaled:
116 tf_logging.warning(
117 "You forgot to call LossScaleOptimizer.get_scaled_loss() before "
118 "calling LossScaleOptimizer.apply_gradients() (you did call "
119 "get_unscaled_gradients() however). This will likely result in "
120 "worse model quality, so please call get_scaled_loss() in the "
121 f"correct place! For example:{example_code}\nFor more information, "
122 "see "
123 "https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision/LossScaleOptimizer" # noqa: E501
124 )
125 elif not gradients_have_been_unscaled:
126 tf_logging.warning(
127 "You forgot to call LossScaleOptimizer.get_unscaled_gradients() "
128 "before calling LossScaleOptimizer.apply_gradients() (you did call "
129 "get_scaled_loss() however). This will likely result in worse "
130 "model quality, so please call get_unscaled_gradients() in the "
131 f"correct place! For example:{example_code}\nFor more information, "
132 "see "
133 "https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision/LossScaleOptimizer" # noqa: E501
134 )
137class _DynamicLossScaleState(tf.__internal__.tracking.Trackable):
138 """The state of a dynamic loss scale."""
140 def __init__(self, initial_loss_scale, growth_steps, multiplier):
141 """Creates the dynamic loss scale."""
142 super().__init__()
143 self._initial_loss_scale = float(initial_loss_scale)
144 self._growth_steps = int(growth_steps)
145 self._multiplier = float(multiplier)
147 self._weights = {}
148 self._current_loss_scale = self._add_weight(
149 name="current_loss_scale",
150 dtype=tf.float32,
151 initial_value=self._initial_loss_scale,
152 )
153 # The number of consecutive steps with finite gradients since the last
154 # nonfinite gradient or change in loss scale. The name is 'good_steps'
155 # for backwards compatibility with older checkpoints.
156 self._counter = self._add_weight(
157 name="good_steps", dtype=tf.int64, initial_value=0
158 )
160 def _add_weight(self, name, initial_value, dtype=None):
161 """Adds a weight to this loss scale.
163 Args:
164 name: Variable name.
165 initial_value: The variable's initial value.
166 dtype: The type of the variable.
168 Returns:
169 A variable.
171 Raises:
172 RuntimeError: If a weight with `name` has already been added.
173 """
174 variable = tf.Variable(
175 initial_value=initial_value,
176 name=name,
177 dtype=dtype,
178 trainable=False,
179 synchronization=tf.VariableSynchronization.AUTO,
180 # Set aggregation to NONE, as loss scaling variables should never be
181 # aggregated.
182 aggregation=tf.VariableAggregation.NONE,
183 )
184 if tf.executing_eagerly():
185 graph_key = None
186 else:
187 graph = tf.compat.v1.get_default_graph()
188 graph_key = graph._graph_key
190 key = (name, graph_key)
191 self._weights[key] = variable
192 self._handle_deferred_dependencies(name=name, trackable=variable)
193 backend.track_variable(variable)
194 return variable
196 def _trackable_children(self, save_type="checkpoint", **kwargs):
197 """From Trackable. Gather graph-specific weights to save."""
198 if tf.executing_eagerly():
199 graph_key = None
200 else:
201 graph = tf.compat.v1.get_default_graph()
202 graph_key = graph._graph_key
203 weights = {}
204 for (name, g), v in sorted(
205 self._weights.items(), key=lambda i: i[0][0]
206 ):
207 if g == graph_key:
208 weights[name] = v
209 weights.update(super()._trackable_children(save_type, **kwargs))
210 return weights
212 def _lookup_dependency(self, name):
213 """From Trackable. Find a weight in the current graph."""
214 unconditional = super()._lookup_dependency(name)
215 if unconditional is not None:
216 return unconditional
217 if tf.executing_eagerly():
218 graph_key = None
219 else:
220 graph = tf.compat.v1.get_default_graph()
221 graph_key = graph._graph_key
222 return self._weights.get((name, graph_key), None)
224 @property
225 def initial_loss_scale(self):
226 return self._initial_loss_scale
228 @property
229 def growth_steps(self):
230 return self._growth_steps
232 @property
233 def multiplier(self):
234 return self._multiplier
236 @property
237 def current_loss_scale(self):
238 """Returns the current loss scale as a float32 `tf.Variable`."""
239 return self._current_loss_scale
241 @property
242 def counter(self):
243 """Returns the counter as a float32 `tf.Variable`."""
244 return self._counter
246 def __call__(self):
247 """Returns the current loss scale as a scalar `float32` tensor."""
248 return tf.convert_to_tensor(self._current_loss_scale)
250 def update(self, grads):
251 """Updates the value of the loss scale.
253 Args:
254 grads: A nested structure of unscaled gradients, each which is an
255 all-reduced gradient of the loss with respect to a weight.
257 Returns:
258 update_op: In eager mode, None. In graph mode, an op to update the
259 loss scale.
260 should_apply_gradients: Either a bool or a scalar boolean tensor. If
261 False, the caller should skip applying `grads` to the variables this
262 step.
263 """
264 grads = tf.nest.flatten(grads)
265 if (
266 tf.distribute.has_strategy()
267 and tf.distribute.in_cross_replica_context()
268 ):
269 distribution = tf.distribute.get_strategy()
270 is_finite_per_replica = distribution.extended.call_for_each_replica(
271 _is_all_finite, args=(grads,)
272 )
273 # Each replica computed the same `is_finite` value, since `grads` is
274 # all-reduced across replicas. Arbitrarily take `is_finite` from the
275 # first replica.
276 is_finite = distribution.experimental_local_results(
277 is_finite_per_replica
278 )[0]
279 else:
280 is_finite = _is_all_finite(grads)
282 def update_if_finite_grads():
283 """Update assuming the gradients are finite."""
285 def incr_loss_scale():
286 new_loss_scale = self.current_loss_scale * self.multiplier
287 return tf.group(
288 _assign_if_finite(self.current_loss_scale, new_loss_scale),
289 self.counter.assign(0),
290 )
292 return tf.cond(
293 self.counter + 1 >= self.growth_steps,
294 incr_loss_scale,
295 lambda: _op_in_graph_mode(self.counter.assign_add(1)),
296 )
298 def update_if_not_finite_grads():
299 """Update assuming the gradients are nonfinite."""
301 new_loss_scale = tf.maximum(
302 self.current_loss_scale / self.multiplier, 1
303 )
304 return tf.group(
305 self.counter.assign(0),
306 self.current_loss_scale.assign(new_loss_scale),
307 )
309 update_op = tf.cond(
310 is_finite, update_if_finite_grads, update_if_not_finite_grads
311 )
312 should_apply_gradients = is_finite
313 return update_op, should_apply_gradients
316# See LossScaleOptimizer docstring for why this is so big
317_DEFAULT_INITIAL_SCALE = 2**15
318_DEFAULT_GROWTH_STEPS = 2000
321# TODO(b/215389169): Delete this class after `OptimizerV2` is deprecated.
322class LossScaleOptimizerMetaclass(type):
323 """Metaclass that delegates LossScaleOptimizer instance creation.
325 This metaclass causes a LossScaleOptimizer or LossScaleOptimizerV3 to be
326 created when a BaseLossScaleOptimizer is constructed. As a result, when a
327 user creates a loss scale optimizer with
328 `tf.keras.mixed_precision.LossScaleOptimizer(opt)`, either a
329 LossScaleOptimizer or LossScaleOptimizerV3 will be created, depending on the
330 type of `opt`.
331 """
333 def __call__(cls, inner_optimizer, *args, **kwargs):
334 if cls is not BaseLossScaleOptimizer:
335 return super(LossScaleOptimizerMetaclass, cls).__call__(
336 inner_optimizer, *args, **kwargs
337 )
338 if isinstance(inner_optimizer, optimizer_v2.OptimizerV2):
339 return LossScaleOptimizer(inner_optimizer, *args, **kwargs)
340 elif isinstance(inner_optimizer, optimizer.Optimizer):
341 return LossScaleOptimizerV3(inner_optimizer, *args, **kwargs)
343 # Raise TypeError because inner_optimizer is not an optimizer
344 msg = (
345 '"inner_optimizer" must be an instance of '
346 "`tf.keras.optimizers.Optimizer` or "
347 "`tf.keras.optimizers.experimental.Optimizer`, but got: "
348 f"{inner_optimizer}."
349 )
350 raise TypeError(msg)
353# TODO(b/215389169): Delete this class after `OptimizerV2` is deprecated.
356@keras_export("keras.mixed_precision.LossScaleOptimizer")
357class BaseLossScaleOptimizer(metaclass=LossScaleOptimizerMetaclass):
358 """An optimizer that applies loss scaling to prevent numeric underflow.
360 Loss scaling is a technique to prevent numeric underflow in intermediate
361 gradients when float16 is used. To prevent underflow, the loss is multiplied
362 (or "scaled") by a certain factor called the "loss scale", which causes
363 intermediate gradients to be scaled by the loss scale as well. The final
364 gradients are divided (or "unscaled") by the loss scale to bring them back
365 to their original value.
367 `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it.
368 By default, the loss scale is dynamically updated over time so you do not
369 have to choose the loss scale. The `minimize` method automatically scales
370 the loss, unscales the gradients, and updates the loss scale so all you have
371 to do is wrap your optimizer with a `LossScaleOptimizer` if you use
372 `minimize`. For example:
374 >>> opt = tf.keras.optimizers.experimental.SGD(0.25)
375 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
376 >>> var = tf.Variable(1.)
377 >>> loss_fn = lambda: var ** 2
378 >>> # 'minimize' applies loss scaling and updates the loss sale.
379 >>> opt.minimize(loss_fn, var_list=[var])
380 >>> var.numpy()
381 0.5
383 If a `tf.GradientTape` is used to compute gradients instead of `minimize`,
384 you must scale the loss and gradients manually. This can be done with the
385 `LossScaleOptimizer.get_scaled_loss` and
386 `LossScaleOptimizer.get_unscaled_gradients` methods. For example:
388 >>> with tf.GradientTape() as tape:
389 ... loss = loss_fn()
390 ... scaled_loss = opt.get_scaled_loss(loss)
391 >>> scaled_grad = tape.gradient(scaled_loss, var)
392 >>> (grad,) = opt.get_unscaled_gradients([scaled_grad])
393 >>> opt.apply_gradients([(grad, var)]) # Loss scale is updated here
394 >>> var.numpy()
395 0.25
397 Warning: If you forget to call `get_scaled_loss` or `get_unscaled_gradients`
398 (or both) when using a `tf.GradientTape`, the model will likely converge to
399 a worse quality. Please make sure you call each function exactly once.
401 When mixed precision with float16 is used, there is typically no risk of
402 underflow affecting model quality if loss scaling is properly used. See
403 [the mixed precision guide](
404 https://www.tensorflow.org/guide/keras/mixed_precision) for more information
405 on how to use mixed precision.
407 Args:
408 inner_optimizer: The `tf.keras.optimizers.Optimizer` or
409 `tf.keras.optimizers.experimental.Optimizer` instance to wrap.
410 dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to
411 True. If True, the loss scale will be dynamically updated over time
412 using an algorithm that keeps the loss scale at approximately its
413 optimal value. If False, a single fixed loss scale is used and
414 `initial_scale` must be specified, which is used as the loss scale.
415 Recommended to keep as True, as choosing a fixed loss scale can be
416 tricky. Currently, there is a small performance overhead to dynamic loss
417 scaling compared to fixed loss scaling.
418 initial_scale: The initial loss scale. If `dynamic` is True, this defaults
419 to `2 ** 15`. If `dynamic` is False, this must be specified and acts as
420 the sole loss scale, as the loss scale does not change over time. When
421 dynamic loss scaling is used, is better for this to be a very high
422 number, because a loss scale that is too high gets lowered far more
423 quickly than a loss scale that is too low gets raised.
424 dynamic_growth_steps: With dynamic loss scaling, every
425 `dynamic_growth_steps` steps with finite gradients, the loss scale is
426 doubled. Defaults to 2000. If a nonfinite gradient is encountered, the
427 count is reset back to zero, gradients are skipped that step, and the
428 loss scale is halved. The count can be queried with
429 `LossScaleOptimizer.dynamic_counter`. This argument can only be
430 specified if `dynamic` is True.
432 `LossScaleOptimizer` will occasionally skip applying gradients to the
433 variables, in which case the trainable variables will not change that step.
434 This is done because the dynamic loss scale will sometimes be raised too
435 high, causing overflow in the gradients. Typically, the first 2 to 15 steps
436 of the model are skipped as the initial loss scale is very high, but
437 afterwards steps will only be skipped on average 0.05% of the time (the
438 fraction of steps skipped is `1 / dynamic_growth_steps`).
440 `LossScaleOptimizer` delegates all public `Optimizer` methods to the inner
441 optimizer. Additionally, in methods `minimize` and `get_gradients`, it
442 scales the loss and unscales the gradients. In methods `minimize` and
443 `apply_gradients`, it additionally updates the loss scale and skips applying
444 gradients if any gradient has a nonfinite value.
446 ### Hyperparameters
448 If wrapping a `tf.keras.optimizers.Optimizer`, hyperparameters can be
449 accessed and set on the LossScaleOptimizer, which will be delegated to the
450 wrapped optimizer.
452 >>> opt = tf.keras.optimizers.legacy.Adam(beta_1=0.8, epsilon=1e-5)
453 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
454 >>> opt.beta_1 # Equivalent to `opt.inner_optimizer.beta_1`
455 0.8
456 >>> opt.beta_1 = 0.7 # Equivalent to `opt.inner_optimizer.beta_1 = 0.7`
457 >>> opt.beta_1
458 0.7
459 >>> opt.inner_optimizer.beta_1
460 0.7
462 However, accessing or setting non-hyperparameters is not delegated to the
463 LossScaleOptimizer. In an Adam optimizer, `beta_1` is a hyperparameter but
464 `epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on
465 `beta_1`.
467 >>> opt.inner_optimizer.epsilon
468 1e-5
469 >>> opt.epsilon
470 Traceback (most recent call last):
471 ...
472 AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon'
473 >>> opt.epsilon = 1e-4 # This does NOT set epsilon on `opt.inner_optimizer`
474 >>> opt.inner_optimizer.epsilon
475 >>> 1e-5
477 In the above example, despite epsilon being set on the LossScaleOptimizer,
478 the old epsilon value will still be used when training as epsilon was not
479 set on the inner optimizer.
480 """
482 @property
483 def dynamic(self):
484 """Bool indicating whether dynamic loss scaling is used."""
485 raise NotImplementedError
487 @property
488 def loss_scale(self):
489 """The current loss scale as a float32 scalar tensor."""
490 raise NotImplementedError
492 @property
493 def dynamic_counter(self):
494 """The number of steps since the loss scale was last increased or
495 decreased.
497 This is None if `LossScaleOptimizer.dynamic` is False.
499 The counter is incremented every step. Once it reaches
500 `LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be
501 doubled and the counter will be reset back to zero. If nonfinite
502 gradients are encountered, the loss scale will be halved and the counter
503 will be reset back to zero.
504 """
505 raise NotImplementedError
507 @property
508 def initial_scale(self):
509 """The initial loss scale.
511 If `LossScaleOptimizer.dynamic` is False, this is the same number as
512 `LossScaleOptimizer.loss_scale`, as the loss scale never changes.
513 """
514 raise NotImplementedError
516 @property
517 def dynamic_growth_steps(self):
518 """The number of steps it takes to increase the loss scale.
520 This is None if `LossScaleOptimizer.dynamic` is False.
522 Every `dynamic_growth_steps` consecutive steps with finite gradients,
523 the loss scale is increased.
524 """
525 raise NotImplementedError
527 @property
528 def inner_optimizer(self):
529 """The optimizer that this LossScaleOptimizer is wrapping."""
530 raise NotImplementedError
532 def get_scaled_loss(self, loss):
533 """Scales the loss by the loss scale.
535 This method is only needed if you compute gradients manually, e.g. with
536 `tf.GradientTape`. In that case, call this method to scale the loss
537 before passing the loss to `tf.GradientTape`. If you use
538 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`,
539 loss scaling is automatically applied and this method is unneeded.
541 If this method is called, `get_unscaled_gradients` should also be
542 called. See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for
543 an example.
545 Args:
546 loss: The loss, which will be multiplied by the loss scale. Can either
547 be a tensor or a callable returning a tensor.
549 Returns:
550 `loss` multiplied by `LossScaleOptimizer.loss_scale`.
551 """
552 # Calls to this function would be delegated to `get_scaled_loss`
553 # of either `LossScaleOptimizer` or `LossScaleOptimizerV3`, depending on
554 # the type of `inner_optimizer`.
555 raise NotImplementedError
557 def get_unscaled_gradients(self, grads):
558 """Unscales the gradients by the loss scale.
560 This method is only needed if you compute gradients manually, e.g. with
561 `tf.GradientTape`. In that case, call this method to unscale the
562 gradients after computing them with `tf.GradientTape`. If you use
563 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`,
564 loss scaling is automatically applied and this method is unneeded.
566 If this method is called, `get_scaled_loss` should also be called. See
567 the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an
568 example.
570 Args:
571 grads: A list of tensors, each which will be divided by the loss
572 scale. Can have None values, which are ignored.
574 Returns:
575 A new list the same size as `grads`, where every non-None value in
576 `grads` is divided by `LossScaleOptimizer.loss_scale`.
577 """
578 # Calls to this function would be delegated to `get_unscaled_gradients`
579 # of either `LossScaleOptimizer` or `LossScaleOptimizerV3`, depending on
580 # the type of `inner_optimizer`.
581 raise NotImplementedError
584class LossScaleOptimizer(
585 tf.__internal__.tracking.DelegatingTrackableMixin,
586 optimizer_v2.OptimizerV2,
587 BaseLossScaleOptimizer,
588):
589 """An optimizer that applies loss scaling to prevent numeric underflow."""
591 _HAS_AGGREGATE_GRAD = True
593 def __init__(
594 self,
595 inner_optimizer,
596 dynamic=True,
597 initial_scale=None,
598 dynamic_growth_steps=None,
599 ):
600 if not isinstance(inner_optimizer, optimizer_v2.OptimizerV2):
601 if isinstance(inner_optimizer, optimizer.Optimizer):
602 # Give better error message if the new experimental optimizer is
603 # passed.
604 raise TypeError(
605 "You passed an instance of the new experimental "
606 "optimizer, `optimizer.Optimizer`, "
607 "to LossScaleOptimizer, but "
608 "only the classic optimizers subclassing from "
609 "`tf.keras.optimizers.Optimizer` can be passed. Please "
610 "use `loss_scale_optimizer.LossScaleOptimizerV3` "
611 "instead of "
612 "`tf.keras.mixed_precision.LossScaleOptimizer`, "
613 "as the former supports wrapping "
614 "instances of the new experimental optimizer. "
615 f"Got optimizer: {inner_optimizer}"
616 )
617 msg = (
618 '"inner_optimizer" must be an instance of '
619 "`tf.keras.optimizers.Optimizer`, but got: %s. "
620 % inner_optimizer
621 )
622 raise TypeError(msg)
623 if not isinstance(dynamic, bool):
624 # Catch errors if a user incorrectly passes a string or float to the
625 # second argument argument, as this was commonly done for the
626 # now-removed LossScaleOptimizerV1.
627 raise TypeError(
628 '"dynamic" argument to LossScaleOptimizer.__init__ must '
629 "be a bool, but got: %r" % (dynamic,)
630 )
631 if isinstance(inner_optimizer, LossScaleOptimizer):
632 raise TypeError(
633 "LossScaleOptimizer cannot wrap another "
634 "LossScaleOptimizer, but got: %s" % (inner_optimizer,)
635 )
636 _raise_if_strategy_unsupported()
637 if getattr(
638 inner_optimizer, "_is_wrapped_by_loss_scale_optimizer", False
639 ):
640 # TODO(reedwm): Maybe support this. The difficulty is that LSO has
641 # the same checkpoint format as the inner optimizer, so multiple
642 # LSOs wrapping the same optimizer causes the checkpointing logic to
643 # become confused.
644 raise ValueError(
645 '"inner_optimizer" is already wrapped by a '
646 "LossScaleOptimizer. An optimizer can only be wrapped "
647 "by a single LossScaleOptimizer"
648 )
649 self._optimizer = inner_optimizer
650 self._optimizer._is_wrapped_by_loss_scale_optimizer = True
652 # We don't call super().__init__, since we do not want to call
653 # OptimizerV2's constructor.
654 tf.__internal__.tracking.DelegatingTrackableMixin.__init__(
655 self, self._optimizer
656 )
658 if dynamic:
659 if initial_scale is None:
660 initial_scale = _DEFAULT_INITIAL_SCALE
661 if dynamic_growth_steps is None:
662 dynamic_growth_steps = _DEFAULT_GROWTH_STEPS
663 self._loss_scale = _DynamicLossScaleState(
664 initial_scale, dynamic_growth_steps, multiplier=2
665 )
666 self._track_trackable(self._loss_scale, "loss_scale")
667 else:
668 if initial_scale is None:
669 raise ValueError(
670 '"initial_scale" must be specified if "dynamic" is False'
671 )
672 self._loss_scale = float(initial_scale)
673 if dynamic_growth_steps is not None:
674 raise ValueError(
675 '"dynamic_growth_steps" must be None if "dynamic" '
676 "is False, but got: %s" % (dynamic_growth_steps,)
677 )
679 # Used to track whether get_scaled_loss() and get_unscaled_gradients()
680 # have been called
681 self._loss_has_been_scaled = False
682 self._gradients_have_been_unscaled = False
684 # To support restoring TensorFlow 2.2 checkpoints.
685 self._track_trackable(
686 FakeOptimizerForRestoration(self._optimizer), "base_optimizer"
687 )
689 @property
690 def dynamic(self):
691 return isinstance(self._loss_scale, _DynamicLossScaleState)
693 @property
694 def loss_scale(self):
695 if isinstance(self._loss_scale, _DynamicLossScaleState):
696 return tf.convert_to_tensor(self._loss_scale.current_loss_scale)
697 else:
698 return tf.convert_to_tensor(self._loss_scale)
700 @property
701 def dynamic_counter(self):
702 if isinstance(self._loss_scale, _DynamicLossScaleState):
703 return self._loss_scale.counter
704 else:
705 return None
707 @property
708 def initial_scale(self):
709 if isinstance(self._loss_scale, _DynamicLossScaleState):
710 return self._loss_scale.initial_loss_scale
711 else:
712 return self._loss_scale
714 @property
715 def dynamic_growth_steps(self):
716 if isinstance(self._loss_scale, _DynamicLossScaleState):
717 return self._loss_scale.growth_steps
718 else:
719 return None
721 @property
722 def inner_optimizer(self):
723 return self._optimizer
725 def get_scaled_loss(self, loss):
726 self._loss_has_been_scaled = True
727 if callable(loss):
729 def new_loss():
730 loss_val = loss()
731 return loss_val * tf.cast(self.loss_scale, loss_val.dtype)
733 return new_loss
734 else:
735 return loss * tf.cast(self.loss_scale, loss.dtype)
737 def get_unscaled_gradients(self, grads):
738 self._gradients_have_been_unscaled = True
739 loss_scale_reciprocal = 1.0 / self.loss_scale
740 return [
741 _multiply_gradient(g, loss_scale_reciprocal)
742 if g is not None
743 else None
744 for g in grads
745 ]
747 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
748 tape = tf.GradientTape() if tape is None else tape
749 with tape:
750 loss = self.get_scaled_loss(loss)
751 grads_and_vars = self._optimizer._compute_gradients(
752 loss, var_list, grad_loss, tape=tape
753 )
754 grads = [g for g, _ in grads_and_vars]
755 weights = [v for _, v in grads_and_vars]
756 unscaled_grads = self.get_unscaled_gradients(grads)
757 return list(zip(unscaled_grads, weights))
759 def get_gradients(self, loss, params):
760 loss = self.get_scaled_loss(loss)
761 grads = self._optimizer.get_gradients(loss, params)
762 return self.get_unscaled_gradients(grads)
764 def _create_all_weights(self, var_list):
765 self._optimizer._create_all_weights(var_list)
767 def apply_gradients(
768 self, grads_and_vars, name=None, experimental_aggregate_gradients=True
769 ):
770 if tf.distribute.in_cross_replica_context():
771 raise ValueError(
772 "apply_gradients() must be called in a replica context."
773 )
774 # We check for the strategy here despite already checking in the
775 # constructor as frequently the optimizer is created outside the
776 # strategy's scope.
777 _raise_if_strategy_unsupported()
778 _maybe_warn_about_scaling(
779 self._loss_has_been_scaled, self._gradients_have_been_unscaled
780 )
782 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
783 if experimental_aggregate_gradients:
784 # We must aggregate the gradients here instead of in
785 # self.optimizer.apply_gradients, so that any NaN or Inf gradients
786 # are propagated to each replica. If any replica has a NaN or Inf
787 # gradient, they must all have a NaN or Inf gradient so that they
788 # all skip the step.
789 grads_and_vars = self._optimizer._transform_unaggregated_gradients(
790 grads_and_vars
791 )
792 grads_and_vars = self._optimizer._aggregate_gradients(
793 grads_and_vars
794 )
796 grads_and_vars = tuple(grads_and_vars)
797 grads = [g for g, _ in grads_and_vars]
798 # We do not want DistributionStrategy to unwrap any MirroredVariables in
799 # grads_and_vars, because even in a replica context, the wrapped
800 # optimizer expects mirrored variables. So we wrap the variables with an
801 # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the
802 # MirroredVariables.
803 wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars])
805 def do_not_apply_fn():
806 # Normally self._optimizer.iterations is incremented in
807 # self._optimizer.apply_gradients(). Since that is not called in
808 # this branch, we increment it here instead.
809 return self._optimizer.iterations.assign_add(1, read_value=False)
811 def _if_should_apply_grads(grads):
812 if isinstance(self._loss_scale, _DynamicLossScaleState):
813 return self._loss_scale.update(grads)
814 else:
815 return (tf.no_op(), True)
817 if tf.__internal__.distribute.strategy_supports_no_merge_call():
818 loss_scale_update_op, should_apply_grads = _if_should_apply_grads(
819 grads
820 )
822 def apply_fn():
823 return self._apply_gradients(grads, wrapped_vars, name)
825 maybe_apply_op = tf.__internal__.smart_cond.smart_cond(
826 should_apply_grads, apply_fn, do_not_apply_fn
827 )
828 return tf.group(maybe_apply_op, loss_scale_update_op)
830 else:
832 def _apply_gradients_cross_replica(
833 distribution, grads, wrapped_vars, name
834 ):
835 (
836 loss_scale_update_op,
837 should_apply_grads,
838 ) = _if_should_apply_grads(grads)
840 def apply_fn():
841 return distribution.extended.call_for_each_replica(
842 self._apply_gradients, args=(grads, wrapped_vars, name)
843 )
845 # Note: We must call this cond() in a cross-replica context.
846 # DistributionStrategy does not support having a cond in a
847 # replica context with a branch that calls `merge_call`, and
848 # self._optimizer.apply_gradients calls `merge_call`.
849 maybe_apply_op = tf.__internal__.smart_cond.smart_cond(
850 should_apply_grads, apply_fn, do_not_apply_fn
851 )
852 return tf.group(maybe_apply_op, loss_scale_update_op)
854 return tf.distribute.get_replica_context().merge_call(
855 _apply_gradients_cross_replica, args=(grads, wrapped_vars, name)
856 )
858 def _apply_gradients(self, grads, wrapped_vars, name):
859 # Pass experimental_aggregate_gradients=False since LossScaleOptimizer
860 # already aggregated the gradients.
861 # TODO(reedwm): This will raise a fairly cryptic error message if
862 # self._optimizer.apply_gradients does not take
863 # experimental_aggregate_gradients.
864 return self._optimizer.apply_gradients(
865 list(zip(grads, wrapped_vars.value)),
866 name=name,
867 experimental_aggregate_gradients=False,
868 )
870 def get_config(self):
871 serialized_optimizer = optimizers.serialize(self._optimizer)
872 return {
873 "inner_optimizer": serialized_optimizer,
874 "dynamic": self.dynamic,
875 "initial_scale": self.initial_scale,
876 "dynamic_growth_steps": self.dynamic_growth_steps,
877 }
879 @classmethod
880 def from_config(cls, config, custom_objects=None):
881 config = config.copy() # Make a copy, since we mutate config
882 if "loss_scale" in config:
883 # If loss_scale is in config, we assume we are deserializing a
884 # LossScaleOptimizer from TF 2.3 or below. We convert the config so
885 # it can be deserialized in the current LossScaleOptimizer.
886 loss_scale = serialization_lib.deserialize_keras_object(
887 config.pop("loss_scale"),
888 module_objects={
889 "FixedLossScale": tf.compat.v1.mixed_precision.FixedLossScale, # noqa: E501
890 "DynamicLossScale": tf.compat.v1.mixed_precision.DynamicLossScale, # noqa: E501
891 },
892 printable_module_name="loss scale",
893 )
895 if isinstance(
896 loss_scale, tf.compat.v1.mixed_precision.FixedLossScale
897 ):
898 config["dynamic"] = False
899 config["initial_scale"] = loss_scale._loss_scale_value
900 elif isinstance(
901 loss_scale, tf.compat.v1.mixed_precision.DynamicLossScale
902 ):
903 config["dynamic"] = True
904 config["initial_scale"] = loss_scale.initial_loss_scale
905 config["dynamic_growth_steps"] = loss_scale.increment_period
906 if loss_scale.multiplier != 2:
907 raise ValueError(
908 "Cannot deserialize LossScaleOptimizer with a "
909 "DynamicLossScale whose multiplier is not 2. Got "
910 "DynamicLossScale: %s" % (loss_scale,)
911 )
912 else:
913 raise ValueError(
914 "Serialized LossScaleOptimizers with a LossScale that is "
915 "neither a FixedLossScale nor a DynamicLossScale can no "
916 "longer be deserialized"
917 )
918 config["inner_optimizer"] = config.pop("optimizer")
919 if isinstance(config["inner_optimizer"], optimizer_v2.OptimizerV2):
920 inner_optimizer = config["inner_optimizer"]
921 else:
922 inner_optimizer = optimizers.deserialize(
923 config["inner_optimizer"],
924 custom_objects=custom_objects,
925 use_legacy_optimizer=True,
926 )
927 del config["inner_optimizer"]
928 return cls(inner_optimizer, **config)
930 # Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer
931 # below.
933 @property
934 def iterations(self):
935 return self._optimizer.iterations
937 @iterations.setter
938 def iterations(self, variable):
939 self._optimizer.iterations = variable
941 def get_slot_names(self):
942 return self._optimizer.get_slot_names()
944 def variables(self):
945 return self._optimizer.variables()
947 @property
948 def weights(self):
949 return self._optimizer.weights
951 def get_weights(self):
952 return self._optimizer.get_weights()
954 def set_weights(self, weights):
955 return self._optimizer.set_weights(weights)
957 @property
958 def clipnorm(self):
959 return self._optimizer.clipnorm
961 @clipnorm.setter
962 def clipnorm(self, val):
963 self._optimizer.clipnorm = val
965 @property
966 def global_clipnorm(self):
967 return self._optimizer.global_clipnorm
969 @global_clipnorm.setter
970 def global_clipnorm(self, val):
971 self._optimizer.global_clipnorm = val
973 @property
974 def clipvalue(self):
975 return self._optimizer.clipvalue
977 @clipvalue.setter
978 def clipvalue(self, val):
979 self._optimizer.clipvalue = val
981 def _aggregate_gradients(self, grads_and_vars):
982 return self._optimizer._aggregate_gradients(grads_and_vars)
984 def _restore_slot_variable(self, slot_name, variable, slot_variable):
985 return self._optimizer._restore_slot_variable(
986 slot_name,
987 variable,
988 slot_variable,
989 )
991 def _create_or_restore_slot_variable(
992 self, slot_variable_position, slot_name, variable
993 ):
994 return self._optimizer._create_or_restore_slot_variable(
995 slot_variable_position, slot_name, variable
996 )
998 def get_slot(self, var, slot_name):
999 return self._optimizer.get_slot(var, slot_name)
1001 def add_slot(self, var, slot_name, initializer="zeros"):
1002 return self._optimizer.add_slot(var, slot_name, initializer)
1004 def __getattribute__(self, name):
1005 try:
1006 return object.__getattribute__(self, name)
1007 except AttributeError as e:
1008 if name == "_optimizer" or name == "_hyper":
1009 # Avoid infinite recursion
1010 raise e
1012 # Delegate hyperparameter accesses to inner optimizer.
1013 if name == "lr":
1014 name = "learning_rate"
1015 if name in self._optimizer._hyper:
1016 return self._optimizer._get_hyper(name)
1017 raise e
1019 def __dir__(self):
1020 result = set(super().__dir__())
1021 if "_optimizer" in result:
1022 result |= self._optimizer._hyper.keys()
1023 if "learning_rate" in self._optimizer._hyper.keys():
1024 result.add("lr")
1025 return list(result)
1027 def __setattr__(self, name, value):
1028 if name == "lr":
1029 name = "learning_rate"
1030 # Delegate setting hyperparameter to inner optimizer if the attribute
1031 # does not exist on the LossScaleOptimizer
1032 try:
1033 # We cannot check for the 'iterations' attribute as it cannot be set
1034 # after it is accessed.
1035 if name != "iterations":
1036 object.__getattribute__(self, name)
1037 has_attribute = True
1038 except AttributeError:
1039 has_attribute = False
1040 if (
1041 name != "_optimizer"
1042 and name in self._optimizer._hyper
1043 and not has_attribute
1044 ):
1045 self._optimizer._set_hyper(name, value)
1046 else:
1047 super().__setattr__(name, value)
1049 # Explicitly delegate learning_rate. Normally hyperparameters are delegated
1050 # in __getattribute__, but if a hyperparameter is not in
1051 # self._optimizer._hyper (e.g. because self._optimizer itself wraps another
1052 # optimizer), then it won't be delegated. Since learning_rate is a very
1053 # commonly accessed hyperparameter, we delegate it here.
1054 @property
1055 def learning_rate(self):
1056 return self._optimizer.learning_rate
1058 @learning_rate.setter
1059 def learning_rate(self, value):
1060 self._optimizer.learning_rate = value
1062 @property
1063 def lr(self):
1064 return self._optimizer.learning_rate
1066 @lr.setter
1067 def lr(self, value):
1068 self._optimizer.lr = value
1070 # We do not override some OptimizerV2 methods. For each, we describe why we
1071 # do not delegate them to self._optimizer:
1072 # * get_updates: get_updates() calls get_gradients(). Since we override
1073 # get_gradients(), we cannot delegate get_updates() to self._optimizer,
1074 # otherwise the overridden get_gradients() method would not be called.
1075 # Luckily, get_updates() does not access any OptimizerV2 fields, so
1076 # inheriting the OptimizerV2 version works fine.
1077 # * minimize: We don't delegate for a similar as get_updates(): it calls
1078 # both self._compute_gradients() and self.apply_gradients(), and both need
1079 # to have the LossScaleOptimizer version called.
1081 # TODO(reedwm): Maybe throw an error if mixed precision is used without this
1082 # optimizer being used.
1085class LossScaleOptimizerV3(
1086 tf.__internal__.tracking.DelegatingTrackableMixin,
1087 optimizer.Optimizer,
1088 BaseLossScaleOptimizer,
1089):
1090 """An optimizer that applies loss scaling to prevent numeric underflow.
1092 This is a copy of the `mixed_precision.LossScaleOptimizer` class
1093 defined above, except it subclasses and wraps the new experimental Optimizer
1094 class instead of the `tf.keras.optimizers.Optimizer` class. Some of the
1095 methods this class defines and calls are different compared to
1096 LossScaleOptimizer due to the differences between the two Optimizer base
1097 classes. Additionally, this class does not support the legacy graph mode,
1098 but LossScaleOptimizer does.
1100 Since the new experimental Optimizer does not have a hyperparameter concept,
1101 LossScaleOptimizerV3 does not delegate arbitrary hyperparameter accesses to
1102 the inner optimizer, unlike LossScaleOptimizer. LossScaleOptimizerV3 does
1103 delegate the "learning_rate" attribute, however.
1104 """
1106 @tf.__internal__.tracking.no_automatic_dependency_tracking
1107 def __init__(
1108 self,
1109 inner_optimizer,
1110 dynamic=True,
1111 initial_scale=None,
1112 dynamic_growth_steps=None,
1113 ):
1114 if not isinstance(inner_optimizer, optimizer.Optimizer):
1115 if isinstance(inner_optimizer, optimizer_v2.OptimizerV2):
1116 # Give better error message if the OptimizerV2 class is passed
1117 # instead of the new experimental optimizer.
1118 raise TypeError(
1119 "You passed a `tf.keras.optimizers.Optimizer` instance to "
1120 "LossScaleOptimizerV3, but only the new experimental "
1121 "optimizer defined in "
1122 "keras/optimizer_expeirmental/optimizer.py can be "
1123 "passed. Please use "
1124 "`tf.keras.mixed_precision.LossScaleOptimizer` "
1125 "instead of LossScaleOptimizerV3, as the former supports "
1126 "`tf.keras.optimizers.Optimizer`s. Got optimizer: "
1127 f"{inner_optimizer}"
1128 )
1129 raise TypeError(
1130 '"inner_optimizer" must be an instance of '
1131 f"Optimizer, but got: {inner_optimizer}."
1132 )
1133 if not isinstance(dynamic, bool):
1134 # Catch errors if a user incorrectly passes a string or float to the
1135 # second argument argument, as this was commonly done for the
1136 # now-removed LossScaleOptimizerV1.
1137 raise TypeError(
1138 '"dynamic" argument to LossScaleOptimizer.__init__ must '
1139 f"be a bool, but got: {repr(dynamic)}"
1140 )
1141 if isinstance(inner_optimizer, LossScaleOptimizerV3):
1142 raise TypeError(
1143 "LossScaleOptimizer cannot wrap another "
1144 f"LossScaleOptimizer, but got: {inner_optimizer}"
1145 )
1146 _raise_if_strategy_unsupported()
1147 if getattr(
1148 inner_optimizer, "_is_wrapped_by_loss_scale_optimizer", False
1149 ):
1150 # TODO(reedwm): Maybe support this. The difficulty is that LSO has
1151 # the same checkpoint format as the inner optimizer, so multiple
1152 # LSOs wrapping the same optimizer causes the checkpointing logic to
1153 # become confused.
1154 raise ValueError(
1155 '"inner_optimizer" is already wrapped by a '
1156 "LossScaleOptimizer. An optimizer can only be wrapped "
1157 "by a single LossScaleOptimizer"
1158 )
1159 self._optimizer = inner_optimizer
1160 self._optimizer._is_wrapped_by_loss_scale_optimizer = True
1162 # We don't call super().__init__, since we do not want to call
1163 # Optimizer's constructor.
1164 tf.__internal__.tracking.DelegatingTrackableMixin.__init__(
1165 self, self._optimizer
1166 )
1168 if dynamic:
1169 if initial_scale is None:
1170 initial_scale = _DEFAULT_INITIAL_SCALE
1171 if dynamic_growth_steps is None:
1172 dynamic_growth_steps = _DEFAULT_GROWTH_STEPS
1173 self._loss_scale = _DynamicLossScaleState(
1174 initial_scale, dynamic_growth_steps, multiplier=2
1175 )
1176 self._track_trackable(self._loss_scale, "loss_scale")
1177 else:
1178 if initial_scale is None:
1179 raise ValueError(
1180 '"initial_scale" must be specified if "dynamic" is False'
1181 )
1182 self._loss_scale = float(initial_scale)
1183 if dynamic_growth_steps is not None:
1184 raise ValueError(
1185 '"dynamic_growth_steps" must be None if "dynamic" '
1186 f"is False, but got: {dynamic_growth_steps}"
1187 )
1189 # Used to track whether get_scaled_loss() and get_unscaled_gradients()
1190 # have been called
1191 self._loss_has_been_scaled = False
1192 self._gradients_have_been_unscaled = False
1194 @property
1195 def dynamic(self):
1196 return isinstance(self._loss_scale, _DynamicLossScaleState)
1198 @property
1199 def loss_scale(self):
1200 if isinstance(self._loss_scale, _DynamicLossScaleState):
1201 return tf.convert_to_tensor(self._loss_scale.current_loss_scale)
1202 else:
1203 return tf.convert_to_tensor(self._loss_scale)
1205 @property
1206 def dynamic_counter(self):
1207 if isinstance(self._loss_scale, _DynamicLossScaleState):
1208 return self._loss_scale.counter
1209 else:
1210 return None
1212 @property
1213 def initial_scale(self):
1214 if isinstance(self._loss_scale, _DynamicLossScaleState):
1215 return self._loss_scale.initial_loss_scale
1216 else:
1217 return self._loss_scale
1219 @property
1220 def dynamic_growth_steps(self):
1221 if isinstance(self._loss_scale, _DynamicLossScaleState):
1222 return self._loss_scale.growth_steps
1223 else:
1224 return None
1226 @property
1227 def inner_optimizer(self):
1228 return self._optimizer
1230 def get_scaled_loss(self, loss):
1231 self._loss_has_been_scaled = True
1232 if callable(loss):
1234 def new_loss():
1235 loss_val = loss()
1236 return loss_val * tf.cast(self.loss_scale, loss_val.dtype)
1238 return new_loss
1239 else:
1240 return loss * tf.cast(self.loss_scale, loss.dtype)
1242 def get_unscaled_gradients(self, grads):
1243 self._gradients_have_been_unscaled = True
1244 loss_scale_reciprocal = 1.0 / self.loss_scale
1245 return [
1246 _multiply_gradient(g, loss_scale_reciprocal)
1247 if g is not None
1248 else None
1249 for g in grads
1250 ]
1252 def compute_gradients(self, loss, var_list, tape=None):
1253 tape = tf.GradientTape() if tape is None else tape
1254 with tape:
1255 loss = self.get_scaled_loss(loss)
1256 grads_and_vars = self._optimizer.compute_gradients(
1257 loss, var_list, tape=tape
1258 )
1259 grads = [g for g, _ in grads_and_vars]
1260 weights = [v for _, v in grads_and_vars]
1261 unscaled_grads = self.get_unscaled_gradients(grads)
1262 return list(zip(unscaled_grads, weights))
1264 def apply_gradients(
1265 self, grads_and_vars, skip_gradients_aggregation=False, **kwargs
1266 ):
1267 if tf.distribute.in_cross_replica_context():
1268 raise ValueError(
1269 "apply_gradients() must be called in a replica context."
1270 )
1271 # We check for the strategy here despite already checking in the
1272 # constructor as frequently the optimizer is created outside the
1273 # strategy's scope.
1274 _raise_if_strategy_unsupported()
1275 _maybe_warn_about_scaling(
1276 self._loss_has_been_scaled, self._gradients_have_been_unscaled
1277 )
1279 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
1280 # `experimental_aggregate_gradients` is an arg in `apply_gradients` of
1281 # v2 optimizer -- the reverse of `skip_gradients_aggregation`.
1282 # We read it from kwargs for backward compatibility.
1283 experimental_aggregate_gradients = kwargs.pop(
1284 "experimental_aggregate_gradients", True
1285 )
1286 run_with_dtensor = (
1287 # `_run_with_dtensor` is for dtensor based strategy scope, and
1288 # `_mesh` is when user explicitly specify the mesh setting for
1289 # optimizer.
1290 self._optimizer._run_with_dtensor
1291 or self._optimizer._mesh
1292 )
1294 if (
1295 not skip_gradients_aggregation
1296 and experimental_aggregate_gradients
1297 and not run_with_dtensor
1298 ):
1299 # We must aggregate the gradients here instead of in
1300 # self.optimizer.apply_gradients, so that any NaN or Inf gradients
1301 # are propagated to each replica. If any replica has a NaN or Inf
1302 # gradient, they must all have a NaN or Inf gradient so that they
1303 # all skip the step.
1304 grads_and_vars = self._optimizer.aggregate_gradients(grads_and_vars)
1306 grads_and_vars = tuple(grads_and_vars)
1307 grads = [g for g, _ in grads_and_vars]
1308 # We do not want DistributionStrategy to unwrap any MirroredVariables in
1309 # grads_and_vars, because even in a replica context, the wrapped
1310 # optimizer expects mirrored variables. So we wrap the variables with an
1311 # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the
1312 # MirroredVariables.
1313 wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars])
1315 def do_not_apply_fn():
1316 # Normally self._optimizer.iterations is incremented in
1317 # self._optimizer.apply_gradients(). Since that is not called in
1318 # this branch, we increment it here instead.
1319 self._optimizer.iterations.assign_add(1, read_value=False)
1321 def _if_should_apply_grads(grads):
1322 if isinstance(self._loss_scale, _DynamicLossScaleState):
1323 _, should_apply_grad = self._loss_scale.update(grads)
1324 return should_apply_grad
1325 else:
1326 return True
1328 if tf.__internal__.distribute.strategy_supports_no_merge_call():
1329 should_apply_grads = _if_should_apply_grads(grads)
1331 def apply_fn():
1332 return self._apply_gradients(grads, wrapped_vars)
1334 tf.__internal__.smart_cond.smart_cond(
1335 should_apply_grads, apply_fn, do_not_apply_fn
1336 )
1337 else:
1339 def _apply_gradients_cross_replica(
1340 distribution, grads, wrapped_vars
1341 ):
1342 should_apply_grads = _if_should_apply_grads(grads)
1344 def apply_fn():
1345 distribution.extended.call_for_each_replica(
1346 self._apply_gradients, args=(grads, wrapped_vars)
1347 )
1349 # Note: We must call this cond() in a cross-replica context.
1350 # DistributionStrategy does not support having a cond in a
1351 # replica context with a branch that calls `merge_call`, and
1352 # self._optimizer.apply_gradients calls `merge_call`.
1353 tf.__internal__.smart_cond.smart_cond(
1354 should_apply_grads, apply_fn, do_not_apply_fn
1355 )
1357 tf.distribute.get_replica_context().merge_call(
1358 _apply_gradients_cross_replica, args=(grads, wrapped_vars)
1359 )
1361 def _apply_gradients(self, grads, wrapped_vars):
1362 # Pass skip_gradients_aggregation=True since LossScaleOptimizer
1363 # already aggregated the gradients.
1364 self._optimizer.apply_gradients(
1365 list(zip(grads, wrapped_vars.value)),
1366 skip_gradients_aggregation=True,
1367 )
1369 def get_config(self):
1370 serialized_optimizer = optimizers.serialize(self._optimizer)
1371 return {
1372 "inner_optimizer": serialized_optimizer,
1373 "dynamic": self.dynamic,
1374 "initial_scale": self.initial_scale,
1375 "dynamic_growth_steps": self.dynamic_growth_steps,
1376 }
1378 @classmethod
1379 def from_config(cls, config, custom_objects=None):
1380 config = config.copy() # Make a copy, since we mutate config
1381 if isinstance(config["inner_optimizer"], optimizer.Optimizer):
1382 inner_optimizer = config["inner_optimizer"]
1383 else:
1384 inner_optimizer = optimizers.deserialize(
1385 config["inner_optimizer"],
1386 custom_objects=custom_objects,
1387 use_legacy_optimizer=False,
1388 )
1389 del config["inner_optimizer"]
1390 return cls(inner_optimizer, **config)
1392 @property
1393 def iterations(self):
1394 return self._optimizer.iterations
1396 @iterations.setter
1397 def iterations(self, variable):
1398 self._optimizer.iterations = variable
1400 @property
1401 def variables(self):
1402 return self._optimizer.variables
1404 def build(self, var_list):
1405 return self._optimizer.build(var_list)
1407 @property
1408 def learning_rate(self):
1409 return self._optimizer.learning_rate
1411 @learning_rate.setter
1412 def learning_rate(self, learning_rate):
1413 self._optimizer.learning_rate = learning_rate
1415 @property
1416 def use_ema(self):
1417 return self._optimizer.use_ema
1419 @use_ema.setter
1420 def use_ema(self, use_ema):
1421 self._optimizer.use_ema = use_ema
1423 @property
1424 def ema_momentum(self):
1425 return self._optimizer.ema_momentum
1427 @ema_momentum.setter
1428 def ema_momentum(self, ema_momentum):
1429 self._optimizer.ema_momentum = ema_momentum
1431 def finalize_variable_values(self, var_list):
1432 self._optimizer.finalize_variable_values(var_list)
1435class FakeOptimizerForRestoration(tf.__internal__.tracking.Trackable):
1436 """A fake optimizer used to support restoring TensorFlow 2.2 checkpoints.
1438 The checkpoint format for LossScaleOptimizers changed after TF 2.2. This
1439 class exists to support restoring TF 2.2 checkpoints in newer version of
1440 TensorFlow.
1442 In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling
1443 the following in LossScaleOptimizer.__init__
1445 ```
1446 self._track_trackable(self._optimizer, 'base_optimizer')
1447 ```
1449 This means a dependency from the LossScaleOptimizer to the wrapped optimizer
1450 would be stored in the checkpoint. However now, the checkpoint format with a
1451 LossScaleOptimizer is the same as the format without a LossScaleOptimizer,
1452 except the loss scale is also stored. This means there is no dependency from
1453 the LossScaleOptimizer to the wrapped optimizer. Instead, the
1454 LossScaleOptimizer acts as if it is the wrapped optimizer, from a
1455 checkpoint's perspective, by overriding all Trackable methods and delegating
1456 them to the wrapped optimizer.
1458 To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency
1459 on this class instead of the inner optimizer. When restored, this class will
1460 instead restore the slot variables of the inner optimizer. Since this class
1461 has no variables, it does not affect the checkpoint when saved.
1462 """
1464 def __init__(self, optimizer):
1465 self._optimizer = optimizer
1467 def get_slot_names(self):
1468 return self._optimizer.get_slot_names()
1470 def _create_or_restore_slot_variable(
1471 self, slot_variable_position, slot_name, variable
1472 ):
1473 return self._optimizer._create_or_restore_slot_variable(
1474 slot_variable_position, slot_name, variable
1475 )
1478def _create_loss_scale_optimizer_from_v1_loss_scale(optimizer, loss_scale):
1479 """Creates an LSO from a tf.compat.v1.mixed_precision.LossScale.
1481 This is only used to pass to
1482 `tf.__internal__.mixed_precision.register_loss_scale_wrapper` below, which
1483 is called so that
1484 `tf.compat.v1.mixed_precision.enable_mixed_precision_graph_rewrite` can
1485 wrap a Keras optimizer with a LossScaleOptimizer.
1487 Args:
1488 optimizer: An OptimizerV2 instance.
1489 loss_scale: A `tf.compat.v1.mixed_precision.LossScale` instance
1491 Returns:
1492 A LossScaleOptimizer that wraps `optimizer` and uses the same loss scaling
1493 algorithm as `loss_scale`.
1494 """
1495 if isinstance(loss_scale, (int, float)):
1496 return LossScaleOptimizer(
1497 optimizer, dynamic=False, initial_scale=loss_scale
1498 )
1499 elif isinstance(loss_scale, tf.compat.v1.mixed_precision.FixedLossScale):
1500 ls_val = loss_scale._loss_scale_value
1501 return LossScaleOptimizer(
1502 optimizer, dynamic=False, initial_scale=ls_val
1503 )
1504 elif loss_scale == "dynamic":
1505 return LossScaleOptimizer(optimizer)
1506 elif isinstance(loss_scale, tf.compat.v1.mixed_precision.DynamicLossScale):
1507 if loss_scale.multiplier != 2:
1508 raise ValueError(
1509 'When passing a DynamicLossScale to "loss_scale", '
1510 "DynamicLossScale.multiplier must be 2. Got: "
1511 f"{loss_scale}"
1512 )
1513 return LossScaleOptimizer(
1514 optimizer,
1515 initial_scale=loss_scale.initial_loss_scale,
1516 dynamic_growth_steps=loss_scale.increment_period,
1517 )
1518 elif isinstance(loss_scale, tf.compat.v1.mixed_precision.LossScale):
1519 raise TypeError(
1520 "Passing a LossScale that is not a FixedLossScale or a "
1521 f"DynamicLossScale is not supported. Got: {loss_scale}"
1522 )
1523 else:
1524 raise ValueError(
1525 "Invalid value passed to loss_scale. loss_scale "
1526 'must be the string "dynamic" (recommended), an int, '
1527 "a float, a FixedLossScale, or a DynamicLossScale. Got "
1528 f"value: {loss_scale}"
1529 )
1532tf.__internal__.mixed_precision.register_loss_scale_wrapper(
1533 optimizer_v2.OptimizerV2,
1534 _create_loss_scale_optimizer_from_v1_loss_scale,
1535 LossScaleOptimizer,
1536)
1539def _multiply_gradient(gradient, scale):
1540 """Multiply a (possibly sparse) gradient by the given scale factor."""
1541 scale = tf.cast(scale, gradient.dtype)
1542 if isinstance(gradient, tf.IndexedSlices):
1543 return tf.IndexedSlices(
1544 gradient.values * scale,
1545 gradient.indices,
1546 dense_shape=gradient.dense_shape,
1547 )
1548 else:
1549 return gradient * scale
1552def strategy_supports_loss_scaling():
1553 """Returns True if the current Strategy supports loss scaling."""
1554 if not tf.distribute.has_strategy():
1555 return True
1556 strategy = tf.distribute.get_strategy()
1557 # Strategies are supported if either there is only one replica or if
1558 # variables are replicated per device. Otherwise, the current model.fit()
1559 # implementation and most custom training loops incorrectly unscale the
1560 # gradients. Currently, gradients are unscaled once per compute replica, but
1561 # they should be unscaled once per variable replica. When there is one
1562 # variable replica for each compute replica, this works fine, but otherwise
1563 # issues will occur.
1564 # TODO(reedwm): Support all strategies.
1565 return (
1566 isinstance(
1567 strategy,
1568 (
1569 tf.distribute.MultiWorkerMirroredStrategy,
1570 tf.compat.v1.distribute.experimental.MultiWorkerMirroredStrategy, # noqa: E501
1571 tf.distribute.OneDeviceStrategy,
1572 tf.compat.v1.distribute.OneDeviceStrategy,
1573 tf.distribute.MirroredStrategy,
1574 tf.compat.v1.distribute.MirroredStrategy,
1575 ),
1576 )
1577 or dtensor_utils.running_with_dtensor_strategy()
1578 )
1581def _raise_if_strategy_unsupported():
1582 """Raise an exception if the current strategy doesn't support loss
1583 scaling."""
1584 if not strategy_supports_loss_scaling():
1585 strategy = tf.distribute.get_strategy()
1586 if isinstance(
1587 strategy,
1588 (
1589 tf.distribute.experimental.TPUStrategy,
1590 tf.compat.v1.distribute.experimental.TPUStrategy,
1591 tf.distribute.TPUStrategy,
1592 ),
1593 ):
1594 raise ValueError(
1595 "Loss scaling is not supported with TPUStrategy. Loss scaling "
1596 "is unnecessary with TPUs, since they support bfloat16 instead "
1597 "of float16 and bfloat16 does not require loss scaling. You "
1598 "should remove the use of the LossScaleOptimizer when TPUs are "
1599 "used."
1600 )
1601 else:
1602 raise ValueError(
1603 "Loss scaling is not supported with the "
1604 "tf.distribute.Strategy: "
1605 f"{strategy.__class__.__name__}. Try using a different "
1606 "Strategy, e.g. a MirroredStrategy"
1607 )