Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/seq2seq/attention_wrapper.py: 18%
524 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A dynamic attention wrapper for RNN cells."""
17import collections
18import functools
19import math
20from packaging.version import Version
22import numpy as np
24import tensorflow as tf
26from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
27from tensorflow_addons.utils import keras_utils
28from tensorflow_addons.utils.types import (
29 AcceptableDTypes,
30 FloatTensorLike,
31 TensorLike,
32 Initializer,
33 Number,
34)
36from typeguard import typechecked
37from typing import Optional, Callable, Union, List
40if Version(tf.__version__) < Version("2.13"):
41 SERIALIZATION_ARGS = {}
42else:
43 SERIALIZATION_ARGS = {"use_legacy_format": True}
46class AttentionMechanism(tf.keras.layers.Layer):
47 """Base class for attention mechanisms.
49 Common functionality includes:
50 1. Storing the query and memory layers.
51 2. Preprocessing and storing the memory.
53 Note that this layer takes memory as its init parameter, which is an
54 anti-pattern of Keras API, we have to keep the memory as init parameter for
55 performance and dependency reason. Under the hood, during `__init__()`, it
56 will invoke `base_layer.__call__(memory, setup_memory=True)`. This will let
57 keras to keep track of the memory tensor as the input of this layer. Once
58 the `__init__()` is done, then user can query the attention by
59 `score = att_obj([query, state])`, and use it as a normal keras layer.
61 Special attention is needed when adding using this class as the base layer
62 for new attention:
63 1. Build() could be invoked at least twice. So please make sure weights
64 are not duplicated.
65 2. Layer.get_weights() might return different set of weights if the
66 instance has `query_layer`. The query_layer weights is not initialized
67 until the memory is configured.
69 Also note that this layer does not work with Keras model when
70 `model.compile(run_eagerly=True)` due to the fact that this layer is
71 stateful. The support for that will be added in a future version.
72 """
74 @typechecked
75 def __init__(
76 self,
77 memory: Union[TensorLike, None],
78 probability_fn: callable,
79 query_layer: Optional[tf.keras.layers.Layer] = None,
80 memory_layer: Optional[tf.keras.layers.Layer] = None,
81 memory_sequence_length: Optional[TensorLike] = None,
82 **kwargs,
83 ):
84 """Construct base AttentionMechanism class.
86 Args:
87 memory: The memory to query; usually the output of an RNN encoder.
88 This tensor should be shaped `[batch_size, max_time, ...]`.
89 probability_fn: A `callable`. Converts the score and previous
90 alignments to probabilities. Its signature should be:
91 `probabilities = probability_fn(score, state)`.
92 query_layer: Optional `tf.keras.layers.Layer` instance. The layer's
93 depth must match the depth of `memory_layer`. If `query_layer` is
94 not provided, the shape of `query` must match that of
95 `memory_layer`.
96 memory_layer: Optional `tf.keras.layers.Layer` instance. The layer's
97 depth must match the depth of `query_layer`.
98 If `memory_layer` is not provided, the shape of `memory` must match
99 that of `query_layer`.
100 memory_sequence_length: (optional) Sequence lengths for the batch
101 entries in memory. If provided, the memory tensor rows are masked
102 with zeros for values past the respective sequence lengths.
103 **kwargs: Dictionary that contains other common arguments for layer
104 creation.
105 """
106 self.query_layer = query_layer
107 self.memory_layer = memory_layer
108 super().__init__(**kwargs)
109 self.default_probability_fn = probability_fn
110 self.probability_fn = probability_fn
112 self.keys = None
113 self.values = None
114 self.batch_size = None
115 self._memory_initialized = False
116 self._check_inner_dims_defined = True
117 self.supports_masking = True
119 if memory is not None:
120 # Setup the memory by self.__call__() with memory and
121 # memory_seq_length. This will make the attention follow the keras
122 # convention which takes all the tensor inputs via __call__().
123 if memory_sequence_length is None:
124 inputs = memory
125 else:
126 inputs = [memory, memory_sequence_length]
128 self.values = super().__call__(inputs, setup_memory=True)
130 @property
131 def memory_initialized(self):
132 """Returns `True` if this attention mechanism has been initialized with
133 a memory."""
134 return self._memory_initialized
136 def build(self, input_shape):
137 if not self._memory_initialized:
138 # This is for setting up the memory, which contains memory and
139 # optional memory_sequence_length. Build the memory_layer with
140 # memory shape.
141 if self.memory_layer is not None and not self.memory_layer.built:
142 if isinstance(input_shape, list):
143 self.memory_layer.build(input_shape[0])
144 else:
145 self.memory_layer.build(input_shape)
146 else:
147 # The input_shape should be query.shape and state.shape. Use the
148 # query to init the query layer.
149 if self.query_layer is not None and not self.query_layer.built:
150 self.query_layer.build(input_shape[0])
152 def __call__(self, inputs, **kwargs):
153 """Preprocess the inputs before calling `base_layer.__call__()`.
155 Note that there are situation here, one for setup memory, and one with
156 actual query and state.
157 1. When the memory has not been configured, we just pass all the param
158 to `base_layer.__call__()`, which will then invoke `self.call()` with
159 proper inputs, which allows this class to setup memory.
160 2. When the memory has already been setup, the input should contain
161 query and state, and optionally processed memory. If the processed
162 memory is not included in the input, we will have to append it to
163 the inputs and give it to the `base_layer.__call__()`. The processed
164 memory is the output of first invocation of `self.__call__()`. If we
165 don't add it here, then from keras perspective, the graph is
166 disconnected since the output from previous call is never used.
168 Args:
169 inputs: the inputs tensors.
170 **kwargs: dict, other keyeword arguments for the `__call__()`
171 """
172 # Allow manual memory reset
173 if kwargs.get("setup_memory", False):
174 self._memory_initialized = False
176 if self._memory_initialized:
177 if len(inputs) not in (2, 3):
178 raise ValueError(
179 "Expect the inputs to have 2 or 3 tensors, got %d" % len(inputs)
180 )
181 if len(inputs) == 2:
182 # We append the calculated memory here so that the graph will be
183 # connected.
184 inputs.append(self.values)
186 return super().__call__(inputs, **kwargs)
188 def call(self, inputs, mask=None, setup_memory=False, **kwargs):
189 """Setup the memory or query the attention.
191 There are two case here, one for setup memory, and the second is query
192 the attention score. `setup_memory` is the flag to indicate which mode
193 it is. The input list will be treated differently based on that flag.
195 Args:
196 inputs: a list of tensor that could either be `query` and `state`, or
197 `memory` and `memory_sequence_length`.
198 `query` is the tensor of dtype matching `memory` and shape
199 `[batch_size, query_depth]`.
200 `state` is the tensor of dtype matching `memory` and shape
201 `[batch_size, alignments_size]`. (`alignments_size` is memory's
202 `max_time`).
203 `memory` is the memory to query; usually the output of an RNN
204 encoder. The tensor should be shaped `[batch_size, max_time, ...]`.
205 `memory_sequence_length` (optional) is the sequence lengths for the
206 batch entries in memory. If provided, the memory tensor rows are
207 masked with zeros for values past the respective sequence lengths.
208 mask: optional bool tensor with shape `[batch, max_time]` for the
209 mask of memory. If it is not None, the corresponding item of the
210 memory should be filtered out during calculation.
211 setup_memory: boolean, whether the input is for setting up memory, or
212 query attention.
213 **kwargs: Dict, other keyword arguments for the call method.
214 Returns:
215 Either processed memory or attention score, based on `setup_memory`.
216 """
217 if setup_memory:
218 if isinstance(inputs, list):
219 if len(inputs) not in (1, 2):
220 raise ValueError(
221 "Expect inputs to have 1 or 2 tensors, got %d" % len(inputs)
222 )
223 memory = inputs[0]
224 memory_sequence_length = inputs[1] if len(inputs) == 2 else None
225 memory_mask = mask
226 else:
227 memory, memory_sequence_length = inputs, None
228 memory_mask = mask
229 self.setup_memory(memory, memory_sequence_length, memory_mask)
230 # We force the self.built to false here since only memory is,
231 # initialized but the real query/state has not been call() yet. The
232 # layer should be build and call again.
233 self.built = False
234 # Return the processed memory in order to create the Keras
235 # connectivity data for it.
236 return self.values
237 else:
238 if not self._memory_initialized:
239 raise ValueError(
240 "Cannot query the attention before the setup of memory"
241 )
242 if len(inputs) not in (2, 3):
243 raise ValueError(
244 "Expect the inputs to have query, state, and optional "
245 "processed memory, got %d items" % len(inputs)
246 )
247 # Ignore the rest of the inputs and only care about the query and
248 # state
249 query, state = inputs[0], inputs[1]
250 return self._calculate_attention(query, state)
252 def setup_memory(self, memory, memory_sequence_length=None, memory_mask=None):
253 """Pre-process the memory before actually query the memory.
255 This should only be called once at the first invocation of `call()`.
257 Args:
258 memory: The memory to query; usually the output of an RNN encoder.
259 This tensor should be shaped `[batch_size, max_time, ...]`.
260 memory_sequence_length (optional): Sequence lengths for the batch
261 entries in memory. If provided, the memory tensor rows are masked
262 with zeros for values past the respective sequence lengths.
263 memory_mask: (Optional) The boolean tensor with shape `[batch_size,
264 max_time]`. For any value equal to False, the corresponding value
265 in memory should be ignored.
266 """
267 if memory_sequence_length is not None and memory_mask is not None:
268 raise ValueError(
269 "memory_sequence_length and memory_mask cannot be "
270 "used at same time for attention."
271 )
272 with tf.name_scope(self.name or "BaseAttentionMechanismInit"):
273 self.values = _prepare_memory(
274 memory,
275 memory_sequence_length=memory_sequence_length,
276 memory_mask=memory_mask,
277 check_inner_dims_defined=self._check_inner_dims_defined,
278 )
279 # Mark the value as check since the memory and memory mask might not
280 # passed from __call__(), which does not have proper keras metadata.
281 # TODO(omalleyt12): Remove this hack once the mask the has proper
282 # keras history.
284 def _mark_checked(tensor):
285 tensor._keras_history_checked = True # pylint: disable=protected-access
287 tf.nest.map_structure(_mark_checked, self.values)
288 if self.memory_layer is not None:
289 self.keys = self.memory_layer(self.values)
290 else:
291 self.keys = self.values
292 self.batch_size = self.keys.shape[0] or tf.shape(self.keys)[0]
293 self._alignments_size = self.keys.shape[1] or tf.shape(self.keys)[1]
294 if memory_mask is not None or memory_sequence_length is not None:
295 unwrapped_probability_fn = self.default_probability_fn
297 def _mask_probability_fn(score, prev):
298 return unwrapped_probability_fn(
299 _maybe_mask_score(
300 score,
301 memory_mask=memory_mask,
302 memory_sequence_length=memory_sequence_length,
303 score_mask_value=score.dtype.min,
304 ),
305 prev,
306 )
308 self.probability_fn = _mask_probability_fn
309 self._memory_initialized = True
311 def _calculate_attention(self, query, state):
312 raise NotImplementedError(
313 "_calculate_attention need to be implemented by subclasses."
314 )
316 def compute_mask(self, inputs, mask=None):
317 # There real input of the attention is query and state, and the memory
318 # layer mask shouldn't be pass down. Returning None for all output mask
319 # here.
320 return None, None
322 def get_config(self):
323 config = {}
324 # Since the probability_fn is likely to be a wrapped function, the child
325 # class should preserve the original function and how its wrapped.
327 if self.query_layer is not None:
328 config["query_layer"] = {
329 "class_name": self.query_layer.__class__.__name__,
330 "config": self.query_layer.get_config(),
331 }
332 if self.memory_layer is not None:
333 config["memory_layer"] = {
334 "class_name": self.memory_layer.__class__.__name__,
335 "config": self.memory_layer.get_config(),
336 }
337 # memory is a required init parameter and its a tensor. It cannot be
338 # serialized to config, so we put a placeholder for it.
339 config["memory"] = None
340 base_config = super().get_config()
341 return {**base_config, **config}
343 def _process_probability_fn(self, func_name):
344 """Helper method to retrieve the probably function by string input."""
345 valid_probability_fns = {
346 "softmax": tf.nn.softmax,
347 "hardmax": hardmax,
348 }
349 if func_name not in valid_probability_fns.keys():
350 raise ValueError(
351 "Invalid probability function: %s, options are %s"
352 % (func_name, valid_probability_fns.keys())
353 )
354 return valid_probability_fns[func_name]
356 @classmethod
357 def deserialize_inner_layer_from_config(cls, config, custom_objects):
358 """Helper method that reconstruct the query and memory from the config.
360 In the get_config() method, the query and memory layer configs are
361 serialized into dict for persistence, this method perform the reverse
362 action to reconstruct the layer from the config.
364 Args:
365 config: dict, the configs that will be used to reconstruct the
366 object.
367 custom_objects: dict mapping class names (or function names) of
368 custom (non-Keras) objects to class/functions.
369 Returns:
370 config: dict, the config with layer instance created, which is ready
371 to be used as init parameters.
372 """
373 # Reconstruct the query and memory layer for parent class.
374 # Instead of updating the input, create a copy and use that.
375 config = config.copy()
376 query_layer_config = config.pop("query_layer", None)
377 if query_layer_config:
378 query_layer = tf.keras.layers.deserialize(
379 query_layer_config,
380 custom_objects=custom_objects,
381 **SERIALIZATION_ARGS,
382 )
383 config["query_layer"] = query_layer
384 memory_layer_config = config.pop("memory_layer", None)
385 if memory_layer_config:
386 memory_layer = tf.keras.layers.deserialize(
387 memory_layer_config,
388 custom_objects=custom_objects,
389 **SERIALIZATION_ARGS,
390 )
391 config["memory_layer"] = memory_layer
392 return config
394 @property
395 def alignments_size(self):
396 if isinstance(self._alignments_size, int):
397 return self._alignments_size
398 else:
399 return tf.TensorShape([None])
401 @property
402 def state_size(self):
403 return self.alignments_size
405 def initial_alignments(self, batch_size, dtype):
406 """Creates the initial alignment values for the `tfa.seq2seq.AttentionWrapper`
407 class.
409 This is important for attention mechanisms that use the previous
410 alignment to calculate the alignment at the next time step
411 (e.g. monotonic attention).
413 The default behavior is to return a tensor of all zeros.
415 Args:
416 batch_size: `int32` scalar, the batch_size.
417 dtype: The `dtype`.
419 Returns:
420 A `dtype` tensor shaped `[batch_size, alignments_size]`
421 (`alignments_size` is the values' `max_time`).
422 """
423 return tf.zeros([batch_size, self._alignments_size], dtype=dtype)
425 def initial_state(self, batch_size, dtype):
426 """Creates the initial state values for the `tfa.seq2seq.AttentionWrapper` class.
428 This is important for attention mechanisms that use the previous
429 alignment to calculate the alignment at the next time step
430 (e.g. monotonic attention).
432 The default behavior is to return the same output as
433 `initial_alignments`.
435 Args:
436 batch_size: `int32` scalar, the batch_size.
437 dtype: The `dtype`.
439 Returns:
440 A structure of all-zero tensors with shapes as described by
441 `state_size`.
442 """
443 return self.initial_alignments(batch_size, dtype)
446def _luong_score(query, keys, scale):
447 """Implements Luong-style (multiplicative) scoring function.
449 This attention has two forms. The first is standard Luong attention,
450 as described in:
452 Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
453 "Effective Approaches to Attention-based Neural Machine Translation."
454 EMNLP 2015. https://arxiv.org/abs/1508.04025
456 The second is the scaled form inspired partly by the normalized form of
457 Bahdanau attention.
459 To enable the second form, call this function with `scale=True`.
461 Args:
462 query: Tensor, shape `[batch_size, num_units]` to compare to keys.
463 keys: Processed memory, shape `[batch_size, max_time, num_units]`.
464 scale: the optional tensor to scale the attention score.
466 Returns:
467 A `[batch_size, max_time]` tensor of unnormalized score values.
469 Raises:
470 ValueError: If `key` and `query` depths do not match.
471 """
472 depth = query.shape[-1]
473 key_units = keys.shape[-1]
474 if depth != key_units:
475 raise ValueError(
476 "Incompatible or unknown inner dimensions between query and keys. "
477 "Query (%s) has units: %s. Keys (%s) have units: %s. "
478 "Perhaps you need to set num_units to the keys' dimension (%s)?"
479 % (query, depth, keys, key_units, key_units)
480 )
482 # Reshape from [batch_size, depth] to [batch_size, 1, depth]
483 # for matmul.
484 query = tf.expand_dims(query, 1)
486 # Inner product along the query units dimension.
487 # matmul shapes: query is [batch_size, 1, depth] and
488 # keys is [batch_size, max_time, depth].
489 # the inner product is asked to **transpose keys' inner shape** to get a
490 # batched matmul on:
491 # [batch_size, 1, depth] . [batch_size, depth, max_time]
492 # resulting in an output shape of:
493 # [batch_size, 1, max_time].
494 # we then squeeze out the center singleton dimension.
495 score = tf.matmul(query, keys, transpose_b=True)
496 score = tf.squeeze(score, [1])
498 if scale is not None:
499 score = scale * score
500 return score
503class LuongAttention(AttentionMechanism):
504 """Implements Luong-style (multiplicative) attention scoring.
506 This attention has two forms. The first is standard Luong attention,
507 as described in:
509 Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
510 [Effective Approaches to Attention-based Neural Machine Translation.
511 EMNLP 2015.](https://arxiv.org/abs/1508.04025)
513 The second is the scaled form inspired partly by the normalized form of
514 Bahdanau attention.
516 To enable the second form, construct the object with parameter
517 `scale=True`.
518 """
520 @typechecked
521 def __init__(
522 self,
523 units: TensorLike,
524 memory: Optional[TensorLike] = None,
525 memory_sequence_length: Optional[TensorLike] = None,
526 scale: bool = False,
527 probability_fn: str = "softmax",
528 dtype: AcceptableDTypes = None,
529 name: str = "LuongAttention",
530 **kwargs,
531 ):
532 """Construct the AttentionMechanism mechanism.
534 Args:
535 units: The depth of the attention mechanism.
536 memory: The memory to query; usually the output of an RNN encoder.
537 This tensor should be shaped `[batch_size, max_time, ...]`.
538 memory_sequence_length: (optional): Sequence lengths for the batch
539 entries in memory. If provided, the memory tensor rows are masked
540 with zeros for values past the respective sequence lengths.
541 scale: Python boolean. Whether to scale the energy term.
542 probability_fn: (optional) string, the name of function to convert
543 the attention score to probabilities. The default is `softmax`
544 which is `tf.nn.softmax`. Other options is `hardmax`, which is
545 hardmax() within this module. Any other value will result
546 intovalidation error. Default to use `softmax`.
547 dtype: The data type for the memory layer of the attention mechanism.
548 name: Name to use when creating ops.
549 **kwargs: Dictionary that contains other common arguments for layer
550 creation.
551 """
552 # For LuongAttention, we only transform the memory layer; thus
553 # num_units **must** match expected the query depth.
554 self.probability_fn_name = probability_fn
555 probability_fn = self._process_probability_fn(self.probability_fn_name)
557 def wrapped_probability_fn(score, _):
558 return probability_fn(score)
560 memory_layer = kwargs.pop("memory_layer", None)
561 if not memory_layer:
562 memory_layer = tf.keras.layers.Dense(
563 units, name="memory_layer", use_bias=False, dtype=dtype
564 )
565 self.units = units
566 self.scale = scale
567 self.scale_weight = None
568 super().__init__(
569 memory=memory,
570 memory_sequence_length=memory_sequence_length,
571 query_layer=None,
572 memory_layer=memory_layer,
573 probability_fn=wrapped_probability_fn,
574 name=name,
575 dtype=dtype,
576 **kwargs,
577 )
579 def build(self, input_shape):
580 super().build(input_shape)
581 if self.scale and self.scale_weight is None:
582 self.scale_weight = self.add_weight(
583 "attention_g", initializer=tf.ones_initializer, shape=()
584 )
585 self.built = True
587 def _calculate_attention(self, query, state):
588 """Score the query based on the keys and values.
590 Args:
591 query: Tensor of dtype matching `self.values` and shape
592 `[batch_size, query_depth]`.
593 state: Tensor of dtype matching `self.values` and shape
594 `[batch_size, alignments_size]`
595 (`alignments_size` is memory's `max_time`).
597 Returns:
598 alignments: Tensor of dtype matching `self.values` and shape
599 `[batch_size, alignments_size]` (`alignments_size` is memory's
600 `max_time`).
601 next_state: Same as the alignments.
602 """
603 score = _luong_score(query, self.keys, self.scale_weight)
604 alignments = self.probability_fn(score, state)
605 next_state = alignments
606 return alignments, next_state
608 def get_config(self):
609 config = {
610 "units": self.units,
611 "scale": self.scale,
612 "probability_fn": self.probability_fn_name,
613 }
614 base_config = super().get_config()
615 return {**base_config, **config}
617 @classmethod
618 def from_config(cls, config, custom_objects=None):
619 config = AttentionMechanism.deserialize_inner_layer_from_config(
620 config, custom_objects=custom_objects
621 )
622 return cls(**config)
625def _bahdanau_score(
626 processed_query, keys, attention_v, attention_g=None, attention_b=None
627):
628 """Implements Bahdanau-style (additive) scoring function.
630 This attention has two forms. The first is Bahdanau attention,
631 as described in:
633 Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
634 "Neural Machine Translation by Jointly Learning to Align and Translate."
635 ICLR 2015. https://arxiv.org/abs/1409.0473
637 The second is the normalized form. This form is inspired by the
638 weight normalization article:
640 Tim Salimans, Diederik P. Kingma.
641 "Weight Normalization: A Simple Reparameterization to Accelerate
642 Training of Deep Neural Networks."
643 https://arxiv.org/abs/1602.07868
645 To enable the second form, set please pass in attention_g and attention_b.
647 Args:
648 processed_query: Tensor, shape `[batch_size, num_units]` to compare to
649 keys.
650 keys: Processed memory, shape `[batch_size, max_time, num_units]`.
651 attention_v: Tensor, shape `[num_units]`.
652 attention_g: Optional scalar tensor for normalization.
653 attention_b: Optional tensor with shape `[num_units]` for normalization.
655 Returns:
656 A `[batch_size, max_time]` tensor of unnormalized score values.
657 """
658 # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
659 processed_query = tf.expand_dims(processed_query, 1)
660 if attention_g is not None and attention_b is not None:
661 normed_v = (
662 attention_g
663 * attention_v
664 * tf.math.rsqrt(tf.reduce_sum(tf.square(attention_v)))
665 )
666 return tf.reduce_sum(
667 normed_v * tf.tanh(keys + processed_query + attention_b), [2]
668 )
669 else:
670 return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query), [2])
673class BahdanauAttention(AttentionMechanism):
674 """Implements Bahdanau-style (additive) attention.
676 This attention has two forms. The first is Bahdanau attention,
677 as described in:
679 Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
680 "Neural Machine Translation by Jointly Learning to Align and Translate."
681 ICLR 2015. https://arxiv.org/abs/1409.0473
683 The second is the normalized form. This form is inspired by the
684 weight normalization article:
686 Tim Salimans, Diederik P. Kingma.
687 "Weight Normalization: A Simple Reparameterization to Accelerate
688 Training of Deep Neural Networks."
689 https://arxiv.org/abs/1602.07868
691 To enable the second form, construct the object with parameter
692 `normalize=True`.
693 """
695 @typechecked
696 def __init__(
697 self,
698 units: TensorLike,
699 memory: Optional[TensorLike] = None,
700 memory_sequence_length: Optional[TensorLike] = None,
701 normalize: bool = False,
702 probability_fn: str = "softmax",
703 kernel_initializer: Initializer = "glorot_uniform",
704 dtype: AcceptableDTypes = None,
705 name: str = "BahdanauAttention",
706 **kwargs,
707 ):
708 """Construct the Attention mechanism.
710 Args:
711 units: The depth of the query mechanism.
712 memory: The memory to query; usually the output of an RNN encoder.
713 This tensor should be shaped `[batch_size, max_time, ...]`.
714 memory_sequence_length: (optional): Sequence lengths for the batch
715 entries in memory. If provided, the memory tensor rows are masked
716 with zeros for values past the respective sequence lengths.
717 normalize: Python boolean. Whether to normalize the energy term.
718 probability_fn: (optional) string, the name of function to convert
719 the attention score to probabilities. The default is `softmax`
720 which is `tf.nn.softmax`. Other options is `hardmax`, which is
721 hardmax() within this module. Any other value will result into
722 validation error. Default to use `softmax`.
723 kernel_initializer: (optional), the name of the initializer for the
724 attention kernel.
725 dtype: The data type for the query and memory layers of the attention
726 mechanism.
727 name: Name to use when creating ops.
728 **kwargs: Dictionary that contains other common arguments for layer
729 creation.
730 """
731 self.probability_fn_name = probability_fn
732 probability_fn = self._process_probability_fn(self.probability_fn_name)
734 def wrapped_probability_fn(score, _):
735 return probability_fn(score)
737 query_layer = kwargs.pop("query_layer", None)
738 if not query_layer:
739 query_layer = tf.keras.layers.Dense(
740 units, name="query_layer", use_bias=False, dtype=dtype
741 )
742 memory_layer = kwargs.pop("memory_layer", None)
743 if not memory_layer:
744 memory_layer = tf.keras.layers.Dense(
745 units, name="memory_layer", use_bias=False, dtype=dtype
746 )
747 self.units = units
748 self.normalize = normalize
749 self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
750 self.attention_v = None
751 self.attention_g = None
752 self.attention_b = None
753 super().__init__(
754 memory=memory,
755 memory_sequence_length=memory_sequence_length,
756 query_layer=query_layer,
757 memory_layer=memory_layer,
758 probability_fn=wrapped_probability_fn,
759 name=name,
760 dtype=dtype,
761 **kwargs,
762 )
764 def build(self, input_shape):
765 super().build(input_shape)
766 if self.attention_v is None:
767 self.attention_v = self.add_weight(
768 "attention_v",
769 [self.units],
770 dtype=self.dtype,
771 initializer=self.kernel_initializer,
772 )
773 if self.normalize and self.attention_g is None and self.attention_b is None:
774 self.attention_g = self.add_weight(
775 "attention_g",
776 initializer=tf.constant_initializer(math.sqrt(1.0 / self.units)),
777 shape=(),
778 )
779 self.attention_b = self.add_weight(
780 "attention_b", shape=[self.units], initializer=tf.zeros_initializer()
781 )
782 self.built = True
784 def _calculate_attention(self, query, state):
785 """Score the query based on the keys and values.
787 Args:
788 query: Tensor of dtype matching `self.values` and shape
789 `[batch_size, query_depth]`.
790 state: Tensor of dtype matching `self.values` and shape
791 `[batch_size, alignments_size]`
792 (`alignments_size` is memory's `max_time`).
794 Returns:
795 alignments: Tensor of dtype matching `self.values` and shape
796 `[batch_size, alignments_size]` (`alignments_size` is memory's
797 `max_time`).
798 next_state: same as alignments.
799 """
800 processed_query = self.query_layer(query) if self.query_layer else query
801 score = _bahdanau_score(
802 processed_query,
803 self.keys,
804 self.attention_v,
805 attention_g=self.attention_g,
806 attention_b=self.attention_b,
807 )
808 alignments = self.probability_fn(score, state)
809 next_state = alignments
810 return alignments, next_state
812 def get_config(self):
813 # yapf: disable
814 config = {
815 "units": self.units,
816 "normalize": self.normalize,
817 "probability_fn": self.probability_fn_name,
818 "kernel_initializer": tf.keras.initializers.serialize(
819 self.kernel_initializer,
820 **SERIALIZATION_ARGS,
821 )
822 }
823 # yapf: enable
825 base_config = super().get_config()
826 return {**base_config, **config}
828 @classmethod
829 def from_config(cls, config, custom_objects=None):
830 config = AttentionMechanism.deserialize_inner_layer_from_config(
831 config,
832 custom_objects=custom_objects,
833 )
834 return cls(**config)
837def safe_cumprod(x: TensorLike, *args, **kwargs) -> tf.Tensor:
838 """Computes cumprod of x in logspace using cumsum to avoid underflow.
840 The cumprod function and its gradient can result in numerical instabilities
841 when its argument has very small and/or zero values. As long as the
842 argument is all positive, we can instead compute the cumulative product as
843 exp(cumsum(log(x))). This function can be called identically to
844 tf.cumprod.
846 Args:
847 x: Tensor to take the cumulative product of.
848 *args: Passed on to cumsum; these are identical to those in cumprod.
849 **kwargs: Passed on to cumsum; these are identical to those in cumprod.
850 Returns:
851 Cumulative product of x.
852 """
853 with tf.name_scope("SafeCumprod"):
854 x = tf.convert_to_tensor(x, name="x")
855 tiny = np.finfo(x.dtype.as_numpy_dtype).tiny
856 return tf.exp(
857 tf.cumsum(tf.math.log(tf.clip_by_value(x, tiny, 1)), *args, **kwargs)
858 )
861def monotonic_attention(
862 p_choose_i: FloatTensorLike, previous_attention: FloatTensorLike, mode: str
863) -> tf.Tensor:
864 """Computes monotonic attention distribution from choosing probabilities.
866 Monotonic attention implies that the input sequence is processed in an
867 explicitly left-to-right manner when generating the output sequence. In
868 addition, once an input sequence element is attended to at a given output
869 timestep, elements occurring before it cannot be attended to at subsequent
870 output timesteps. This function generates attention distributions
871 according to these assumptions. For more information, see `Online and
872 Linear-Time Attention by Enforcing Monotonic Alignments`.
874 Args:
875 p_choose_i: Probability of choosing input sequence/memory element i.
876 Should be of shape (batch_size, input_sequence_length), and should all
877 be in the range [0, 1].
878 previous_attention: The attention distribution from the previous output
879 timestep. Should be of shape (batch_size, input_sequence_length). For
880 the first output timestep, preevious_attention[n] should be
881 [1, 0, 0, ..., 0] for all n in [0, ... batch_size - 1].
882 mode: How to compute the attention distribution. Must be one of
883 'recursive', 'parallel', or 'hard'.
884 * 'recursive' uses tf.scan to recursively compute the distribution.
885 This is slowest but is exact, general, and does not suffer from
886 numerical instabilities.
887 * 'parallel' uses parallelized cumulative-sum and cumulative-product
888 operations to compute a closed-form solution to the recurrence
889 relation defining the attention distribution. This makes it more
890 efficient than 'recursive', but it requires numerical checks which
891 make the distribution non-exact. This can be a problem in
892 particular when input_sequence_length is long and/or p_choose_i has
893 entries very close to 0 or 1.
894 * 'hard' requires that the probabilities in p_choose_i are all either
895 0 or 1, and subsequently uses a more efficient and exact solution.
897 Returns:
898 A tensor of shape (batch_size, input_sequence_length) representing the
899 attention distributions for each sequence in the batch.
901 Raises:
902 ValueError: mode is not one of 'recursive', 'parallel', 'hard'.
903 """
904 # Force things to be tensors
905 p_choose_i = tf.convert_to_tensor(p_choose_i, name="p_choose_i")
906 previous_attention = tf.convert_to_tensor(
907 previous_attention, name="previous_attention"
908 )
909 if mode == "recursive":
910 # Use .shape[0] when it's not None, or fall back on symbolic shape
911 batch_size = p_choose_i.shape[0] or tf.shape(p_choose_i)[0]
912 # Compute [1, 1 - p_choose_i[0], 1 - p_choose_i[1], ..., 1 - p_choose_
913 # i[-2]]
914 shifted_1mp_choose_i = tf.concat(
915 [tf.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1
916 )
917 # Compute attention distribution recursively as
918 # q[i] = (1 - p_choose_i[i - 1])*q[i - 1] + previous_attention[i]
919 # attention[i] = p_choose_i[i]*q[i]
920 attention = p_choose_i * tf.transpose(
921 tf.scan(
922 # Need to use reshape to remind TF of the shape between loop
923 # iterations
924 lambda x, yz: tf.reshape(yz[0] * x + yz[1], (batch_size,)),
925 # Loop variables yz[0] and yz[1]
926 [tf.transpose(shifted_1mp_choose_i), tf.transpose(previous_attention)],
927 # Initial value of x is just zeros
928 tf.zeros((batch_size,)),
929 )
930 )
931 elif mode == "parallel":
932 # safe_cumprod computes cumprod in logspace with numeric checks
933 cumprod_1mp_choose_i = safe_cumprod(1 - p_choose_i, axis=1, exclusive=True)
934 # Compute recurrence relation solution
935 attention = (
936 p_choose_i
937 * cumprod_1mp_choose_i
938 * tf.cumsum(
939 previous_attention /
940 # Clip cumprod_1mp to avoid divide-by-zero
941 tf.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.0),
942 axis=1,
943 )
944 )
945 elif mode == "hard":
946 # Remove any probabilities before the index chosen last time step
947 p_choose_i *= tf.cumsum(previous_attention, axis=1)
948 # Now, use exclusive cumprod to remove probabilities after the first
949 # chosen index, like so:
950 # p_choose_i = [0, 0, 0, 1, 1, 0, 1, 1]
951 # cumprod(1 - p_choose_i, exclusive=True) = [1, 1, 1, 1, 0, 0, 0, 0]
952 # Product of above: [0, 0, 0, 1, 0, 0, 0, 0]
953 attention = p_choose_i * tf.math.cumprod(1 - p_choose_i, axis=1, exclusive=True)
954 else:
955 raise ValueError("mode must be 'recursive', 'parallel', or 'hard'.")
956 return attention
959def _monotonic_probability_fn(
960 score, previous_alignments, sigmoid_noise, mode, seed=None
961):
962 """Attention probability function for monotonic attention.
964 Takes in unnormalized attention scores, adds pre-sigmoid noise to encourage
965 the model to make discrete attention decisions, passes them through a
966 sigmoid to obtain "choosing" probabilities, and then calls
967 monotonic_attention to obtain the attention distribution. For more
968 information, see
970 Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
971 "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
972 ICML 2017. https://arxiv.org/abs/1704.00784
974 Args:
975 score: Unnormalized attention scores, shape
976 `[batch_size, alignments_size]`
977 previous_alignments: Previous attention distribution, shape
978 `[batch_size, alignments_size]`
979 sigmoid_noise: Standard deviation of pre-sigmoid noise. Setting this
980 larger than 0 will encourage the model to produce large attention
981 scores, effectively making the choosing probabilities discrete and the
982 resulting attention distribution one-hot. It should be set to 0 at
983 test-time, and when hard attention is not desired.
984 mode: How to compute the attention distribution. Must be one of
985 'recursive', 'parallel', or 'hard'. See the docstring for
986 `tfa.seq2seq.monotonic_attention` for more information.
987 seed: (optional) Random seed for pre-sigmoid noise.
989 Returns:
990 A `[batch_size, alignments_size]`-shape tensor corresponding to the
991 resulting attention distribution.
992 """
993 # Optionally add pre-sigmoid noise to the scores
994 if sigmoid_noise > 0:
995 noise = tf.random.normal(tf.shape(score), dtype=score.dtype, seed=seed)
996 score += sigmoid_noise * noise
997 # Compute "choosing" probabilities from the attention scores
998 if mode == "hard":
999 # When mode is hard, use a hard sigmoid
1000 p_choose_i = tf.cast(score > 0, score.dtype)
1001 else:
1002 p_choose_i = tf.sigmoid(score)
1003 # Convert from choosing probabilities to attention distribution
1004 return monotonic_attention(p_choose_i, previous_alignments, mode)
1007class _BaseMonotonicAttentionMechanism(AttentionMechanism):
1008 """Base attention mechanism for monotonic attention.
1010 Simply overrides the initial_alignments function to provide a dirac
1011 distribution, which is needed in order for the monotonic attention
1012 distributions to have the correct behavior.
1013 """
1015 def initial_alignments(self, batch_size, dtype):
1016 """Creates the initial alignment values for the monotonic attentions.
1018 Initializes to dirac distributions, i.e.
1019 [1, 0, 0, ...memory length..., 0] for all entries in the batch.
1021 Args:
1022 batch_size: `int32` scalar, the batch_size.
1023 dtype: The `dtype`.
1025 Returns:
1026 A `dtype` tensor shaped `[batch_size, alignments_size]`
1027 (`alignments_size` is the values' `max_time`).
1028 """
1029 max_time = self._alignments_size
1030 return tf.one_hot(
1031 tf.zeros((batch_size,), dtype=tf.int32), max_time, dtype=dtype
1032 )
1035class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
1036 """Monotonic attention mechanism with Bahdanau-style energy function.
1038 This type of attention enforces a monotonic constraint on the attention
1039 distributions; that is once the model attends to a given point in the
1040 memory it can't attend to any prior points at subsequence output timesteps.
1041 It achieves this by using the `_monotonic_probability_fn` instead of `softmax`
1042 to construct its attention distributions. Since the attention scores are
1043 passed through a sigmoid, a learnable scalar bias parameter is applied
1044 after the score function and before the sigmoid. Otherwise, it is
1045 equivalent to `tfa.seq2seq.BahdanauAttention`. This approach is proposed in
1047 Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
1048 "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
1049 ICML 2017. https://arxiv.org/abs/1704.00784
1050 """
1052 @typechecked
1053 def __init__(
1054 self,
1055 units: TensorLike,
1056 memory: Optional[TensorLike] = None,
1057 memory_sequence_length: Optional[TensorLike] = None,
1058 normalize: bool = False,
1059 sigmoid_noise: FloatTensorLike = 0.0,
1060 sigmoid_noise_seed: Optional[FloatTensorLike] = None,
1061 score_bias_init: FloatTensorLike = 0.0,
1062 mode: str = "parallel",
1063 kernel_initializer: Initializer = "glorot_uniform",
1064 dtype: AcceptableDTypes = None,
1065 name: str = "BahdanauMonotonicAttention",
1066 **kwargs,
1067 ):
1068 """Construct the attention mechanism.
1070 Args:
1071 units: The depth of the query mechanism.
1072 memory: The memory to query; usually the output of an RNN encoder.
1073 This tensor should be shaped `[batch_size, max_time, ...]`.
1074 memory_sequence_length: (optional): Sequence lengths for the batch
1075 entries in memory. If provided, the memory tensor rows are masked
1076 with zeros for values past the respective sequence lengths.
1077 normalize: Python boolean. Whether to normalize the energy term.
1078 sigmoid_noise: Standard deviation of pre-sigmoid noise. See the
1079 docstring for `_monotonic_probability_fn` for more information.
1080 sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise.
1081 score_bias_init: Initial value for score bias scalar. It's
1082 recommended to initialize this to a negative value when the length
1083 of the memory is large.
1084 mode: How to compute the attention distribution. Must be one of
1085 'recursive', 'parallel', or 'hard'. See the docstring for
1086 `tfa.seq2seq.monotonic_attention` for more information.
1087 kernel_initializer: (optional), the name of the initializer for the
1088 attention kernel.
1089 dtype: The data type for the query and memory layers of the attention
1090 mechanism.
1091 name: Name to use when creating ops.
1092 **kwargs: Dictionary that contains other common arguments for layer
1093 creation.
1094 """
1095 # Set up the monotonic probability fn with supplied parameters
1096 wrapped_probability_fn = functools.partial(
1097 _monotonic_probability_fn,
1098 sigmoid_noise=sigmoid_noise,
1099 mode=mode,
1100 seed=sigmoid_noise_seed,
1101 )
1102 query_layer = kwargs.pop("query_layer", None)
1103 if not query_layer:
1104 query_layer = tf.keras.layers.Dense(
1105 units, name="query_layer", use_bias=False, dtype=dtype
1106 )
1107 memory_layer = kwargs.pop("memory_layer", None)
1108 if not memory_layer:
1109 memory_layer = tf.keras.layers.Dense(
1110 units, name="memory_layer", use_bias=False, dtype=dtype
1111 )
1112 self.units = units
1113 self.normalize = normalize
1114 self.sigmoid_noise = sigmoid_noise
1115 self.sigmoid_noise_seed = sigmoid_noise_seed
1116 self.score_bias_init = score_bias_init
1117 self.mode = mode
1118 self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
1119 self.attention_v = None
1120 self.attention_score_bias = None
1121 self.attention_g = None
1122 self.attention_b = None
1123 super().__init__(
1124 memory=memory,
1125 memory_sequence_length=memory_sequence_length,
1126 query_layer=query_layer,
1127 memory_layer=memory_layer,
1128 probability_fn=wrapped_probability_fn,
1129 name=name,
1130 dtype=dtype,
1131 **kwargs,
1132 )
1134 def build(self, input_shape):
1135 super().build(input_shape)
1136 if self.attention_v is None:
1137 self.attention_v = self.add_weight(
1138 "attention_v",
1139 [self.units],
1140 dtype=self.dtype,
1141 initializer=self.kernel_initializer,
1142 )
1143 if self.attention_score_bias is None:
1144 self.attention_score_bias = self.add_weight(
1145 "attention_score_bias",
1146 shape=(),
1147 dtype=self.dtype,
1148 initializer=tf.constant_initializer(self.score_bias_init),
1149 )
1150 if self.normalize and self.attention_g is None and self.attention_b is None:
1151 self.attention_g = self.add_weight(
1152 "attention_g",
1153 dtype=self.dtype,
1154 initializer=tf.constant_initializer(math.sqrt(1.0 / self.units)),
1155 shape=(),
1156 )
1157 self.attention_b = self.add_weight(
1158 "attention_b",
1159 [self.units],
1160 dtype=self.dtype,
1161 initializer=tf.zeros_initializer(),
1162 )
1163 self.built = True
1165 def _calculate_attention(self, query, state):
1166 """Score the query based on the keys and values.
1168 Args:
1169 query: Tensor of dtype matching `self.values` and shape
1170 `[batch_size, query_depth]`.
1171 state: Tensor of dtype matching `self.values` and shape
1172 `[batch_size, alignments_size]`
1173 (`alignments_size` is memory's `max_time`).
1175 Returns:
1176 alignments: Tensor of dtype matching `self.values` and shape
1177 `[batch_size, alignments_size]` (`alignments_size` is memory's
1178 `max_time`).
1179 """
1180 processed_query = self.query_layer(query) if self.query_layer else query
1181 score = _bahdanau_score(
1182 processed_query,
1183 self.keys,
1184 self.attention_v,
1185 attention_g=self.attention_g,
1186 attention_b=self.attention_b,
1187 )
1188 score += self.attention_score_bias
1189 alignments = self.probability_fn(score, state)
1190 next_state = alignments
1191 return alignments, next_state
1193 def get_config(self):
1194 # yapf: disable
1195 config = {
1196 "units": self.units,
1197 "normalize": self.normalize,
1198 "sigmoid_noise": self.sigmoid_noise,
1199 "sigmoid_noise_seed": self.sigmoid_noise_seed,
1200 "score_bias_init": self.score_bias_init,
1201 "mode": self.mode,
1202 "kernel_initializer": tf.keras.initializers.serialize(
1203 self.kernel_initializer,
1204 **SERIALIZATION_ARGS,
1205 ),
1206 }
1207 # yapf: enable
1209 base_config = super().get_config()
1210 return {**base_config, **config}
1212 @classmethod
1213 def from_config(cls, config, custom_objects=None):
1214 config = AttentionMechanism.deserialize_inner_layer_from_config(
1215 config, custom_objects=custom_objects
1216 )
1217 return cls(**config)
1220class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
1221 """Monotonic attention mechanism with Luong-style energy function.
1223 This type of attention enforces a monotonic constraint on the attention
1224 distributions; that is once the model attends to a given point in the
1225 memory it can't attend to any prior points at subsequence output timesteps.
1226 It achieves this by using the `_monotonic_probability_fn` instead of `softmax`
1227 to construct its attention distributions. Otherwise, it is equivalent to
1228 `tfa.seq2seq.LuongAttention`. This approach is proposed in
1230 [Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
1231 "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
1232 ICML 2017.](https://arxiv.org/abs/1704.00784)
1233 """
1235 @typechecked
1236 def __init__(
1237 self,
1238 units: TensorLike,
1239 memory: Optional[TensorLike] = None,
1240 memory_sequence_length: Optional[TensorLike] = None,
1241 scale: bool = False,
1242 sigmoid_noise: FloatTensorLike = 0.0,
1243 sigmoid_noise_seed: Optional[FloatTensorLike] = None,
1244 score_bias_init: FloatTensorLike = 0.0,
1245 mode: str = "parallel",
1246 dtype: AcceptableDTypes = None,
1247 name: str = "LuongMonotonicAttention",
1248 **kwargs,
1249 ):
1250 """Construct the attention mechanism.
1252 Args:
1253 units: The depth of the query mechanism.
1254 memory: The memory to query; usually the output of an RNN encoder.
1255 This tensor should be shaped `[batch_size, max_time, ...]`.
1256 memory_sequence_length: (optional): Sequence lengths for the batch
1257 entries in memory. If provided, the memory tensor rows are masked
1258 with zeros for values past the respective sequence lengths.
1259 scale: Python boolean. Whether to scale the energy term.
1260 sigmoid_noise: Standard deviation of pre-sigmoid noise. See the
1261 docstring for `_monotonic_probability_fn` for more information.
1262 sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise.
1263 score_bias_init: Initial value for score bias scalar. It's
1264 recommended to initialize this to a negative value when the length
1265 of the memory is large.
1266 mode: How to compute the attention distribution. Must be one of
1267 'recursive', 'parallel', or 'hard'. See the docstring for
1268 `tfa.seq2seq.monotonic_attention` for more information.
1269 dtype: The data type for the query and memory layers of the attention
1270 mechanism.
1271 name: Name to use when creating ops.
1272 **kwargs: Dictionary that contains other common arguments for layer
1273 creation.
1274 """
1275 # Set up the monotonic probability fn with supplied parameters
1276 wrapped_probability_fn = functools.partial(
1277 _monotonic_probability_fn,
1278 sigmoid_noise=sigmoid_noise,
1279 mode=mode,
1280 seed=sigmoid_noise_seed,
1281 )
1282 memory_layer = kwargs.pop("memory_layer", None)
1283 if not memory_layer:
1284 memory_layer = tf.keras.layers.Dense(
1285 units, name="memory_layer", use_bias=False, dtype=dtype
1286 )
1287 self.units = units
1288 self.scale = scale
1289 self.sigmoid_noise = sigmoid_noise
1290 self.sigmoid_noise_seed = sigmoid_noise_seed
1291 self.score_bias_init = score_bias_init
1292 self.mode = mode
1293 self.attention_g = None
1294 self.attention_score_bias = None
1295 super().__init__(
1296 memory=memory,
1297 memory_sequence_length=memory_sequence_length,
1298 query_layer=None,
1299 memory_layer=memory_layer,
1300 probability_fn=wrapped_probability_fn,
1301 name=name,
1302 dtype=dtype,
1303 **kwargs,
1304 )
1306 def build(self, input_shape):
1307 super().build(input_shape)
1308 if self.scale and self.attention_g is None:
1309 self.attention_g = self.add_weight(
1310 "attention_g", initializer=tf.ones_initializer, shape=()
1311 )
1312 if self.attention_score_bias is None:
1313 self.attention_score_bias = self.add_weight(
1314 "attention_score_bias",
1315 shape=(),
1316 initializer=tf.constant_initializer(self.score_bias_init),
1317 )
1318 self.built = True
1320 def _calculate_attention(self, query, state):
1321 """Score the query based on the keys and values.
1323 Args:
1324 query: Tensor of dtype matching `self.values` and shape
1325 `[batch_size, query_depth]`.
1326 state: Tensor of dtype matching `self.values` and shape
1327 `[batch_size, alignments_size]`
1328 (`alignments_size` is memory's `max_time`).
1330 Returns:
1331 alignments: Tensor of dtype matching `self.values` and shape
1332 `[batch_size, alignments_size]` (`alignments_size` is memory's
1333 `max_time`).
1334 next_state: Same as alignments
1335 """
1336 score = _luong_score(query, self.keys, self.attention_g)
1337 score += self.attention_score_bias
1338 alignments = self.probability_fn(score, state)
1339 next_state = alignments
1340 return alignments, next_state
1342 def get_config(self):
1343 config = {
1344 "units": self.units,
1345 "scale": self.scale,
1346 "sigmoid_noise": self.sigmoid_noise,
1347 "sigmoid_noise_seed": self.sigmoid_noise_seed,
1348 "score_bias_init": self.score_bias_init,
1349 "mode": self.mode,
1350 }
1351 base_config = super().get_config()
1352 return {**base_config, **config}
1354 @classmethod
1355 def from_config(cls, config, custom_objects=None):
1356 config = AttentionMechanism.deserialize_inner_layer_from_config(
1357 config, custom_objects=custom_objects
1358 )
1359 return cls(**config)
1362class AttentionWrapperState(
1363 collections.namedtuple(
1364 "AttentionWrapperState",
1365 (
1366 "cell_state",
1367 "attention",
1368 "alignments",
1369 "alignment_history",
1370 "attention_state",
1371 ),
1372 )
1373):
1374 """State of a `tfa.seq2seq.AttentionWrapper`.
1376 Attributes:
1377 cell_state: The state of the wrapped RNN cell at the previous time
1378 step.
1379 attention: The attention emitted at the previous time step.
1380 alignments: A single or tuple of `Tensor`(s) containing the
1381 alignments emitted at the previous time step for each attention
1382 mechanism.
1383 alignment_history: (if enabled) a single or tuple of `TensorArray`(s)
1384 containing alignment matrices from all time steps for each attention
1385 mechanism. Call `stack()` on each to convert to a `Tensor`.
1386 attention_state: A single or tuple of nested objects
1387 containing attention mechanism state for each attention mechanism.
1388 The objects may contain Tensors or TensorArrays.
1389 """
1391 def clone(self, **kwargs):
1392 """Clone this object, overriding components provided by kwargs.
1394 The new state fields' shape must match original state fields' shape.
1395 This will be validated, and original fields' shape will be propagated
1396 to new fields.
1398 Example:
1400 >>> batch_size = 1
1401 >>> memory = tf.random.normal(shape=[batch_size, 3, 100])
1402 >>> encoder_state = [tf.zeros((batch_size, 100)), tf.zeros((batch_size, 100))]
1403 >>> attention_mechanism = tfa.seq2seq.LuongAttention(100, memory=memory, memory_sequence_length=[3] * batch_size)
1404 >>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(100), attention_mechanism, attention_layer_size=10)
1405 >>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
1406 >>> decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
1408 Args:
1409 **kwargs: Any properties of the state object to replace in the
1410 returned `AttentionWrapperState`.
1412 Returns:
1413 A new `AttentionWrapperState` whose properties are the same as
1414 this one, except any overridden properties as provided in `kwargs`.
1415 """
1417 def with_same_shape(old, new):
1418 """Check and set new tensor's shape."""
1419 if isinstance(old, tf.Tensor) and isinstance(new, tf.Tensor):
1420 if not tf.executing_eagerly():
1421 new_shape = tf.shape(new)
1422 old_shape = tf.shape(old)
1423 assert_equal = tf.debugging.assert_equal(new_shape, old_shape)
1424 with tf.control_dependencies([assert_equal]):
1425 # Add an identity op so that control deps can kick in.
1426 return tf.identity(new)
1427 else:
1428 if old.shape.as_list() != new.shape.as_list():
1429 raise ValueError(
1430 "The shape of the AttentionWrapperState is "
1431 "expected to be same as the one to clone. "
1432 "self.shape: %s, input.shape: %s" % (old.shape, new.shape)
1433 )
1434 return new
1435 return new
1437 return tf.nest.map_structure(with_same_shape, self, super()._replace(**kwargs))
1440def _prepare_memory(
1441 memory, memory_sequence_length=None, memory_mask=None, check_inner_dims_defined=True
1442):
1443 """Convert to tensor and possibly mask `memory`.
1445 Args:
1446 memory: `Tensor`, shaped `[batch_size, max_time, ...]`.
1447 memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`.
1448 memory_mask: `boolean` tensor with shape [batch_size, max_time]. The
1449 memory should be skipped when the corresponding mask is False.
1450 check_inner_dims_defined: Python boolean. If `True`, the `memory`
1451 argument's shape is checked to ensure all but the two outermost
1452 dimensions are fully defined.
1454 Returns:
1455 A (possibly masked), checked, new `memory`.
1457 Raises:
1458 ValueError: If `check_inner_dims_defined` is `True` and not
1459 `memory.shape[2:].is_fully_defined()`.
1460 """
1461 memory = tf.nest.map_structure(
1462 lambda m: tf.convert_to_tensor(m, name="memory"), memory
1463 )
1464 if memory_sequence_length is not None and memory_mask is not None:
1465 raise ValueError(
1466 "memory_sequence_length and memory_mask can't be provided at same time."
1467 )
1468 if memory_sequence_length is not None:
1469 memory_sequence_length = tf.convert_to_tensor(
1470 memory_sequence_length, name="memory_sequence_length"
1471 )
1472 if check_inner_dims_defined:
1474 def _check_dims(m):
1475 if not m.shape[2:].is_fully_defined():
1476 raise ValueError(
1477 "Expected memory %s to have fully defined inner dims, "
1478 "but saw shape: %s" % (m.name, m.shape)
1479 )
1481 tf.nest.map_structure(_check_dims, memory)
1482 if memory_sequence_length is None and memory_mask is None:
1483 return memory
1484 elif memory_sequence_length is not None:
1485 seq_len_mask = tf.sequence_mask(
1486 memory_sequence_length,
1487 maxlen=tf.shape(tf.nest.flatten(memory)[0])[1],
1488 dtype=tf.nest.flatten(memory)[0].dtype,
1489 )
1490 else:
1491 # For memory_mask is not None
1492 seq_len_mask = tf.cast(memory_mask, dtype=tf.nest.flatten(memory)[0].dtype)
1494 def _maybe_mask(m, seq_len_mask):
1495 """Mask the memory based on the memory mask."""
1496 rank = m.shape.ndims
1497 rank = rank if rank is not None else tf.rank(m)
1498 extra_ones = tf.ones(rank - 2, dtype=tf.int32)
1499 seq_len_mask = tf.reshape(
1500 seq_len_mask, tf.concat((tf.shape(seq_len_mask), extra_ones), 0)
1501 )
1502 return m * seq_len_mask
1504 return tf.nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory)
1507def _maybe_mask_score(
1508 score, memory_sequence_length=None, memory_mask=None, score_mask_value=None
1509):
1510 """Mask the attention score based on the masks."""
1511 if memory_sequence_length is None and memory_mask is None:
1512 return score
1513 if memory_sequence_length is not None and memory_mask is not None:
1514 raise ValueError(
1515 "memory_sequence_length and memory_mask can't be provided at same time."
1516 )
1517 if memory_sequence_length is not None:
1518 message = "All values in memory_sequence_length must greater than zero."
1519 with tf.control_dependencies(
1520 [
1521 tf.debugging.assert_positive( # pylint: disable=bad-continuation
1522 memory_sequence_length, message=message
1523 )
1524 ]
1525 ):
1526 memory_mask = tf.sequence_mask(
1527 memory_sequence_length, maxlen=tf.shape(score)[1]
1528 )
1529 score_mask_values = score_mask_value * tf.ones_like(score)
1530 return tf.where(memory_mask, score, score_mask_values)
1533def hardmax(logits: TensorLike, name: Optional[str] = None) -> tf.Tensor:
1534 """Returns batched one-hot vectors.
1536 The depth index containing the `1` is that of the maximum logit value.
1538 Args:
1539 logits: A batch tensor of logit values.
1540 name: Name to use when creating ops.
1541 Returns:
1542 A batched one-hot tensor.
1543 """
1544 with tf.name_scope(name or "Hardmax"):
1545 logits = tf.convert_to_tensor(logits, name="logits")
1546 depth = logits.shape[-1] or tf.shape(logits)[-1]
1547 return tf.one_hot(tf.argmax(logits, -1), depth, dtype=logits.dtype)
1550def _compute_attention(
1551 attention_mechanism, cell_output, attention_state, attention_layer
1552):
1553 """Computes the attention and alignments for a given
1554 attention_mechanism."""
1555 alignments, next_attention_state = attention_mechanism(
1556 [cell_output, attention_state]
1557 )
1559 # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
1560 expanded_alignments = tf.expand_dims(alignments, 1)
1561 # Context is the inner product of alignments and values along the
1562 # memory time dimension.
1563 # alignments shape is
1564 # [batch_size, 1, memory_time]
1565 # attention_mechanism.values shape is
1566 # [batch_size, memory_time, memory_size]
1567 # the batched matmul is over memory_time, so the output shape is
1568 # [batch_size, 1, memory_size].
1569 # we then squeeze out the singleton dim.
1570 context_ = tf.matmul(expanded_alignments, attention_mechanism.values)
1571 context_ = tf.squeeze(context_, [1])
1573 if attention_layer is not None:
1574 attention = attention_layer(tf.concat([cell_output, context_], 1))
1575 else:
1576 attention = context_
1578 return attention, alignments, next_attention_state
1581class AttentionWrapper(AbstractRNNCell):
1582 """Wraps another RNN cell with attention.
1584 Example:
1586 >>> batch_size = 4
1587 >>> max_time = 7
1588 >>> hidden_size = 32
1589 >>>
1590 >>> memory = tf.random.uniform([batch_size, max_time, hidden_size])
1591 >>> memory_sequence_length = tf.fill([batch_size], max_time)
1592 >>>
1593 >>> attention_mechanism = tfa.seq2seq.LuongAttention(hidden_size)
1594 >>> attention_mechanism.setup_memory(memory, memory_sequence_length)
1595 >>>
1596 >>> cell = tf.keras.layers.LSTMCell(hidden_size)
1597 >>> cell = tfa.seq2seq.AttentionWrapper(
1598 ... cell, attention_mechanism, attention_layer_size=hidden_size)
1599 >>>
1600 >>> inputs = tf.random.uniform([batch_size, hidden_size])
1601 >>> state = cell.get_initial_state(inputs)
1602 >>>
1603 >>> outputs, state = cell(inputs, state)
1604 >>> outputs.shape
1605 TensorShape([4, 32])
1606 """
1608 @typechecked
1609 def __init__(
1610 self,
1611 cell: tf.keras.layers.Layer,
1612 attention_mechanism: Union[AttentionMechanism, List[AttentionMechanism]],
1613 attention_layer_size: Optional[Union[Number, List[Number]]] = None,
1614 alignment_history: bool = False,
1615 cell_input_fn: Optional[Callable] = None,
1616 output_attention: bool = True,
1617 initial_cell_state: Optional[TensorLike] = None,
1618 name: Optional[str] = None,
1619 attention_layer: Optional[
1620 Union[tf.keras.layers.Layer, List[tf.keras.layers.Layer]]
1621 ] = None,
1622 attention_fn: Optional[Callable] = None,
1623 **kwargs,
1624 ):
1625 """Construct the `AttentionWrapper`.
1627 **NOTE** If you are using the `tfa.seq2seq.BeamSearchDecoder` with a cell wrapped
1628 in `AttentionWrapper`, then you must ensure that:
1630 - The encoder output has been tiled to `beam_width` via
1631 `tfa.seq2seq.tile_batch` (NOT `tf.tile`).
1632 - The `batch_size` argument passed to the `get_initial_state` method of
1633 this wrapper is equal to `true_batch_size * beam_width`.
1634 - The initial state created with `get_initial_state` above contains a
1635 `cell_state` value containing properly tiled final state from the
1636 encoder.
1638 An example:
1640 >>> batch_size = 1
1641 >>> beam_width = 5
1642 >>> sequence_length = tf.convert_to_tensor([5])
1643 >>> encoder_outputs = tf.random.uniform(shape=(batch_size, 5, 10))
1644 >>> encoder_final_state = [tf.zeros((batch_size, 10)), tf.zeros((batch_size, 10))]
1645 >>> tiled_encoder_outputs = tfa.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width)
1646 >>> tiled_encoder_final_state = tfa.seq2seq.tile_batch(encoder_final_state, multiplier=beam_width)
1647 >>> tiled_sequence_length = tfa.seq2seq.tile_batch(sequence_length, multiplier=beam_width)
1648 >>> attention_mechanism = tfa.seq2seq.BahdanauAttention(10, memory=tiled_encoder_outputs, memory_sequence_length=tiled_sequence_length)
1649 >>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(10), attention_mechanism)
1650 >>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size * beam_width, dtype=tf.float32)
1651 >>> decoder_initial_state = decoder_initial_state.clone(cell_state=tiled_encoder_final_state)
1653 Args:
1654 cell: A layer that implements the `tf.keras.layers.AbstractRNNCell`
1655 interface.
1656 attention_mechanism: A list of `tfa.seq2seq.AttentionMechanism`
1657 instances single instance.
1658 attention_layer_size: A list of Python integers or a single Python
1659 integer, the depth of the attention (output) layer(s). If `None`
1660 (default), use the context as attention at each time step.
1661 Otherwise, feed the context and cell output into the attention
1662 layer to generate attention at each time step. If
1663 `attention_mechanism` is a list, `attention_layer_size` must be a list
1664 of the same length. If `attention_layer` is set, this must be `None`.
1665 If `attention_fn` is set, it must guaranteed that the outputs of
1666 `attention_fn` also meet the above requirements.
1667 alignment_history: Python boolean, whether to store alignment history
1668 from all time steps in the final output state (currently stored as
1669 a time major `TensorArray` on which you must call `stack()`).
1670 cell_input_fn: (optional) A `callable`. The default is:
1671 `lambda inputs, attention:
1672 tf.concat([inputs, attention], -1)`.
1673 output_attention: Python bool. If `True` (default), the output at
1674 each time step is the attention value. This is the behavior of
1675 Luong-style attention mechanisms. If `False`, the output at each
1676 time step is the output of `cell`. This is the behavior of
1677 Bahdanau-style attention mechanisms. In both cases, the
1678 `attention` tensor is propagated to the next time step via the
1679 state and is used there. This flag only controls whether the
1680 attention mechanism is propagated up to the next cell in an RNN
1681 stack or to the top RNN output.
1682 initial_cell_state: The initial state value to use for the cell when
1683 the user calls `get_initial_state()`. Note that if this value is
1684 provided now, and the user uses a `batch_size` argument of
1685 `get_initial_state` which does not match the batch size of
1686 `initial_cell_state`, proper behavior is not guaranteed.
1687 name: Name to use when creating ops.
1688 attention_layer: A list of `tf.keras.layers.Layer` instances or a
1689 single `tf.keras.layers.Layer` instance taking the context
1690 and cell output as inputs to generate attention at each time step.
1691 If `None` (default), use the context as attention at each time step.
1692 If `attention_mechanism` is a list, `attention_layer` must be a list of
1693 the same length. If `attention_layer_size` is set, this must be
1694 `None`.
1695 attention_fn: An optional callable function that allows users to
1696 provide their own customized attention function, which takes input
1697 `(attention_mechanism, cell_output, attention_state,
1698 attention_layer)` and outputs `(attention, alignments,
1699 next_attention_state)`. If provided, the `attention_layer_size` should
1700 be the size of the outputs of `attention_fn`.
1701 **kwargs: Other keyword arguments for layer creation.
1703 Raises:
1704 TypeError: `attention_layer_size` is not `None` and
1705 (`attention_mechanism` is a list but `attention_layer_size` is not;
1706 or vice versa).
1707 ValueError: if `attention_layer_size` is not `None`,
1708 `attention_mechanism` is a list, and its length does not match that
1709 of `attention_layer_size`; if `attention_layer_size` and
1710 `attention_layer` are set simultaneously.
1711 """
1712 super().__init__(name=name, **kwargs)
1713 keras_utils.assert_like_rnncell("cell", cell)
1714 if isinstance(attention_mechanism, (list, tuple)):
1715 self._is_multi = True
1716 attention_mechanisms = list(attention_mechanism)
1717 else:
1718 self._is_multi = False
1719 attention_mechanisms = [attention_mechanism]
1721 if cell_input_fn is None:
1723 def cell_input_fn(inputs, attention):
1724 return tf.concat([inputs, attention], -1)
1726 if attention_layer_size is not None and attention_layer is not None:
1727 raise ValueError(
1728 "Only one of attention_layer_size and attention_layer should be set"
1729 )
1731 if attention_layer_size is not None:
1732 attention_layer_sizes = tuple(
1733 attention_layer_size
1734 if isinstance(attention_layer_size, (list, tuple))
1735 else (attention_layer_size,)
1736 )
1737 if len(attention_layer_sizes) != len(attention_mechanisms):
1738 raise ValueError(
1739 "If provided, attention_layer_size must contain exactly "
1740 "one integer per attention_mechanism, saw: %d vs %d"
1741 % (len(attention_layer_sizes), len(attention_mechanisms))
1742 )
1743 dtype = kwargs.get("dtype", None)
1744 self._attention_layers = list(
1745 tf.keras.layers.Dense(
1746 attention_layer_size,
1747 name="attention_layer",
1748 use_bias=False,
1749 dtype=dtype,
1750 )
1751 for i, attention_layer_size in enumerate(attention_layer_sizes)
1752 )
1753 elif attention_layer is not None:
1754 self._attention_layers = list(
1755 attention_layer
1756 if isinstance(attention_layer, (list, tuple))
1757 else (attention_layer,)
1758 )
1759 if len(self._attention_layers) != len(attention_mechanisms):
1760 raise ValueError(
1761 "If provided, attention_layer must contain exactly one "
1762 "layer per attention_mechanism, saw: %d vs %d"
1763 % (len(self._attention_layers), len(attention_mechanisms))
1764 )
1765 else:
1766 self._attention_layers = None
1768 if attention_fn is None:
1769 attention_fn = _compute_attention
1770 self._attention_fn = attention_fn
1771 self._attention_layer_size = None
1773 self._cell = cell
1774 self._attention_mechanisms = attention_mechanisms
1775 self._cell_input_fn = cell_input_fn
1776 self._output_attention = output_attention
1777 self._alignment_history = alignment_history
1778 with tf.name_scope(name or "AttentionWrapperInit"):
1779 if initial_cell_state is None:
1780 self._initial_cell_state = None
1781 else:
1782 final_state_tensor = tf.nest.flatten(initial_cell_state)[-1]
1783 state_batch_size = (
1784 final_state_tensor.shape[0] or tf.shape(final_state_tensor)[0]
1785 )
1786 error_message = (
1787 "When constructing AttentionWrapper %s: " % self.name
1788 + "Non-matching batch sizes between the memory "
1789 "(encoder output) and initial_cell_state. Are you using "
1790 "the BeamSearchDecoder? You may need to tile your "
1791 "initial state via the tfa.seq2seq.tile_batch "
1792 "function with argument multiple=beam_width."
1793 )
1794 with tf.control_dependencies(
1795 self._batch_size_checks( # pylint: disable=bad-continuation
1796 state_batch_size, error_message
1797 )
1798 ):
1799 self._initial_cell_state = tf.nest.map_structure(
1800 lambda s: tf.identity(s, name="check_initial_cell_state"),
1801 initial_cell_state,
1802 )
1804 def _attention_mechanisms_checks(self):
1805 for attention_mechanism in self._attention_mechanisms:
1806 if not attention_mechanism.memory_initialized:
1807 raise ValueError(
1808 "The AttentionMechanism instances passed to "
1809 "this AttentionWrapper should be initialized "
1810 "with a memory first, either by passing it "
1811 "to the AttentionMechanism constructor or "
1812 "calling attention_mechanism.setup_memory()"
1813 )
1815 def _batch_size_checks(self, batch_size, error_message):
1816 self._attention_mechanisms_checks()
1817 return [
1818 tf.debugging.assert_equal(
1819 batch_size, attention_mechanism.batch_size, message=error_message
1820 )
1821 for attention_mechanism in self._attention_mechanisms
1822 ]
1824 def _get_attention_layer_size(self):
1825 if self._attention_layer_size is not None:
1826 return self._attention_layer_size
1827 self._attention_mechanisms_checks()
1828 attention_output_sizes = (
1829 attention_mechanism.values.shape[-1]
1830 for attention_mechanism in self._attention_mechanisms
1831 )
1832 if self._attention_layers is None:
1833 self._attention_layer_size = sum(attention_output_sizes)
1834 else:
1835 # Compute the layer output size from its input which is the
1836 # concatenation of the cell output and the attention mechanism
1837 # output.
1838 self._attention_layer_size = sum(
1839 layer.compute_output_shape(
1840 [None, self._cell.output_size + attention_output_size]
1841 )[-1]
1842 for layer, attention_output_size in zip(
1843 self._attention_layers, attention_output_sizes
1844 )
1845 )
1846 return self._attention_layer_size
1848 def _item_or_tuple(self, seq):
1849 """Returns `seq` as tuple or the singular element.
1851 Which is returned is determined by how the AttentionMechanism(s) were
1852 passed to the constructor.
1854 Args:
1855 seq: A non-empty sequence of items or generator.
1857 Returns:
1858 Either the values in the sequence as a tuple if
1859 AttentionMechanism(s) were passed to the constructor as a sequence
1860 or the singular element.
1861 """
1862 t = tuple(seq)
1863 if self._is_multi:
1864 return t
1865 else:
1866 return t[0]
1868 @property
1869 def output_size(self):
1870 if self._output_attention:
1871 return self._get_attention_layer_size()
1872 else:
1873 return self._cell.output_size
1875 @property
1876 def state_size(self):
1877 """The `state_size` property of `tfa.seq2seq.AttentionWrapper`.
1879 Returns:
1880 A `tfa.seq2seq.AttentionWrapperState` tuple containing shapes used
1881 by this object.
1882 """
1883 return AttentionWrapperState(
1884 cell_state=self._cell.state_size,
1885 attention=self._get_attention_layer_size(),
1886 alignments=self._item_or_tuple(
1887 a.alignments_size for a in self._attention_mechanisms
1888 ),
1889 attention_state=self._item_or_tuple(
1890 a.state_size for a in self._attention_mechanisms
1891 ),
1892 alignment_history=self._item_or_tuple(
1893 a.alignments_size if self._alignment_history else ()
1894 for a in self._attention_mechanisms
1895 ),
1896 ) # sometimes a TensorArray
1898 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
1899 """Return an initial (zero) state tuple for this `tfa.seq2seq.AttentionWrapper`.
1901 **NOTE** Please see the initializer documentation for details of how
1902 to call `get_initial_state` if using a `tfa.seq2seq.AttentionWrapper`
1903 with a `tfa.seq2seq.BeamSearchDecoder`.
1905 Args:
1906 inputs: The inputs that will be fed to this cell.
1907 batch_size: `0D` integer tensor: the batch size.
1908 dtype: The internal state data type.
1910 Returns:
1911 An `tfa.seq2seq.AttentionWrapperState` tuple containing zeroed out tensors and,
1912 possibly, empty `TensorArray` objects.
1914 Raises:
1915 ValueError: (or, possibly at runtime, `InvalidArgument`), if
1916 `batch_size` does not match the output size of the encoder passed
1917 to the wrapper object at initialization time.
1918 """
1919 if inputs is not None:
1920 batch_size = tf.shape(inputs)[0]
1921 dtype = inputs.dtype
1922 with tf.name_scope(
1923 type(self).__name__ + "ZeroState"
1924 ): # pylint: disable=bad-continuation
1925 if self._initial_cell_state is not None:
1926 cell_state = self._initial_cell_state
1927 else:
1928 cell_state = self._cell.get_initial_state(
1929 batch_size=batch_size, dtype=dtype
1930 )
1931 error_message = (
1932 "When calling get_initial_state of AttentionWrapper %s: " % self.name
1933 + "Non-matching batch sizes between the memory "
1934 "(encoder output) and the requested batch size. Are you using "
1935 "the BeamSearchDecoder? If so, make sure your encoder output "
1936 "has been tiled to beam_width via "
1937 "tfa.seq2seq.tile_batch, and the batch_size= argument "
1938 "passed to get_initial_state is batch_size * beam_width."
1939 )
1940 with tf.control_dependencies(
1941 self._batch_size_checks(batch_size, error_message)
1942 ): # pylint: disable=bad-continuation
1943 cell_state = tf.nest.map_structure(
1944 lambda s: tf.identity(s, name="checked_cell_state"), cell_state
1945 )
1946 initial_alignments = [
1947 attention_mechanism.initial_alignments(batch_size, dtype)
1948 for attention_mechanism in self._attention_mechanisms
1949 ]
1950 return AttentionWrapperState(
1951 cell_state=cell_state,
1952 attention=tf.zeros(
1953 [batch_size, self._get_attention_layer_size()], dtype=dtype
1954 ),
1955 alignments=self._item_or_tuple(initial_alignments),
1956 attention_state=self._item_or_tuple(
1957 attention_mechanism.initial_state(batch_size, dtype)
1958 for attention_mechanism in self._attention_mechanisms
1959 ),
1960 alignment_history=self._item_or_tuple(
1961 tf.TensorArray(
1962 dtype, size=0, dynamic_size=True, element_shape=alignment.shape
1963 )
1964 if self._alignment_history
1965 else ()
1966 for alignment in initial_alignments
1967 ),
1968 )
1970 def call(self, inputs, state, **kwargs):
1971 """Perform a step of attention-wrapped RNN.
1973 - Step 1: Mix the `inputs` and previous step's `attention` output via
1974 `cell_input_fn`.
1975 - Step 2: Call the wrapped `cell` with this input and its previous
1976 state.
1977 - Step 3: Score the cell's output with `attention_mechanism`.
1978 - Step 4: Calculate the alignments by passing the score through the
1979 `normalizer`.
1980 - Step 5: Calculate the context vector as the inner product between the
1981 alignments and the attention_mechanism's values (memory).
1982 - Step 6: Calculate the attention output by concatenating the cell
1983 output and context through the attention layer (a linear layer with
1984 `attention_layer_size` outputs).
1986 Args:
1987 inputs: (Possibly nested tuple of) Tensor, the input at this time
1988 step.
1989 state: An instance of `tfa.seq2seq.AttentionWrapperState` containing
1990 tensors from the previous time step.
1991 **kwargs: Dict, other keyword arguments for the cell call method.
1993 Returns:
1994 A tuple `(attention_or_cell_output, next_state)`, where:
1996 - `attention_or_cell_output` depending on `output_attention`.
1997 - `next_state` is an instance of `tfa.seq2seq.AttentionWrapperState`
1998 containing the state calculated at this time step.
2000 Raises:
2001 TypeError: If `state` is not an instance of `tfa.seq2seq.AttentionWrapperState`.
2002 """
2003 if not isinstance(state, AttentionWrapperState):
2004 try:
2005 state = AttentionWrapperState(*state)
2006 except TypeError:
2007 raise TypeError(
2008 "Expected state to be instance of AttentionWrapperState or "
2009 "values that can construct AttentionWrapperState. "
2010 "Received type %s instead." % type(state)
2011 )
2013 # Step 1: Calculate the true inputs to the cell based on the
2014 # previous attention value.
2015 cell_inputs = self._cell_input_fn(inputs, state.attention)
2016 cell_state = state.cell_state
2017 cell_output, next_cell_state = self._cell(cell_inputs, cell_state, **kwargs)
2018 next_cell_state = tf.nest.pack_sequence_as(
2019 cell_state, tf.nest.flatten(next_cell_state)
2020 )
2022 cell_batch_size = cell_output.shape[0] or tf.shape(cell_output)[0]
2023 error_message = (
2024 "When applying AttentionWrapper %s: " % self.name
2025 + "Non-matching batch sizes between the memory "
2026 "(encoder output) and the query (decoder output). Are you using "
2027 "the BeamSearchDecoder? You may need to tile your memory input "
2028 "via the tfa.seq2seq.tile_batch function with argument "
2029 "multiple=beam_width."
2030 )
2031 with tf.control_dependencies(
2032 self._batch_size_checks(cell_batch_size, error_message)
2033 ): # pylint: disable=bad-continuation
2034 cell_output = tf.identity(cell_output, name="checked_cell_output")
2036 if self._is_multi:
2037 previous_attention_state = state.attention_state
2038 previous_alignment_history = state.alignment_history
2039 else:
2040 previous_attention_state = [state.attention_state]
2041 previous_alignment_history = [state.alignment_history]
2043 all_alignments = []
2044 all_attentions = []
2045 all_attention_states = []
2046 maybe_all_histories = []
2047 for i, attention_mechanism in enumerate(self._attention_mechanisms):
2048 attention, alignments, next_attention_state = self._attention_fn(
2049 attention_mechanism,
2050 cell_output,
2051 previous_attention_state[i],
2052 self._attention_layers[i] if self._attention_layers else None,
2053 )
2054 alignment_history = (
2055 previous_alignment_history[i].write(
2056 previous_alignment_history[i].size(), alignments
2057 )
2058 if self._alignment_history
2059 else ()
2060 )
2062 all_attention_states.append(next_attention_state)
2063 all_alignments.append(alignments)
2064 all_attentions.append(attention)
2065 maybe_all_histories.append(alignment_history)
2067 attention = tf.concat(all_attentions, 1)
2068 next_state = AttentionWrapperState(
2069 cell_state=next_cell_state,
2070 attention=attention,
2071 attention_state=self._item_or_tuple(all_attention_states),
2072 alignments=self._item_or_tuple(all_alignments),
2073 alignment_history=self._item_or_tuple(maybe_all_histories),
2074 )
2076 if self._output_attention:
2077 return attention, next_state
2078 else:
2079 return cell_output, next_state