Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/attention/multi_head_attention.py: 16%
222 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras-based multi-head attention layer."""
18import collections
19import math
20import string
22import numpy as np
23import tensorflow.compat.v2 as tf
25from keras.src import constraints
26from keras.src import initializers
27from keras.src import regularizers
28from keras.src.engine.base_layer import Layer
29from keras.src.layers import activation
30from keras.src.layers import core
31from keras.src.layers import regularization
32from keras.src.utils import tf_utils
34# isort: off
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.util.tf_export import keras_export
38_CHR_IDX = string.ascii_lowercase
41def _build_attention_equation(rank, attn_axes):
42 """Builds einsum equations for the attention computation.
44 Query, key, value inputs after projection are expected to have the shape as:
45 `(bs, <non-attention dims>, <attention dims>, num_heads, channels)`.
46 `bs` and `<non-attention dims>` are treated as `<batch dims>`.
48 The attention operations can be generalized:
49 (1) Query-key dot product:
50 `(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
51 <key attention dims>, num_heads, channels) -> (<batch dims>,
52 num_heads, <query attention dims>, <key attention dims>)`
53 (2) Combination:
54 `(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
55 (<batch dims>, <value attention dims>, num_heads, channels) -> (<batch
56 dims>, <query attention dims>, num_heads, channels)`
58 Args:
59 rank: Rank of query, key, value tensors.
60 attn_axes: List/tuple of axes, `[-1, rank)`,
61 that attention will be applied to.
63 Returns:
64 Einsum equations.
65 """
66 target_notation = _CHR_IDX[:rank]
67 # `batch_dims` includes the head dim.
68 batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
69 letter_offset = rank
70 source_notation = ""
71 for i in range(rank):
72 if i in batch_dims or i == rank - 1:
73 source_notation += target_notation[i]
74 else:
75 source_notation += _CHR_IDX[letter_offset]
76 letter_offset += 1
78 product_notation = "".join(
79 [target_notation[i] for i in batch_dims]
80 + [target_notation[i] for i in attn_axes]
81 + [source_notation[i] for i in attn_axes]
82 )
83 dot_product_equation = "%s,%s->%s" % (
84 source_notation,
85 target_notation,
86 product_notation,
87 )
88 attn_scores_rank = len(product_notation)
89 combine_equation = "%s,%s->%s" % (
90 product_notation,
91 source_notation,
92 target_notation,
93 )
94 return dot_product_equation, combine_equation, attn_scores_rank
97def _build_proj_equation(free_dims, bound_dims, output_dims):
98 """Builds an einsum equation for projections inside multi-head attention."""
99 input_str = ""
100 kernel_str = ""
101 output_str = ""
102 bias_axes = ""
103 letter_offset = 0
104 for i in range(free_dims):
105 char = _CHR_IDX[i + letter_offset]
106 input_str += char
107 output_str += char
109 letter_offset += free_dims
110 for i in range(bound_dims):
111 char = _CHR_IDX[i + letter_offset]
112 input_str += char
113 kernel_str += char
115 letter_offset += bound_dims
116 for i in range(output_dims):
117 char = _CHR_IDX[i + letter_offset]
118 kernel_str += char
119 output_str += char
120 bias_axes += char
121 equation = f"{input_str},{kernel_str}->{output_str}"
123 return equation, bias_axes, len(output_str)
126def _get_output_shape(output_rank, known_last_dims):
127 return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
130@keras_export("keras.layers.MultiHeadAttention")
131class MultiHeadAttention(Layer):
132 """MultiHeadAttention layer.
134 This is an implementation of multi-headed attention as described in the
135 paper "Attention is all you Need" (Vaswani et al., 2017).
136 If `query`, `key,` `value` are the same, then
137 this is self-attention. Each timestep in `query` attends to the
138 corresponding sequence in `key`, and returns a fixed-width vector.
140 This layer first projects `query`, `key` and `value`. These are
141 (effectively) a list of tensors of length `num_attention_heads`, where the
142 corresponding shapes are `(batch_size, <query dimensions>, key_dim)`,
143 `(batch_size, <key/value dimensions>, key_dim)`,
144 `(batch_size, <key/value dimensions>, value_dim)`.
146 Then, the query and key tensors are dot-producted and scaled. These are
147 softmaxed to obtain attention probabilities. The value tensors are then
148 interpolated by these probabilities, then concatenated back to a single
149 tensor.
151 Finally, the result tensor with the last dimension as value_dim can take an
152 linear projection and return.
154 When using `MultiHeadAttention` inside a custom layer, the custom layer must
155 implement its own `build()` method and call `MultiHeadAttention`'s
156 `_build_from_signature()` there.
157 This enables weights to be restored correctly when the model is loaded.
159 Examples:
161 Performs 1D cross-attention over two sequence inputs with an attention mask.
162 Returns the additional attention weights over heads.
164 >>> layer = MultiHeadAttention(num_heads=2, key_dim=2)
165 >>> target = tf.keras.Input(shape=[8, 16])
166 >>> source = tf.keras.Input(shape=[4, 16])
167 >>> output_tensor, weights = layer(target, source,
168 ... return_attention_scores=True)
169 >>> print(output_tensor.shape)
170 (None, 8, 16)
171 >>> print(weights.shape)
172 (None, 2, 8, 4)
174 Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
176 >>> layer = MultiHeadAttention(
177 ... num_heads=2, key_dim=2, attention_axes=(2, 3))
178 >>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
179 >>> output_tensor = layer(input_tensor, input_tensor)
180 >>> print(output_tensor.shape)
181 (None, 5, 3, 4, 16)
183 Args:
184 num_heads: Number of attention heads.
185 key_dim: Size of each attention head for query and key.
186 value_dim: Size of each attention head for value.
187 dropout: Dropout probability.
188 use_bias: Boolean, whether the dense layers use bias vectors/matrices.
189 output_shape: The expected shape of an output tensor, besides the batch
190 and sequence dims. If not specified, projects back to the query
191 feature dim (the query input's last dimension).
192 attention_axes: axes over which the attention is applied. `None` means
193 attention over all axes, but batch, heads, and features.
194 kernel_initializer: Initializer for dense layer kernels.
195 bias_initializer: Initializer for dense layer biases.
196 kernel_regularizer: Regularizer for dense layer kernels.
197 bias_regularizer: Regularizer for dense layer biases.
198 activity_regularizer: Regularizer for dense layer activity.
199 kernel_constraint: Constraint for dense layer kernels.
200 bias_constraint: Constraint for dense layer kernels.
202 Call arguments:
203 query: Query `Tensor` of shape `(B, T, dim)`.
204 value: Value `Tensor` of shape `(B, S, dim)`.
205 key: Optional key `Tensor` of shape `(B, S, dim)`. If not given, will
206 use `value` for both `key` and `value`, which is the most common
207 case.
208 attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
209 attention to certain positions. The boolean mask specifies which
210 query elements can attend to which key elements, 1 indicates
211 attention and 0 indicates no attention. Broadcasting can happen for
212 the missing batch dimensions and the head dimension.
213 return_attention_scores: A boolean to indicate whether the output should
214 be `(attention_output, attention_scores)` if `True`, or
215 `attention_output` if `False`. Defaults to `False`.
216 training: Python boolean indicating whether the layer should behave in
217 training mode (adding dropout) or in inference mode (no dropout).
218 Will go with either using the training mode of the parent
219 layer/model, or False (inference) if there is no parent layer.
220 use_causal_mask: A boolean to indicate whether to apply a causal mask to
221 prevent tokens from attending to future tokens (e.g., used in a
222 decoder Transformer).
224 Returns:
225 attention_output: The result of the computation, of shape `(B, T, E)`,
226 where `T` is for target sequence shapes and `E` is the query input
227 last dimension if `output_shape` is `None`. Otherwise, the
228 multi-head outputs are projected to the shape specified by
229 `output_shape`.
230 attention_scores: [Optional] multi-head attention coefficients over
231 attention axes.
232 """
234 def __init__(
235 self,
236 num_heads,
237 key_dim,
238 value_dim=None,
239 dropout=0.0,
240 use_bias=True,
241 output_shape=None,
242 attention_axes=None,
243 kernel_initializer="glorot_uniform",
244 bias_initializer="zeros",
245 kernel_regularizer=None,
246 bias_regularizer=None,
247 activity_regularizer=None,
248 kernel_constraint=None,
249 bias_constraint=None,
250 **kwargs,
251 ):
252 super().__init__(**kwargs)
253 self.supports_masking = True
254 self._num_heads = num_heads
255 self._key_dim = key_dim
256 self._value_dim = value_dim if value_dim else key_dim
257 self._dropout = dropout
258 self._use_bias = use_bias
259 self._output_shape = output_shape
260 self._kernel_initializer = initializers.get(kernel_initializer)
261 self._bias_initializer = initializers.get(bias_initializer)
262 self._kernel_regularizer = regularizers.get(kernel_regularizer)
263 self._bias_regularizer = regularizers.get(bias_regularizer)
264 self._activity_regularizer = regularizers.get(activity_regularizer)
265 self._kernel_constraint = constraints.get(kernel_constraint)
266 self._bias_constraint = constraints.get(bias_constraint)
267 if attention_axes is not None and not isinstance(
268 attention_axes, collections.abc.Sized
269 ):
270 self._attention_axes = (attention_axes,)
271 else:
272 self._attention_axes = attention_axes
273 self._built_from_signature = False
274 self._query_shape, self._key_shape, self._value_shape = None, None, None
276 def get_config(self):
277 config = {
278 "num_heads": self._num_heads,
279 "key_dim": self._key_dim,
280 "value_dim": self._value_dim,
281 "dropout": self._dropout,
282 "use_bias": self._use_bias,
283 "output_shape": self._output_shape,
284 "attention_axes": self._attention_axes,
285 "kernel_initializer": initializers.serialize(
286 self._kernel_initializer
287 ),
288 "bias_initializer": initializers.serialize(self._bias_initializer),
289 "kernel_regularizer": regularizers.serialize(
290 self._kernel_regularizer
291 ),
292 "bias_regularizer": regularizers.serialize(self._bias_regularizer),
293 "activity_regularizer": regularizers.serialize(
294 self._activity_regularizer
295 ),
296 "kernel_constraint": constraints.serialize(self._kernel_constraint),
297 "bias_constraint": constraints.serialize(self._bias_constraint),
298 "query_shape": self._query_shape,
299 "key_shape": self._key_shape,
300 "value_shape": self._value_shape,
301 }
302 base_config = super().get_config()
303 return dict(list(base_config.items()) + list(config.items()))
305 @classmethod
306 def from_config(cls, config):
307 # If the layer has a different build() function from the Keras default,
308 # we need to trigger the customized build to create weights.
309 query_shape = config.pop("query_shape")
310 key_shape = config.pop("key_shape")
311 value_shape = config.pop("value_shape")
312 layer = cls(**config)
313 if None in [query_shape, key_shape, value_shape]:
314 logging.warning(
315 "One of dimensions of the input shape is missing. It "
316 "should have been memorized when the layer was serialized. "
317 "%s is created without weights.",
318 str(cls),
319 )
320 else:
321 layer._build_from_signature(query_shape, value_shape, key_shape)
322 return layer
324 def _build_from_signature(self, query, value, key=None):
325 """Builds layers and variables.
327 Once the method is called, self._built_from_signature will be set to
328 True.
330 Args:
331 query: Query tensor or TensorShape.
332 value: Value tensor or TensorShape.
333 key: Key tensor or TensorShape.
334 """
335 self._built_from_signature = True
336 if hasattr(query, "shape"):
337 self._query_shape = tf.TensorShape(query.shape)
338 else:
339 self._query_shape = tf.TensorShape(query)
340 if hasattr(value, "shape"):
341 self._value_shape = tf.TensorShape(value.shape)
342 else:
343 self._value_shape = tf.TensorShape(value)
344 if key is None:
345 self._key_shape = self._value_shape
346 elif hasattr(key, "shape"):
347 self._key_shape = tf.TensorShape(key.shape)
348 else:
349 self._key_shape = tf.TensorShape(key)
351 # Any setup work performed only once should happen in an `init_scope`
352 # to avoid creating symbolic Tensors that will later pollute any eager
353 # operations.
354 with tf_utils.maybe_init_scope(self):
355 free_dims = self._query_shape.rank - 1
356 einsum_equation, bias_axes, output_rank = _build_proj_equation(
357 free_dims, bound_dims=1, output_dims=2
358 )
359 self._query_dense = core.EinsumDense(
360 einsum_equation,
361 output_shape=_get_output_shape(
362 output_rank - 1, [self._num_heads, self._key_dim]
363 ),
364 bias_axes=bias_axes if self._use_bias else None,
365 name="query",
366 **self._get_common_kwargs_for_sublayer(),
367 )
368 einsum_equation, bias_axes, output_rank = _build_proj_equation(
369 self._key_shape.rank - 1, bound_dims=1, output_dims=2
370 )
371 self._key_dense = core.EinsumDense(
372 einsum_equation,
373 output_shape=_get_output_shape(
374 output_rank - 1, [self._num_heads, self._key_dim]
375 ),
376 bias_axes=bias_axes if self._use_bias else None,
377 name="key",
378 **self._get_common_kwargs_for_sublayer(),
379 )
380 einsum_equation, bias_axes, output_rank = _build_proj_equation(
381 self._value_shape.rank - 1, bound_dims=1, output_dims=2
382 )
383 self._value_dense = core.EinsumDense(
384 einsum_equation,
385 output_shape=_get_output_shape(
386 output_rank - 1, [self._num_heads, self._value_dim]
387 ),
388 bias_axes=bias_axes if self._use_bias else None,
389 name="value",
390 **self._get_common_kwargs_for_sublayer(),
391 )
393 # Builds the attention computations for multi-head dot product
394 # attention. These computations could be wrapped into the keras
395 # attention layer once it supports mult-head einsum computations.
396 self._build_attention(output_rank)
397 self._output_dense = self._make_output_dense(
398 free_dims,
399 self._get_common_kwargs_for_sublayer(),
400 "attention_output",
401 )
403 def _get_common_kwargs_for_sublayer(self):
404 common_kwargs = dict(
405 kernel_regularizer=self._kernel_regularizer,
406 bias_regularizer=self._bias_regularizer,
407 activity_regularizer=self._activity_regularizer,
408 kernel_constraint=self._kernel_constraint,
409 bias_constraint=self._bias_constraint,
410 )
411 # Create new clone of kernel/bias initializer, so that we don't reuse
412 # the initializer instance, which could lead to same init value since
413 # initializer is stateless.
414 kernel_initializer = self._kernel_initializer.__class__.from_config(
415 self._kernel_initializer.get_config()
416 )
417 bias_initializer = self._bias_initializer.__class__.from_config(
418 self._bias_initializer.get_config()
419 )
420 common_kwargs["kernel_initializer"] = kernel_initializer
421 common_kwargs["bias_initializer"] = bias_initializer
422 return common_kwargs
424 def _make_output_dense(self, free_dims, common_kwargs, name=None):
425 """Builds the output projection matrix.
427 Args:
428 free_dims: Number of free dimensions for einsum equation building.
429 common_kwargs: Common keyword arguments for einsum layer.
430 name: Name for the projection layer.
432 Returns:
433 Projection layer.
434 """
435 if self._output_shape:
436 if not isinstance(self._output_shape, collections.abc.Sized):
437 output_shape = [self._output_shape]
438 else:
439 output_shape = self._output_shape
440 else:
441 output_shape = [self._query_shape[-1]]
442 einsum_equation, bias_axes, output_rank = _build_proj_equation(
443 free_dims, bound_dims=2, output_dims=len(output_shape)
444 )
445 return core.EinsumDense(
446 einsum_equation,
447 output_shape=_get_output_shape(output_rank - 1, output_shape),
448 bias_axes=bias_axes if self._use_bias else None,
449 name=name,
450 **common_kwargs,
451 )
453 def _build_attention(self, rank):
454 """Builds multi-head dot-product attention computations.
456 This function builds attributes necessary for `_compute_attention` to
457 customize attention computation to replace the default dot-product
458 attention.
460 Args:
461 rank: the rank of query, key, value tensors.
462 """
463 if self._attention_axes is None:
464 self._attention_axes = tuple(range(1, rank - 2))
465 else:
466 self._attention_axes = tuple(self._attention_axes)
467 (
468 self._dot_product_equation,
469 self._combine_equation,
470 attn_scores_rank,
471 ) = _build_attention_equation(rank, attn_axes=self._attention_axes)
472 norm_axes = tuple(
473 range(
474 attn_scores_rank - len(self._attention_axes), attn_scores_rank
475 )
476 )
477 self._softmax = activation.Softmax(axis=norm_axes)
478 self._dropout_layer = regularization.Dropout(rate=self._dropout)
480 def _masked_softmax(self, attention_scores, attention_mask=None):
481 # Normalize the attention scores to probabilities.
482 # `attention_scores` = [B, N, T, S]
483 if attention_mask is not None:
484 # The expand dim happens starting from the `num_heads` dimension,
485 # (<batch_dims>, num_heads, <query_attention_dims,
486 # key_attention_dims>)
487 mask_expansion_axis = -len(self._attention_axes) * 2 - 1
488 for _ in range(
489 len(attention_scores.shape) - len(attention_mask.shape)
490 ):
491 attention_mask = tf.expand_dims(
492 attention_mask, axis=mask_expansion_axis
493 )
494 return self._softmax(attention_scores, attention_mask)
496 def _compute_attention(
497 self, query, key, value, attention_mask=None, training=None
498 ):
499 """Applies Dot-product attention with query, key, value tensors.
501 This function defines the computation inside `call` with projected
502 multi-head Q, K, V inputs. Users can override this function for
503 customized attention implementation.
505 Args:
506 query: Projected query `Tensor` of shape `(B, T, N, key_dim)`.
507 key: Projected key `Tensor` of shape `(B, S, N, key_dim)`.
508 value: Projected value `Tensor` of shape `(B, S, N, value_dim)`.
509 attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
510 attention to certain positions. It is generally not needed if
511 the `query` and `value` (and/or `key`) are masked.
512 training: Python boolean indicating whether the layer should behave
513 in training mode (adding dropout) or in inference mode (doing
514 nothing).
516 Returns:
517 attention_output: Multi-headed outputs of attention computation.
518 attention_scores: Multi-headed attention weights.
519 """
520 # Note: Applying scalar multiply at the smaller end of einsum improves
521 # XLA performance, but may introduce slight numeric differences in
522 # the Transformer attention head.
523 query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
525 # Take the dot product between "query" and "key" to get the raw
526 # attention scores.
527 attention_scores = tf.einsum(self._dot_product_equation, key, query)
529 attention_scores = self._masked_softmax(
530 attention_scores, attention_mask
531 )
533 # This is actually dropping out entire tokens to attend to, which might
534 # seem a bit unusual, but is taken from the original Transformer paper.
535 attention_scores_dropout = self._dropout_layer(
536 attention_scores, training=training
537 )
539 # `context_layer` = [B, T, N, H]
540 attention_output = tf.einsum(
541 self._combine_equation, attention_scores_dropout, value
542 )
543 return attention_output, attention_scores
545 def call(
546 self,
547 query,
548 value,
549 key=None,
550 attention_mask=None,
551 return_attention_scores=False,
552 training=None,
553 use_causal_mask=False,
554 ):
555 if not self._built_from_signature:
556 self._build_from_signature(query=query, value=value, key=key)
557 if key is None:
558 key = value
560 # Convert RaggedTensor to Tensor.
561 query_is_ragged = isinstance(query, tf.RaggedTensor)
562 if query_is_ragged:
563 query_lengths = query.nested_row_lengths()
564 query = query.to_tensor()
565 key_is_ragged = isinstance(key, tf.RaggedTensor)
566 value_is_ragged = isinstance(value, tf.RaggedTensor)
567 if key_is_ragged and value_is_ragged:
568 # Ensure they have the same shape.
569 bounding_shape = tf.math.maximum(
570 key.bounding_shape(), value.bounding_shape()
571 )
572 key = key.to_tensor(shape=bounding_shape)
573 value = value.to_tensor(shape=bounding_shape)
574 elif key_is_ragged:
575 key = key.to_tensor(shape=tf.shape(value))
576 elif value_is_ragged:
577 value = value.to_tensor(shape=tf.shape(key))
579 attention_mask = self._compute_attention_mask(
580 query,
581 value,
582 key=key,
583 attention_mask=attention_mask,
584 use_causal_mask=use_causal_mask,
585 )
587 # N = `num_attention_heads`
588 # H = `size_per_head`
589 # `query` = [B, T, N ,H]
590 query = self._query_dense(query)
592 # `key` = [B, S, N, H]
593 key = self._key_dense(key)
595 # `value` = [B, S, N, H]
596 value = self._value_dense(value)
598 attention_output, attention_scores = self._compute_attention(
599 query, key, value, attention_mask, training
600 )
601 attention_output = self._output_dense(attention_output)
603 if query_is_ragged:
604 attention_output = tf.RaggedTensor.from_tensor(
605 attention_output, lengths=query_lengths
606 )
608 if return_attention_scores:
609 return attention_output, attention_scores
610 return attention_output
612 def _compute_attention_mask(
613 self, query, value, key=None, attention_mask=None, use_causal_mask=False
614 ):
615 """Computes the attention mask, using the Keras masks of the inputs.
617 * The `query`'s mask is reshaped from [B, T] to [B, T, 1].
618 * The `value`'s mask is reshaped from [B, S] to [B, 1, S].
619 * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s
620 mask is ignored if `key` is `None` or if `key is value`.
621 * If `use_causal_mask=True`, then the causal mask is computed. Its shape
622 is [1, T, S].
624 All defined masks are merged using a logical AND operation (`&`).
626 In general, if the `query` and `value` are masked, then there is no need
627 to define the `attention_mask`.
629 Args:
630 query: Projected query `Tensor` of shape `(B, T, N, key_dim)`.
631 key: Projected key `Tensor` of shape `(B, T, N, key_dim)`.
632 value: Projected value `Tensor` of shape `(B, T, N, value_dim)`.
633 attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
634 attention to certain positions.
635 use_causal_mask: A boolean to indicate whether to apply a causal
636 mask to prevent tokens from attending to future tokens (e.g.,
637 used in a decoder Transformer).
639 Returns:
640 attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
641 attention to certain positions, based on the Keras masks of the
642 `query`, `key`, `value`, and `attention_mask` tensors, and the
643 causal mask if `use_causal_mask=True`.
644 """
645 query_mask = getattr(query, "_keras_mask", None)
646 value_mask = getattr(value, "_keras_mask", None)
647 key_mask = getattr(key, "_keras_mask", None)
648 auto_mask = None
649 if query_mask is not None:
650 query_mask = tf.cast(query_mask, tf.bool) # defensive casting
651 # B = batch size, T = max query length
652 auto_mask = query_mask[:, :, tf.newaxis] # shape is [B, T, 1]
653 if value_mask is not None:
654 value_mask = tf.cast(value_mask, tf.bool) # defensive casting
655 # B = batch size, S == max value length
656 mask = value_mask[:, tf.newaxis, :] # shape is [B, 1, S]
657 auto_mask = mask if auto_mask is None else auto_mask & mask
658 if key_mask is not None:
659 key_mask = tf.cast(key_mask, tf.bool) # defensive casting
660 # B == batch size, S == max key length == max value length
661 mask = key_mask[:, tf.newaxis, :] # shape is [B, 1, S]
662 auto_mask = mask if auto_mask is None else auto_mask & mask
663 if use_causal_mask:
664 # the shape of the causal mask is [1, T, S]
665 mask = self._compute_causal_mask(query, value)
666 auto_mask = mask if auto_mask is None else auto_mask & mask
667 if auto_mask is not None:
668 # merge attention_mask & automatic mask, to shape [B, T, S]
669 attention_mask = (
670 auto_mask
671 if attention_mask is None
672 else tf.cast(attention_mask, bool) & auto_mask
673 )
674 return attention_mask
676 def _compute_causal_mask(self, query, value=None):
677 """Computes a causal mask (e.g., for masked self-attention layers).
679 For example, if query and value both contain sequences of length 4,
680 this function returns a boolean `Tensor` equal to:
682 ```
683 [[[True, False, False, False],
684 [True, True, False, False],
685 [True, True, True, False],
686 [True, True, True, True]]]
687 ```
689 Args:
690 query: query `Tensor` of shape `(B, T, ...)`.
691 value: value `Tensor` of shape `(B, S, ...)` (optional, defaults to
692 query).
694 Returns:
695 mask: a boolean `Tensor` of shape [1, T, S] containing a lower
696 triangular matrix of shape [T, S].
697 """
698 q_seq_length = tf.shape(query)[1]
699 v_seq_length = q_seq_length if value is None else tf.shape(value)[1]
700 return tf.linalg.band_part( # creates a lower triangular matrix
701 tf.ones((1, q_seq_length, v_seq_length), tf.bool), -1, 0
702 )
704 def compute_output_shape(self, query_shape, value_shape, key_shape=None):
706 if key_shape is None:
707 key_shape = value_shape
709 query_shape = tf.TensorShape(query_shape)
710 value_shape = tf.TensorShape(value_shape)
711 key_shape = tf.TensorShape(key_shape)
713 if query_shape[-1] != value_shape[-1]:
714 raise ValueError(
715 "The last dimension of `query_shape` and `value_shape` "
716 f"must be equal, but are {query_shape[-1]}, {value_shape[-1]}. "
717 "Received: query_shape={query_shape}, value_shape={value_shape}"
718 )
720 if value_shape[1:-1] != key_shape[1:-1]:
721 raise ValueError(
722 "All dimensions of `value` and `key`, except the last one, "
723 f"must be equal. Received {value_shape} and "
724 f"{key_shape}"
725 )
727 if self._output_shape:
728 return query_shape[:-1].concatenate(self._output_shape)
730 return query_shape