Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/losses/sparsemax_loss.py: 46%

35 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16import tensorflow as tf 

17from tensorflow_addons.activations.sparsemax import sparsemax 

18 

19from tensorflow_addons.utils.types import TensorLike 

20from typeguard import typechecked 

21from typing import Optional 

22 

23 

24@tf.keras.utils.register_keras_serializable(package="Addons") 

25def sparsemax_loss( 

26 logits: TensorLike, 

27 sparsemax: TensorLike, 

28 labels: TensorLike, 

29 name: Optional[str] = None, 

30) -> tf.Tensor: 

31 """Sparsemax loss function [1]. 

32 

33 Computes the generalized multi-label classification loss for the sparsemax 

34 function. The implementation is a reformulation of the original loss 

35 function such that it uses the sparsemax probability output instead of the 

36 internal $ \tau $ variable. However, the output is identical to the original 

37 loss function. 

38 

39 [1]: https://arxiv.org/abs/1602.02068 

40 

41 Args: 

42 logits: A `Tensor`. Must be one of the following types: `float32`, 

43 `float64`. 

44 sparsemax: A `Tensor`. Must have the same type as `logits`. 

45 labels: A `Tensor`. Must have the same type as `logits`. 

46 name: A name for the operation (optional). 

47 Returns: 

48 A `Tensor`. Has the same type as `logits`. 

49 """ 

50 logits = tf.convert_to_tensor(logits, name="logits") 

51 sparsemax = tf.convert_to_tensor(sparsemax, name="sparsemax") 

52 labels = tf.convert_to_tensor(labels, name="labels") 

53 

54 # In the paper, they call the logits z. 

55 # A constant can be substracted from logits to make the algorithm 

56 # more numerically stable in theory. However, there are really no major 

57 # source numerical instability in this algorithm. 

58 z = logits 

59 

60 # sum over support 

61 # Use a conditional where instead of a multiplication to support z = -inf. 

62 # If z = -inf, and there is no support (sparsemax = 0), a multiplication 

63 # would cause 0 * -inf = nan, which is not correct in this case. 

64 sum_s = tf.where( 

65 tf.math.logical_or(sparsemax > 0, tf.math.is_nan(sparsemax)), 

66 sparsemax * (z - 0.5 * sparsemax), 

67 tf.zeros_like(sparsemax), 

68 ) 

69 

70 # - z_k + ||q||^2 

71 q_part = labels * (0.5 * labels - z) 

72 # Fix the case where labels = 0 and z = -inf, where q_part would 

73 # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for 

74 # z = -inf should be consideredself. 

75 # The code below also coveres the case where z = inf. Howeverm in this 

76 # caose the sparsemax will be nan, which means the sum_s will also be nan, 

77 # therefor this case doesn't need addtional special treatment. 

78 q_part_safe = tf.where( 

79 tf.math.logical_and(tf.math.equal(labels, 0), tf.math.is_inf(z)), 

80 tf.zeros_like(z), 

81 q_part, 

82 ) 

83 

84 return tf.math.reduce_sum(sum_s + q_part_safe, axis=1) 

85 

86 

87@tf.function 

88@tf.keras.utils.register_keras_serializable(package="Addons") 

89def sparsemax_loss_from_logits( 

90 y_true: TensorLike, logits_pred: TensorLike 

91) -> tf.Tensor: 

92 y_pred = sparsemax(logits_pred) 

93 loss = sparsemax_loss(logits_pred, y_pred, y_true) 

94 return loss 

95 

96 

97@tf.keras.utils.register_keras_serializable(package="Addons") 

98class SparsemaxLoss(tf.keras.losses.Loss): 

99 """Sparsemax loss function. 

100 

101 Computes the generalized multi-label classification loss for the sparsemax 

102 function. 

103 

104 Because the sparsemax loss function needs both the probability output and 

105 the logits to compute the loss value, `from_logits` must be `True`. 

106 

107 Because it computes the generalized multi-label loss, the shape of both 

108 `y_pred` and `y_true` must be `[batch_size, num_classes]`. 

109 

110 Args: 

111 from_logits: Whether `y_pred` is expected to be a logits tensor. Default 

112 is `True`, meaning `y_pred` is the logits. 

113 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 

114 loss. Default value is `SUM_OVER_BATCH_SIZE`. 

115 name: Optional name for the op. 

116 """ 

117 

118 @typechecked 

119 def __init__( 

120 self, 

121 from_logits: bool = True, 

122 reduction: str = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, 

123 name: str = "sparsemax_loss", 

124 ): 

125 if from_logits is not True: 

126 raise ValueError("from_logits must be True") 

127 

128 super().__init__(name=name, reduction=reduction) 

129 self.from_logits = from_logits 

130 

131 def call(self, y_true, y_pred): 

132 return sparsemax_loss_from_logits(y_true, y_pred) 

133 

134 def get_config(self): 

135 config = { 

136 "from_logits": self.from_logits, 

137 } 

138 base_config = super().get_config() 

139 return {**base_config, **config}