Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/attention/multi_head_attention.py: 16%

222 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras-based multi-head attention layer.""" 

16 

17 

18import collections 

19import math 

20import string 

21 

22import numpy as np 

23import tensorflow.compat.v2 as tf 

24 

25from keras.src import constraints 

26from keras.src import initializers 

27from keras.src import regularizers 

28from keras.src.engine.base_layer import Layer 

29from keras.src.layers import activation 

30from keras.src.layers import core 

31from keras.src.layers import regularization 

32from keras.src.utils import tf_utils 

33 

34# isort: off 

35from tensorflow.python.platform import tf_logging as logging 

36from tensorflow.python.util.tf_export import keras_export 

37 

38_CHR_IDX = string.ascii_lowercase 

39 

40 

41def _build_attention_equation(rank, attn_axes): 

42 """Builds einsum equations for the attention computation. 

43 

44 Query, key, value inputs after projection are expected to have the shape as: 

45 `(bs, <non-attention dims>, <attention dims>, num_heads, channels)`. 

46 `bs` and `<non-attention dims>` are treated as `<batch dims>`. 

47 

48 The attention operations can be generalized: 

49 (1) Query-key dot product: 

50 `(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>, 

51 <key attention dims>, num_heads, channels) -> (<batch dims>, 

52 num_heads, <query attention dims>, <key attention dims>)` 

53 (2) Combination: 

54 `(<batch dims>, num_heads, <query attention dims>, <key attention dims>), 

55 (<batch dims>, <value attention dims>, num_heads, channels) -> (<batch 

56 dims>, <query attention dims>, num_heads, channels)` 

57 

58 Args: 

59 rank: Rank of query, key, value tensors. 

60 attn_axes: List/tuple of axes, `[-1, rank)`, 

61 that attention will be applied to. 

62 

63 Returns: 

64 Einsum equations. 

65 """ 

66 target_notation = _CHR_IDX[:rank] 

67 # `batch_dims` includes the head dim. 

68 batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,))) 

69 letter_offset = rank 

70 source_notation = "" 

71 for i in range(rank): 

72 if i in batch_dims or i == rank - 1: 

73 source_notation += target_notation[i] 

74 else: 

75 source_notation += _CHR_IDX[letter_offset] 

76 letter_offset += 1 

77 

78 product_notation = "".join( 

79 [target_notation[i] for i in batch_dims] 

80 + [target_notation[i] for i in attn_axes] 

81 + [source_notation[i] for i in attn_axes] 

82 ) 

83 dot_product_equation = "%s,%s->%s" % ( 

84 source_notation, 

85 target_notation, 

86 product_notation, 

87 ) 

88 attn_scores_rank = len(product_notation) 

89 combine_equation = "%s,%s->%s" % ( 

90 product_notation, 

91 source_notation, 

92 target_notation, 

93 ) 

94 return dot_product_equation, combine_equation, attn_scores_rank 

95 

96 

97def _build_proj_equation(free_dims, bound_dims, output_dims): 

98 """Builds an einsum equation for projections inside multi-head attention.""" 

99 input_str = "" 

100 kernel_str = "" 

101 output_str = "" 

102 bias_axes = "" 

103 letter_offset = 0 

104 for i in range(free_dims): 

105 char = _CHR_IDX[i + letter_offset] 

106 input_str += char 

107 output_str += char 

108 

109 letter_offset += free_dims 

110 for i in range(bound_dims): 

111 char = _CHR_IDX[i + letter_offset] 

112 input_str += char 

113 kernel_str += char 

114 

115 letter_offset += bound_dims 

116 for i in range(output_dims): 

117 char = _CHR_IDX[i + letter_offset] 

118 kernel_str += char 

119 output_str += char 

120 bias_axes += char 

121 equation = f"{input_str},{kernel_str}->{output_str}" 

122 

123 return equation, bias_axes, len(output_str) 

124 

125 

126def _get_output_shape(output_rank, known_last_dims): 

127 return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims) 

128 

129 

130@keras_export("keras.layers.MultiHeadAttention") 

131class MultiHeadAttention(Layer): 

132 """MultiHeadAttention layer. 

133 

134 This is an implementation of multi-headed attention as described in the 

