Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/attention/additive_attention.py: 36%
28 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Additive attention layer that can be used in sequence DNN/CNN models.
17This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2.
18Attention is formed by three tensors: Query, Key and Value.
19"""
22import tensorflow.compat.v2 as tf
24from keras.src.layers.attention.base_dense_attention import BaseDenseAttention
26# isort: off
27from tensorflow.python.util.tf_export import keras_export
30@keras_export("keras.layers.AdditiveAttention")
31class AdditiveAttention(BaseDenseAttention):
32 """Additive attention layer, a.k.a. Bahdanau-style attention.
34 Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor
35 of shape `[batch_size, Tv, dim]` and `key` tensor of shape
36 `[batch_size, Tv, dim]`. The calculation follows the steps:
38 1. Reshape `query` and `key` into shapes `[batch_size, Tq, 1, dim]`
39 and `[batch_size, 1, Tv, dim]` respectively.
40 2. Calculate scores with shape `[batch_size, Tq, Tv]` as a non-linear
41 sum: `scores = tf.reduce_sum(tf.tanh(query + key), axis=-1)`
42 3. Use scores to calculate a distribution with shape
43 `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`.
44 4. Use `distribution` to create a linear combination of `value` with
45 shape `[batch_size, Tq, dim]`:
46 `return tf.matmul(distribution, value)`.
48 Args:
49 use_scale: If `True`, will create a variable to scale the attention
50 scores.
51 dropout: Float between 0 and 1. Fraction of the units to drop for the
52 attention scores. Defaults to `0.0`.
54 Call Args:
56 inputs: List of the following tensors:
57 * query: Query `Tensor` of shape `[batch_size, Tq, dim]`.
58 * value: Value `Tensor` of shape `[batch_size, Tv, dim]`.
59 * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`.
60 If not given, will use `value` for both `key` and `value`,
61 which is the most common case.
62 mask: List of the following tensors:
63 * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`.
64 If given, the output will be zero at the positions where
65 `mask==False`.
66 * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`.
67 If given, will apply the mask such that values at positions
68 where `mask==False` do not contribute to the result.
69 training: Python boolean indicating whether the layer should behave in
70 training mode (adding dropout) or in inference mode (no dropout).
71 return_attention_scores: bool, it `True`, returns the attention scores
72 (after masking and softmax) as an additional output argument.
73 use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds
74 a mask such that position `i` cannot attend to positions `j > i`.
75 This prevents the flow of information from the future towards the
76 past. Defaults to `False`.
78 Output:
80 Attention outputs of shape `[batch_size, Tq, dim]`.
81 [Optional] Attention scores after masking and softmax with shape
82 `[batch_size, Tq, Tv]`.
84 The meaning of `query`, `value` and `key` depend on the application. In the
85 case of text similarity, for example, `query` is the sequence embeddings of
86 the first piece of text and `value` is the sequence embeddings of the second
87 piece of text. `key` is usually the same tensor as `value`.
89 Here is a code example for using `AdditiveAttention` in a CNN+Attention
90 network:
92 ```python
93 # Variable-length int sequences.
94 query_input = tf.keras.Input(shape=(None,), dtype='int32')
95 value_input = tf.keras.Input(shape=(None,), dtype='int32')
97 # Embedding lookup.
98 token_embedding = tf.keras.layers.Embedding(max_tokens, dimension)
99 # Query embeddings of shape [batch_size, Tq, dimension].
100 query_embeddings = token_embedding(query_input)
101 # Value embeddings of shape [batch_size, Tv, dimension].
102 value_embeddings = token_embedding(value_input)
104 # CNN layer.
105 cnn_layer = tf.keras.layers.Conv1D(
106 filters=100,
107 kernel_size=4,
108 # Use 'same' padding so outputs have the same shape as inputs.
109 padding='same')
110 # Query encoding of shape [batch_size, Tq, filters].
111 query_seq_encoding = cnn_layer(query_embeddings)
112 # Value encoding of shape [batch_size, Tv, filters].
113 value_seq_encoding = cnn_layer(value_embeddings)
115 # Query-value attention of shape [batch_size, Tq, filters].
116 query_value_attention_seq = tf.keras.layers.AdditiveAttention()(
117 [query_seq_encoding, value_seq_encoding])
119 # Reduce over the sequence axis to produce encodings of shape
120 # [batch_size, filters].
121 query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
122 query_seq_encoding)
123 query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
124 query_value_attention_seq)
126 # Concatenate query and document encodings to produce a DNN input layer.
127 input_layer = tf.keras.layers.Concatenate()(
128 [query_encoding, query_value_attention])
130 # Add DNN layers, and create Model.
131 # ...
132 ```
133 """
135 def __init__(self, use_scale=True, **kwargs):
136 super().__init__(**kwargs)
137 self.use_scale = use_scale
139 def build(self, input_shape):
140 v_shape = tf.TensorShape(input_shape[1])
141 dim = v_shape[-1]
142 dim = tf.compat.dimension_value(dim)
143 if self.use_scale:
144 self.scale = self.add_weight(
145 name="scale",
146 shape=[dim],
147 initializer="glorot_uniform",
148 dtype=self.dtype,
149 trainable=True,
150 )
151 else:
152 self.scale = None
153 super().build(input_shape)
155 def _calculate_scores(self, query, key):
156 """Calculates attention scores as a nonlinear sum of query and key.
158 Args:
159 query: Query tensor of shape `[batch_size, Tq, dim]`.
160 key: Key tensor of shape `[batch_size, Tv, dim]`.
161 Returns:
162 Tensor of shape `[batch_size, Tq, Tv]`.
163 """
164 # Reshape tensors to enable broadcasting.
165 # Reshape into [batch_size, Tq, 1, dim].
166 q_reshaped = tf.expand_dims(query, axis=-2)
167 # Reshape into [batch_size, 1, Tv, dim].
168 k_reshaped = tf.expand_dims(key, axis=-3)
169 if self.use_scale:
170 scale = self.scale
171 else:
172 scale = 1.0
173 return tf.reduce_sum(scale * tf.tanh(q_reshaped + k_reshaped), axis=-1)
175 def get_config(self):
176 config = {"use_scale": self.use_scale}
177 base_config = super().get_config()
178 return dict(list(base_config.items()) + list(config.items()))