Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/activations/sparsemax.py: 15%

39 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16import tensorflow as tf 

17 

18from tensorflow_addons.utils import types 

19 

20 

21@tf.keras.utils.register_keras_serializable(package="Addons") 

22def sparsemax(logits: types.TensorLike, axis: int = -1) -> tf.Tensor: 

23 r"""Sparsemax activation function. 

24 

25 For each batch $i$, and class $j$, 

26 compute sparsemax activation function: 

27 

28 $$ 

29 \mathrm{sparsemax}(x)[i, j] = \max(\mathrm{logits}[i, j] - \tau(\mathrm{logits}[i, :]), 0). 

30 $$ 

31 

32 See [From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification](https://arxiv.org/abs/1602.02068). 

33 

34 Usage: 

35 

36 >>> x = tf.constant([[-1.0, 0.0, 1.0], [-5.0, 1.0, 2.0]]) 

37 >>> tfa.activations.sparsemax(x) 

38 <tf.Tensor: shape=(2, 3), dtype=float32, numpy= 

39 array([[0., 0., 1.], 

40 [0., 0., 1.]], dtype=float32)> 

41 

42 Args: 

43 logits: A `Tensor`. 

44 axis: `int`, axis along which the sparsemax operation is applied. 

45 Returns: 

46 A `Tensor`, output of sparsemax transformation. Has the same type and 

47 shape as `logits`. 

48 Raises: 

49 ValueError: In case `dim(logits) == 1`. 

50 """ 

51 logits = tf.convert_to_tensor(logits, name="logits") 

52 

53 # We need its original shape for shape inference. 

54 shape = logits.get_shape() 

55 rank = shape.rank 

56 is_last_axis = (axis == -1) or (axis == rank - 1) 

57 

58 if is_last_axis: 

59 output = _compute_2d_sparsemax(logits) 

60 output.set_shape(shape) 

61 return output 

62 

63 # If dim is not the last dimension, we have to do a transpose so that we can 

64 # still perform softmax on its last dimension. 

65 

66 # Swap logits' dimension of dim and its last dimension. 

67 rank_op = tf.rank(logits) 

68 axis_norm = axis % rank 

69 logits = _swap_axis(logits, axis_norm, tf.math.subtract(rank_op, 1)) 

70 

71 # Do the actual softmax on its last dimension. 

72 output = _compute_2d_sparsemax(logits) 

73 output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1)) 

74 

75 # Make shape inference work since transpose may erase its static shape. 

76 output.set_shape(shape) 

77 return output 

78 

79 

80def _swap_axis(logits, dim_index, last_index, **kwargs): 

81 return tf.transpose( 

82 logits, 

83 tf.concat( 

84 [ 

85 tf.range(dim_index), 

86 [last_index], 

87 tf.range(dim_index + 1, last_index), 

88 [dim_index], 

89 ], 

90 0, 

91 ), 

92 **kwargs, 

93 ) 

94 

95 

96def _compute_2d_sparsemax(logits): 

97 """Performs the sparsemax operation when axis=-1.""" 

98 shape_op = tf.shape(logits) 

99 obs = tf.math.reduce_prod(shape_op[:-1]) 

100 dims = shape_op[-1] 

101 

102 # In the paper, they call the logits z. 

103 # The mean(logits) can be substracted from logits to make the algorithm 

104 # more numerically stable. the instability in this algorithm comes mostly 

105 # from the z_cumsum. Substacting the mean will cause z_cumsum to be close 

106 # to zero. However, in practise the numerical instability issues are very 

107 # minor and substacting the mean causes extra issues with inf and nan 

108 # input. 

109 # Reshape to [obs, dims] as it is almost free and means the remanining 

110 # code doesn't need to worry about the rank. 

111 z = tf.reshape(logits, [obs, dims]) 

112 

113 # sort z 

114 z_sorted, _ = tf.nn.top_k(z, k=dims) 

115 

116 # calculate k(z) 

117 z_cumsum = tf.math.cumsum(z_sorted, axis=-1) 

118 k = tf.range(1, tf.cast(dims, logits.dtype) + 1, dtype=logits.dtype) 

119 z_check = 1 + k * z_sorted > z_cumsum 

120 # because the z_check vector is always [1,1,...1,0,0,...0] finding the 

121 # (index + 1) of the last `1` is the same as just summing the number of 1. 

122 k_z = tf.math.reduce_sum(tf.cast(z_check, tf.int32), axis=-1) 

123 

124 # calculate tau(z) 

125 # If there are inf values or all values are -inf, the k_z will be zero, 

126 # this is mathematically invalid and will also cause the gather_nd to fail. 

127 # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then 

128 # fixed later (see p_safe) by returning p = nan. This results in the same 

129 # behavior as softmax. 

130 k_z_safe = tf.math.maximum(k_z, 1) 

131 indices = tf.stack([tf.range(0, obs), tf.reshape(k_z_safe, [-1]) - 1], axis=1) 

132 tau_sum = tf.gather_nd(z_cumsum, indices) 

133 tau_z = (tau_sum - 1) / tf.cast(k_z, logits.dtype) 

134 

135 # calculate p 

136 p = tf.math.maximum(tf.cast(0, logits.dtype), z - tf.expand_dims(tau_z, -1)) 

137 # If k_z = 0 or if z = nan, then the input is invalid 

138 p_safe = tf.where( 

139 tf.expand_dims( 

140 tf.math.logical_or(tf.math.equal(k_z, 0), tf.math.is_nan(z_cumsum[:, -1])), 

141 axis=-1, 

142 ), 

143 tf.fill([obs, dims], tf.cast(float("nan"), logits.dtype)), 

144 p, 

145 ) 

146 

147 # Reshape back to original size 

148 p_safe = tf.reshape(p_safe, shape_op) 

149 return p_safe