135 paper "Attention is all you Need" (Vaswani et al., 2017). 

136 If `query`, `key,` `value` are the same, then 

137 this is self-attention. Each timestep in `query` attends to the 

138 corresponding sequence in `key`, and returns a fixed-width vector. 

139 

140 This layer first projects `query`, `key` and `value`. These are 

141 (effectively) a list of tensors of length `num_attention_heads`, where the 

142 corresponding shapes are `(batch_size, <query dimensions>, key_dim)`, 

143 `(batch_size, <key/value dimensions>, key_dim)`, 

144 `(batch_size, <key/value dimensions>, value_dim)`. 

145 

146 Then, the query and key tensors are dot-producted and scaled. These are 

147 softmaxed to obtain attention probabilities. The value tensors are then 

148 interpolated by these probabilities, then concatenated back to a single 

149 tensor. 

150 

151 Finally, the result tensor with the last dimension as value_dim can take an 

152 linear projection and return. 

153 

154 When using `MultiHeadAttention` inside a custom layer, the custom layer must 

155 implement its own `build()` method and call `MultiHeadAttention`'s 

156 `_build_from_signature()` there. 

157 This enables weights to be restored correctly when the model is loaded. 

158 

159 Examples: 

160 

161 Performs 1D cross-attention over two sequence inputs with an attention mask. 

162 Returns the additional attention weights over heads. 

163 

164 >>> layer = MultiHeadAttention(num_heads=2, key_dim=2) 

165 >>> target = tf.keras.Input(shape=[8, 16]) 

166 >>> source = tf.keras.Input(shape=[4, 16]) 

167 >>> output_tensor, weights = layer(target, source, 

168 ... return_attention_scores=True) 

169 >>> print(output_tensor.shape) 

170 (None, 8, 16) 

171 >>> print(weights.shape) 

172 (None, 2, 8, 4) 

173 

174 Performs 2D self-attention over a 5D input tensor on axes 2 and 3. 

175 

176 >>> layer = MultiHeadAttention( 

177 ... num_heads=2, key_dim=2, attention_axes=(2, 3)) 

178 >>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16]) 

179 >>> output_tensor = layer(input_tensor, input_tensor) 

180 >>> print(output_tensor.shape) 

181 (None, 5, 3, 4, 16) 

182 

183 Args: 

184 num_heads: Number of attention heads. 

185 key_dim: Size of each attention head for query and key. 

186 value_dim: Size of each attention head for value. 

187 dropout: Dropout probability. 

188 use_bias: Boolean, whether the dense layers use bias vectors/matrices. 

189 output_shape: The expected shape of an output tensor, besides the batch 

190 and sequence dims. If not specified, projects back to the query 

