Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/losses_utils.py: 22%
135 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utilities related to loss functions."""
18from tensorflow.python.distribute import distribute_lib
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_conversion
21from tensorflow.python.keras import backend
22from tensorflow.python.keras.engine import keras_tensor
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import cond
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops.ragged import ragged_tensor
27from tensorflow.python.util.tf_export import keras_export
30@keras_export('keras.losses.Reduction', v1=[])
31class ReductionV2(object):
32 """Types of loss reduction.
34 Contains the following values:
36 * `AUTO`: Indicates that the reduction option will be determined by the usage
37 context. For almost all cases this defaults to `SUM_OVER_BATCH_SIZE`. When
38 used with `tf.distribute.Strategy`, outside of built-in training loops such
39 as `tf.keras` `compile` and `fit`, we expect reduction value to be
40 `SUM` or `NONE`. Using `AUTO` in that case will raise an error.
41 * `NONE`: No **additional** reduction is applied to the output of the wrapped
42 loss function. When non-scalar losses are returned to Keras functions like
43 `fit`/`evaluate`, the unreduced vector loss is passed to the optimizer
44 but the reported loss will be a scalar value.
46 Caution: **Verify the shape of the outputs when using** `Reduction.NONE`.
47 The builtin loss functions wrapped by the loss classes reduce
48 one dimension (`axis=-1`, or `axis` if specified by loss function).
49 `Reduction.NONE` just means that no **additional** reduction is applied by
50 the class wrapper. For categorical losses with an example input shape of
51 `[batch, W, H, n_classes]` the `n_classes` dimension is reduced. For
52 pointwise losses your must include a dummy axis so that `[batch, W, H, 1]`
53 is reduced to `[batch, W, H]`. Without the dummy axis `[batch, W, H]`
54 will be incorrectly reduced to `[batch, W]`.
56 * `SUM`: Scalar sum of weighted losses.
57 * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
58 This reduction type is not supported when used with
59 `tf.distribute.Strategy` outside of built-in training loops like `tf.keras`
60 `compile`/`fit`.
62 You can implement 'SUM_OVER_BATCH_SIZE' using global batch size like:
63 ```
64 with strategy.scope():
65 loss_obj = tf.keras.losses.CategoricalCrossentropy(
66 reduction=tf.keras.losses.Reduction.NONE)
67 ....
68 loss = tf.reduce_sum(loss_obj(labels, predictions)) *
69 (1. / global_batch_size)
70 ```
72 Please see the [custom training guide](
73 https://www.tensorflow.org/tutorials/distribute/custom_training) for more
74 details on this.
75 """
77 AUTO = 'auto'
78 NONE = 'none'
79 SUM = 'sum'
80 SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
82 @classmethod
83 def all(cls):
84 return (cls.AUTO, cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
86 @classmethod
87 def validate(cls, key):
88 if key not in cls.all():
89 raise ValueError('Invalid Reduction Key %s.' % key)
92def remove_squeezable_dimensions(
93 labels, predictions, expected_rank_diff=0, name=None):
94 """Squeeze last dim if ranks differ from expected by exactly 1.
96 In the common case where we expect shapes to match, `expected_rank_diff`
97 defaults to 0, and we squeeze the last dimension of the larger rank if they
98 differ by 1.
100 But, for example, if `labels` contains class IDs and `predictions` contains 1
101 probability per class, we expect `predictions` to have 1 more dimension than
102 `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze
103 `labels` if `rank(predictions) - rank(labels) == 0`, and
104 `predictions` if `rank(predictions) - rank(labels) == 2`.
106 This will use static shape if available. Otherwise, it will add graph
107 operations, which could result in a performance hit.
109 Args:
110 labels: Label values, a `Tensor` whose dimensions match `predictions`.
111 predictions: Predicted values, a `Tensor` of arbitrary dimensions.
112 expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
113 name: Name of the op.
115 Returns:
116 Tuple of `labels` and `predictions`, possibly with last dim squeezed.
117 """
118 with backend.name_scope(name or 'remove_squeezable_dimensions'):
119 if not isinstance(predictions, ragged_tensor.RaggedTensor):
120 predictions = tensor_conversion.convert_to_tensor_v2_with_dispatch(
121 predictions
122 )
123 if not isinstance(labels, ragged_tensor.RaggedTensor):
124 labels = tensor_conversion.convert_to_tensor_v2_with_dispatch(labels)
125 predictions_shape = predictions.shape
126 predictions_rank = predictions_shape.ndims
127 labels_shape = labels.shape
128 labels_rank = labels_shape.ndims
129 if (labels_rank is not None) and (predictions_rank is not None):
130 # Use static rank.
131 rank_diff = predictions_rank - labels_rank
132 if (rank_diff == expected_rank_diff + 1 and
133 predictions_shape.dims[-1].is_compatible_with(1)):
134 predictions = array_ops.squeeze(predictions, [-1])
135 elif (rank_diff == expected_rank_diff - 1 and
136 labels_shape.dims[-1].is_compatible_with(1)):
137 labels = array_ops.squeeze(labels, [-1])
138 return labels, predictions
140 # Use dynamic rank.
141 rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
142 if (predictions_rank is None) or (
143 predictions_shape.dims[-1].is_compatible_with(1)):
144 predictions = cond.cond(
145 math_ops.equal(expected_rank_diff + 1, rank_diff),
146 lambda: array_ops.squeeze(predictions, [-1]),
147 lambda: predictions)
148 if (labels_rank is None) or (
149 labels_shape.dims[-1].is_compatible_with(1)):
150 labels = cond.cond(
151 math_ops.equal(expected_rank_diff - 1, rank_diff),
152 lambda: array_ops.squeeze(labels, [-1]),
153 lambda: labels)
154 return labels, predictions
157def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):
158 """Squeeze or expand last dimension if needed.
160 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
161 (using `remove_squeezable_dimensions`).
162 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
163 from the new rank of `y_pred`.
164 If `sample_weight` is scalar, it is kept scalar.
166 This will use static shape if available. Otherwise, it will add graph
167 operations, which could result in a performance hit.
169 Args:
170 y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
171 y_true: Optional label `Tensor` whose dimensions match `y_pred`.
172 sample_weight: Optional weight scalar or `Tensor` whose dimensions match
173 `y_pred`.
175 Returns:
176 Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
177 the last dimension squeezed,
178 `sample_weight` could be extended by one dimension.
179 If `sample_weight` is None, (y_pred, y_true) is returned.
180 """
181 y_pred_shape = y_pred.shape
182 y_pred_rank = y_pred_shape.ndims
183 if y_true is not None:
185 # If sparse matrix is provided as `y_true`, the last dimension in `y_pred`
186 # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)),
187 # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3))
188 # In this case, we should not try to remove squeezable dimension.
189 y_true_shape = y_true.shape
190 y_true_rank = y_true_shape.ndims
191 if (y_true_rank is not None) and (y_pred_rank is not None):
192 # Use static rank for `y_true` and `y_pred`.
193 if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1:
194 y_true, y_pred = remove_squeezable_dimensions(
195 y_true, y_pred)
196 else:
197 # Use dynamic rank.
198 rank_diff = array_ops.rank(y_pred) - array_ops.rank(y_true)
199 squeeze_dims = lambda: remove_squeezable_dimensions( # pylint: disable=g-long-lambda
200 y_true, y_pred)
201 is_last_dim_1 = math_ops.equal(1, array_ops.shape(y_pred)[-1])
202 maybe_squeeze_dims = lambda: cond.cond( # pylint: disable=g-long-lambda
203 is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred))
204 y_true, y_pred = cond.cond(
205 math_ops.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims)
207 if sample_weight is None:
208 return y_pred, y_true
210 weights_shape = sample_weight.shape
211 weights_rank = weights_shape.ndims
212 if weights_rank == 0: # If weights is scalar, do nothing.
213 return y_pred, y_true, sample_weight
215 if (y_pred_rank is not None) and (weights_rank is not None):
216 # Use static rank.
217 if weights_rank - y_pred_rank == 1:
218 sample_weight = array_ops.squeeze(sample_weight, [-1])
219 elif y_pred_rank - weights_rank == 1:
220 sample_weight = array_ops.expand_dims(sample_weight, [-1])
221 return y_pred, y_true, sample_weight
223 # Use dynamic rank.
224 weights_rank_tensor = array_ops.rank(sample_weight)
225 rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
226 maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
228 def _maybe_expand_weights():
229 expand_weights = lambda: array_ops.expand_dims(sample_weight, [-1])
230 return cond.cond(
231 math_ops.equal(rank_diff, -1), expand_weights, lambda: sample_weight)
233 def _maybe_adjust_weights():
234 return cond.cond(
235 math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
236 _maybe_expand_weights)
238 # squeeze or expand last dim of `sample_weight` if its rank differs by 1
239 # from the new rank of `y_pred`.
240 sample_weight = cond.cond(
241 math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
242 _maybe_adjust_weights)
243 return y_pred, y_true, sample_weight
246def _safe_mean(losses, num_present):
247 """Computes a safe mean of the losses.
249 Args:
250 losses: `Tensor` whose elements contain individual loss measurements.
251 num_present: The number of measurable elements in `losses`.
253 Returns:
254 A scalar representing the mean of `losses`. If `num_present` is zero,
255 then zero is returned.
256 """
257 total_loss = math_ops.reduce_sum(losses)
258 return math_ops.div_no_nan(total_loss, num_present, name='value')
261def _num_elements(losses):
262 """Computes the number of elements in `losses` tensor."""
263 with backend.name_scope('num_elements') as scope:
264 return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
267def reduce_weighted_loss(weighted_losses,
268 reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
269 """Reduces the individual weighted loss measurements."""
270 if reduction == ReductionV2.NONE:
271 loss = weighted_losses
272 else:
273 loss = math_ops.reduce_sum(weighted_losses)
274 if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
275 loss = _safe_mean(loss, _num_elements(weighted_losses))
276 return loss
279@keras_export('keras.__internal__.losses.compute_weighted_loss', v1=[])
280def compute_weighted_loss(losses,
281 sample_weight=None,
282 reduction=ReductionV2.SUM_OVER_BATCH_SIZE,
283 name=None):
284 """Computes the weighted loss.
286 Args:
287 losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
288 sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
289 `losses`, or be broadcastable to `losses`.
290 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss.
291 Default value is `SUM_OVER_BATCH_SIZE`.
292 name: Optional name for the op.
294 Raises:
295 ValueError: If the shape of `sample_weight` is not compatible with `losses`.
297 Returns:
298 Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
299 `NONE`, this has the same shape as `losses`; otherwise, it is scalar.
300 """
301 ReductionV2.validate(reduction)
303 # If this function is called directly, then we just default 'AUTO' to
304 # 'SUM_OVER_BATCH_SIZE'. Eg. Canned estimator use cases.
305 if reduction == ReductionV2.AUTO:
306 reduction = ReductionV2.SUM_OVER_BATCH_SIZE
307 if sample_weight is None:
308 sample_weight = 1.0
309 with backend.name_scope(name or 'weighted_loss'):
310 # Save the `reduction` argument for loss normalization when distributing
311 # to multiple replicas. Used only for estimator + v1 optimizer flow.
312 ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access
314 if not isinstance(losses,
315 (keras_tensor.KerasTensor, ragged_tensor.RaggedTensor)):
316 losses = tensor_conversion.convert_to_tensor_v2_with_dispatch(losses)
317 input_dtype = losses.dtype
319 if not isinstance(sample_weight, keras_tensor.KerasTensor):
320 sample_weight = tensor_conversion.convert_to_tensor_v2_with_dispatch(
321 sample_weight
322 )
324 # TODO(psv): Handle casting here in a better way, eg. if losses is float64
325 # we do not want to lose precision.
326 losses = math_ops.cast(losses, 'float32')
327 sample_weight = math_ops.cast(sample_weight, 'float32')
328 # Update dimensions of `sample_weight` to match with `losses` if possible.
329 losses, _, sample_weight = squeeze_or_expand_dimensions( # pylint: disable=unbalanced-tuple-unpacking
330 losses, None, sample_weight)
331 weighted_losses = math_ops.multiply(losses, sample_weight)
333 # Apply reduction function to the individual weighted losses.
334 loss = reduce_weighted_loss(weighted_losses, reduction)
335 # Convert the result back to the input type.
336 loss = math_ops.cast(loss, input_dtype)
337 return loss
340def scale_loss_for_distribution(loss_value):
341 """Scales and returns the given loss value by the number of replicas."""
342 num_replicas = (
343 distribute_lib.get_strategy().num_replicas_in_sync)
344 if num_replicas > 1:
345 loss_value *= (1. / num_replicas)
346 return loss_value
349def cast_losses_to_common_dtype(losses):
350 """Cast a list of losses to a common dtype.
352 If any loss is floating-point, they will all be casted to the most-precise
353 floating-point loss. Otherwise the losses are not casted. We also skip casting
354 losses if there are any complex losses.
356 Args:
357 losses: A list of losses.
359 Returns:
360 `losses`, but they have been casted to a common dtype.
361 """
362 highest_float = None
363 for loss in losses:
364 if loss.dtype.is_floating:
365 if highest_float is None or loss.dtype.size > highest_float.size:
366 highest_float = loss.dtype
367 elif {loss.dtype, highest_float} == {'bfloat16', 'float16'}:
368 highest_float = 'float32'
369 if loss.dtype.is_complex:
370 return losses # If we find any complex losses, do not cast any losses
371 if highest_float:
372 losses = [math_ops.cast(loss, highest_float) for loss in losses]
373 return losses