Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/losses/util.py: 31%
84 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for manipulating the loss collections."""
17from tensorflow.python.eager import context
18from tensorflow.python.framework import constant_op
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import check_ops
23from tensorflow.python.ops import cond
24from tensorflow.python.ops import confusion_matrix
25from tensorflow.python.ops import math_ops
26from tensorflow.python.util import tf_contextlib
27from tensorflow.python.util.tf_export import tf_export
30def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):
31 """Squeeze or expand last dimension if needed.
33 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
34 (using `confusion_matrix.remove_squeezable_dimensions`).
35 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
36 from the new rank of `y_pred`.
37 If `sample_weight` is scalar, it is kept scalar.
39 This will use static shape if available. Otherwise, it will add graph
40 operations, which could result in a performance hit.
42 Args:
43 y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
44 y_true: Optional label `Tensor` whose dimensions match `y_pred`.
45 sample_weight: Optional weight scalar or `Tensor` whose dimensions match
46 `y_pred`.
48 Returns:
49 Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
50 the last dimension squeezed,
51 `sample_weight` could be extended by one dimension.
52 If `sample_weight` is None, (y_pred, y_true) is returned.
53 """
54 y_pred_shape = y_pred.shape
55 y_pred_rank = y_pred_shape.ndims
56 if y_true is not None:
58 # If sparse matrix is provided as `y_true`, the last dimension in `y_pred`
59 # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)),
60 # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3))
61 # In this case, we should not try to remove squeezable dimension.
62 y_true_shape = y_true.shape
63 y_true_rank = y_true_shape.ndims
64 if (y_true_rank is not None) and (y_pred_rank is not None):
65 # Use static rank for `y_true` and `y_pred`.
66 if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1:
67 y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
68 y_true, y_pred)
69 else:
70 # Use dynamic rank.
71 rank_diff = array_ops.rank(y_pred) - array_ops.rank(y_true)
72 squeeze_dims = lambda: confusion_matrix.remove_squeezable_dimensions( # pylint: disable=g-long-lambda
73 y_true, y_pred)
74 is_last_dim_1 = math_ops.equal(1, array_ops.shape(y_pred)[-1])
75 maybe_squeeze_dims = lambda: cond.cond( # pylint: disable=g-long-lambda
76 is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred))
77 y_true, y_pred = cond.cond(
78 math_ops.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims)
80 if sample_weight is None:
81 return y_pred, y_true
83 weights_shape = sample_weight.shape
84 weights_rank = weights_shape.ndims
85 if weights_rank == 0: # If weights is scalar, do nothing.
86 return y_pred, y_true, sample_weight
88 if (y_pred_rank is not None) and (weights_rank is not None):
89 # Use static rank.
90 if weights_rank - y_pred_rank == 1:
91 sample_weight = array_ops.squeeze(sample_weight, [-1])
92 elif y_pred_rank - weights_rank == 1:
93 sample_weight = array_ops.expand_dims(sample_weight, [-1])
94 return y_pred, y_true, sample_weight
96 # Use dynamic rank.
97 weights_rank_tensor = array_ops.rank(sample_weight)
98 rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
99 maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
101 def _maybe_expand_weights():
102 expand_weights = lambda: array_ops.expand_dims(sample_weight, [-1])
103 return cond.cond(
104 math_ops.equal(rank_diff, -1), expand_weights, lambda: sample_weight)
106 def _maybe_adjust_weights():
107 return cond.cond(
108 math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
109 _maybe_expand_weights)
111 # squeeze or expand last dim of `sample_weight` if its rank differs by 1
112 # from the new rank of `y_pred`.
113 sample_weight = cond.cond(
114 math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
115 _maybe_adjust_weights)
116 return y_pred, y_true, sample_weight
119def scale_losses_by_sample_weight(losses, sample_weight):
120 """Scales loss values by the given sample weights.
122 `sample_weight` dimensions are updated to match with the dimension of `losses`
123 if possible by using squeeze/expand/broadcast.
125 Args:
126 losses: Loss tensor.
127 sample_weight: Sample weights tensor.
129 Returns:
130 `losses` scaled by `sample_weight` with dtype float32.
131 """
132 # TODO(psv): Handle the casting here in a better way, eg. if losses is float64
133 # we do not want to lose precision.
134 losses = math_ops.cast(losses, dtypes.float32)
135 sample_weight = math_ops.cast(sample_weight, dtypes.float32)
137 # Update dimensions of `sample_weight` to match with `losses` if possible.
138 losses, _, sample_weight = squeeze_or_expand_dimensions(
139 losses, None, sample_weight)
140 return math_ops.multiply(losses, sample_weight)
143@tf_contextlib.contextmanager
144def check_per_example_loss_rank(per_example_loss):
145 """Context manager that checks that the rank of per_example_loss is at least 1.
147 Args:
148 per_example_loss: Per example loss tensor.
150 Yields:
151 A context manager.
152 """
153 loss_rank = per_example_loss.shape.rank
154 if loss_rank is not None:
155 # Handle static rank.
156 if loss_rank == 0:
157 raise ValueError(
158 "Invalid value passed for `per_example_loss`. Expected a tensor with "
159 f"at least rank 1. Received per_example_loss={per_example_loss} with "
160 f"rank {loss_rank}")
161 yield
162 else:
163 # Handle dynamic rank.
164 with ops.control_dependencies([
165 check_ops.assert_greater_equal(
166 array_ops.rank(per_example_loss),
167 math_ops.cast(1, dtype=dtypes.int32),
168 message="Invalid value passed for `per_example_loss`. Expected a "
169 "tensor with at least rank 1.")
170 ]):
171 yield
174@tf_export(v1=["losses.add_loss"])
175def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
176 """Adds a externally defined loss to the collection of losses.
178 Args:
179 loss: A loss `Tensor`.
180 loss_collection: Optional collection to add the loss to.
181 """
182 # Since we have no way of figuring out when a training iteration starts or
183 # ends, holding on to a loss when executing eagerly is indistinguishable from
184 # leaking memory. We instead leave the collection empty.
185 if loss_collection and not context.executing_eagerly():
186 ops.add_to_collection(loss_collection, loss)
189@tf_export(v1=["losses.get_losses"])
190def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES):
191 """Gets the list of losses from the loss_collection.
193 Args:
194 scope: An optional scope name for filtering the losses to return.
195 loss_collection: Optional losses collection.
197 Returns:
198 a list of loss tensors.
199 """
200 return ops.get_collection(loss_collection, scope)
203@tf_export(v1=["losses.get_regularization_losses"])
204def get_regularization_losses(scope=None):
205 """Gets the list of regularization losses.
207 Args:
208 scope: An optional scope name for filtering the losses to return.
210 Returns:
211 A list of regularization losses as Tensors.
212 """
213 return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
216@tf_export(v1=["losses.get_regularization_loss"])
217def get_regularization_loss(scope=None, name="total_regularization_loss"):
218 """Gets the total regularization loss.
220 Args:
221 scope: An optional scope name for filtering the losses to return.
222 name: The name of the returned tensor.
224 Returns:
225 A scalar regularization loss.
226 """
227 losses = get_regularization_losses(scope)
228 if losses:
229 return math_ops.add_n(losses, name=name)
230 else:
231 return constant_op.constant(0.0)
234@tf_export(v1=["losses.get_total_loss"])
235def get_total_loss(add_regularization_losses=True,
236 name="total_loss",
237 scope=None):
238 """Returns a tensor whose value represents the total loss.
240 In particular, this adds any losses you have added with `tf.add_loss()` to
241 any regularization losses that have been added by regularization parameters
242 on layers constructors e.g. `tf.layers`. Be very sure to use this if you
243 are constructing a loss_op manually. Otherwise regularization arguments
244 on `tf.layers` methods will not function.
246 Args:
247 add_regularization_losses: A boolean indicating whether or not to use the
248 regularization losses in the sum.
249 name: The name of the returned tensor.
250 scope: An optional scope name for filtering the losses to return. Note that
251 this filters the losses added with `tf.add_loss()` as well as the
252 regularization losses to that scope.
254 Returns:
255 A `Tensor` whose value represents the total loss.
257 Raises:
258 ValueError: if `losses` is not iterable.
259 """
260 losses = get_losses(scope=scope)
261 if add_regularization_losses:
262 losses += get_regularization_losses(scope=scope)
263 return math_ops.add_n(losses, name=name)