191 feature dim (the query input's last dimension). 

192 attention_axes: axes over which the attention is applied. `None` means 

193 attention over all axes, but batch, heads, and features. 

194 kernel_initializer: Initializer for dense layer kernels. 

195 bias_initializer: Initializer for dense layer biases. 

196 kernel_regularizer: Regularizer for dense layer kernels. 

197 bias_regularizer: Regularizer for dense layer biases. 

198 activity_regularizer: Regularizer for dense layer activity. 

199 kernel_constraint: Constraint for dense layer kernels. 

200 bias_constraint: Constraint for dense layer kernels. 

201 

202 Call arguments: 

203 query: Query `Tensor` of shape `(B, T, dim)`. 

204 value: Value `Tensor` of shape `(B, S, dim)`. 

205 key: Optional key `Tensor` of shape `(B, S, dim)`. If not given, will 

206 use `value` for both `key` and `value`, which is the most common 

207 case. 

208 attention_mask: a boolean mask of shape `(B, T, S)`, that prevents 

209 attention to certain positions. The boolean mask specifies which 

210 query elements can attend to which key elements, 1 indicates 

211 attention and 0 indicates no attention. Broadcasting can happen for 

212 the missing batch dimensions and the head dimension. 

213 return_attention_scores: A boolean to indicate whether the output should 

214 be `(attention_output, attention_scores)` if `True`, or 

215 `attention_output` if `False`. Defaults to `False`. 

216 training: Python boolean indicating whether the layer should behave in 

217 training mode (adding dropout) or in inference mode (no dropout). 

218 Will go with either using the training mode of the parent 

219 layer/model, or False (inference) if there is no parent layer. 

220 use_causal_mask: A boolean to indicate whether to apply a causal mask to 

221 prevent tokens from attending to future tokens (e.g., used in a 

222 decoder Transformer). 

223 

224 Returns: 

225 attention_output: The result of the computation, of shape `(B, T, E)`, 

226 where `T` is for target sequence shapes and `E` is the query input 

227 last dimension if `output_shape` is `None`. Otherwise, the 

228 multi-head outputs are projected to the shape specified by 

229 `output_shape`. 

230 attention_scores: [Optional] multi-head attention coefficients over 

231 attention axes. 

232 """ 

233 

234 def __init__( 

235 self, 

236 num_heads, 

237 key_dim, 

238 value_dim=None, 

239 dropout=0.0, 

240 use_bias=True, 

241 output_shape=None, 

242 attention_axes=None, 

243 kernel_initializer="glorot_uniform", 

244 bias_initializer="zeros", 

245 kernel_regularizer=None, 

246 bias_regularizer=None, 

247 activity_regularizer=None, 

248 kernel_constraint=None, 

249 bias_constraint=None, 

250 **kwargs, 

251 ): 

252 super().__init__(**kwargs) 

253 self.supports_masking = True 

254 self._num_heads = num_heads 

255 self._key_dim = key_dim 

256 self._value_dim = value_dim if value_dim else key_dim 

257 self._dropout = dropout 

258 self._use_bias = use_bias 

259 self._output_shape = output_shape 

260 self._kernel_initializer = initializers.get(kernel_initializer) 

261 self._bias_initializer = initializers.get(bias_initializer) 

262 self._kernel_regularizer = regularizers.get(kernel_regularizer) 

263 self._bias_regularizer = regularizers.get(bias_regularizer) 

264 self._activity_regularizer = regularizers.get(activity_regularizer) 

265 self._kernel_constraint = constraints.get(kernel_constraint) 

266 self._bias_constraint = constraints.get(bias_constraint) 

267 if attention_axes is not None and not isinstance( 

268 attention_axes, collections.abc.Sized 

269 ): 

270 self._attention_axes = (attention_axes,) 

271 else: 

272 self._attention_axes = attention_axes 

273 self._built_from_signature = False 

274 self._query_shape, self._key_shape, self._value_shape = None, None, None 

275 

276 def get_config(self): 

277 config = { 

278 "num_heads": self._num_heads, 

279 "key_dim": self._key_dim, 

280 "value_dim": self._value_dim, 

281 "dropout": self._dropout, 

282 "use_bias": self._use_bias, 

283 "output_shape": self._output_shape, 

284 "attention_axes": self._attention_axes, 

285 "kernel_initializer": initializers.serialize( 

286 self._kernel_initializer 

287 ), 

288 "bias_initializer": initializers.serialize(self._bias_initializer), 

289 "kernel_regularizer": regularizers.serialize( 

290 self._kernel_regularizer 

291 ), 

292 "bias_regularizer": regularizers.serialize(self._bias_regularizer), 

293 "activity_regularizer": regularizers.serialize( 

294 self._activity_regularizer 

295 ), 

296 "kernel_constraint": constraints.serialize(self._kernel_constraint), 

297 "bias_constraint": constraints.serialize(self._bias_constraint), 

298 "query_shape": self._query_shape, 

299 "key_shape": self._key_shape, 

300 "value_shape": self._value_shape, 

301 } 

302 base_config = super().get_config() 

303 return dict(list(base_config.items()) + list(config.items())) 

304 

305 @classmethod 

306 def from_config(cls, config): 

307 # If the layer has a different build() function from the Keras default, 

308 # we need to trigger the customized build to create weights. 

309 query_shape = config.pop("query_shape") 

310 key_shape = config.pop("key_shape") 

311 value_shape = config.pop("value_shape") 

312 layer = cls(**config) 

313 if None in [query_shape, key_shape, value_shape]: 

314 logging.warning( 

315 "One of dimensions of the input shape is missing. It " 

316 "should have been memorized when the layer was serialized. " 

317 "%s is created without weights.", 

318 str(cls), 

319 ) 

320 else: 

321 layer._build_from_signature(query_shape, value_shape, key_shape) 

322 return layer 

323 

324 def _build_from_signature(self, query, value, key=None): 

325 """Builds layers and variables. 

326 

327 Once the method is called, self._built_from_signature will be set to 

328 True. 

329 

330 Args: 

331 query: Query tensor or TensorShape. 

332 value: Value tensor or TensorShape. 

333 key: Key tensor or TensorShape. 

334 """ 

335 self._built_from_signature = True 

336 if hasattr(query, "shape"): 

337 self._query_shape = tf.TensorShape(query.shape) 

338 else: 

339 self._query_shape = tf.TensorShape(query) 

340 if hasattr(value, "shape"): 

341 self._value_shape = tf.TensorShape(value.shape) 

342 else: 

343 self._value_shape = tf.TensorShape(value) 

344 if key is None: 

345 self._key_shape = self._value_shape 

346 elif hasattr(key, "shape"): 

347 self._key_shape = tf.TensorShape(key.shape) 

348 else: 

349 self._key_shape = tf.TensorShape(key) 

350 

351 # Any setup work performed only once should happen in an `init_scope` 

352 # to avoid creating symbolic Tensors that will later pollute any eager 

353 # operations. 

354 with tf_utils.maybe_init_scope(self): 

355 free_dims = self._query_shape.rank - 1 

356 einsum_equation, bias_axes, output_rank = _build_proj_equation( 

357 free_dims, bound_dims=1, output_dims=2 

358 ) 

359 self._query_dense = core.EinsumDense( 

360 einsum_equation, 

361 output_shape=_get_output_shape( 

362 output_rank - 1, [self._num_heads, self._key_dim] 

363 ), 

364 bias_axes=bias_axes if self._use_bias else None, 

365 name="query", 

366 **self._get_common_kwargs_for_sublayer(), 

367 ) 

368 einsum_equation, bias_axes, output_rank = _build_proj_equation( 

369 self._key_shape.rank - 1, bound_dims=1, output_dims=2 

370 ) 

371 self._key_dense = core.EinsumDense( 

372 einsum_equation, 

373 output_shape=_get_output_shape( 

374 output_rank - 1, [self._num_heads, self._key_dim] 

375 ), 

376 bias_axes=bias_axes if self._use_bias else None, 

377 name="key", 

378 **self._get_common_kwargs_for_sublayer(), 

379 ) 

380 einsum_equation, bias_axes, output_rank = _build_proj_equation( 

381 self._value_shape.rank - 1, bound_dims=1, output_dims=2 

382 ) 

383 self._value_dense = core.EinsumDense( 

384 einsum_equation, 

385 output_shape=_get_output_shape( 

386 output_rank - 1, [self._num_heads, self._value_dim] 

387 ), 

388 bias_axes=bias_axes if self._use_bias else None, 

389 name="value", 

390 **self._get_common_kwargs_for_sublayer(), 

391 ) 

392 

393 # Builds the attention computations for multi-head dot product 

394 # attention. These computations could be wrapped into the keras 

395 # attention layer once it supports mult-head einsum computations. 

396 self._build_attention(output_rank) 

397 self._output_dense = self._make_output_dense( 

398 free_dims, 

399 self._get_common_kwargs_for_sublayer(), 

400 "attention_output", 

401 ) 

402 

403 def _get_common_kwargs_for_sublayer(self): 

404 common_kwargs = dict( 

405 kernel_regularizer=self._kernel_regularizer, 

406 bias_regularizer=self._bias_regularizer, 

407 activity_regularizer=self._activity_regularizer, 

408 kernel_constraint=self._kernel_constraint, 

409 bias_constraint=self._bias_constraint, 

410 ) 

411 # Create new clone of kernel/bias initializer, so that we don't reuse 

412 # the initializer instance, which could lead to same init value since 

413 # initializer is stateless. 

414 kernel_initializer = self._kernel_initializer.__class__.from_config( 

415 self._kernel_initializer.get_config() 

416 ) 

417 bias_initializer = self._bias_initializer.__class__.from_config( 

418 self._bias_initializer.get_config() 

419 ) 

420 common_kwargs["kernel_initializer"] = kernel_initializer 

421 common_kwargs["bias_initializer"] = bias_initializer 

422 return common_kwargs 

423 

424 def _make_output_dense(self, free_dims, common_kwargs, name=None): 

425 """Builds the output projection matrix. 

426 

427 Args: 

428 free_dims: Number of free dimensions for einsum equation building. 

429 common_kwargs: Common keyword arguments for einsum layer. 

430 name: Name for the projection layer. 

431 

432 Returns: 

433 Projection layer. 

434 """ 

435 if self._output_shape: 

436 if not isinstance(self._output_shape, collections.abc.Sized): 

437 output_shape = [self._output_shape] 

438 else: 

439 output_shape = self._output_shape 

440 else: 

441 output_shape = [self._query_shape[-1]] 

442 einsum_equation, bias_axes, output_rank = _build_proj_equation( 

443 free_dims, bound_dims=2, output_dims=len(output_shape) 

444 ) 

445 return core.EinsumDense( 

446 einsum_equation, 

447 output_shape=_get_output_shape(output_rank - 1, output_shape), 

448 bias_axes=bias_axes if self._use_bias else None, 

449 name=name, 

450 **common_kwargs, 

451 ) 

452 

453 def _build_attention(self, rank): 

454 """Builds multi-head dot-product attention computations. 

455 

456 This function builds attributes necessary for `_compute_attention` to 

457 customize attention computation to replace the default dot-product 

458 attention. 

459 

460 Args: 

461 rank: the rank of query, key, value tensors. 

462 """ 

463 if self._attention_axes is None: 

464 self._attention_axes = tuple(range(1, rank - 2)) 

465 else: 

466 self._attention_axes = tuple(self._attention_axes) 

467 ( 

468 self._dot_product_equation, 

469 self._combine_equation, 

470 attn_scores_rank, 

471 ) = _build_attention_equation(rank, attn_axes=self._attention_axes) 

472 norm_axes = tuple( 

473 range( 

474 attn_scores_rank - len(self._attention_axes), attn_scores_rank 

475 ) 

476 ) 

477 self._softmax = activation.Softmax(axis=norm_axes) 

478 self._dropout_layer = regularization.Dropout(rate=self._dropout) 

479 

480 def _masked_softmax(self, attention_scores, attention_mask=None): 

481 # Normalize the attention scores to probabilities. 

482 # `attention_scores` = [B, N, T, S] 

483 if attention_mask is not None: 

484 # The expand dim happens starting from the `num_heads` dimension, 

485 # (<batch_dims>, num_heads, <query_attention_dims, 

486 # key_attention_dims>) 

487 mask_expansion_axis = -len(self._attention_axes) * 2 - 1 

488 for _ in range( 

489 len(attention_scores.shape) - len(attention_mask.shape) 

490 ): 

491 attention_mask = tf.expand_dims( 

492 attention_mask, axis=mask_expansion_axis 

493 ) 

494 return self._softmax(attention_scores, attention_mask) 

495 

496 def _compute_attention( 

497 self, query, key, value, attention_mask=None, training=None 

498 ): 

499 """Applies Dot-product attention with query, key, value tensors. 

500 

501 This function defines the computation inside `call` with projected 

502 multi-head Q, K, V inputs. Users can override this function for 

503 customized attention implementation. 

504 

505 Args: 

506 query: Projected query `Tensor` of shape `(B, T, N, key_dim)`. 

507 key: Projected key `Tensor` of shape `(B, S, N, key_dim)`. 

508 value: Projected value `Tensor` of shape `(B, S, N, value_dim)`. 

509 attention_mask: a boolean mask of shape `(B, T, S)`, that prevents 

510 attention to certain positions. It is generally not needed if 

511 the `query` and `value` (and/or `key`) are masked. 

512 training: Python boolean indicating whether the layer should behave 

513 in training mode (adding dropout) or in inference mode (doing 

514 nothing). 

515 

516 Returns: 

517 attention_output: Multi-headed outputs of attention computation. 

518 attention_scores: Multi-headed attention weights. 

519 """ 

520 # Note: Applying scalar multiply at the smaller end of einsum improves 

521 # XLA performance, but may introduce slight numeric differences in 

522 # the Transformer attention head. 

523 query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim))) 

524 

525 # Take the dot product between "query" and "key" to get the raw 

526 # attention scores. 

527 attention_scores = tf.einsum(self._dot_product_equation, key, query) 

528 

529 attention_scores = self._masked_softmax( 

530 attention_scores, attention_mask 

531 ) 

532 

533 # This is actually dropping out entire tokens to attend to, which might 

534 # seem a bit unusual, but is taken from the original Transformer paper. 

535 attention_scores_dropout = self._dropout_layer( 

536 attention_scores, training=training 

537 ) 

538 

539 # `context_layer` = [B, T, N, H] 

540 attention_output = tf.einsum( 

541 self._combine_equation, attention_scores_dropout, value 

542 ) 

543 return attention_output, attention_scores 

544 

545 def call( 

546 self, 

547 query, 

548 value, 

549 key=None, 

550 attention_mask=None, 

551 return_attention_scores=False, 

552 training=None, 

553 use_causal_mask=False, 

554 ): 

555 if not self._built_from_signature: 

556 self._build_from_signature(query=query, value=value, key=key) 

557 if key is None: 

558 key = value 

559 

560 # Convert RaggedTensor to Tensor. 

561 query_is_ragged = isinstance(query, tf.RaggedTensor) 

562 if query_is_ragged: 

563 query_lengths = query.nested_row_lengths() 

564 query = query.to_tensor() 

565 key_is_ragged = isinstance(key, tf.RaggedTensor) 

566 value_is_ragged = isinstance(value, tf.RaggedTensor) 

567 if key_is_ragged and value_is_ragged: 

568 # Ensure they have the same shape. 

569 bounding_shape = tf.math.maximum( 

570 key.bounding_shape(), value.bounding_shape() 

571 ) 

572 key = key.to_tensor(shape=bounding_shape) 

573 value = value.to_tensor(shape=bounding_shape) 

574 elif key_is_ragged: 

575 key = key.to_tensor(shape=tf.shape(value)) 

576 elif value_is_ragged: 

577 value = value.to_tensor(shape=tf.shape(key)) 

578 

579 attention_mask = self._compute_attention_mask( 

580 query, 

581 value, 

582 key=key, 

583 attention_mask=attention_mask, 

584 use_causal_mask=use_causal_mask, 

585 ) 

586 

587 # N = `num_attention_heads` 

588 # H = `size_per_head` 

589 # `query` = [B, T, N ,H] 

590 query = self._query_dense(query) 

591 

592 # `key` = [B, S, N, H] 

593 key = self._key_dense(key) 

594 

595 # `value` = [B, S, N, H] 

596 value = self._value_dense(value) 

597 

598 attention_output, attention_scores = self._compute_attention( 

599 query, key, value, attention_mask, training 

600 ) 

601 attention_output = self._output_dense(attention_output) 

602 

603 if query_is_ragged: 

604 attention_output = tf.RaggedTensor.from_tensor( 

605 attention_output, lengths=query_lengths 

606 ) 

607 

608 if return_attention_scores: 

609 return attention_output, attention_scores 

610 return attention_output 

611 

612 def _compute_attention_mask( 

613 self, query, value, key=None, attention_mask=None, use_causal_mask=False 

614 ): 

615 """Computes the attention mask, using the Keras masks of the inputs. 

616 

617 * The `query`'s mask is reshaped from [B, T] to [B, T, 1]. 

618 * The `value`'s mask is reshaped from [B, S] to [B, 1, S]. 

619 * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s 

620 mask is ignored if `key` is `None` or if `key is value`. 

621 * If `use_causal_mask=True`, then the causal mask is computed. Its shape 

622 is [1, T, S]. 

623 

624 All defined masks are merged using a logical AND operation (`&`). 

625 

626 In general, if the `query` and `value` are masked, then there is no need 

627 to define the `attention_mask`. 

628 

629 Args: 

630 query: Projected query `Tensor` of shape `(B, T, N, key_dim)`. 

631 key: Projected key `Tensor` of shape `(B, T, N, key_dim)`. 

632 value: Projected value `Tensor` of shape `(B, T, N, value_dim)`. 

633 attention_mask: a boolean mask of shape `(B, T, S)`, that prevents 

634 attention to certain positions. 

635 use_causal_mask: A boolean to indicate whether to apply a causal 

636 mask to prevent tokens from attending to future tokens (e.g., 

637 used in a decoder Transformer). 

638 

639 Returns: 

640 attention_mask: a boolean mask of shape `(B, T, S)`, that prevents 

641 attention to certain positions, based on the Keras masks of the 

642 `query`, `key`, `value`, and `attention_mask` tensors, and the 

643 causal mask if `use_causal_mask=True`. 

644 """ 

645 query_mask = getattr(query, "_keras_mask", None) 

646 value_mask = getattr(value, "_keras_mask", None) 

647 key_mask = getattr(key, "_keras_mask", None) 

648 auto_mask = None 

649 if query_mask is not None: 

650 query_mask = tf.cast(query_mask, tf.bool) # defensive casting 

651 # B = batch size, T = max query length 

652 auto_mask = query_mask[:, :, tf.newaxis] # shape is [B, T, 1] 

653 if value_mask is not None: 

654 value_mask = tf.cast(value_mask, tf.bool) # defensive casting 

655 # B = batch size, S == max value length 

656 mask = value_mask[:, tf.newaxis, :] # shape is [B, 1, S] 

657 auto_mask = mask if auto_mask is None else auto_mask & mask 

658 if key_mask is not None: 

659 key_mask = tf.cast(key_mask, tf.bool) # defensive casting 

660 # B == batch size, S == max key length == max value length 

661 mask = key_mask[:, tf.newaxis, :] # shape is [B, 1, S] 

662 auto_mask = mask if auto_mask is None else auto_mask & mask 

663 if use_causal_mask: 

664 # the shape of the causal mask is [1, T, S] 

665 mask = self._compute_causal_mask(query, value) 

666 auto_mask = mask if auto_mask is None else auto_mask & mask 

667 if auto_mask is not None: 

668 # merge attention_mask & automatic mask, to shape [B, T, S] 

669 attention_mask = ( 

670 auto_mask 

671 if attention_mask is None 

672 else tf.cast(attention_mask, bool) & auto_mask 

673 ) 

674 return attention_mask 

675 

676 def _compute_causal_mask(self, query, value=None): 

677 """Computes a causal mask (e.g., for masked self-attention layers). 

678 

679 For example, if query and value both contain sequences of length 4, 

680 this function returns a boolean `Tensor` equal to: 

681 

682 ``` 

683 [[[True, False, False, False], 

684 [True, True, False, False], 

685 [True, True, True, False], 

686 [True, True, True, True]]] 

687 ``` 

688 

689 Args: 

690 query: query `Tensor` of shape `(B, T, ...)`. 

691 value: value `Tensor` of shape `(B, S, ...)` (optional, defaults to 

692 query). 

693 

694 Returns: 

695 mask: a boolean `Tensor` of shape [1, T, S] containing a lower 

696 triangular matrix of shape [T, S]. 

697 """ 

698 q_seq_length = tf.shape(query)[1] 

699 v_seq_length = q_seq_length if value is None else tf.shape(value)[1] 

700 return tf.linalg.band_part( # creates a lower triangular matrix 

701 tf.ones((1, q_seq_length, v_seq_length), tf.bool), -1, 0 

702 ) 

703 

704 def compute_output_shape(self, query_shape, value_shape, key_shape=None): 

705 

706 if key_shape is None: 

707 key_shape = value_shape 

708 

709 query_shape = tf.TensorShape(query_shape) 

710 value_shape = tf.TensorShape(value_shape) 

711 key_shape = tf.TensorShape(key_shape) 

712 

713 if query_shape[-1] != value_shape[-1]: 

714 raise ValueError( 

715 "The last dimension of `query_shape` and `value_shape` " 

716 f"must be equal, but are {query_shape[-1]}, {value_shape[-1]}. " 

717 "Received: query_shape={query_shape}, value_shape={value_shape}" 

718 ) 

719 

720 if value_shape[1:-1] != key_shape[1:-1]: 

721 raise ValueError( 

722 "All dimensions of `value` and `key`, except the last one, " 

723 f"must be equal. Received {value_shape} and " 

724 f"{key_shape}" 

725 ) 

726 

727 if self._output_shape: 

728 return query_shape[:-1].concatenate(self._output_shape) 

729 

730 return query_shape 

731