Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/experimental/loss_scale_optimizer.py: 34%
80 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains LossScale classes."""
16from tensorflow.python.distribute import distribute_lib
17from tensorflow.python.framework import indexed_slices
18from tensorflow.python.framework import smart_cond
19from tensorflow.python.ops import control_flow_ops
20from tensorflow.python.ops import math_ops
21from tensorflow.python.training import optimizer
22from tensorflow.python.training.experimental import loss_scale as loss_scale_module
23from tensorflow.python.util import deprecation
24from tensorflow.python.util.tf_export import tf_export
27@deprecation.deprecated_endpoints(
28 'train.experimental.MixedPrecisionLossScaleOptimizer')
29@tf_export(v1=['mixed_precision.MixedPrecisionLossScaleOptimizer',
30 'train.experimental.MixedPrecisionLossScaleOptimizer'])
31class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
32 """An optimizer that applies loss scaling.
34 Loss scaling is a process that multiplies the loss by a multiplier called the
35 loss scale, and divides each gradient by the same multiplier. The pseudocode
36 for this process is:
38 ```
39 loss = ...
40 loss *= loss_scale
41 grads = gradients(loss, vars)
42 grads /= loss_scale
43 ```
45 Mathematically, loss scaling has no effect, but can help avoid numerical
46 underflow in intermediate gradients when float16 tensors are used for mixed
47 precision training. By multiplying the loss, each intermediate gradient will
48 have the same multiplier applied.
50 The loss scale can either be a fixed constant, chosen by the user, or be
51 dynamically determined. Dynamically determining the loss scale is convenient
52 as a loss scale does not have to be explicitly chosen. However it reduces
53 performance.
55 This optimizer wraps another optimizer and applies loss scaling to it via a
56 `LossScale`. Loss scaling is applied whenever gradients are
57 computed, such as through `minimize()`.
58 """
60 def __init__(self, opt, loss_scale):
61 if not isinstance(opt, optimizer.Optimizer):
62 raise ValueError('"opt" must be an instance of Optimizer, but got: %s' %
63 type(opt))
64 self._optimizer = opt
66 use_locking = opt._use_locking # pylint: disable=protected-access
67 name = opt.get_name()
68 super(MixedPrecisionLossScaleOptimizer, self).__init__(use_locking, name)
70 self._loss_scale = loss_scale_module.get(loss_scale)
71 if self._loss_scale is None:
72 raise ValueError('loss_scale cannot be None')
73 self._track_trackable(self._optimizer, 'base_optimizer')
74 self._track_trackable(self._loss_scale, 'loss_scale')
76 def _doing_dynamic_loss_scaling(self):
77 """Check if `_loss_scale` dynamically manages the loss scale."""
78 return isinstance(self._loss_scale, loss_scale_module.DynamicLossScale)
80 def compute_gradients(self,
81 loss,
82 var_list=None,
83 gate_gradients=optimizer.Optimizer.GATE_OP,
84 aggregation_method=None,
85 colocate_gradients_with_ops=False,
86 grad_loss=None):
87 """Compute gradients of `loss` for the variables in `var_list`.
89 This adjusts the dynamic range of the gradient evaluation by scaling up
90 the `loss` value. The gradient values are then scaled back down by the
91 reciprocal of the loss scale. This is useful in reduced precision training
92 where small gradient values would otherwise underflow the representable
93 range.
95 Args:
96 loss: A Tensor containing the value to minimize or a callable taking no
97 arguments which returns the value to minimize. When eager execution is
98 enabled it must be a callable.
99 var_list: Optional list or tuple of `tf.Variable` to update to minimize
100 `loss`. Defaults to the list of variables collected in the graph under
101 the key `GraphKeys.TRAINABLE_VARIABLES`.
102 gate_gradients: How to gate the computation of gradients. Can be
103 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
104 aggregation_method: Specifies the method used to combine gradient terms.
105 Valid values are defined in the class `AggregationMethod`.
106 colocate_gradients_with_ops: If True, try colocating gradients with the
107 corresponding op.
108 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
110 Returns:
111 A list of (gradient, variable) pairs. Variable is always present, but
112 gradient can be `None`.
113 """
114 loss = self._scale_loss(loss)
115 grads_and_vars = self._optimizer.compute_gradients(
116 loss=loss,
117 var_list=var_list,
118 gate_gradients=gate_gradients,
119 aggregation_method=aggregation_method,
120 colocate_gradients_with_ops=colocate_gradients_with_ops,
121 grad_loss=grad_loss)
123 grads = [g for g, _ in grads_and_vars]
124 variables = [v for _, v in grads_and_vars]
125 unscaled_grads = self._unscale_grads(grads)
126 return list(zip(unscaled_grads, variables))
128 def _scale_loss(self, loss):
129 loss_scale = self._loss_scale()
130 if callable(loss):
131 def new_loss():
132 loss_val = loss()
133 return loss_val * math_ops.cast(loss_scale, loss_val.dtype)
134 return new_loss
135 else:
136 return loss * math_ops.cast(loss_scale, loss.dtype)
138 def _unscale_grads(self, grads):
139 loss_scale = self._loss_scale()
140 loss_scale_reciprocal = 1 / loss_scale
141 return [
142 None if g is None else self._scale_grad(g, loss_scale_reciprocal)
143 for g in grads
144 ]
146 def _scale_grad(self, grad, loss_scale_reciprocal):
147 if isinstance(grad, indexed_slices.IndexedSlices):
148 grad_vals = grad.values * loss_scale_reciprocal
149 return indexed_slices.IndexedSlices(grad_vals, grad.indices,
150 grad.dense_shape)
151 return grad * loss_scale_reciprocal
153 def apply_gradients(self, grads_and_vars, global_step=None, name=None):
154 """Apply gradients to variables.
156 This is the second part of `minimize()`. It returns an `Operation` that
157 conditionally applies gradients if all gradient values are finite.
158 Otherwise no update is performed (nor is `global_step` incremented).
160 Args:
161 grads_and_vars: List of (gradient, variable) pairs as returned by
162 `compute_gradients()`.
163 global_step: Optional `Variable` to increment by one after the variables
164 have been updated.
165 name: Optional name for the returned operation. Default to the name
166 passed to the `Optimizer` constructor.
168 Returns:
169 An `Operation` that conditionally applies the specified gradients. If
170 `global_step` was not None, that operation also increments `global_step`.
172 Raises:
173 RuntimeError: If you should use `_distributed_apply()` instead.
174 """
175 if distribute_lib.in_cross_replica_context():
176 raise ValueError('apply_gradients() must be called in a replica context.')
178 if not self._doing_dynamic_loss_scaling():
179 return self._optimizer.apply_gradients(grads_and_vars, global_step, name)
181 replica_context = distribute_lib.get_replica_context()
182 grads_and_vars = tuple(grads_and_vars)
184 # TODO(nluehr) cleanup GraphKeys.TRAIN_OP
185 return replica_context.merge_call(
186 self._distributed_apply, args=(grads_and_vars, global_step, name))
188 def _distributed_apply(self,
189 distribution,
190 grads_and_vars,
191 global_step=None,
192 name=None):
193 """A version of `apply_gradients` for cross replica context.
195 When users are in a cross replica strategy, they must call this rather than
196 `apply_gradients()`.
198 Args:
199 distribution: a `DistributionStrategy` object.
200 grads_and_vars: List of (gradient, variable) pairs as returned by
201 `compute_gradients()` and then aggregated across replicas.
202 global_step: Optional (mirrored) `Variable` to increment by one after the
203 variables have been updated.
204 name: Optional name for the returned operation. Default to the name passed
205 to the `Optimizer` constructor.
207 Returns:
208 An `Operation` that applies the specified gradients across all
209 replicas. If `global_step` was not None, that operation also
210 increments `global_step`
211 """
212 name = name if name is not None else self.get_name()
213 grads = [g for g, _ in grads_and_vars]
214 loss_scale_update_op, should_apply_grads = (self._loss_scale.update(grads))
216 def apply_fn():
217 return self._apply_gradients(distribution, grads_and_vars, global_step,
218 name + '-wrapped')
220 maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
221 control_flow_ops.no_op)
222 return control_flow_ops.group(
223 maybe_apply_op, loss_scale_update_op, name=name)
225 def _apply_gradients(self, distribution, grads_and_vars, global_step, name):
226 """Unconditionally apply gradients in cross replica context."""
227 update_ops = distribution.extended.call_for_each_replica(
228 self._optimizer.apply_gradients,
229 args=(grads_and_vars, global_step, name))
230 return distribution.group(update_ops)
232 def _apply_sparse(self, grad, var):
233 """This function should never be called."""
234 raise RuntimeError('This function should never be called')
236 def _apply_dense(self, grad, var):
237 """This function should never be called."""
238 raise RuntimeError('This function should never be called')
240 def _resource_apply_sparse(self, grad, handle, indices):
241 """This function should never be called."""
242 raise RuntimeError('This function should never be called')
244 def _resource_apply_dense(self, grad, handle):
245 """This function should never be called."""
246 raise RuntimeError('This function should never be called')
248 def variables(self):
249 """Returns the variables of the Optimizer."""
250 return (self._optimizer.variables() +
251 list(self._loss_scale._weights.values())) # pylint: disable=protected-access