Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/losses/focal_loss.py: 41%

32 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implements Focal loss.""" 

16 

17import tensorflow as tf 

18import tensorflow.keras.backend as K 

19from typeguard import typechecked 

20 

21from tensorflow_addons.utils.keras_utils import LossFunctionWrapper 

22from tensorflow_addons.utils.types import FloatTensorLike, TensorLike 

23 

24 

25@tf.keras.utils.register_keras_serializable(package="Addons") 

26class SigmoidFocalCrossEntropy(LossFunctionWrapper): 

27 """Implements the focal loss function. 

28 

29 Focal loss was first introduced in the RetinaNet paper 

30 (https://arxiv.org/pdf/1708.02002.pdf). Focal loss is extremely useful for 

31 classification when you have highly imbalanced classes. It down-weights 

32 well-classified examples and focuses on hard examples. The loss value is 

33 much higher for a sample which is misclassified by the classifier as compared 

34 to the loss value corresponding to a well-classified example. One of the 

35 best use-cases of focal loss is its usage in object detection where the 

36 imbalance between the background class and other classes is extremely high. 

37 

38 Usage: 

39 

40 >>> fl = tfa.losses.SigmoidFocalCrossEntropy() 

41 >>> loss = fl( 

42 ... y_true = [[1.0], [1.0], [0.0]],y_pred = [[0.97], [0.91], [0.03]]) 

43 >>> loss 

44 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([6.8532745e-06, 1.9097870e-04, 2.0559824e-05], 

45 dtype=float32)> 

46 

47 Usage with `tf.keras` API: 

48 

49 >>> model = tf.keras.Model() 

50 >>> model.compile('sgd', loss=tfa.losses.SigmoidFocalCrossEntropy()) 

51 

52 Args: 

53 alpha: balancing factor, default value is 0.25. 

54 gamma: modulating factor, default value is 2.0. 

55 

56 Returns: 

57 Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same 

58 shape as `y_true`; otherwise, it is scalar. 

59 

60 Raises: 

61 ValueError: If the shape of `sample_weight` is invalid or value of 

62 `gamma` is less than zero. 

63 """ 

64 

65 @typechecked 

66 def __init__( 

67 self, 

68 from_logits: bool = False, 

69 alpha: FloatTensorLike = 0.25, 

70 gamma: FloatTensorLike = 2.0, 

71 reduction: str = tf.keras.losses.Reduction.NONE, 

72 name: str = "sigmoid_focal_crossentropy", 

73 ): 

74 super().__init__( 

75 sigmoid_focal_crossentropy, 

76 name=name, 

77 reduction=reduction, 

78 from_logits=from_logits, 

79 alpha=alpha, 

80 gamma=gamma, 

81 ) 

82 

83 

84@tf.keras.utils.register_keras_serializable(package="Addons") 

85@tf.function 

86def sigmoid_focal_crossentropy( 

87 y_true: TensorLike, 

88 y_pred: TensorLike, 

89 alpha: FloatTensorLike = 0.25, 

90 gamma: FloatTensorLike = 2.0, 

91 from_logits: bool = False, 

92) -> tf.Tensor: 

93 """Implements the focal loss function. 

94 

95 Focal loss was first introduced in the RetinaNet paper 

96 (https://arxiv.org/pdf/1708.02002.pdf). Focal loss is extremely useful for 

97 classification when you have highly imbalanced classes. It down-weights 

98 well-classified examples and focuses on hard examples. The loss value is 

99 much higher for a sample which is misclassified by the classifier as compared 

100 to the loss value corresponding to a well-classified example. One of the 

101 best use-cases of focal loss is its usage in object detection where the 

102 imbalance between the background class and other classes is extremely high. 

103 

104 Args: 

105 y_true: true targets tensor. 

106 y_pred: predictions tensor. 

107 alpha: balancing factor. 

108 gamma: modulating factor. 

109 

110 Returns: 

111 Weighted loss float `Tensor`. If `reduction` is `NONE`,this has the 

112 same shape as `y_true`; otherwise, it is scalar. 

113 """ 

114 if gamma and gamma < 0: 

115 raise ValueError("Value of gamma should be greater than or equal to zero.") 

116 

117 y_pred = tf.convert_to_tensor(y_pred) 

118 y_true = tf.cast(y_true, dtype=y_pred.dtype) 

119 

120 # Get the cross_entropy for each entry 

121 ce = K.binary_crossentropy(y_true, y_pred, from_logits=from_logits) 

122 

123 # If logits are provided then convert the predictions into probabilities 

124 if from_logits: 

125 pred_prob = tf.sigmoid(y_pred) 

126 else: 

127 pred_prob = y_pred 

128 

129 p_t = (y_true * pred_prob) + ((1 - y_true) * (1 - pred_prob)) 

130 alpha_factor = 1.0 

131 modulating_factor = 1.0 

132 

133 if alpha: 

134 alpha = tf.cast(alpha, dtype=y_true.dtype) 

135 alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha) 

136 

137 if gamma: 

138 gamma = tf.cast(gamma, dtype=y_true.dtype) 

139 modulating_factor = tf.pow((1.0 - p_t), gamma) 

140 

141 # compute the final loss and return 

142 return tf.reduce_sum(alpha_factor * modulating_factor * ce, axis=-1)