Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/seq2seq/loss.py: 17%

66 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Loss functions for sequence models.""" 

16 

17import tensorflow as tf 

18from tensorflow_addons.utils.types import TensorLike 

19 

20from typeguard import typechecked 

21from typing import Callable, Optional 

22 

23 

24def sequence_loss( 

25 logits: TensorLike, 

26 targets: TensorLike, 

27 weights: TensorLike, 

28 average_across_timesteps: bool = True, 

29 average_across_batch: bool = True, 

30 sum_over_timesteps: bool = False, 

31 sum_over_batch: bool = False, 

32 softmax_loss_function: Optional[Callable] = None, 

33 name: Optional[str] = None, 

34) -> tf.Tensor: 

35 """Computes the weighted cross-entropy loss for a sequence of logits. 

36 

37 Depending on the values of `average_across_timesteps` / 

38 `sum_over_timesteps` and `average_across_batch` / `sum_over_batch`, the 

39 return Tensor will have rank 0, 1, or 2 as these arguments reduce the 

40 cross-entropy at each target, which has shape 

41 `[batch_size, sequence_length]`, over their respective dimensions. For 

42 example, if `average_across_timesteps` is `True` and `average_across_batch` 

43 is `False`, then the return Tensor will have shape `[batch_size]`. 

44 

45 Note that `average_across_timesteps` and `sum_over_timesteps` cannot be 

46 True at same time. Same for `average_across_batch` and `sum_over_batch`. 

47 

48 The recommended loss reduction in tf 2.0 has been changed to sum_over, 

49 instead of weighted average. User are recommend to use `sum_over_timesteps` 

50 and `sum_over_batch` for reduction. 

51 

52 Args: 

53 logits: A Tensor of shape 

54 `[batch_size, sequence_length, num_decoder_symbols]` and dtype float. 

55 The logits correspond to the prediction across all classes at each 

56 timestep. 

57 targets: A Tensor of shape `[batch_size, sequence_length]` and dtype 

58 int. The target represents the true class at each timestep. 

59 weights: A Tensor of shape `[batch_size, sequence_length]` and dtype 

60 float. `weights` constitutes the weighting of each prediction in the 

61 sequence. When using `weights` as masking, set all valid timesteps to 1 

62 and all padded timesteps to 0, e.g. a mask returned by 

63 `tf.sequence_mask`. 

64 average_across_timesteps: If set, sum the cost across the sequence 

65 dimension and divide the cost by the total label weight across 

66 timesteps. 

67 average_across_batch: If set, sum the cost across the batch dimension and 

68 divide the returned cost by the batch size. 

69 sum_over_timesteps: If set, sum the cost across the sequence dimension 

70 and divide the size of the sequence. Note that any element with 0 

71 weights will be excluded from size calculation. 

72 sum_over_batch: if set, sum the cost across the batch dimension and 

73 divide the total cost by the batch size. Not that any element with 0 

74 weights will be excluded from size calculation. 

75 softmax_loss_function: Function (labels, logits) -> loss-batch 

76 to be used instead of the standard softmax (the default if this is 

77 None). **Note that to avoid confusion, it is required for the function 

78 to accept named arguments.** 

79 name: Optional name for this operation, defaults to "sequence_loss". 

80 

81 Returns: 

82 A float Tensor of rank 0, 1, or 2 depending on the 

83 `average_across_timesteps` and `average_across_batch` arguments. By 

84 default, it has rank 0 (scalar) and is the weighted average cross-entropy 

85 (log-perplexity) per symbol. 

86 

87 Raises: 

88 ValueError: logits does not have 3 dimensions or targets does not have 2 

89 dimensions or weights does not have 2 dimensions. 

