Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/attention/attention.py: 28%
36 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Attention layer that can be used in sequence DNN/CNN models.
17This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2.
18Attention is formed by three tensors: Query, Key and Value.
19"""
22import tensorflow.compat.v2 as tf
24from keras.src.layers.attention.base_dense_attention import BaseDenseAttention
26# isort: off
27from tensorflow.python.util.tf_export import keras_export
30@keras_export("keras.layers.Attention")
31class Attention(BaseDenseAttention):
32 """Dot-product attention layer, a.k.a. Luong-style attention.
34 Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor
35 of shape `[batch_size, Tv, dim]` and `key` tensor of shape
36 `[batch_size, Tv, dim]`. The calculation follows the steps:
38 1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot
39 product: `scores = tf.matmul(query, key, transpose_b=True)`.
40 2. Use scores to calculate a distribution with shape
41 `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`.
42 3. Use `distribution` to create a linear combination of `value` with
43 shape `[batch_size, Tq, dim]`:
44 `return tf.matmul(distribution, value)`.
46 Args:
47 use_scale: If `True`, will create a scalar variable to scale the
48 attention scores.
49 dropout: Float between 0 and 1. Fraction of the units to drop for the
50 attention scores. Defaults to 0.0.
51 score_mode: Function to use to compute attention scores, one of
52 `{"dot", "concat"}`. `"dot"` refers to the dot product between the
53 query and key vectors. `"concat"` refers to the hyperbolic tangent
54 of the concatenation of the query and key vectors.
56 Call Args:
58 inputs: List of the following tensors:
59 * query: Query `Tensor` of shape `[batch_size, Tq, dim]`.
60 * value: Value `Tensor` of shape `[batch_size, Tv, dim]`.
61 * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If
62 not given, will use `value` for both `key` and `value`, which is
63 the most common case.
64 mask: List of the following tensors:
65 * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`.
66 If given, the output will be zero at the positions where
67 `mask==False`.
68 * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`.
69 If given, will apply the mask such that values at positions
70 where `mask==False` do not contribute to the result.
71 return_attention_scores: bool, it `True`, returns the attention scores
72 (after masking and softmax) as an additional output argument.
73 training: Python boolean indicating whether the layer should behave in
74 training mode (adding dropout) or in inference mode (no dropout).
75 use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds
76 a mask such that position `i` cannot attend to positions `j > i`.
77 This prevents the flow of information from the future towards the
78 past.
79 Defaults to `False`.
81 Output:
83 Attention outputs of shape `[batch_size, Tq, dim]`.
84 [Optional] Attention scores after masking and softmax with shape
85 `[batch_size, Tq, Tv]`.
87 The meaning of `query`, `value` and `key` depend on the application. In the
88 case of text similarity, for example, `query` is the sequence embeddings of
89 the first piece of text and `value` is the sequence embeddings of the second
90 piece of text. `key` is usually the same tensor as `value`.
92 Here is a code example for using `Attention` in a CNN+Attention network:
94 ```python
95 # Variable-length int sequences.
96 query_input = tf.keras.Input(shape=(None,), dtype='int32')
97 value_input = tf.keras.Input(shape=(None,), dtype='int32')
99 # Embedding lookup.
100 token_embedding = tf.keras.layers.Embedding(input_dim=1000, output_dim=64)
101 # Query embeddings of shape [batch_size, Tq, dimension].
102 query_embeddings = token_embedding(query_input)
103 # Value embeddings of shape [batch_size, Tv, dimension].
104 value_embeddings = token_embedding(value_input)
106 # CNN layer.
107 cnn_layer = tf.keras.layers.Conv1D(
108 filters=100,
109 kernel_size=4,
110 # Use 'same' padding so outputs have the same shape as inputs.
111 padding='same')
112 # Query encoding of shape [batch_size, Tq, filters].
113 query_seq_encoding = cnn_layer(query_embeddings)
114 # Value encoding of shape [batch_size, Tv, filters].
115 value_seq_encoding = cnn_layer(value_embeddings)
117 # Query-value attention of shape [batch_size, Tq, filters].
118 query_value_attention_seq = tf.keras.layers.Attention()(
119 [query_seq_encoding, value_seq_encoding])
121 # Reduce over the sequence axis to produce encodings of shape
122 # [batch_size, filters].
123 query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
124 query_seq_encoding)
125 query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
126 query_value_attention_seq)
128 # Concatenate query and document encodings to produce a DNN input layer.
129 input_layer = tf.keras.layers.Concatenate()(
130 [query_encoding, query_value_attention])
132 # Add DNN layers, and create Model.
133 # ...
134 ```
135 """
137 def __init__(self, use_scale=False, score_mode="dot", **kwargs):
138 super().__init__(**kwargs)
139 self.use_scale = use_scale
140 self.score_mode = score_mode
141 if self.score_mode not in ["dot", "concat"]:
142 raise ValueError(
143 f"Received: score_mode={score_mode}. Acceptable values "
144 'are: ["dot", "concat"]'
145 )
147 def build(self, input_shape):
148 """Creates variable when `use_scale` is True or `score_mode` is
149 `concat`."""
150 if self.use_scale:
151 self.scale = self.add_weight(
152 name="scale",
153 shape=(),
154 initializer="ones",
155 dtype=self.dtype,
156 trainable=True,
157 )
158 else:
159 self.scale = None
160 if self.score_mode == "concat":
161 self.concat_score_weight = self.add_weight(
162 name="concat_score_weight",
163 shape=(),
164 initializer="ones",
165 dtype=self.dtype,
166 trainable=True,
167 )
168 else:
169 self.concat_score_weight = None
170 super().build(input_shape)
172 def _calculate_scores(self, query, key):
173 """Calculates attention scores as a query-key dot product.
175 Args:
176 query: Query tensor of shape `[batch_size, Tq, dim]`.
177 key: Key tensor of shape `[batch_size, Tv, dim]`.
178 Returns:
179 Tensor of shape `[batch_size, Tq, Tv]`.
180 """
181 if self.score_mode == "dot":
182 scores = tf.matmul(query, key, transpose_b=True)
183 if self.scale is not None:
184 scores *= self.scale
185 elif self.score_mode == "concat":
186 # Reshape tensors to enable broadcasting.
187 # Reshape into [batch_size, Tq, 1, dim].
188 q_reshaped = tf.expand_dims(query, axis=-2)
189 # Reshape into [batch_size, 1, Tv, dim].
190 k_reshaped = tf.expand_dims(key, axis=-3)
191 if self.scale is not None:
192 scores = self.concat_score_weight * tf.reduce_sum(
193 tf.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1
194 )
195 else:
196 scores = self.concat_score_weight * tf.reduce_sum(
197 tf.tanh(q_reshaped + k_reshaped), axis=-1
198 )
200 return scores
202 def get_config(self):
203 config = {"use_scale": self.use_scale, "score_mode": self.score_mode}
204 base_config = super().get_config()
205 return dict(list(base_config.items()) + list(config.items()))