Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/losses/focal_loss.py: 41%
32 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements Focal loss."""
17import tensorflow as tf
18import tensorflow.keras.backend as K
19from typeguard import typechecked
21from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
22from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
25@tf.keras.utils.register_keras_serializable(package="Addons")
26class SigmoidFocalCrossEntropy(LossFunctionWrapper):
27 """Implements the focal loss function.
29 Focal loss was first introduced in the RetinaNet paper
30 (https://arxiv.org/pdf/1708.02002.pdf). Focal loss is extremely useful for
31 classification when you have highly imbalanced classes. It down-weights
32 well-classified examples and focuses on hard examples. The loss value is
33 much higher for a sample which is misclassified by the classifier as compared
34 to the loss value corresponding to a well-classified example. One of the
35 best use-cases of focal loss is its usage in object detection where the
36 imbalance between the background class and other classes is extremely high.
38 Usage:
40 >>> fl = tfa.losses.SigmoidFocalCrossEntropy()
41 >>> loss = fl(
42 ... y_true = [[1.0], [1.0], [0.0]],y_pred = [[0.97], [0.91], [0.03]])
43 >>> loss
44 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([6.8532745e-06, 1.9097870e-04, 2.0559824e-05],
45 dtype=float32)>
47 Usage with `tf.keras` API:
49 >>> model = tf.keras.Model()
50 >>> model.compile('sgd', loss=tfa.losses.SigmoidFocalCrossEntropy())
52 Args:
53 alpha: balancing factor, default value is 0.25.
54 gamma: modulating factor, default value is 2.0.
56 Returns:
57 Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
58 shape as `y_true`; otherwise, it is scalar.
60 Raises:
61 ValueError: If the shape of `sample_weight` is invalid or value of
62 `gamma` is less than zero.
63 """
65 @typechecked
66 def __init__(
67 self,
68 from_logits: bool = False,
69 alpha: FloatTensorLike = 0.25,
70 gamma: FloatTensorLike = 2.0,
71 reduction: str = tf.keras.losses.Reduction.NONE,
72 name: str = "sigmoid_focal_crossentropy",
73 ):
74 super().__init__(
75 sigmoid_focal_crossentropy,
76 name=name,
77 reduction=reduction,
78 from_logits=from_logits,
79 alpha=alpha,
80 gamma=gamma,
81 )
84@tf.keras.utils.register_keras_serializable(package="Addons")
85@tf.function
86def sigmoid_focal_crossentropy(
87 y_true: TensorLike,
88 y_pred: TensorLike,
89 alpha: FloatTensorLike = 0.25,
90 gamma: FloatTensorLike = 2.0,
91 from_logits: bool = False,
92) -> tf.Tensor:
93 """Implements the focal loss function.
95 Focal loss was first introduced in the RetinaNet paper
96 (https://arxiv.org/pdf/1708.02002.pdf). Focal loss is extremely useful for
97 classification when you have highly imbalanced classes. It down-weights
98 well-classified examples and focuses on hard examples. The loss value is
99 much higher for a sample which is misclassified by the classifier as compared
100 to the loss value corresponding to a well-classified example. One of the
101 best use-cases of focal loss is its usage in object detection where the
102 imbalance between the background class and other classes is extremely high.
104 Args:
105 y_true: true targets tensor.
106 y_pred: predictions tensor.
107 alpha: balancing factor.
108 gamma: modulating factor.
110 Returns:
111 Weighted loss float `Tensor`. If `reduction` is `NONE`,this has the
112 same shape as `y_true`; otherwise, it is scalar.
113 """
114 if gamma and gamma < 0:
115 raise ValueError("Value of gamma should be greater than or equal to zero.")
117 y_pred = tf.convert_to_tensor(y_pred)
118 y_true = tf.cast(y_true, dtype=y_pred.dtype)
120 # Get the cross_entropy for each entry
121 ce = K.binary_crossentropy(y_true, y_pred, from_logits=from_logits)
123 # If logits are provided then convert the predictions into probabilities
124 if from_logits:
125 pred_prob = tf.sigmoid(y_pred)
126 else:
127 pred_prob = y_pred
129 p_t = (y_true * pred_prob) + ((1 - y_true) * (1 - pred_prob))
130 alpha_factor = 1.0
131 modulating_factor = 1.0
133 if alpha:
134 alpha = tf.cast(alpha, dtype=y_true.dtype)
135 alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
137 if gamma:
138 gamma = tf.cast(gamma, dtype=y_true.dtype)
139 modulating_factor = tf.pow((1.0 - p_t), gamma)
141 # compute the final loss and return
142 return tf.reduce_sum(alpha_factor * modulating_factor * ce, axis=-1)