90 """ 

91 if len(logits.shape) != 3: 

92 raise ValueError( 

93 "Logits must be a [batch_size x sequence_length x logits] tensor" 

94 ) 

95 

96 targets_rank = len(targets.shape) 

97 if targets_rank != 2 and targets_rank != 3: 

98 raise ValueError( 

99 "Targets must be either a [batch_size x sequence_length] tensor " 

100 + "where each element contains the labels' index" 

101 + "or a [batch_size x sequence_length x num_classes] tensor " 

102 + "where the third axis is a one-hot representation of the labels" 

103 ) 

104 

105 if len(weights.shape) != 2: 

106 raise ValueError("Weights must be a [batch_size x sequence_length] tensor") 

107 

108 if average_across_timesteps and sum_over_timesteps: 

109 raise ValueError( 

110 "average_across_timesteps and sum_over_timesteps cannot " 

111 "be set to True at same time." 

112 ) 

113 if average_across_batch and sum_over_batch: 

114 raise ValueError( 

115 "average_across_batch and sum_over_batch cannot be set " 

116 "to True at same time." 

117 ) 

118 if average_across_batch and sum_over_timesteps: 

119 raise ValueError( 

120 "average_across_batch and sum_over_timesteps cannot be set " 

121 "to True at same time because of ambiguous order." 

122 ) 

123 if sum_over_batch and average_across_timesteps: 

124 raise ValueError( 

125 "sum_over_batch and average_across_timesteps cannot be set " 

126 "to True at same time because of ambiguous order." 

127 ) 

128 with tf.name_scope(name or "sequence_loss"): 

129 num_classes = tf.shape(input=logits)[2] 

130 logits_flat = tf.reshape(logits, [-1, num_classes]) 

131 if softmax_loss_function is None: 

132 if targets_rank == 2: 

133 targets = tf.reshape(targets, [-1]) 

134 crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( 

135 labels=targets, logits=logits_flat 

136 ) 

137 else: 

138 targets = tf.reshape(targets, [-1, num_classes]) 

139 crossent = tf.nn.softmax_cross_entropy_with_logits( 

140 labels=targets, logits=logits_flat 

141 ) 

142 else: 

143 targets = tf.reshape(targets, [-1]) 

144 crossent = softmax_loss_function(labels=targets, logits=logits_flat) 

145 crossent *= tf.reshape(weights, [-1]) 

146 if average_across_timesteps and average_across_batch: 

147 crossent = tf.reduce_sum(input_tensor=crossent) 

148 total_size = tf.reduce_sum(input_tensor=weights) 

149 crossent = tf.math.divide_no_nan(crossent, total_size) 

150 elif sum_over_timesteps and sum_over_batch: 

151 crossent = tf.reduce_sum(input_tensor=crossent) 

152 total_count = tf.cast(tf.math.count_nonzero(weights), crossent.dtype) 

153 crossent = tf.math.divide_no_nan(crossent, total_count) 

154 else: 

155 crossent = tf.reshape(crossent, tf.shape(input=logits)[0:2]) 

156 if average_across_timesteps or average_across_batch: 

157 reduce_axis = [0] if average_across_batch else [1] 

158 crossent = tf.reduce_sum(input_tensor=crossent, axis=reduce_axis) 

159 total_size = tf.reduce_sum(input_tensor=weights, axis=reduce_axis) 

160 crossent = tf.math.divide_no_nan(crossent, total_size) 

161 elif sum_over_timesteps or sum_over_batch: 

162 reduce_axis = [0] if sum_over_batch else [1] 

163 crossent = tf.reduce_sum(input_tensor=crossent, axis=reduce_axis) 

164 total_count = tf.cast( 

165 tf.math.count_nonzero(weights, axis=reduce_axis), 

166 dtype=crossent.dtype, 

167 ) 

168 crossent = tf.math.divide_no_nan(crossent, total_count) 

169 return crossent 

170 

171 

172class SequenceLoss(tf.keras.losses.Loss): 

173 """Weighted cross-entropy loss for a sequence of logits.""" 

174 

175 @typechecked 

176 def __init__( 

177 self, 

178 average_across_timesteps: bool = False, 

179 average_across_batch: bool = False, 

180 sum_over_timesteps: bool = True, 

181 sum_over_batch: bool = True, 

182 softmax_loss_function: Optional[Callable] = None, 

183 name: Optional[str] = None, 

184 ): 

185 super().__init__(reduction=tf.keras.losses.Reduction.NONE, name=name) 

186 self.average_across_timesteps = average_across_timesteps 

187 self.average_across_batch = average_across_batch 

188 self.sum_over_timesteps = sum_over_timesteps 

189 self.sum_over_batch = sum_over_batch 

190 self.softmax_loss_function = softmax_loss_function 

191 

192 def __call__(self, y_true, y_pred, sample_weight=None): 

193 """Override the parent __call__ to have a customized reduce 

194 behavior.""" 

195 return sequence_loss( 

196 y_pred, 

197 y_true, 

198 sample_weight, 

199 average_across_timesteps=self.average_across_timesteps, 

200 average_across_batch=self.average_across_batch, 

201 sum_over_timesteps=self.sum_over_timesteps, 

202 sum_over_batch=self.sum_over_batch, 

203 softmax_loss_function=self.softmax_loss_function, 

204 name=self.name, 

205 ) 

206 

207 def call(self, y_true, y_pred): 

208 # Skip this method since the __call__ contains real implementation. 

209 pass