Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/seq2seq/loss.py: 17%
66 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Loss functions for sequence models."""
17import tensorflow as tf
18from tensorflow_addons.utils.types import TensorLike
20from typeguard import typechecked
21from typing import Callable, Optional
24def sequence_loss(
25 logits: TensorLike,
26 targets: TensorLike,
27 weights: TensorLike,
28 average_across_timesteps: bool = True,
29 average_across_batch: bool = True,
30 sum_over_timesteps: bool = False,
31 sum_over_batch: bool = False,
32 softmax_loss_function: Optional[Callable] = None,
33 name: Optional[str] = None,
34) -> tf.Tensor:
35 """Computes the weighted cross-entropy loss for a sequence of logits.
37 Depending on the values of `average_across_timesteps` /
38 `sum_over_timesteps` and `average_across_batch` / `sum_over_batch`, the
39 return Tensor will have rank 0, 1, or 2 as these arguments reduce the
40 cross-entropy at each target, which has shape
41 `[batch_size, sequence_length]`, over their respective dimensions. For
42 example, if `average_across_timesteps` is `True` and `average_across_batch`
43 is `False`, then the return Tensor will have shape `[batch_size]`.
45 Note that `average_across_timesteps` and `sum_over_timesteps` cannot be
46 True at same time. Same for `average_across_batch` and `sum_over_batch`.
48 The recommended loss reduction in tf 2.0 has been changed to sum_over,
49 instead of weighted average. User are recommend to use `sum_over_timesteps`
50 and `sum_over_batch` for reduction.
52 Args:
53 logits: A Tensor of shape
54 `[batch_size, sequence_length, num_decoder_symbols]` and dtype float.
55 The logits correspond to the prediction across all classes at each
56 timestep.
57 targets: A Tensor of shape `[batch_size, sequence_length]` and dtype
58 int. The target represents the true class at each timestep.
59 weights: A Tensor of shape `[batch_size, sequence_length]` and dtype
60 float. `weights` constitutes the weighting of each prediction in the
61 sequence. When using `weights` as masking, set all valid timesteps to 1
62 and all padded timesteps to 0, e.g. a mask returned by
63 `tf.sequence_mask`.
64 average_across_timesteps: If set, sum the cost across the sequence
65 dimension and divide the cost by the total label weight across
66 timesteps.
67 average_across_batch: If set, sum the cost across the batch dimension and
68 divide the returned cost by the batch size.
69 sum_over_timesteps: If set, sum the cost across the sequence dimension
70 and divide the size of the sequence. Note that any element with 0
71 weights will be excluded from size calculation.
72 sum_over_batch: if set, sum the cost across the batch dimension and
73 divide the total cost by the batch size. Not that any element with 0
74 weights will be excluded from size calculation.
75 softmax_loss_function: Function (labels, logits) -> loss-batch
76 to be used instead of the standard softmax (the default if this is
77 None). **Note that to avoid confusion, it is required for the function
78 to accept named arguments.**
79 name: Optional name for this operation, defaults to "sequence_loss".
81 Returns:
82 A float Tensor of rank 0, 1, or 2 depending on the
83 `average_across_timesteps` and `average_across_batch` arguments. By
84 default, it has rank 0 (scalar) and is the weighted average cross-entropy
85 (log-perplexity) per symbol.
87 Raises:
88 ValueError: logits does not have 3 dimensions or targets does not have 2
89 dimensions or weights does not have 2 dimensions.
90 """
91 if len(logits.shape) != 3:
92 raise ValueError(
93 "Logits must be a [batch_size x sequence_length x logits] tensor"
94 )
96 targets_rank = len(targets.shape)
97 if targets_rank != 2 and targets_rank != 3:
98 raise ValueError(
99 "Targets must be either a [batch_size x sequence_length] tensor "
100 + "where each element contains the labels' index"
101 + "or a [batch_size x sequence_length x num_classes] tensor "
102 + "where the third axis is a one-hot representation of the labels"
103 )
105 if len(weights.shape) != 2:
106 raise ValueError("Weights must be a [batch_size x sequence_length] tensor")
108 if average_across_timesteps and sum_over_timesteps:
109 raise ValueError(
110 "average_across_timesteps and sum_over_timesteps cannot "
111 "be set to True at same time."
112 )
113 if average_across_batch and sum_over_batch:
114 raise ValueError(
115 "average_across_batch and sum_over_batch cannot be set "
116 "to True at same time."
117 )
118 if average_across_batch and sum_over_timesteps:
119 raise ValueError(
120 "average_across_batch and sum_over_timesteps cannot be set "
121 "to True at same time because of ambiguous order."
122 )
123 if sum_over_batch and average_across_timesteps:
124 raise ValueError(
125 "sum_over_batch and average_across_timesteps cannot be set "
126 "to True at same time because of ambiguous order."
127 )
128 with tf.name_scope(name or "sequence_loss"):
129 num_classes = tf.shape(input=logits)[2]
130 logits_flat = tf.reshape(logits, [-1, num_classes])
131 if softmax_loss_function is None:
132 if targets_rank == 2:
133 targets = tf.reshape(targets, [-1])
134 crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
135 labels=targets, logits=logits_flat
136 )
137 else:
138 targets = tf.reshape(targets, [-1, num_classes])
139 crossent = tf.nn.softmax_cross_entropy_with_logits(
140 labels=targets, logits=logits_flat
141 )
142 else:
143 targets = tf.reshape(targets, [-1])
144 crossent = softmax_loss_function(labels=targets, logits=logits_flat)
145 crossent *= tf.reshape(weights, [-1])
146 if average_across_timesteps and average_across_batch:
147 crossent = tf.reduce_sum(input_tensor=crossent)
148 total_size = tf.reduce_sum(input_tensor=weights)
149 crossent = tf.math.divide_no_nan(crossent, total_size)
150 elif sum_over_timesteps and sum_over_batch:
151 crossent = tf.reduce_sum(input_tensor=crossent)
152 total_count = tf.cast(tf.math.count_nonzero(weights), crossent.dtype)
153 crossent = tf.math.divide_no_nan(crossent, total_count)
154 else:
155 crossent = tf.reshape(crossent, tf.shape(input=logits)[0:2])
156 if average_across_timesteps or average_across_batch:
157 reduce_axis = [0] if average_across_batch else [1]
158 crossent = tf.reduce_sum(input_tensor=crossent, axis=reduce_axis)
159 total_size = tf.reduce_sum(input_tensor=weights, axis=reduce_axis)
160 crossent = tf.math.divide_no_nan(crossent, total_size)
161 elif sum_over_timesteps or sum_over_batch:
162 reduce_axis = [0] if sum_over_batch else [1]
163 crossent = tf.reduce_sum(input_tensor=crossent, axis=reduce_axis)
164 total_count = tf.cast(
165 tf.math.count_nonzero(weights, axis=reduce_axis),
166 dtype=crossent.dtype,
167 )
168 crossent = tf.math.divide_no_nan(crossent, total_count)
169 return crossent
172class SequenceLoss(tf.keras.losses.Loss):
173 """Weighted cross-entropy loss for a sequence of logits."""
175 @typechecked
176 def __init__(
177 self,
178 average_across_timesteps: bool = False,
179 average_across_batch: bool = False,
180 sum_over_timesteps: bool = True,
181 sum_over_batch: bool = True,
182 softmax_loss_function: Optional[Callable] = None,
183 name: Optional[str] = None,
184 ):
185 super().__init__(reduction=tf.keras.losses.Reduction.NONE, name=name)
186 self.average_across_timesteps = average_across_timesteps
187 self.average_across_batch = average_across_batch
188 self.sum_over_timesteps = sum_over_timesteps
189 self.sum_over_batch = sum_over_batch
190 self.softmax_loss_function = softmax_loss_function
192 def __call__(self, y_true, y_pred, sample_weight=None):
193 """Override the parent __call__ to have a customized reduce
194 behavior."""
195 return sequence_loss(
196 y_pred,
197 y_true,
198 sample_weight,
199 average_across_timesteps=self.average_across_timesteps,
200 average_across_batch=self.average_across_batch,
201 sum_over_timesteps=self.sum_over_timesteps,
202 sum_over_batch=self.sum_over_batch,
203 softmax_loss_function=self.softmax_loss_function,
204 name=self.name,
205 )
207 def call(self, y_true, y_pred):
208 # Skip this method since the __call__ contains real implementation.
209 pass