Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/multihead_attention.py: 12%
83 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
16import typing
17import warnings
19import tensorflow as tf
22@tf.keras.utils.register_keras_serializable(package="Addons")
23class MultiHeadAttention(tf.keras.layers.Layer):
24 r"""MultiHead Attention layer.
26 Defines the MultiHead Attention operation as described in
27 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) which takes
28 in the tensors `query`, `key`, and `value`, and returns the dot-product attention
29 between them:
31 >>> mha = MultiHeadAttention(head_size=128, num_heads=12)
32 >>> query = np.random.rand(3, 5, 4) # (batch_size, query_elements, query_depth)
33 >>> key = np.random.rand(3, 6, 5) # (batch_size, key_elements, key_depth)
34 >>> value = np.random.rand(3, 6, 6) # (batch_size, key_elements, value_depth)
35 >>> attention = mha([query, key, value]) # (batch_size, query_elements, value_depth)
36 >>> attention.shape
37 TensorShape([3, 5, 6])
39 If `value` is not given then internally `value = key` will be used:
41 >>> mha = MultiHeadAttention(head_size=128, num_heads=12)
42 >>> query = np.random.rand(3, 5, 5) # (batch_size, query_elements, query_depth)
43 >>> key = np.random.rand(3, 6, 10) # (batch_size, key_elements, key_depth)
44 >>> attention = mha([query, key]) # (batch_size, query_elements, key_depth)
45 >>> attention.shape
46 TensorShape([3, 5, 10])
48 Args:
49 head_size: int, dimensionality of the `query`, `key` and `value` tensors
50 after the linear transformation.
51 num_heads: int, number of attention heads.
52 output_size: int, dimensionality of the output space, if `None` then the
53 input dimension of `value` or `key` will be used,
54 default `None`.
55 dropout: float, `rate` parameter for the dropout layer that is
56 applied to attention after softmax,
57 default `0`.
58 use_projection_bias: bool, whether to use a bias term after the linear
59 output projection.
60 return_attn_coef: bool, if `True`, return the attention coefficients as
61 an additional output argument.
62 kernel_initializer: initializer, initializer for the kernel weights.
63 kernel_regularizer: regularizer, regularizer for the kernel weights.
64 kernel_constraint: constraint, constraint for the kernel weights.
65 bias_initializer: initializer, initializer for the bias weights.
66 bias_regularizer: regularizer, regularizer for the bias weights.
67 bias_constraint: constraint, constraint for the bias weights.
69 Call Args:
70 inputs: List of `[query, key, value]` where
71 * `query`: Tensor of shape `(..., query_elements, query_depth)`
72 * `key`: `Tensor of shape '(..., key_elements, key_depth)`
73 * `value`: Tensor of shape `(..., key_elements, value_depth)`, optional, if not given `key` will be used.
74 mask: a binary Tensor of shape `[batch_size?, num_heads?, query_elements, key_elements]`
75 which specifies which query elements can attendo to which key elements,
76 `1` indicates attention and `0` indicates no attention.
78 Output shape:
79 * `(..., query_elements, output_size)` if `output_size` is given, else
80 * `(..., query_elements, value_depth)` if `value` is given, else
81 * `(..., query_elements, key_depth)`
82 """
84 def __init__(
85 self,
86 head_size: int,
87 num_heads: int,
88 output_size: int = None,
89 dropout: float = 0.0,
90 use_projection_bias: bool = True,
91 return_attn_coef: bool = False,
92 kernel_initializer: typing.Union[str, typing.Callable] = "glorot_uniform",
93 kernel_regularizer: typing.Union[str, typing.Callable] = None,
94 kernel_constraint: typing.Union[str, typing.Callable] = None,
95 bias_initializer: typing.Union[str, typing.Callable] = "zeros",
96 bias_regularizer: typing.Union[str, typing.Callable] = None,
97 bias_constraint: typing.Union[str, typing.Callable] = None,
98 **kwargs,
99 ):
100 warnings.warn(
101 "`MultiHeadAttention` will be deprecated in Addons 0.13. "
102 "Please use `tf.keras.layers.MultiHeadAttention` instead.",
103 DeprecationWarning,
104 )
106 super().__init__(**kwargs)
108 if output_size is not None and output_size < 1:
109 raise ValueError("output_size must be a positive number")
111 self.head_size = head_size
112 self.num_heads = num_heads
113 self.output_size = output_size
114 self.use_projection_bias = use_projection_bias
115 self.return_attn_coef = return_attn_coef
117 self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
118 self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
119 self.kernel_constraint = tf.keras.constraints.get(kernel_constraint)
120 self.bias_initializer = tf.keras.initializers.get(bias_initializer)
121 self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
122 self.bias_constraint = tf.keras.constraints.get(bias_constraint)
124 self.dropout = tf.keras.layers.Dropout(dropout)
125 self._droput_rate = dropout
127 def build(self, input_shape):
129 num_query_features = input_shape[0][-1]
130 num_key_features = input_shape[1][-1]
131 num_value_features = (
132 input_shape[2][-1] if len(input_shape) > 2 else num_key_features
133 )
134 output_size = (
135 self.output_size if self.output_size is not None else num_value_features
136 )
138 self.query_kernel = self.add_weight(
139 name="query_kernel",
140 shape=[self.num_heads, num_query_features, self.head_size],
141 initializer=self.kernel_initializer,
142 regularizer=self.kernel_regularizer,
143 constraint=self.kernel_constraint,
144 )
145 self.key_kernel = self.add_weight(
146 name="key_kernel",
147 shape=[self.num_heads, num_key_features, self.head_size],
148 initializer=self.kernel_initializer,
149 regularizer=self.kernel_regularizer,
150 constraint=self.kernel_constraint,
151 )
152 self.value_kernel = self.add_weight(
153 name="value_kernel",
154 shape=[self.num_heads, num_value_features, self.head_size],
155 initializer=self.kernel_initializer,
156 regularizer=self.kernel_regularizer,
157 constraint=self.kernel_constraint,
158 )
159 self.projection_kernel = self.add_weight(
160 name="projection_kernel",
161 shape=[self.num_heads, self.head_size, output_size],
162 initializer=self.kernel_initializer,
163 regularizer=self.kernel_regularizer,
164 constraint=self.kernel_constraint,
165 )
167 if self.use_projection_bias:
168 self.projection_bias = self.add_weight(
169 name="projection_bias",
170 shape=[output_size],
171 initializer=self.bias_initializer,
172 regularizer=self.bias_regularizer,
173 constraint=self.bias_constraint,
174 )
175 else:
176 self.projection_bias = None
178 super().build(input_shape)
180 def call(self, inputs, training=None, mask=None):
182 # einsum nomenclature
183 # ------------------------
184 # N = query elements
185 # M = key/value elements
186 # H = heads
187 # I = input features
188 # O = output features
190 query = inputs[0]
191 key = inputs[1]
192 value = inputs[2] if len(inputs) > 2 else key
194 # verify shapes
195 if key.shape[-2] != value.shape[-2]:
196 raise ValueError(
197 "the number of elements in 'key' must be equal to the same as the number of elements in 'value'"
198 )
200 if mask is not None:
201 if len(mask.shape) < 2:
202 raise ValueError("'mask' must have atleast 2 dimensions")
203 if query.shape[-2] != mask.shape[-2]:
204 raise ValueError(
205 "mask's second to last dimension must be equal to the number of elements in 'query'"
206 )
207 if key.shape[-2] != mask.shape[-1]:
208 raise ValueError(
209 "mask's last dimension must be equal to the number of elements in 'key'"
210 )
212 # Linear transformations
213 query = tf.einsum("...NI , HIO -> ...NHO", query, self.query_kernel)
214 key = tf.einsum("...MI , HIO -> ...MHO", key, self.key_kernel)
215 value = tf.einsum("...MI , HIO -> ...MHO", value, self.value_kernel)
217 # Scale dot-product, doing the division to either query or key
218 # instead of their product saves some computation
219 depth = tf.constant(self.head_size, dtype=query.dtype)
220 query /= tf.sqrt(depth)
222 # Calculate dot product attention
223 logits = tf.einsum("...NHO,...MHO->...HNM", query, key)
225 # apply mask
226 if mask is not None:
227 mask = tf.cast(mask, tf.float32)
229 # possibly expand on the head dimension so broadcasting works
230 if len(mask.shape) != len(logits.shape):
231 mask = tf.expand_dims(mask, -3)
233 logits += -10e9 * (1.0 - mask)
235 attn_coef = tf.nn.softmax(logits)
237 # attention dropout
238 attn_coef_dropout = self.dropout(attn_coef, training=training)
240 # attention * value
241 multihead_output = tf.einsum("...HNM,...MHI->...NHI", attn_coef_dropout, value)
243 # Run the outputs through another linear projection layer. Recombining heads
244 # is automatically done.
245 output = tf.einsum(
246 "...NHI,HIO->...NO", multihead_output, self.projection_kernel
247 )
249 if self.projection_bias is not None:
250 output += self.projection_bias
252 if self.return_attn_coef:
253 return output, attn_coef
254 else:
255 return output
257 def compute_output_shape(self, input_shape):
258 num_value_features = (
259 input_shape[2][-1] if len(input_shape) > 2 else input_shape[1][-1]
260 )
261 output_size = (
262 self.output_size if self.output_size is not None else num_value_features
263 )
265 output_shape = input_shape[0][:-1] + (output_size,)
267 if self.return_attn_coef:
268 num_query_elements = input_shape[0][-2]
269 num_key_elements = input_shape[1][-2]
270 attn_coef_shape = input_shape[0][:-2] + (
271 self.num_heads,
272 num_query_elements,
273 num_key_elements,
274 )
276 return output_shape, attn_coef_shape
277 else:
278 return output_shape
280 def get_config(self):
281 config = super().get_config()
283 config.update(
284 head_size=self.head_size,
285 num_heads=self.num_heads,
286 output_size=self.output_size,
287 dropout=self._droput_rate,
288 use_projection_bias=self.use_projection_bias,
289 return_attn_coef=self.return_attn_coef,
290 kernel_initializer=tf.keras.initializers.serialize(self.kernel_initializer),
291 kernel_regularizer=tf.keras.regularizers.serialize(self.kernel_regularizer),
292 kernel_constraint=tf.keras.constraints.serialize(self.kernel_constraint),
293 bias_initializer=tf.keras.initializers.serialize(self.bias_initializer),
294 bias_regularizer=tf.keras.regularizers.serialize(self.bias_regularizer),
295 bias_constraint=tf.keras.constraints.serialize(self.bias_constraint),
296 )
298 return config