Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/losses/sparsemax_loss.py: 46%
35 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16import tensorflow as tf
17from tensorflow_addons.activations.sparsemax import sparsemax
19from tensorflow_addons.utils.types import TensorLike
20from typeguard import typechecked
21from typing import Optional
24@tf.keras.utils.register_keras_serializable(package="Addons")
25def sparsemax_loss(
26 logits: TensorLike,
27 sparsemax: TensorLike,
28 labels: TensorLike,
29 name: Optional[str] = None,
30) -> tf.Tensor:
31 """Sparsemax loss function [1].
33 Computes the generalized multi-label classification loss for the sparsemax
34 function. The implementation is a reformulation of the original loss
35 function such that it uses the sparsemax probability output instead of the
36 internal $ \tau $ variable. However, the output is identical to the original
37 loss function.
39 [1]: https://arxiv.org/abs/1602.02068
41 Args:
42 logits: A `Tensor`. Must be one of the following types: `float32`,
43 `float64`.
44 sparsemax: A `Tensor`. Must have the same type as `logits`.
45 labels: A `Tensor`. Must have the same type as `logits`.
46 name: A name for the operation (optional).
47 Returns:
48 A `Tensor`. Has the same type as `logits`.
49 """
50 logits = tf.convert_to_tensor(logits, name="logits")
51 sparsemax = tf.convert_to_tensor(sparsemax, name="sparsemax")
52 labels = tf.convert_to_tensor(labels, name="labels")
54 # In the paper, they call the logits z.
55 # A constant can be substracted from logits to make the algorithm
56 # more numerically stable in theory. However, there are really no major
57 # source numerical instability in this algorithm.
58 z = logits
60 # sum over support
61 # Use a conditional where instead of a multiplication to support z = -inf.
62 # If z = -inf, and there is no support (sparsemax = 0), a multiplication
63 # would cause 0 * -inf = nan, which is not correct in this case.
64 sum_s = tf.where(
65 tf.math.logical_or(sparsemax > 0, tf.math.is_nan(sparsemax)),
66 sparsemax * (z - 0.5 * sparsemax),
67 tf.zeros_like(sparsemax),
68 )
70 # - z_k + ||q||^2
71 q_part = labels * (0.5 * labels - z)
72 # Fix the case where labels = 0 and z = -inf, where q_part would
73 # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for
74 # z = -inf should be consideredself.
75 # The code below also coveres the case where z = inf. Howeverm in this
76 # caose the sparsemax will be nan, which means the sum_s will also be nan,
77 # therefor this case doesn't need addtional special treatment.
78 q_part_safe = tf.where(
79 tf.math.logical_and(tf.math.equal(labels, 0), tf.math.is_inf(z)),
80 tf.zeros_like(z),
81 q_part,
82 )
84 return tf.math.reduce_sum(sum_s + q_part_safe, axis=1)
87@tf.function
88@tf.keras.utils.register_keras_serializable(package="Addons")
89def sparsemax_loss_from_logits(
90 y_true: TensorLike, logits_pred: TensorLike
91) -> tf.Tensor:
92 y_pred = sparsemax(logits_pred)
93 loss = sparsemax_loss(logits_pred, y_pred, y_true)
94 return loss
97@tf.keras.utils.register_keras_serializable(package="Addons")
98class SparsemaxLoss(tf.keras.losses.Loss):
99 """Sparsemax loss function.
101 Computes the generalized multi-label classification loss for the sparsemax
102 function.
104 Because the sparsemax loss function needs both the probability output and
105 the logits to compute the loss value, `from_logits` must be `True`.
107 Because it computes the generalized multi-label loss, the shape of both
108 `y_pred` and `y_true` must be `[batch_size, num_classes]`.
110 Args:
111 from_logits: Whether `y_pred` is expected to be a logits tensor. Default
112 is `True`, meaning `y_pred` is the logits.
113 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
114 loss. Default value is `SUM_OVER_BATCH_SIZE`.
115 name: Optional name for the op.
116 """
118 @typechecked
119 def __init__(
120 self,
121 from_logits: bool = True,
122 reduction: str = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
123 name: str = "sparsemax_loss",
124 ):
125 if from_logits is not True:
126 raise ValueError("from_logits must be True")
128 super().__init__(name=name, reduction=reduction)
129 self.from_logits = from_logits
131 def call(self, y_true, y_pred):
132 return sparsemax_loss_from_logits(y_true, y_pred)
134 def get_config(self):
135 config = {
136 "from_logits": self.from_logits,
137 }
138 base_config = super().get_config()
139 return {**base_config, **config}