Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/activations/sparsemax.py: 15%
39 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16import tensorflow as tf
18from tensorflow_addons.utils import types
21@tf.keras.utils.register_keras_serializable(package="Addons")
22def sparsemax(logits: types.TensorLike, axis: int = -1) -> tf.Tensor:
23 r"""Sparsemax activation function.
25 For each batch $i$, and class $j$,
26 compute sparsemax activation function:
28 $$
29 \mathrm{sparsemax}(x)[i, j] = \max(\mathrm{logits}[i, j] - \tau(\mathrm{logits}[i, :]), 0).
30 $$
32 See [From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification](https://arxiv.org/abs/1602.02068).
34 Usage:
36 >>> x = tf.constant([[-1.0, 0.0, 1.0], [-5.0, 1.0, 2.0]])
37 >>> tfa.activations.sparsemax(x)
38 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
39 array([[0., 0., 1.],
40 [0., 0., 1.]], dtype=float32)>
42 Args:
43 logits: A `Tensor`.
44 axis: `int`, axis along which the sparsemax operation is applied.
45 Returns:
46 A `Tensor`, output of sparsemax transformation. Has the same type and
47 shape as `logits`.
48 Raises:
49 ValueError: In case `dim(logits) == 1`.
50 """
51 logits = tf.convert_to_tensor(logits, name="logits")
53 # We need its original shape for shape inference.
54 shape = logits.get_shape()
55 rank = shape.rank
56 is_last_axis = (axis == -1) or (axis == rank - 1)
58 if is_last_axis:
59 output = _compute_2d_sparsemax(logits)
60 output.set_shape(shape)
61 return output
63 # If dim is not the last dimension, we have to do a transpose so that we can
64 # still perform softmax on its last dimension.
66 # Swap logits' dimension of dim and its last dimension.
67 rank_op = tf.rank(logits)
68 axis_norm = axis % rank
69 logits = _swap_axis(logits, axis_norm, tf.math.subtract(rank_op, 1))
71 # Do the actual softmax on its last dimension.
72 output = _compute_2d_sparsemax(logits)
73 output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1))
75 # Make shape inference work since transpose may erase its static shape.
76 output.set_shape(shape)
77 return output
80def _swap_axis(logits, dim_index, last_index, **kwargs):
81 return tf.transpose(
82 logits,
83 tf.concat(
84 [
85 tf.range(dim_index),
86 [last_index],
87 tf.range(dim_index + 1, last_index),
88 [dim_index],
89 ],
90 0,
91 ),
92 **kwargs,
93 )
96def _compute_2d_sparsemax(logits):
97 """Performs the sparsemax operation when axis=-1."""
98 shape_op = tf.shape(logits)
99 obs = tf.math.reduce_prod(shape_op[:-1])
100 dims = shape_op[-1]
102 # In the paper, they call the logits z.
103 # The mean(logits) can be substracted from logits to make the algorithm
104 # more numerically stable. the instability in this algorithm comes mostly
105 # from the z_cumsum. Substacting the mean will cause z_cumsum to be close
106 # to zero. However, in practise the numerical instability issues are very
107 # minor and substacting the mean causes extra issues with inf and nan
108 # input.
109 # Reshape to [obs, dims] as it is almost free and means the remanining
110 # code doesn't need to worry about the rank.
111 z = tf.reshape(logits, [obs, dims])
113 # sort z
114 z_sorted, _ = tf.nn.top_k(z, k=dims)
116 # calculate k(z)
117 z_cumsum = tf.math.cumsum(z_sorted, axis=-1)
118 k = tf.range(1, tf.cast(dims, logits.dtype) + 1, dtype=logits.dtype)
119 z_check = 1 + k * z_sorted > z_cumsum
120 # because the z_check vector is always [1,1,...1,0,0,...0] finding the
121 # (index + 1) of the last `1` is the same as just summing the number of 1.
122 k_z = tf.math.reduce_sum(tf.cast(z_check, tf.int32), axis=-1)
124 # calculate tau(z)
125 # If there are inf values or all values are -inf, the k_z will be zero,
126 # this is mathematically invalid and will also cause the gather_nd to fail.
127 # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then
128 # fixed later (see p_safe) by returning p = nan. This results in the same
129 # behavior as softmax.
130 k_z_safe = tf.math.maximum(k_z, 1)
131 indices = tf.stack([tf.range(0, obs), tf.reshape(k_z_safe, [-1]) - 1], axis=1)
132 tau_sum = tf.gather_nd(z_cumsum, indices)
133 tau_z = (tau_sum - 1) / tf.cast(k_z, logits.dtype)
135 # calculate p
136 p = tf.math.maximum(tf.cast(0, logits.dtype), z - tf.expand_dims(tau_z, -1))
137 # If k_z = 0 or if z = nan, then the input is invalid
138 p_safe = tf.where(
139 tf.expand_dims(
140 tf.math.logical_or(tf.math.equal(k_z, 0), tf.math.is_nan(z_cumsum[:, -1])),
141 axis=-1,
142 ),
143 tf.fill([obs, dims], tf.cast(float("nan"), logits.dtype)),
144 p,
145 )
147 # Reshape back to original size
148 p_safe = tf.reshape(p_safe, shape_op)
149 return p_safe