Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/attention/base_dense_attention.py: 21%
94 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for attention layers that can be used in sequence DNN/CNN models.
17This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2.
18Attention is formed by three tensors: Query, Key and Value.
19"""
21import tensorflow.compat.v2 as tf
22from absl import logging
24from keras.src import backend
25from keras.src.engine import base_layer
26from keras.src.utils import control_flow_util
28# isort: off
29from tensorflow.python.util.tf_export import keras_export
32@keras_export("keras.__internal__.layers.BaseDenseAttention", v1=[])
33class BaseDenseAttention(base_layer.BaseRandomLayer):
34 """Base Attention class for Dense networks.
36 This class is suitable for Dense or CNN networks, and not for RNN networks.
38 Implementations of attention mechanisms should inherit from this class, and
39 reuse the `apply_attention_scores()` method.
41 Args:
42 dropout: Float between 0 and 1. Fraction of the units to drop for the
43 attention scores.
45 Call Args:
46 inputs: List of the following tensors:
47 * query: Query `Tensor` of shape `[batch_size, Tq, dim]`.
48 * value: Value `Tensor` of shape `[batch_size, Tv, dim]`.
49 * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If
50 not given, will use `value` for both `key` and `value`, which is
51 the most common case.
52 mask: List of the following tensors:
53 * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`.
54 If given, the output will be zero at the positions where
55 `mask==False`.
56 * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`.
57 If given, will apply the mask such that values at positions
58 where `mask==False` do not contribute to the result.
59 training: Python boolean indicating whether the layer should behave in
60 training mode (adding dropout) or in inference mode (no dropout).
61 return_attention_scores: bool, if `True`, returns the attention scores
62 (after masking and softmax) as an additional output argument.
64 Output:
66 Attention outputs of shape `[batch_size, Tq, dim]`.
67 [Optional] Attention scores after masking and softmax with shape
68 `[batch_size, Tq, Tv]`.
69 """
71 def __init__(self, dropout=0.0, **kwargs):
72 # Deprecated field `causal` determines whether to using causal masking.
73 # Use `use_causal_mask` in call() method instead.
74 if "causal" in kwargs:
75 logging.warning(
76 "`causal` argument is deprecated. Please use `use_causal_mask` "
77 "in call() method to specify causal masking."
78 )
79 self.causal = kwargs.pop("causal", False)
80 super().__init__(**kwargs)
81 self.dropout = dropout
82 self.supports_masking = True
84 def build(self, input_shape):
85 # Skip RNG initialization if dropout rate is 0. This will let the layer
86 # be purely stateless, with no reference to any variable.
87 if self.dropout > 0:
88 super().build(input_shape)
89 self.built = True
91 def _calculate_scores(self, query, key):
92 """Calculates attention scores.
94 Args:
95 query: Query tensor of shape `[batch_size, Tq, dim]`.
96 key: Key tensor of shape `[batch_size, Tv, dim]`.
98 Returns:
99 Tensor of shape `[batch_size, Tq, Tv]`.
100 """
101 return NotImplementedError
103 def _apply_scores(self, scores, value, scores_mask=None, training=None):
104 """Applies attention scores to the given value tensor.
106 To use this method in your attention layer, follow the steps:
108 * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of
109 shape `[batch_size, Tv]` to calculate the attention `scores`.
110 * Pass `scores` and `value` tensors to this method. The method applies
111 `scores_mask`, calculates
112 `attention_distribution = softmax(scores)`, then returns
113 `matmul(attention_distribution, value).
114 * Apply `query_mask` and return the result.
116 Args:
117 scores: Scores float tensor of shape `[batch_size, Tq, Tv]`.
118 value: Value tensor of shape `[batch_size, Tv, dim]`.
119 scores_mask: A boolean mask `Tensor` of shape `[batch_size, 1, Tv]`
120 or `[batch_size, Tq, Tv]`. If given, scores at positions where
121 `scores_mask==False` do not contribute to the result. It must
122 contain at least one `True` value in each line along the last
123 dimension.
124 training: Python boolean indicating whether the layer should behave
125 in training mode (adding dropout) or in inference mode
126 (no dropout).
128 Returns:
129 Tensor of shape `[batch_size, Tq, dim]`.
130 Attention scores after masking and softmax with shape
131 `[batch_size, Tq, Tv]`.
132 """
133 if scores_mask is not None:
134 padding_mask = tf.logical_not(scores_mask)
135 # Bias so padding positions do not contribute to attention
136 # distribution. Note 65504. is the max float16 value.
137 if scores.dtype is tf.float16:
138 scores -= 65504.0 * tf.cast(padding_mask, dtype=scores.dtype)
139 else:
140 scores -= 1.0e9 * tf.cast(padding_mask, dtype=scores.dtype)
141 if training is None:
142 training = backend.learning_phase()
143 weights = tf.nn.softmax(scores)
145 if self.dropout > 0:
147 def dropped_weights():
148 return self._random_generator.dropout(
149 weights, rate=self.dropout
150 )
152 weights = control_flow_util.smart_cond(
153 training, dropped_weights, lambda: tf.identity(weights)
154 )
155 return tf.matmul(weights, value), weights
157 # TODO(b/125916026): Consider exposing a __call__ method with named args.
158 def call(
159 self,
160 inputs,
161 mask=None,
162 training=None,
163 return_attention_scores=False,
164 use_causal_mask=False,
165 ):
166 self._validate_call_args(inputs=inputs, mask=mask)
167 q = inputs[0]
168 v = inputs[1]
169 k = inputs[2] if len(inputs) > 2 else v
170 q_mask = mask[0] if mask else None
171 v_mask = mask[1] if mask else None
172 scores = self._calculate_scores(query=q, key=k)
173 if v_mask is not None:
174 # Mask of shape [batch_size, 1, Tv].
175 v_mask = tf.expand_dims(v_mask, axis=-2)
176 if self.causal or use_causal_mask:
177 # Creates a lower triangular mask, so position i cannot attend to
178 # positions j>i. This prevents the flow of information from the
179 # future into the past.
180 scores_shape = tf.shape(scores)
181 # causal_mask_shape = [1, Tq, Tv].
182 causal_mask_shape = tf.concat(
183 [tf.ones_like(scores_shape[:-2]), scores_shape[-2:]], axis=0
184 )
185 causal_mask = _lower_triangular_mask(causal_mask_shape)
186 else:
187 causal_mask = None
188 scores_mask = _merge_masks(v_mask, causal_mask)
189 result, attention_scores = self._apply_scores(
190 scores=scores, value=v, scores_mask=scores_mask, training=training
191 )
192 if q_mask is not None:
193 # Mask of shape [batch_size, Tq, 1].
194 q_mask = tf.expand_dims(q_mask, axis=-1)
195 result *= tf.cast(q_mask, dtype=result.dtype)
196 if return_attention_scores:
197 return result, attention_scores
198 return result
200 def compute_mask(self, inputs, mask=None):
201 self._validate_call_args(inputs=inputs, mask=mask)
202 if mask:
203 q_mask = mask[0]
204 if q_mask is None:
205 return None
206 return tf.convert_to_tensor(q_mask)
207 return None
209 def compute_output_shape(self, input_shape):
210 # return_attention_scores argument of BaseDenseAttention.call method
211 # is ignored. Output shape of attention_scores cannot be returned.
212 return tf.TensorShape(input_shape[0])
214 def _validate_call_args(self, inputs, mask):
215 """Validates arguments of the call method."""
216 class_name = self.__class__.__name__
217 if not isinstance(inputs, list):
218 raise ValueError(
219 f"{class_name} layer must be called on a list of inputs, "
220 "namely [query, value] or [query, value, key]. "
221 f"Received: {inputs}."
222 )
223 if len(inputs) < 2 or len(inputs) > 3:
224 raise ValueError(
225 f"{class_name} layer accepts inputs list of length 2 or 3, "
226 "namely [query, value] or [query, value, key]. "
227 f"Received length: {len(inputs)}."
228 )
229 if mask:
230 if not isinstance(mask, list):
231 raise ValueError(
232 f"{class_name} layer mask must be a list, "
233 f"namely [query_mask, value_mask]. Received: {mask}."
234 )
235 if len(mask) < 2 or len(mask) > len(inputs):
236 raise ValueError(
237 f"{class_name} layer mask must be a list of length 2, "
238 "namely [query_mask, value_mask]. "
239 f"Received length: {len(mask)}."
240 )
242 def get_config(self):
243 config = {
244 "dropout": self.dropout,
245 }
246 base_config = super().get_config()
247 return dict(list(base_config.items()) + list(config.items()))
250def _lower_triangular_mask(shape):
251 """Creates a lower-triangular boolean mask over the last 2 dimensions."""
252 row_index = tf.cumsum(tf.ones(shape=shape, dtype=tf.int32), axis=-2)
253 col_index = tf.cumsum(tf.ones(shape=shape, dtype=tf.int32), axis=-1)
254 return tf.greater_equal(row_index, col_index)
257def _merge_masks(x, y):
258 if x is None:
259 return y
260 if y is None:
261 return x
262 return tf.logical_and(x, y)