Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/seq2seq/attention_wrapper.py: 18%

524 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A dynamic attention wrapper for RNN cells.""" 

16 

17import collections 

18import functools 

19import math 

20from packaging.version import Version 

21 

22import numpy as np 

23 

24import tensorflow as tf 

25 

26from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell 

27from tensorflow_addons.utils import keras_utils 

28from tensorflow_addons.utils.types import ( 

29 AcceptableDTypes, 

30 FloatTensorLike, 

31 TensorLike, 

32 Initializer, 

33 Number, 

34) 

35 

36from typeguard import typechecked 

37from typing import Optional, Callable, Union, List 

38 

39 

40if Version(tf.__version__) < Version("2.13"): 

41 SERIALIZATION_ARGS = {} 

42else: 

43 SERIALIZATION_ARGS = {"use_legacy_format": True} 

44 

45 

46class AttentionMechanism(tf.keras.layers.Layer): 

47 """Base class for attention mechanisms. 

48 

49 Common functionality includes: 

50 1. Storing the query and memory layers. 

51 2. Preprocessing and storing the memory. 

52 

53 Note that this layer takes memory as its init parameter, which is an 

54 anti-pattern of Keras API, we have to keep the memory as init parameter for 

55 performance and dependency reason. Under the hood, during `__init__()`, it 

56 will invoke `base_layer.__call__(memory, setup_memory=True)`. This will let 

57 keras to keep track of the memory tensor as the input of this layer. Once 

58 the `__init__()` is done, then user can query the attention by 

59 `score = att_obj([query, state])`, and use it as a normal keras layer. 

60 

61 Special attention is needed when adding using this class as the base layer 

62 for new attention: 

63 1. Build() could be invoked at least twice. So please make sure weights 

64 are not duplicated. 

65 2. Layer.get_weights() might return different set of weights if the 

66 instance has `query_layer`. The query_layer weights is not initialized 

67 until the memory is configured. 

68 

69 Also note that this layer does not work with Keras model when 

70 `model.compile(run_eagerly=True)` due to the fact that this layer is 

71 stateful. The support for that will be added in a future version. 

72 """ 

73 

74 @typechecked 

75 def __init__( 

76 self, 

77 memory: Union[TensorLike, None], 

78 probability_fn: callable, 

79 query_layer: Optional[tf.keras.layers.Layer] = None, 

80 memory_layer: Optional[tf.keras.layers.Layer] = None, 

81 memory_sequence_length: Optional[TensorLike] = None, 

82 **kwargs, 

83 ): 

84 """Construct base AttentionMechanism class. 

85 

86 Args: 

87 memory: The memory to query; usually the output of an RNN encoder. 

88 This tensor should be shaped `[batch_size, max_time, ...]`. 

89 probability_fn: A `callable`. Converts the score and previous 

90 alignments to probabilities. Its signature should be: 

91 `probabilities = probability_fn(score, state)`. 

92 query_layer: Optional `tf.keras.layers.Layer` instance. The layer's 

93 depth must match the depth of `memory_layer`. If `query_layer` is 

94 not provided, the shape of `query` must match that of 

95 `memory_layer`. 

96 memory_layer: Optional `tf.keras.layers.Layer` instance. The layer's 

97 depth must match the depth of `query_layer`. 

98 If `memory_layer` is not provided, the shape of `memory` must match 

99 that of `query_layer`. 

100 memory_sequence_length: (optional) Sequence lengths for the batch 

101 entries in memory. If provided, the memory tensor rows are masked 

102 with zeros for values past the respective sequence lengths. 

103 **kwargs: Dictionary that contains other common arguments for layer 

104 creation. 

105 """ 

106 self.query_layer = query_layer 

107 self.memory_layer = memory_layer 

108 super().__init__(**kwargs) 

109 self.default_probability_fn = probability_fn 

110 self.probability_fn = probability_fn 

111 

112 self.keys = None 

113 self.values = None 

114 self.batch_size = None 

115 self._memory_initialized = False 

116 self._check_inner_dims_defined = True 

117 self.supports_masking = True 

118 

119 if memory is not None: 

120 # Setup the memory by self.__call__() with memory and 

121 # memory_seq_length. This will make the attention follow the keras 

122 # convention which takes all the tensor inputs via __call__(). 

123 if memory_sequence_length is None: 

124 inputs = memory 

125 else: 

126 inputs = [memory, memory_sequence_length] 

127 

128 self.values = super().__call__(inputs, setup_memory=True) 

129 

130 @property 

131 def memory_initialized(self): 

132 """Returns `True` if this attention mechanism has been initialized with 

133 a memory.""" 

134 return self._memory_initialized 

135 

136 def build(self, input_shape): 

137 if not self._memory_initialized: 

138 # This is for setting up the memory, which contains memory and 

139 # optional memory_sequence_length. Build the memory_layer with 

140 # memory shape. 

141 if self.memory_layer is not None and not self.memory_layer.built: 

142 if isinstance(input_shape, list): 

143 self.memory_layer.build(input_shape[0]) 

144 else: 

145 self.memory_layer.build(input_shape) 

146 else: 

147 # The input_shape should be query.shape and state.shape. Use the 

148 # query to init the query layer. 

149 if self.query_layer is not None and not self.query_layer.built: 

150 self.query_layer.build(input_shape[0]) 

151 

152 def __call__(self, inputs, **kwargs): 

153 """Preprocess the inputs before calling `base_layer.__call__()`. 

154 

155 Note that there are situation here, one for setup memory, and one with 

156 actual query and state. 

157 1. When the memory has not been configured, we just pass all the param 

158 to `base_layer.__call__()`, which will then invoke `self.call()` with 

159 proper inputs, which allows this class to setup memory. 

160 2. When the memory has already been setup, the input should contain 

161 query and state, and optionally processed memory. If the processed 

162 memory is not included in the input, we will have to append it to 

163 the inputs and give it to the `base_layer.__call__()`. The processed 

164 memory is the output of first invocation of `self.__call__()`. If we 

165 don't add it here, then from keras perspective, the graph is 

166 disconnected since the output from previous call is never used. 

167 

168 Args: 

169 inputs: the inputs tensors. 

170 **kwargs: dict, other keyeword arguments for the `__call__()` 

171 """ 

172 # Allow manual memory reset 

173 if kwargs.get("setup_memory", False): 

174 self._memory_initialized = False 

175 

176 if self._memory_initialized: 

177 if len(inputs) not in (2, 3): 

178 raise ValueError( 

179 "Expect the inputs to have 2 or 3 tensors, got %d" % len(inputs) 

180 ) 

181 if len(inputs) == 2: 

182 # We append the calculated memory here so that the graph will be 

183 # connected. 

184 inputs.append(self.values) 

185 

186 return super().__call__(inputs, **kwargs) 

187 

188 def call(self, inputs, mask=None, setup_memory=False, **kwargs): 

189 """Setup the memory or query the attention. 

190 

191 There are two case here, one for setup memory, and the second is query 

192 the attention score. `setup_memory` is the flag to indicate which mode 

193 it is. The input list will be treated differently based on that flag. 

194 

195 Args: 

196 inputs: a list of tensor that could either be `query` and `state`, or 

197 `memory` and `memory_sequence_length`. 

198 `query` is the tensor of dtype matching `memory` and shape 

199 `[batch_size, query_depth]`. 

200 `state` is the tensor of dtype matching `memory` and shape 

201 `[batch_size, alignments_size]`. (`alignments_size` is memory's 

202 `max_time`). 

203 `memory` is the memory to query; usually the output of an RNN 

204 encoder. The tensor should be shaped `[batch_size, max_time, ...]`. 

205 `memory_sequence_length` (optional) is the sequence lengths for the 

206 batch entries in memory. If provided, the memory tensor rows are 

207 masked with zeros for values past the respective sequence lengths. 

208 mask: optional bool tensor with shape `[batch, max_time]` for the 

209 mask of memory. If it is not None, the corresponding item of the 

210 memory should be filtered out during calculation. 

211 setup_memory: boolean, whether the input is for setting up memory, or 

212 query attention. 

213 **kwargs: Dict, other keyword arguments for the call method. 

214 Returns: 

215 Either processed memory or attention score, based on `setup_memory`. 

216 """ 

217 if setup_memory: 

218 if isinstance(inputs, list): 

219 if len(inputs) not in (1, 2): 

220 raise ValueError( 

221 "Expect inputs to have 1 or 2 tensors, got %d" % len(inputs) 

222 ) 

223 memory = inputs[0] 

224 memory_sequence_length = inputs[1] if len(inputs) == 2 else None 

225 memory_mask = mask 

226 else: 

227 memory, memory_sequence_length = inputs, None 

228 memory_mask = mask 

229 self.setup_memory(memory, memory_sequence_length, memory_mask) 

230 # We force the self.built to false here since only memory is, 

231 # initialized but the real query/state has not been call() yet. The 

232 # layer should be build and call again. 

233 self.built = False 

234 # Return the processed memory in order to create the Keras 

235 # connectivity data for it. 

236 return self.values 

237 else: 

238 if not self._memory_initialized: 

239 raise ValueError( 

240 "Cannot query the attention before the setup of memory" 

241 ) 

242 if len(inputs) not in (2, 3): 

243 raise ValueError( 

244 "Expect the inputs to have query, state, and optional " 

245 "processed memory, got %d items" % len(inputs) 

246 ) 

247 # Ignore the rest of the inputs and only care about the query and 

248 # state 

249 query, state = inputs[0], inputs[1] 

250 return self._calculate_attention(query, state) 

251 

252 def setup_memory(self, memory, memory_sequence_length=None, memory_mask=None): 

253 """Pre-process the memory before actually query the memory. 

254 

255 This should only be called once at the first invocation of `call()`. 

256 

257 Args: 

258 memory: The memory to query; usually the output of an RNN encoder. 

259 This tensor should be shaped `[batch_size, max_time, ...]`. 

260 memory_sequence_length (optional): Sequence lengths for the batch 

261 entries in memory. If provided, the memory tensor rows are masked 

262 with zeros for values past the respective sequence lengths. 

263 memory_mask: (Optional) The boolean tensor with shape `[batch_size, 

264 max_time]`. For any value equal to False, the corresponding value 

265 in memory should be ignored. 

266 """ 

267 if memory_sequence_length is not None and memory_mask is not None: 

268 raise ValueError( 

269 "memory_sequence_length and memory_mask cannot be " 

270 "used at same time for attention." 

271 ) 

272 with tf.name_scope(self.name or "BaseAttentionMechanismInit"): 

273 self.values = _prepare_memory( 

274 memory, 

275 memory_sequence_length=memory_sequence_length, 

276 memory_mask=memory_mask, 

277 check_inner_dims_defined=self._check_inner_dims_defined, 

278 ) 

279 # Mark the value as check since the memory and memory mask might not 

280 # passed from __call__(), which does not have proper keras metadata. 

281 # TODO(omalleyt12): Remove this hack once the mask the has proper 

282 # keras history. 

283 

284 def _mark_checked(tensor): 

285 tensor._keras_history_checked = True # pylint: disable=protected-access 

286 

287 tf.nest.map_structure(_mark_checked, self.values) 

288 if self.memory_layer is not None: 

289 self.keys = self.memory_layer(self.values) 

290 else: 

291 self.keys = self.values 

292 self.batch_size = self.keys.shape[0] or tf.shape(self.keys)[0] 

293 self._alignments_size = self.keys.shape[1] or tf.shape(self.keys)[1] 

294 if memory_mask is not None or memory_sequence_length is not None: 

295 unwrapped_probability_fn = self.default_probability_fn 

296 

297 def _mask_probability_fn(score, prev): 

298 return unwrapped_probability_fn( 

299 _maybe_mask_score( 

300 score, 

301 memory_mask=memory_mask, 

302 memory_sequence_length=memory_sequence_length, 

303 score_mask_value=score.dtype.min, 

304 ), 

305 prev, 

306 ) 

307 

308 self.probability_fn = _mask_probability_fn 

309 self._memory_initialized = True 

310 

311 def _calculate_attention(self, query, state): 

312 raise NotImplementedError( 

313 "_calculate_attention need to be implemented by subclasses." 

314 ) 

315 

316 def compute_mask(self, inputs, mask=None): 

317 # There real input of the attention is query and state, and the memory 

318 # layer mask shouldn't be pass down. Returning None for all output mask 

319 # here. 

320 return None, None 

321 

322 def get_config(self): 

323 config = {} 

324 # Since the probability_fn is likely to be a wrapped function, the child 

325 # class should preserve the original function and how its wrapped. 

326 

327 if self.query_layer is not None: 

328 config["query_layer"] = { 

329 "class_name": self.query_layer.__class__.__name__, 

330 "config": self.query_layer.get_config(), 

331 } 

332 if self.memory_layer is not None: 

333 config["memory_layer"] = { 

334 "class_name": self.memory_layer.__class__.__name__, 

335 "config": self.memory_layer.get_config(), 

336 } 

337 # memory is a required init parameter and its a tensor. It cannot be 

338 # serialized to config, so we put a placeholder for it. 

339 config["memory"] = None 

340 base_config = super().get_config() 

341 return {**base_config, **config} 

342 

343 def _process_probability_fn(self, func_name): 

344 """Helper method to retrieve the probably function by string input.""" 

345 valid_probability_fns = { 

346 "softmax": tf.nn.softmax, 

347 "hardmax": hardmax, 

348 } 

349 if func_name not in valid_probability_fns.keys(): 

350 raise ValueError( 

351 "Invalid probability function: %s, options are %s" 

352 % (func_name, valid_probability_fns.keys()) 

353 ) 

354 return valid_probability_fns[func_name] 

355 

356 @classmethod 

357 def deserialize_inner_layer_from_config(cls, config, custom_objects): 

358 """Helper method that reconstruct the query and memory from the config. 

359 

360 In the get_config() method, the query and memory layer configs are 

361 serialized into dict for persistence, this method perform the reverse 

362 action to reconstruct the layer from the config. 

363 

364 Args: 

365 config: dict, the configs that will be used to reconstruct the 

366 object. 

367 custom_objects: dict mapping class names (or function names) of 

368 custom (non-Keras) objects to class/functions. 

369 Returns: 

370 config: dict, the config with layer instance created, which is ready 

371 to be used as init parameters. 

372 """ 

373 # Reconstruct the query and memory layer for parent class. 

374 # Instead of updating the input, create a copy and use that. 

375 config = config.copy() 

376 query_layer_config = config.pop("query_layer", None) 

377 if query_layer_config: 

378 query_layer = tf.keras.layers.deserialize( 

379 query_layer_config, 

380 custom_objects=custom_objects, 

381 **SERIALIZATION_ARGS, 

382 ) 

383 config["query_layer"] = query_layer 

384 memory_layer_config = config.pop("memory_layer", None) 

385 if memory_layer_config: 

386 memory_layer = tf.keras.layers.deserialize( 

387 memory_layer_config, 

388 custom_objects=custom_objects, 

389 **SERIALIZATION_ARGS, 

390 ) 

391 config["memory_layer"] = memory_layer 

392 return config 

393 

394 @property 

395 def alignments_size(self): 

396 if isinstance(self._alignments_size, int): 

397 return self._alignments_size 

398 else: 

399 return tf.TensorShape([None]) 

400 

401 @property 

402 def state_size(self): 

403 return self.alignments_size 

404 

405 def initial_alignments(self, batch_size, dtype): 

406 """Creates the initial alignment values for the `tfa.seq2seq.AttentionWrapper` 

407 class. 

408 

409 This is important for attention mechanisms that use the previous 

410 alignment to calculate the alignment at the next time step 

411 (e.g. monotonic attention). 

412 

413 The default behavior is to return a tensor of all zeros. 

414 

415 Args: 

416 batch_size: `int32` scalar, the batch_size. 

417 dtype: The `dtype`. 

418 

419 Returns: 

420 A `dtype` tensor shaped `[batch_size, alignments_size]` 

421 (`alignments_size` is the values' `max_time`). 

422 """ 

423 return tf.zeros([batch_size, self._alignments_size], dtype=dtype) 

424 

425 def initial_state(self, batch_size, dtype): 

426 """Creates the initial state values for the `tfa.seq2seq.AttentionWrapper` class. 

427 

428 This is important for attention mechanisms that use the previous 

429 alignment to calculate the alignment at the next time step 

430 (e.g. monotonic attention). 

431 

432 The default behavior is to return the same output as 

433 `initial_alignments`. 

434 

435 Args: 

436 batch_size: `int32` scalar, the batch_size. 

437 dtype: The `dtype`. 

438 

439 Returns: 

440 A structure of all-zero tensors with shapes as described by 

441 `state_size`. 

442 """ 

443 return self.initial_alignments(batch_size, dtype) 

444 

445 

446def _luong_score(query, keys, scale): 

447 """Implements Luong-style (multiplicative) scoring function. 

448 

449 This attention has two forms. The first is standard Luong attention, 

450 as described in: 

451 

452 Minh-Thang Luong, Hieu Pham, Christopher D. Manning. 

453 "Effective Approaches to Attention-based Neural Machine Translation." 

454 EMNLP 2015. https://arxiv.org/abs/1508.04025 

455 

456 The second is the scaled form inspired partly by the normalized form of 

457 Bahdanau attention. 

458 

459 To enable the second form, call this function with `scale=True`. 

460 

461 Args: 

462 query: Tensor, shape `[batch_size, num_units]` to compare to keys. 

463 keys: Processed memory, shape `[batch_size, max_time, num_units]`. 

464 scale: the optional tensor to scale the attention score. 

465 

466 Returns: 

467 A `[batch_size, max_time]` tensor of unnormalized score values. 

468 

469 Raises: 

470 ValueError: If `key` and `query` depths do not match. 

471 """ 

472 depth = query.shape[-1] 

473 key_units = keys.shape[-1] 

474 if depth != key_units: 

475 raise ValueError( 

476 "Incompatible or unknown inner dimensions between query and keys. " 

477 "Query (%s) has units: %s. Keys (%s) have units: %s. " 

478 "Perhaps you need to set num_units to the keys' dimension (%s)?" 

479 % (query, depth, keys, key_units, key_units) 

480 ) 

481 

482 # Reshape from [batch_size, depth] to [batch_size, 1, depth] 

483 # for matmul. 

484 query = tf.expand_dims(query, 1) 

485 

486 # Inner product along the query units dimension. 

487 # matmul shapes: query is [batch_size, 1, depth] and 

488 # keys is [batch_size, max_time, depth]. 

489 # the inner product is asked to **transpose keys' inner shape** to get a 

490 # batched matmul on: 

491 # [batch_size, 1, depth] . [batch_size, depth, max_time] 

492 # resulting in an output shape of: 

493 # [batch_size, 1, max_time]. 

494 # we then squeeze out the center singleton dimension. 

495 score = tf.matmul(query, keys, transpose_b=True) 

496 score = tf.squeeze(score, [1]) 

497 

498 if scale is not None: 

499 score = scale * score 

500 return score 

501 

502 

503class LuongAttention(AttentionMechanism): 

504 """Implements Luong-style (multiplicative) attention scoring. 

505 

506 This attention has two forms. The first is standard Luong attention, 

507 as described in: 

508 

509 Minh-Thang Luong, Hieu Pham, Christopher D. Manning. 

510 [Effective Approaches to Attention-based Neural Machine Translation. 

511 EMNLP 2015.](https://arxiv.org/abs/1508.04025) 

512 

513 The second is the scaled form inspired partly by the normalized form of 

514 Bahdanau attention. 

515 

516 To enable the second form, construct the object with parameter 

517 `scale=True`. 

518 """ 

519 

520 @typechecked 

521 def __init__( 

522 self, 

523 units: TensorLike, 

524 memory: Optional[TensorLike] = None, 

525 memory_sequence_length: Optional[TensorLike] = None, 

526 scale: bool = False, 

527 probability_fn: str = "softmax", 

528 dtype: AcceptableDTypes = None, 

529 name: str = "LuongAttention", 

530 **kwargs, 

531 ): 

532 """Construct the AttentionMechanism mechanism. 

533 

534 Args: 

535 units: The depth of the attention mechanism. 

536 memory: The memory to query; usually the output of an RNN encoder. 

537 This tensor should be shaped `[batch_size, max_time, ...]`. 

538 memory_sequence_length: (optional): Sequence lengths for the batch 

539 entries in memory. If provided, the memory tensor rows are masked 

540 with zeros for values past the respective sequence lengths. 

541 scale: Python boolean. Whether to scale the energy term. 

542 probability_fn: (optional) string, the name of function to convert 

543 the attention score to probabilities. The default is `softmax` 

544 which is `tf.nn.softmax`. Other options is `hardmax`, which is 

545 hardmax() within this module. Any other value will result 

546 intovalidation error. Default to use `softmax`. 

547 dtype: The data type for the memory layer of the attention mechanism. 

548 name: Name to use when creating ops. 

549 **kwargs: Dictionary that contains other common arguments for layer 

550 creation. 

551 """ 

552 # For LuongAttention, we only transform the memory layer; thus 

553 # num_units **must** match expected the query depth. 

554 self.probability_fn_name = probability_fn 

555 probability_fn = self._process_probability_fn(self.probability_fn_name) 

556 

557 def wrapped_probability_fn(score, _): 

558 return probability_fn(score) 

559 

560 memory_layer = kwargs.pop("memory_layer", None) 

561 if not memory_layer: 

562 memory_layer = tf.keras.layers.Dense( 

563 units, name="memory_layer", use_bias=False, dtype=dtype 

564 ) 

565 self.units = units 

566 self.scale = scale 

567 self.scale_weight = None 

568 super().__init__( 

569 memory=memory, 

570 memory_sequence_length=memory_sequence_length, 

571 query_layer=None, 

572 memory_layer=memory_layer, 

573 probability_fn=wrapped_probability_fn, 

574 name=name, 

575 dtype=dtype, 

576 **kwargs, 

577 ) 

578 

579 def build(self, input_shape): 

580 super().build(input_shape) 

581 if self.scale and self.scale_weight is None: 

582 self.scale_weight = self.add_weight( 

583 "attention_g", initializer=tf.ones_initializer, shape=() 

584 ) 

585 self.built = True 

586 

587 def _calculate_attention(self, query, state): 

588 """Score the query based on the keys and values. 

589 

590 Args: 

591 query: Tensor of dtype matching `self.values` and shape 

592 `[batch_size, query_depth]`. 

593 state: Tensor of dtype matching `self.values` and shape 

594 `[batch_size, alignments_size]` 

595 (`alignments_size` is memory's `max_time`). 

596 

597 Returns: 

598 alignments: Tensor of dtype matching `self.values` and shape 

599 `[batch_size, alignments_size]` (`alignments_size` is memory's 

600 `max_time`). 

601 next_state: Same as the alignments. 

602 """ 

603 score = _luong_score(query, self.keys, self.scale_weight) 

604 alignments = self.probability_fn(score, state) 

605 next_state = alignments 

606 return alignments, next_state 

607 

608 def get_config(self): 

609 config = { 

610 "units": self.units, 

611 "scale": self.scale, 

612 "probability_fn": self.probability_fn_name, 

613 } 

614 base_config = super().get_config() 

615 return {**base_config, **config} 

616 

617 @classmethod 

618 def from_config(cls, config, custom_objects=None): 

619 config = AttentionMechanism.deserialize_inner_layer_from_config( 

620 config, custom_objects=custom_objects 

621 ) 

622 return cls(**config) 

623 

624 

625def _bahdanau_score( 

626 processed_query, keys, attention_v, attention_g=None, attention_b=None 

627): 

628 """Implements Bahdanau-style (additive) scoring function. 

629 

630 This attention has two forms. The first is Bahdanau attention, 

631 as described in: 

632 

633 Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. 

634 "Neural Machine Translation by Jointly Learning to Align and Translate." 

635 ICLR 2015. https://arxiv.org/abs/1409.0473 

636 

637 The second is the normalized form. This form is inspired by the 

638 weight normalization article: 

639 

640 Tim Salimans, Diederik P. Kingma. 

641 "Weight Normalization: A Simple Reparameterization to Accelerate 

642 Training of Deep Neural Networks." 

643 https://arxiv.org/abs/1602.07868 

644 

645 To enable the second form, set please pass in attention_g and attention_b. 

646 

647 Args: 

648 processed_query: Tensor, shape `[batch_size, num_units]` to compare to 

649 keys. 

650 keys: Processed memory, shape `[batch_size, max_time, num_units]`. 

651 attention_v: Tensor, shape `[num_units]`. 

652 attention_g: Optional scalar tensor for normalization. 

653 attention_b: Optional tensor with shape `[num_units]` for normalization. 

654 

655 Returns: 

656 A `[batch_size, max_time]` tensor of unnormalized score values. 

657 """ 

658 # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. 

659 processed_query = tf.expand_dims(processed_query, 1) 

660 if attention_g is not None and attention_b is not None: 

661 normed_v = ( 

662 attention_g 

663 * attention_v 

664 * tf.math.rsqrt(tf.reduce_sum(tf.square(attention_v))) 

665 ) 

666 return tf.reduce_sum( 

667 normed_v * tf.tanh(keys + processed_query + attention_b), [2] 

668 ) 

669 else: 

670 return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query), [2]) 

671 

672 

673class BahdanauAttention(AttentionMechanism): 

674 """Implements Bahdanau-style (additive) attention. 

675 

676 This attention has two forms. The first is Bahdanau attention, 

677 as described in: 

678 

679 Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. 

680 "Neural Machine Translation by Jointly Learning to Align and Translate." 

681 ICLR 2015. https://arxiv.org/abs/1409.0473 

682 

683 The second is the normalized form. This form is inspired by the 

684 weight normalization article: 

685 

686 Tim Salimans, Diederik P. Kingma. 

687 "Weight Normalization: A Simple Reparameterization to Accelerate 

688 Training of Deep Neural Networks." 

689 https://arxiv.org/abs/1602.07868 

690 

691 To enable the second form, construct the object with parameter 

692 `normalize=True`. 

693 """ 

694 

695 @typechecked 

696 def __init__( 

697 self, 

698 units: TensorLike, 

699 memory: Optional[TensorLike] = None, 

700 memory_sequence_length: Optional[TensorLike] = None, 

701 normalize: bool = False, 

702 probability_fn: str = "softmax", 

703 kernel_initializer: Initializer = "glorot_uniform", 

704 dtype: AcceptableDTypes = None, 

705 name: str = "BahdanauAttention", 

706 **kwargs, 

707 ): 

708 """Construct the Attention mechanism. 

709 

710 Args: 

711 units: The depth of the query mechanism. 

712 memory: The memory to query; usually the output of an RNN encoder. 

713 This tensor should be shaped `[batch_size, max_time, ...]`. 

714 memory_sequence_length: (optional): Sequence lengths for the batch 

715 entries in memory. If provided, the memory tensor rows are masked 

716 with zeros for values past the respective sequence lengths. 

717 normalize: Python boolean. Whether to normalize the energy term. 

718 probability_fn: (optional) string, the name of function to convert 

719 the attention score to probabilities. The default is `softmax` 

720 which is `tf.nn.softmax`. Other options is `hardmax`, which is 

721 hardmax() within this module. Any other value will result into 

722 validation error. Default to use `softmax`. 

723 kernel_initializer: (optional), the name of the initializer for the 

724 attention kernel. 

725 dtype: The data type for the query and memory layers of the attention 

726 mechanism. 

727 name: Name to use when creating ops. 

728 **kwargs: Dictionary that contains other common arguments for layer 

729 creation. 

730 """ 

731 self.probability_fn_name = probability_fn 

732 probability_fn = self._process_probability_fn(self.probability_fn_name) 

733 

734 def wrapped_probability_fn(score, _): 

735 return probability_fn(score) 

736 

737 query_layer = kwargs.pop("query_layer", None) 

738 if not query_layer: 

739 query_layer = tf.keras.layers.Dense( 

740 units, name="query_layer", use_bias=False, dtype=dtype 

741 ) 

742 memory_layer = kwargs.pop("memory_layer", None) 

743 if not memory_layer: 

744 memory_layer = tf.keras.layers.Dense( 

745 units, name="memory_layer", use_bias=False, dtype=dtype 

746 ) 

747 self.units = units 

748 self.normalize = normalize 

749 self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) 

750 self.attention_v = None 

751 self.attention_g = None 

752 self.attention_b = None 

753 super().__init__( 

754 memory=memory, 

755 memory_sequence_length=memory_sequence_length, 

756 query_layer=query_layer, 

757 memory_layer=memory_layer, 

758 probability_fn=wrapped_probability_fn, 

759 name=name, 

760 dtype=dtype, 

761 **kwargs, 

762 ) 

763 

764 def build(self, input_shape): 

765 super().build(input_shape) 

766 if self.attention_v is None: 

767 self.attention_v = self.add_weight( 

768 "attention_v", 

769 [self.units], 

770 dtype=self.dtype, 

771 initializer=self.kernel_initializer, 

772 ) 

773 if self.normalize and self.attention_g is None and self.attention_b is None: 

774 self.attention_g = self.add_weight( 

775 "attention_g", 

776 initializer=tf.constant_initializer(math.sqrt(1.0 / self.units)), 

777 shape=(), 

778 ) 

779 self.attention_b = self.add_weight( 

780 "attention_b", shape=[self.units], initializer=tf.zeros_initializer() 

781 ) 

782 self.built = True 

783 

784 def _calculate_attention(self, query, state): 

785 """Score the query based on the keys and values. 

786 

787 Args: 

788 query: Tensor of dtype matching `self.values` and shape 

789 `[batch_size, query_depth]`. 

790 state: Tensor of dtype matching `self.values` and shape 

791 `[batch_size, alignments_size]` 

792 (`alignments_size` is memory's `max_time`). 

793 

794 Returns: 

795 alignments: Tensor of dtype matching `self.values` and shape 

796 `[batch_size, alignments_size]` (`alignments_size` is memory's 

797 `max_time`). 

798 next_state: same as alignments. 

799 """ 

800 processed_query = self.query_layer(query) if self.query_layer else query 

801 score = _bahdanau_score( 

802 processed_query, 

803 self.keys, 

804 self.attention_v, 

805 attention_g=self.attention_g, 

806 attention_b=self.attention_b, 

807 ) 

808 alignments = self.probability_fn(score, state) 

809 next_state = alignments 

810 return alignments, next_state 

811 

812 def get_config(self): 

813 # yapf: disable 

814 config = { 

815 "units": self.units, 

816 "normalize": self.normalize, 

817 "probability_fn": self.probability_fn_name, 

818 "kernel_initializer": tf.keras.initializers.serialize( 

819 self.kernel_initializer, 

820 **SERIALIZATION_ARGS, 

821 ) 

822 } 

823 # yapf: enable 

824 

825 base_config = super().get_config() 

826 return {**base_config, **config} 

827 

828 @classmethod 

829 def from_config(cls, config, custom_objects=None): 

830 config = AttentionMechanism.deserialize_inner_layer_from_config( 

831 config, 

832 custom_objects=custom_objects, 

833 ) 

834 return cls(**config) 

835 

836 

837def safe_cumprod(x: TensorLike, *args, **kwargs) -> tf.Tensor: 

838 """Computes cumprod of x in logspace using cumsum to avoid underflow. 

839 

840 The cumprod function and its gradient can result in numerical instabilities 

841 when its argument has very small and/or zero values. As long as the 

842 argument is all positive, we can instead compute the cumulative product as 

843 exp(cumsum(log(x))). This function can be called identically to 

844 tf.cumprod. 

845 

846 Args: 

847 x: Tensor to take the cumulative product of. 

848 *args: Passed on to cumsum; these are identical to those in cumprod. 

849 **kwargs: Passed on to cumsum; these are identical to those in cumprod. 

850 Returns: 

851 Cumulative product of x. 

852 """ 

853 with tf.name_scope("SafeCumprod"): 

854 x = tf.convert_to_tensor(x, name="x") 

855 tiny = np.finfo(x.dtype.as_numpy_dtype).tiny 

856 return tf.exp( 

857 tf.cumsum(tf.math.log(tf.clip_by_value(x, tiny, 1)), *args, **kwargs) 

858 ) 

859 

860 

861def monotonic_attention( 

862 p_choose_i: FloatTensorLike, previous_attention: FloatTensorLike, mode: str 

863) -> tf.Tensor: 

864 """Computes monotonic attention distribution from choosing probabilities. 

865 

866 Monotonic attention implies that the input sequence is processed in an 

867 explicitly left-to-right manner when generating the output sequence. In 

868 addition, once an input sequence element is attended to at a given output 

869 timestep, elements occurring before it cannot be attended to at subsequent 

870 output timesteps. This function generates attention distributions 

871 according to these assumptions. For more information, see `Online and 

872 Linear-Time Attention by Enforcing Monotonic Alignments`. 

873 

874 Args: 

875 p_choose_i: Probability of choosing input sequence/memory element i. 

876 Should be of shape (batch_size, input_sequence_length), and should all 

877 be in the range [0, 1]. 

878 previous_attention: The attention distribution from the previous output 

879 timestep. Should be of shape (batch_size, input_sequence_length). For 

880 the first output timestep, preevious_attention[n] should be 

881 [1, 0, 0, ..., 0] for all n in [0, ... batch_size - 1]. 

882 mode: How to compute the attention distribution. Must be one of 

883 'recursive', 'parallel', or 'hard'. 

884 * 'recursive' uses tf.scan to recursively compute the distribution. 

885 This is slowest but is exact, general, and does not suffer from 

886 numerical instabilities. 

887 * 'parallel' uses parallelized cumulative-sum and cumulative-product 

888 operations to compute a closed-form solution to the recurrence 

889 relation defining the attention distribution. This makes it more 

890 efficient than 'recursive', but it requires numerical checks which 

891 make the distribution non-exact. This can be a problem in 

892 particular when input_sequence_length is long and/or p_choose_i has 

893 entries very close to 0 or 1. 

894 * 'hard' requires that the probabilities in p_choose_i are all either 

895 0 or 1, and subsequently uses a more efficient and exact solution. 

896 

897 Returns: 

898 A tensor of shape (batch_size, input_sequence_length) representing the 

899 attention distributions for each sequence in the batch. 

900 

901 Raises: 

902 ValueError: mode is not one of 'recursive', 'parallel', 'hard'. 

903 """ 

904 # Force things to be tensors 

905 p_choose_i = tf.convert_to_tensor(p_choose_i, name="p_choose_i") 

906 previous_attention = tf.convert_to_tensor( 

907 previous_attention, name="previous_attention" 

908 ) 

909 if mode == "recursive": 

910 # Use .shape[0] when it's not None, or fall back on symbolic shape 

911 batch_size = p_choose_i.shape[0] or tf.shape(p_choose_i)[0] 

912 # Compute [1, 1 - p_choose_i[0], 1 - p_choose_i[1], ..., 1 - p_choose_ 

913 # i[-2]] 

914 shifted_1mp_choose_i = tf.concat( 

915 [tf.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1 

916 ) 

917 # Compute attention distribution recursively as 

918 # q[i] = (1 - p_choose_i[i - 1])*q[i - 1] + previous_attention[i] 

919 # attention[i] = p_choose_i[i]*q[i] 

920 attention = p_choose_i * tf.transpose( 

921 tf.scan( 

922 # Need to use reshape to remind TF of the shape between loop 

923 # iterations 

924 lambda x, yz: tf.reshape(yz[0] * x + yz[1], (batch_size,)), 

925 # Loop variables yz[0] and yz[1] 

926 [tf.transpose(shifted_1mp_choose_i), tf.transpose(previous_attention)], 

927 # Initial value of x is just zeros 

928 tf.zeros((batch_size,)), 

929 ) 

930 ) 

931 elif mode == "parallel": 

932 # safe_cumprod computes cumprod in logspace with numeric checks 

933 cumprod_1mp_choose_i = safe_cumprod(1 - p_choose_i, axis=1, exclusive=True) 

934 # Compute recurrence relation solution 

935 attention = ( 

936 p_choose_i 

937 * cumprod_1mp_choose_i 

938 * tf.cumsum( 

939 previous_attention / 

940 # Clip cumprod_1mp to avoid divide-by-zero 

941 tf.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.0), 

942 axis=1, 

943 ) 

944 ) 

945 elif mode == "hard": 

946 # Remove any probabilities before the index chosen last time step 

947 p_choose_i *= tf.cumsum(previous_attention, axis=1) 

948 # Now, use exclusive cumprod to remove probabilities after the first 

949 # chosen index, like so: 

950 # p_choose_i = [0, 0, 0, 1, 1, 0, 1, 1] 

951 # cumprod(1 - p_choose_i, exclusive=True) = [1, 1, 1, 1, 0, 0, 0, 0] 

952 # Product of above: [0, 0, 0, 1, 0, 0, 0, 0] 

953 attention = p_choose_i * tf.math.cumprod(1 - p_choose_i, axis=1, exclusive=True) 

954 else: 

955 raise ValueError("mode must be 'recursive', 'parallel', or 'hard'.") 

956 return attention 

957 

958 

959def _monotonic_probability_fn( 

960 score, previous_alignments, sigmoid_noise, mode, seed=None 

961): 

962 """Attention probability function for monotonic attention. 

963 

964 Takes in unnormalized attention scores, adds pre-sigmoid noise to encourage 

965 the model to make discrete attention decisions, passes them through a 

966 sigmoid to obtain "choosing" probabilities, and then calls 

967 monotonic_attention to obtain the attention distribution. For more 

968 information, see 

969 

970 Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, 

971 "Online and Linear-Time Attention by Enforcing Monotonic Alignments." 

972 ICML 2017. https://arxiv.org/abs/1704.00784 

973 

974 Args: 

975 score: Unnormalized attention scores, shape 

976 `[batch_size, alignments_size]` 

977 previous_alignments: Previous attention distribution, shape 

978 `[batch_size, alignments_size]` 

979 sigmoid_noise: Standard deviation of pre-sigmoid noise. Setting this 

980 larger than 0 will encourage the model to produce large attention 

981 scores, effectively making the choosing probabilities discrete and the 

982 resulting attention distribution one-hot. It should be set to 0 at 

983 test-time, and when hard attention is not desired. 

984 mode: How to compute the attention distribution. Must be one of 

985 'recursive', 'parallel', or 'hard'. See the docstring for 

986 `tfa.seq2seq.monotonic_attention` for more information. 

987 seed: (optional) Random seed for pre-sigmoid noise. 

988 

989 Returns: 

990 A `[batch_size, alignments_size]`-shape tensor corresponding to the 

991 resulting attention distribution. 

992 """ 

993 # Optionally add pre-sigmoid noise to the scores 

994 if sigmoid_noise > 0: 

995 noise = tf.random.normal(tf.shape(score), dtype=score.dtype, seed=seed) 

996 score += sigmoid_noise * noise 

997 # Compute "choosing" probabilities from the attention scores 

998 if mode == "hard": 

999 # When mode is hard, use a hard sigmoid 

1000 p_choose_i = tf.cast(score > 0, score.dtype) 

1001 else: 

1002 p_choose_i = tf.sigmoid(score) 

1003 # Convert from choosing probabilities to attention distribution 

1004 return monotonic_attention(p_choose_i, previous_alignments, mode) 

1005 

1006 

1007class _BaseMonotonicAttentionMechanism(AttentionMechanism): 

1008 """Base attention mechanism for monotonic attention. 

1009 

1010 Simply overrides the initial_alignments function to provide a dirac 

1011 distribution, which is needed in order for the monotonic attention 

1012 distributions to have the correct behavior. 

1013 """ 

1014 

1015 def initial_alignments(self, batch_size, dtype): 

1016 """Creates the initial alignment values for the monotonic attentions. 

1017 

1018 Initializes to dirac distributions, i.e. 

1019 [1, 0, 0, ...memory length..., 0] for all entries in the batch. 

1020 

1021 Args: 

1022 batch_size: `int32` scalar, the batch_size. 

1023 dtype: The `dtype`. 

1024 

1025 Returns: 

1026 A `dtype` tensor shaped `[batch_size, alignments_size]` 

1027 (`alignments_size` is the values' `max_time`). 

1028 """ 

1029 max_time = self._alignments_size 

1030 return tf.one_hot( 

1031 tf.zeros((batch_size,), dtype=tf.int32), max_time, dtype=dtype 

1032 ) 

1033 

1034 

1035class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): 

1036 """Monotonic attention mechanism with Bahdanau-style energy function. 

1037 

1038 This type of attention enforces a monotonic constraint on the attention 

1039 distributions; that is once the model attends to a given point in the 

1040 memory it can't attend to any prior points at subsequence output timesteps. 

1041 It achieves this by using the `_monotonic_probability_fn` instead of `softmax` 

1042 to construct its attention distributions. Since the attention scores are 

1043 passed through a sigmoid, a learnable scalar bias parameter is applied 

1044 after the score function and before the sigmoid. Otherwise, it is 

1045 equivalent to `tfa.seq2seq.BahdanauAttention`. This approach is proposed in 

1046 

1047 Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, 

1048 "Online and Linear-Time Attention by Enforcing Monotonic Alignments." 

1049 ICML 2017. https://arxiv.org/abs/1704.00784 

1050 """ 

1051 

1052 @typechecked 

1053 def __init__( 

1054 self, 

1055 units: TensorLike, 

1056 memory: Optional[TensorLike] = None, 

1057 memory_sequence_length: Optional[TensorLike] = None, 

1058 normalize: bool = False, 

1059 sigmoid_noise: FloatTensorLike = 0.0, 

1060 sigmoid_noise_seed: Optional[FloatTensorLike] = None, 

1061 score_bias_init: FloatTensorLike = 0.0, 

1062 mode: str = "parallel", 

1063 kernel_initializer: Initializer = "glorot_uniform", 

1064 dtype: AcceptableDTypes = None, 

1065 name: str = "BahdanauMonotonicAttention", 

1066 **kwargs, 

1067 ): 

1068 """Construct the attention mechanism. 

1069 

1070 Args: 

1071 units: The depth of the query mechanism. 

1072 memory: The memory to query; usually the output of an RNN encoder. 

1073 This tensor should be shaped `[batch_size, max_time, ...]`. 

1074 memory_sequence_length: (optional): Sequence lengths for the batch 

1075 entries in memory. If provided, the memory tensor rows are masked 

1076 with zeros for values past the respective sequence lengths. 

1077 normalize: Python boolean. Whether to normalize the energy term. 

1078 sigmoid_noise: Standard deviation of pre-sigmoid noise. See the 

1079 docstring for `_monotonic_probability_fn` for more information. 

1080 sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. 

1081 score_bias_init: Initial value for score bias scalar. It's 

1082 recommended to initialize this to a negative value when the length 

1083 of the memory is large. 

1084 mode: How to compute the attention distribution. Must be one of 

1085 'recursive', 'parallel', or 'hard'. See the docstring for 

1086 `tfa.seq2seq.monotonic_attention` for more information. 

1087 kernel_initializer: (optional), the name of the initializer for the 

1088 attention kernel. 

1089 dtype: The data type for the query and memory layers of the attention 

1090 mechanism. 

1091 name: Name to use when creating ops. 

1092 **kwargs: Dictionary that contains other common arguments for layer 

1093 creation. 

1094 """ 

1095 # Set up the monotonic probability fn with supplied parameters 

1096 wrapped_probability_fn = functools.partial( 

1097 _monotonic_probability_fn, 

1098 sigmoid_noise=sigmoid_noise, 

1099 mode=mode, 

1100 seed=sigmoid_noise_seed, 

1101 ) 

1102 query_layer = kwargs.pop("query_layer", None) 

1103 if not query_layer: 

1104 query_layer = tf.keras.layers.Dense( 

1105 units, name="query_layer", use_bias=False, dtype=dtype 

1106 ) 

1107 memory_layer = kwargs.pop("memory_layer", None) 

1108 if not memory_layer: 

1109 memory_layer = tf.keras.layers.Dense( 

1110 units, name="memory_layer", use_bias=False, dtype=dtype 

1111 ) 

1112 self.units = units 

1113 self.normalize = normalize 

1114 self.sigmoid_noise = sigmoid_noise 

1115 self.sigmoid_noise_seed = sigmoid_noise_seed 

1116 self.score_bias_init = score_bias_init 

1117 self.mode = mode 

1118 self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) 

1119 self.attention_v = None 

1120 self.attention_score_bias = None 

1121 self.attention_g = None 

1122 self.attention_b = None 

1123 super().__init__( 

1124 memory=memory, 

1125 memory_sequence_length=memory_sequence_length, 

1126 query_layer=query_layer, 

1127 memory_layer=memory_layer, 

1128 probability_fn=wrapped_probability_fn, 

1129 name=name, 

1130 dtype=dtype, 

1131 **kwargs, 

1132 ) 

1133 

1134 def build(self, input_shape): 

1135 super().build(input_shape) 

1136 if self.attention_v is None: 

1137 self.attention_v = self.add_weight( 

1138 "attention_v", 

1139 [self.units], 

1140 dtype=self.dtype, 

1141 initializer=self.kernel_initializer, 

1142 ) 

1143 if self.attention_score_bias is None: 

1144 self.attention_score_bias = self.add_weight( 

1145 "attention_score_bias", 

1146 shape=(), 

1147 dtype=self.dtype, 

1148 initializer=tf.constant_initializer(self.score_bias_init), 

1149 ) 

1150 if self.normalize and self.attention_g is None and self.attention_b is None: 

1151 self.attention_g = self.add_weight( 

1152 "attention_g", 

1153 dtype=self.dtype, 

1154 initializer=tf.constant_initializer(math.sqrt(1.0 / self.units)), 

1155 shape=(), 

1156 ) 

1157 self.attention_b = self.add_weight( 

1158 "attention_b", 

1159 [self.units], 

1160 dtype=self.dtype, 

1161 initializer=tf.zeros_initializer(), 

1162 ) 

1163 self.built = True 

1164 

1165 def _calculate_attention(self, query, state): 

1166 """Score the query based on the keys and values. 

1167 

1168 Args: 

1169 query: Tensor of dtype matching `self.values` and shape 

1170 `[batch_size, query_depth]`. 

1171 state: Tensor of dtype matching `self.values` and shape 

1172 `[batch_size, alignments_size]` 

1173 (`alignments_size` is memory's `max_time`). 

1174 

1175 Returns: 

1176 alignments: Tensor of dtype matching `self.values` and shape 

1177 `[batch_size, alignments_size]` (`alignments_size` is memory's 

1178 `max_time`). 

1179 """ 

1180 processed_query = self.query_layer(query) if self.query_layer else query 

1181 score = _bahdanau_score( 

1182 processed_query, 

1183 self.keys, 

1184 self.attention_v, 

1185 attention_g=self.attention_g, 

1186 attention_b=self.attention_b, 

1187 ) 

1188 score += self.attention_score_bias 

1189 alignments = self.probability_fn(score, state) 

1190 next_state = alignments 

1191 return alignments, next_state 

1192 

1193 def get_config(self): 

1194 # yapf: disable 

1195 config = { 

1196 "units": self.units, 

1197 "normalize": self.normalize, 

1198 "sigmoid_noise": self.sigmoid_noise, 

1199 "sigmoid_noise_seed": self.sigmoid_noise_seed, 

1200 "score_bias_init": self.score_bias_init, 

1201 "mode": self.mode, 

1202 "kernel_initializer": tf.keras.initializers.serialize( 

1203 self.kernel_initializer, 

1204 **SERIALIZATION_ARGS, 

1205 ), 

1206 } 

1207 # yapf: enable 

1208 

1209 base_config = super().get_config() 

1210 return {**base_config, **config} 

1211 

1212 @classmethod 

1213 def from_config(cls, config, custom_objects=None): 

1214 config = AttentionMechanism.deserialize_inner_layer_from_config( 

1215 config, custom_objects=custom_objects 

1216 ) 

1217 return cls(**config) 

1218 

1219 

1220class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): 

1221 """Monotonic attention mechanism with Luong-style energy function. 

1222 

1223 This type of attention enforces a monotonic constraint on the attention 

1224 distributions; that is once the model attends to a given point in the 

1225 memory it can't attend to any prior points at subsequence output timesteps. 

1226 It achieves this by using the `_monotonic_probability_fn` instead of `softmax` 

1227 to construct its attention distributions. Otherwise, it is equivalent to 

1228 `tfa.seq2seq.LuongAttention`. This approach is proposed in 

1229 

1230 [Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, 

1231 "Online and Linear-Time Attention by Enforcing Monotonic Alignments." 

1232 ICML 2017.](https://arxiv.org/abs/1704.00784) 

1233 """ 

1234 

1235 @typechecked 

1236 def __init__( 

1237 self, 

1238 units: TensorLike, 

1239 memory: Optional[TensorLike] = None, 

1240 memory_sequence_length: Optional[TensorLike] = None, 

1241 scale: bool = False, 

1242 sigmoid_noise: FloatTensorLike = 0.0, 

1243 sigmoid_noise_seed: Optional[FloatTensorLike] = None, 

1244 score_bias_init: FloatTensorLike = 0.0, 

1245 mode: str = "parallel", 

1246 dtype: AcceptableDTypes = None, 

1247 name: str = "LuongMonotonicAttention", 

1248 **kwargs, 

1249 ): 

1250 """Construct the attention mechanism. 

1251 

1252 Args: 

1253 units: The depth of the query mechanism. 

1254 memory: The memory to query; usually the output of an RNN encoder. 

1255 This tensor should be shaped `[batch_size, max_time, ...]`. 

1256 memory_sequence_length: (optional): Sequence lengths for the batch 

1257 entries in memory. If provided, the memory tensor rows are masked 

1258 with zeros for values past the respective sequence lengths. 

1259 scale: Python boolean. Whether to scale the energy term. 

1260 sigmoid_noise: Standard deviation of pre-sigmoid noise. See the 

1261 docstring for `_monotonic_probability_fn` for more information. 

1262 sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. 

1263 score_bias_init: Initial value for score bias scalar. It's 

1264 recommended to initialize this to a negative value when the length 

1265 of the memory is large. 

1266 mode: How to compute the attention distribution. Must be one of 

1267 'recursive', 'parallel', or 'hard'. See the docstring for 

1268 `tfa.seq2seq.monotonic_attention` for more information. 

1269 dtype: The data type for the query and memory layers of the attention 

1270 mechanism. 

1271 name: Name to use when creating ops. 

1272 **kwargs: Dictionary that contains other common arguments for layer 

1273 creation. 

1274 """ 

1275 # Set up the monotonic probability fn with supplied parameters 

1276 wrapped_probability_fn = functools.partial( 

1277 _monotonic_probability_fn, 

1278 sigmoid_noise=sigmoid_noise, 

1279 mode=mode, 

1280 seed=sigmoid_noise_seed, 

1281 ) 

1282 memory_layer = kwargs.pop("memory_layer", None) 

1283 if not memory_layer: 

1284 memory_layer = tf.keras.layers.Dense( 

1285 units, name="memory_layer", use_bias=False, dtype=dtype 

1286 ) 

1287 self.units = units 

1288 self.scale = scale 

1289 self.sigmoid_noise = sigmoid_noise 

1290 self.sigmoid_noise_seed = sigmoid_noise_seed 

1291 self.score_bias_init = score_bias_init 

1292 self.mode = mode 

1293 self.attention_g = None 

1294 self.attention_score_bias = None 

1295 super().__init__( 

1296 memory=memory, 

1297 memory_sequence_length=memory_sequence_length, 

1298 query_layer=None, 

1299 memory_layer=memory_layer, 

1300 probability_fn=wrapped_probability_fn, 

1301 name=name, 

1302 dtype=dtype, 

1303 **kwargs, 

1304 ) 

1305 

1306 def build(self, input_shape): 

1307 super().build(input_shape) 

1308 if self.scale and self.attention_g is None: 

1309 self.attention_g = self.add_weight( 

1310 "attention_g", initializer=tf.ones_initializer, shape=() 

1311 ) 

1312 if self.attention_score_bias is None: 

1313 self.attention_score_bias = self.add_weight( 

1314 "attention_score_bias", 

1315 shape=(), 

1316 initializer=tf.constant_initializer(self.score_bias_init), 

1317 ) 

1318 self.built = True 

1319 

1320 def _calculate_attention(self, query, state): 

1321 """Score the query based on the keys and values. 

1322 

1323 Args: 

1324 query: Tensor of dtype matching `self.values` and shape 

1325 `[batch_size, query_depth]`. 

1326 state: Tensor of dtype matching `self.values` and shape 

1327 `[batch_size, alignments_size]` 

1328 (`alignments_size` is memory's `max_time`). 

1329 

1330 Returns: 

1331 alignments: Tensor of dtype matching `self.values` and shape 

1332 `[batch_size, alignments_size]` (`alignments_size` is memory's 

1333 `max_time`). 

1334 next_state: Same as alignments 

1335 """ 

1336 score = _luong_score(query, self.keys, self.attention_g) 

1337 score += self.attention_score_bias 

1338 alignments = self.probability_fn(score, state) 

1339 next_state = alignments 

1340 return alignments, next_state 

1341 

1342 def get_config(self): 

1343 config = { 

1344 "units": self.units, 

1345 "scale": self.scale, 

1346 "sigmoid_noise": self.sigmoid_noise, 

1347 "sigmoid_noise_seed": self.sigmoid_noise_seed, 

1348 "score_bias_init": self.score_bias_init, 

1349 "mode": self.mode, 

1350 } 

1351 base_config = super().get_config() 

1352 return {**base_config, **config} 

1353 

1354 @classmethod 

1355 def from_config(cls, config, custom_objects=None): 

1356 config = AttentionMechanism.deserialize_inner_layer_from_config( 

1357 config, custom_objects=custom_objects 

1358 ) 

1359 return cls(**config) 

1360 

1361 

1362class AttentionWrapperState( 

1363 collections.namedtuple( 

1364 "AttentionWrapperState", 

1365 ( 

1366 "cell_state", 

1367 "attention", 

1368 "alignments", 

1369 "alignment_history", 

1370 "attention_state", 

1371 ), 

1372 ) 

1373): 

1374 """State of a `tfa.seq2seq.AttentionWrapper`. 

1375 

1376 Attributes: 

1377 cell_state: The state of the wrapped RNN cell at the previous time 

1378 step. 

1379 attention: The attention emitted at the previous time step. 

1380 alignments: A single or tuple of `Tensor`(s) containing the 

1381 alignments emitted at the previous time step for each attention 

1382 mechanism. 

1383 alignment_history: (if enabled) a single or tuple of `TensorArray`(s) 

1384 containing alignment matrices from all time steps for each attention 

1385 mechanism. Call `stack()` on each to convert to a `Tensor`. 

1386 attention_state: A single or tuple of nested objects 

1387 containing attention mechanism state for each attention mechanism. 

1388 The objects may contain Tensors or TensorArrays. 

1389 """ 

1390 

1391 def clone(self, **kwargs): 

1392 """Clone this object, overriding components provided by kwargs. 

1393 

1394 The new state fields' shape must match original state fields' shape. 

1395 This will be validated, and original fields' shape will be propagated 

1396 to new fields. 

1397 

1398 Example: 

1399 

1400 >>> batch_size = 1 

1401 >>> memory = tf.random.normal(shape=[batch_size, 3, 100]) 

1402 >>> encoder_state = [tf.zeros((batch_size, 100)), tf.zeros((batch_size, 100))] 

1403 >>> attention_mechanism = tfa.seq2seq.LuongAttention(100, memory=memory, memory_sequence_length=[3] * batch_size) 

1404 >>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(100), attention_mechanism, attention_layer_size=10) 

1405 >>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32) 

1406 >>> decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state) 

1407 

1408 Args: 

1409 **kwargs: Any properties of the state object to replace in the 

1410 returned `AttentionWrapperState`. 

1411 

1412 Returns: 

1413 A new `AttentionWrapperState` whose properties are the same as 

1414 this one, except any overridden properties as provided in `kwargs`. 

1415 """ 

1416 

1417 def with_same_shape(old, new): 

1418 """Check and set new tensor's shape.""" 

1419 if isinstance(old, tf.Tensor) and isinstance(new, tf.Tensor): 

1420 if not tf.executing_eagerly(): 

1421 new_shape = tf.shape(new) 

1422 old_shape = tf.shape(old) 

1423 assert_equal = tf.debugging.assert_equal(new_shape, old_shape) 

1424 with tf.control_dependencies([assert_equal]): 

1425 # Add an identity op so that control deps can kick in. 

1426 return tf.identity(new) 

1427 else: 

1428 if old.shape.as_list() != new.shape.as_list(): 

1429 raise ValueError( 

1430 "The shape of the AttentionWrapperState is " 

1431 "expected to be same as the one to clone. " 

1432 "self.shape: %s, input.shape: %s" % (old.shape, new.shape) 

1433 ) 

1434 return new 

1435 return new 

1436 

1437 return tf.nest.map_structure(with_same_shape, self, super()._replace(**kwargs)) 

1438 

1439 

1440def _prepare_memory( 

1441 memory, memory_sequence_length=None, memory_mask=None, check_inner_dims_defined=True 

1442): 

1443 """Convert to tensor and possibly mask `memory`. 

1444 

1445 Args: 

1446 memory: `Tensor`, shaped `[batch_size, max_time, ...]`. 

1447 memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. 

1448 memory_mask: `boolean` tensor with shape [batch_size, max_time]. The 

1449 memory should be skipped when the corresponding mask is False. 

1450 check_inner_dims_defined: Python boolean. If `True`, the `memory` 

1451 argument's shape is checked to ensure all but the two outermost 

1452 dimensions are fully defined. 

1453 

1454 Returns: 

1455 A (possibly masked), checked, new `memory`. 

1456 

1457 Raises: 

1458 ValueError: If `check_inner_dims_defined` is `True` and not 

1459 `memory.shape[2:].is_fully_defined()`. 

1460 """ 

1461 memory = tf.nest.map_structure( 

1462 lambda m: tf.convert_to_tensor(m, name="memory"), memory 

1463 ) 

1464 if memory_sequence_length is not None and memory_mask is not None: 

1465 raise ValueError( 

1466 "memory_sequence_length and memory_mask can't be provided at same time." 

1467 ) 

1468 if memory_sequence_length is not None: 

1469 memory_sequence_length = tf.convert_to_tensor( 

1470 memory_sequence_length, name="memory_sequence_length" 

1471 ) 

1472 if check_inner_dims_defined: 

1473 

1474 def _check_dims(m): 

1475 if not m.shape[2:].is_fully_defined(): 

1476 raise ValueError( 

1477 "Expected memory %s to have fully defined inner dims, " 

1478 "but saw shape: %s" % (m.name, m.shape) 

1479 ) 

1480 

1481 tf.nest.map_structure(_check_dims, memory) 

1482 if memory_sequence_length is None and memory_mask is None: 

1483 return memory 

1484 elif memory_sequence_length is not None: 

1485 seq_len_mask = tf.sequence_mask( 

1486 memory_sequence_length, 

1487 maxlen=tf.shape(tf.nest.flatten(memory)[0])[1], 

1488 dtype=tf.nest.flatten(memory)[0].dtype, 

1489 ) 

1490 else: 

1491 # For memory_mask is not None 

1492 seq_len_mask = tf.cast(memory_mask, dtype=tf.nest.flatten(memory)[0].dtype) 

1493 

1494 def _maybe_mask(m, seq_len_mask): 

1495 """Mask the memory based on the memory mask.""" 

1496 rank = m.shape.ndims 

1497 rank = rank if rank is not None else tf.rank(m) 

1498 extra_ones = tf.ones(rank - 2, dtype=tf.int32) 

1499 seq_len_mask = tf.reshape( 

1500 seq_len_mask, tf.concat((tf.shape(seq_len_mask), extra_ones), 0) 

1501 ) 

1502 return m * seq_len_mask 

1503 

1504 return tf.nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) 

1505 

1506 

1507def _maybe_mask_score( 

1508 score, memory_sequence_length=None, memory_mask=None, score_mask_value=None 

1509): 

1510 """Mask the attention score based on the masks.""" 

1511 if memory_sequence_length is None and memory_mask is None: 

1512 return score 

1513 if memory_sequence_length is not None and memory_mask is not None: 

1514 raise ValueError( 

1515 "memory_sequence_length and memory_mask can't be provided at same time." 

1516 ) 

1517 if memory_sequence_length is not None: 

1518 message = "All values in memory_sequence_length must greater than zero." 

1519 with tf.control_dependencies( 

1520 [ 

1521 tf.debugging.assert_positive( # pylint: disable=bad-continuation 

1522 memory_sequence_length, message=message 

1523 ) 

1524 ] 

1525 ): 

1526 memory_mask = tf.sequence_mask( 

1527 memory_sequence_length, maxlen=tf.shape(score)[1] 

1528 ) 

1529 score_mask_values = score_mask_value * tf.ones_like(score) 

1530 return tf.where(memory_mask, score, score_mask_values) 

1531 

1532 

1533def hardmax(logits: TensorLike, name: Optional[str] = None) -> tf.Tensor: 

1534 """Returns batched one-hot vectors. 

1535 

1536 The depth index containing the `1` is that of the maximum logit value. 

1537 

1538 Args: 

1539 logits: A batch tensor of logit values. 

1540 name: Name to use when creating ops. 

1541 Returns: 

1542 A batched one-hot tensor. 

1543 """ 

1544 with tf.name_scope(name or "Hardmax"): 

1545 logits = tf.convert_to_tensor(logits, name="logits") 

1546 depth = logits.shape[-1] or tf.shape(logits)[-1] 

1547 return tf.one_hot(tf.argmax(logits, -1), depth, dtype=logits.dtype) 

1548 

1549 

1550def _compute_attention( 

1551 attention_mechanism, cell_output, attention_state, attention_layer 

1552): 

1553 """Computes the attention and alignments for a given 

1554 attention_mechanism.""" 

1555 alignments, next_attention_state = attention_mechanism( 

1556 [cell_output, attention_state] 

1557 ) 

1558 

1559 # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] 

1560 expanded_alignments = tf.expand_dims(alignments, 1) 

1561 # Context is the inner product of alignments and values along the 

1562 # memory time dimension. 

1563 # alignments shape is 

1564 # [batch_size, 1, memory_time] 

1565 # attention_mechanism.values shape is 

1566 # [batch_size, memory_time, memory_size] 

1567 # the batched matmul is over memory_time, so the output shape is 

1568 # [batch_size, 1, memory_size]. 

1569 # we then squeeze out the singleton dim. 

1570 context_ = tf.matmul(expanded_alignments, attention_mechanism.values) 

1571 context_ = tf.squeeze(context_, [1]) 

1572 

1573 if attention_layer is not None: 

1574 attention = attention_layer(tf.concat([cell_output, context_], 1)) 

1575 else: 

1576 attention = context_ 

1577 

1578 return attention, alignments, next_attention_state 

1579 

1580 

1581class AttentionWrapper(AbstractRNNCell): 

1582 """Wraps another RNN cell with attention. 

1583 

1584 Example: 

1585 

1586 >>> batch_size = 4 

1587 >>> max_time = 7 

1588 >>> hidden_size = 32 

1589 >>> 

1590 >>> memory = tf.random.uniform([batch_size, max_time, hidden_size]) 

1591 >>> memory_sequence_length = tf.fill([batch_size], max_time) 

1592 >>> 

1593 >>> attention_mechanism = tfa.seq2seq.LuongAttention(hidden_size) 

1594 >>> attention_mechanism.setup_memory(memory, memory_sequence_length) 

1595 >>> 

1596 >>> cell = tf.keras.layers.LSTMCell(hidden_size) 

1597 >>> cell = tfa.seq2seq.AttentionWrapper( 

1598 ... cell, attention_mechanism, attention_layer_size=hidden_size) 

1599 >>> 

1600 >>> inputs = tf.random.uniform([batch_size, hidden_size]) 

1601 >>> state = cell.get_initial_state(inputs) 

1602 >>> 

1603 >>> outputs, state = cell(inputs, state) 

1604 >>> outputs.shape 

1605 TensorShape([4, 32]) 

1606 """ 

1607 

1608 @typechecked 

1609 def __init__( 

1610 self, 

1611 cell: tf.keras.layers.Layer, 

1612 attention_mechanism: Union[AttentionMechanism, List[AttentionMechanism]], 

1613 attention_layer_size: Optional[Union[Number, List[Number]]] = None, 

1614 alignment_history: bool = False, 

1615 cell_input_fn: Optional[Callable] = None, 

1616 output_attention: bool = True, 

1617 initial_cell_state: Optional[TensorLike] = None, 

1618 name: Optional[str] = None, 

1619 attention_layer: Optional[ 

1620 Union[tf.keras.layers.Layer, List[tf.keras.layers.Layer]] 

1621 ] = None, 

1622 attention_fn: Optional[Callable] = None, 

1623 **kwargs, 

1624 ): 

1625 """Construct the `AttentionWrapper`. 

1626 

1627 **NOTE** If you are using the `tfa.seq2seq.BeamSearchDecoder` with a cell wrapped 

1628 in `AttentionWrapper`, then you must ensure that: 

1629 

1630 - The encoder output has been tiled to `beam_width` via 

1631 `tfa.seq2seq.tile_batch` (NOT `tf.tile`). 

1632 - The `batch_size` argument passed to the `get_initial_state` method of 

1633 this wrapper is equal to `true_batch_size * beam_width`. 

1634 - The initial state created with `get_initial_state` above contains a 

1635 `cell_state` value containing properly tiled final state from the 

1636 encoder. 

1637 

1638 An example: 

1639 

1640 >>> batch_size = 1 

1641 >>> beam_width = 5 

1642 >>> sequence_length = tf.convert_to_tensor([5]) 

1643 >>> encoder_outputs = tf.random.uniform(shape=(batch_size, 5, 10)) 

1644 >>> encoder_final_state = [tf.zeros((batch_size, 10)), tf.zeros((batch_size, 10))] 

1645 >>> tiled_encoder_outputs = tfa.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width) 

1646 >>> tiled_encoder_final_state = tfa.seq2seq.tile_batch(encoder_final_state, multiplier=beam_width) 

1647 >>> tiled_sequence_length = tfa.seq2seq.tile_batch(sequence_length, multiplier=beam_width) 

1648 >>> attention_mechanism = tfa.seq2seq.BahdanauAttention(10, memory=tiled_encoder_outputs, memory_sequence_length=tiled_sequence_length) 

1649 >>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(10), attention_mechanism) 

1650 >>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size * beam_width, dtype=tf.float32) 

1651 >>> decoder_initial_state = decoder_initial_state.clone(cell_state=tiled_encoder_final_state) 

1652 

1653 Args: 

1654 cell: A layer that implements the `tf.keras.layers.AbstractRNNCell` 

1655 interface. 

1656 attention_mechanism: A list of `tfa.seq2seq.AttentionMechanism` 

1657 instances single instance. 

1658 attention_layer_size: A list of Python integers or a single Python 

1659 integer, the depth of the attention (output) layer(s). If `None` 

1660 (default), use the context as attention at each time step. 

1661 Otherwise, feed the context and cell output into the attention 

1662 layer to generate attention at each time step. If 

1663 `attention_mechanism` is a list, `attention_layer_size` must be a list 

1664 of the same length. If `attention_layer` is set, this must be `None`. 

1665 If `attention_fn` is set, it must guaranteed that the outputs of 

1666 `attention_fn` also meet the above requirements. 

1667 alignment_history: Python boolean, whether to store alignment history 

1668 from all time steps in the final output state (currently stored as 

1669 a time major `TensorArray` on which you must call `stack()`). 

1670 cell_input_fn: (optional) A `callable`. The default is: 

1671 `lambda inputs, attention: 

1672 tf.concat([inputs, attention], -1)`. 

1673 output_attention: Python bool. If `True` (default), the output at 

1674 each time step is the attention value. This is the behavior of 

1675 Luong-style attention mechanisms. If `False`, the output at each 

1676 time step is the output of `cell`. This is the behavior of 

1677 Bahdanau-style attention mechanisms. In both cases, the 

1678 `attention` tensor is propagated to the next time step via the 

1679 state and is used there. This flag only controls whether the 

1680 attention mechanism is propagated up to the next cell in an RNN 

1681 stack or to the top RNN output. 

1682 initial_cell_state: The initial state value to use for the cell when 

1683 the user calls `get_initial_state()`. Note that if this value is 

1684 provided now, and the user uses a `batch_size` argument of 

1685 `get_initial_state` which does not match the batch size of 

1686 `initial_cell_state`, proper behavior is not guaranteed. 

1687 name: Name to use when creating ops. 

1688 attention_layer: A list of `tf.keras.layers.Layer` instances or a 

1689 single `tf.keras.layers.Layer` instance taking the context 

1690 and cell output as inputs to generate attention at each time step. 

1691 If `None` (default), use the context as attention at each time step. 

1692 If `attention_mechanism` is a list, `attention_layer` must be a list of 

1693 the same length. If `attention_layer_size` is set, this must be 

1694 `None`. 

1695 attention_fn: An optional callable function that allows users to 

1696 provide their own customized attention function, which takes input 

1697 `(attention_mechanism, cell_output, attention_state, 

1698 attention_layer)` and outputs `(attention, alignments, 

1699 next_attention_state)`. If provided, the `attention_layer_size` should 

1700 be the size of the outputs of `attention_fn`. 

1701 **kwargs: Other keyword arguments for layer creation. 

1702 

1703 Raises: 

1704 TypeError: `attention_layer_size` is not `None` and 

1705 (`attention_mechanism` is a list but `attention_layer_size` is not; 

1706 or vice versa). 

1707 ValueError: if `attention_layer_size` is not `None`, 

1708 `attention_mechanism` is a list, and its length does not match that 

1709 of `attention_layer_size`; if `attention_layer_size` and 

1710 `attention_layer` are set simultaneously. 

1711 """ 

1712 super().__init__(name=name, **kwargs) 

1713 keras_utils.assert_like_rnncell("cell", cell) 

1714 if isinstance(attention_mechanism, (list, tuple)): 

1715 self._is_multi = True 

1716 attention_mechanisms = list(attention_mechanism) 

1717 else: 

1718 self._is_multi = False 

1719 attention_mechanisms = [attention_mechanism] 

1720 

1721 if cell_input_fn is None: 

1722 

1723 def cell_input_fn(inputs, attention): 

1724 return tf.concat([inputs, attention], -1) 

1725 

1726 if attention_layer_size is not None and attention_layer is not None: 

1727 raise ValueError( 

1728 "Only one of attention_layer_size and attention_layer should be set" 

1729 ) 

1730 

1731 if attention_layer_size is not None: 

1732 attention_layer_sizes = tuple( 

1733 attention_layer_size 

1734 if isinstance(attention_layer_size, (list, tuple)) 

1735 else (attention_layer_size,) 

1736 ) 

1737 if len(attention_layer_sizes) != len(attention_mechanisms): 

1738 raise ValueError( 

1739 "If provided, attention_layer_size must contain exactly " 

1740 "one integer per attention_mechanism, saw: %d vs %d" 

1741 % (len(attention_layer_sizes), len(attention_mechanisms)) 

1742 ) 

1743 dtype = kwargs.get("dtype", None) 

1744 self._attention_layers = list( 

1745 tf.keras.layers.Dense( 

1746 attention_layer_size, 

1747 name="attention_layer", 

1748 use_bias=False, 

1749 dtype=dtype, 

1750 ) 

1751 for i, attention_layer_size in enumerate(attention_layer_sizes) 

1752 ) 

1753 elif attention_layer is not None: 

1754 self._attention_layers = list( 

1755 attention_layer 

1756 if isinstance(attention_layer, (list, tuple)) 

1757 else (attention_layer,) 

1758 ) 

1759 if len(self._attention_layers) != len(attention_mechanisms): 

1760 raise ValueError( 

1761 "If provided, attention_layer must contain exactly one " 

1762 "layer per attention_mechanism, saw: %d vs %d" 

1763 % (len(self._attention_layers), len(attention_mechanisms)) 

1764 ) 

1765 else: 

1766 self._attention_layers = None 

1767 

1768 if attention_fn is None: 

1769 attention_fn = _compute_attention 

1770 self._attention_fn = attention_fn 

1771 self._attention_layer_size = None 

1772 

1773 self._cell = cell 

1774 self._attention_mechanisms = attention_mechanisms 

1775 self._cell_input_fn = cell_input_fn 

1776 self._output_attention = output_attention 

1777 self._alignment_history = alignment_history 

1778 with tf.name_scope(name or "AttentionWrapperInit"): 

1779 if initial_cell_state is None: 

1780 self._initial_cell_state = None 

1781 else: 

1782 final_state_tensor = tf.nest.flatten(initial_cell_state)[-1] 

1783 state_batch_size = ( 

1784 final_state_tensor.shape[0] or tf.shape(final_state_tensor)[0] 

1785 ) 

1786 error_message = ( 

1787 "When constructing AttentionWrapper %s: " % self.name 

1788 + "Non-matching batch sizes between the memory " 

1789 "(encoder output) and initial_cell_state. Are you using " 

1790 "the BeamSearchDecoder? You may need to tile your " 

1791 "initial state via the tfa.seq2seq.tile_batch " 

1792 "function with argument multiple=beam_width." 

1793 ) 

1794 with tf.control_dependencies( 

1795 self._batch_size_checks( # pylint: disable=bad-continuation 

1796 state_batch_size, error_message 

1797 ) 

1798 ): 

1799 self._initial_cell_state = tf.nest.map_structure( 

1800 lambda s: tf.identity(s, name="check_initial_cell_state"), 

1801 initial_cell_state, 

1802 ) 

1803 

1804 def _attention_mechanisms_checks(self): 

1805 for attention_mechanism in self._attention_mechanisms: 

1806 if not attention_mechanism.memory_initialized: 

1807 raise ValueError( 

1808 "The AttentionMechanism instances passed to " 

1809 "this AttentionWrapper should be initialized " 

1810 "with a memory first, either by passing it " 

1811 "to the AttentionMechanism constructor or " 

1812 "calling attention_mechanism.setup_memory()" 

1813 ) 

1814 

1815 def _batch_size_checks(self, batch_size, error_message): 

1816 self._attention_mechanisms_checks() 

1817 return [ 

1818 tf.debugging.assert_equal( 

1819 batch_size, attention_mechanism.batch_size, message=error_message 

1820 ) 

1821 for attention_mechanism in self._attention_mechanisms 

1822 ] 

1823 

1824 def _get_attention_layer_size(self): 

1825 if self._attention_layer_size is not None: 

1826 return self._attention_layer_size 

1827 self._attention_mechanisms_checks() 

1828 attention_output_sizes = ( 

1829 attention_mechanism.values.shape[-1] 

1830 for attention_mechanism in self._attention_mechanisms 

1831 ) 

1832 if self._attention_layers is None: 

1833 self._attention_layer_size = sum(attention_output_sizes) 

1834 else: 

1835 # Compute the layer output size from its input which is the 

1836 # concatenation of the cell output and the attention mechanism 

1837 # output. 

1838 self._attention_layer_size = sum( 

1839 layer.compute_output_shape( 

1840 [None, self._cell.output_size + attention_output_size] 

1841 )[-1] 

1842 for layer, attention_output_size in zip( 

1843 self._attention_layers, attention_output_sizes 

1844 ) 

1845 ) 

1846 return self._attention_layer_size 

1847 

1848 def _item_or_tuple(self, seq): 

1849 """Returns `seq` as tuple or the singular element. 

1850 

1851 Which is returned is determined by how the AttentionMechanism(s) were 

1852 passed to the constructor. 

1853 

1854 Args: 

1855 seq: A non-empty sequence of items or generator. 

1856 

1857 Returns: 

1858 Either the values in the sequence as a tuple if 

1859 AttentionMechanism(s) were passed to the constructor as a sequence 

1860 or the singular element. 

1861 """ 

1862 t = tuple(seq) 

1863 if self._is_multi: 

1864 return t 

1865 else: 

1866 return t[0] 

1867 

1868 @property 

1869 def output_size(self): 

1870 if self._output_attention: 

1871 return self._get_attention_layer_size() 

1872 else: 

1873 return self._cell.output_size 

1874 

1875 @property 

1876 def state_size(self): 

1877 """The `state_size` property of `tfa.seq2seq.AttentionWrapper`. 

1878 

1879 Returns: 

1880 A `tfa.seq2seq.AttentionWrapperState` tuple containing shapes used 

1881 by this object. 

1882 """ 

1883 return AttentionWrapperState( 

1884 cell_state=self._cell.state_size, 

1885 attention=self._get_attention_layer_size(), 

1886 alignments=self._item_or_tuple( 

1887 a.alignments_size for a in self._attention_mechanisms 

1888 ), 

1889 attention_state=self._item_or_tuple( 

1890 a.state_size for a in self._attention_mechanisms 

1891 ), 

1892 alignment_history=self._item_or_tuple( 

1893 a.alignments_size if self._alignment_history else () 

1894 for a in self._attention_mechanisms 

1895 ), 

1896 ) # sometimes a TensorArray 

1897 

1898 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

1899 """Return an initial (zero) state tuple for this `tfa.seq2seq.AttentionWrapper`. 

1900 

1901 **NOTE** Please see the initializer documentation for details of how 

1902 to call `get_initial_state` if using a `tfa.seq2seq.AttentionWrapper` 

1903 with a `tfa.seq2seq.BeamSearchDecoder`. 

1904 

1905 Args: 

1906 inputs: The inputs that will be fed to this cell. 

1907 batch_size: `0D` integer tensor: the batch size. 

1908 dtype: The internal state data type. 

1909 

1910 Returns: 

1911 An `tfa.seq2seq.AttentionWrapperState` tuple containing zeroed out tensors and, 

1912 possibly, empty `TensorArray` objects. 

1913 

1914 Raises: 

1915 ValueError: (or, possibly at runtime, `InvalidArgument`), if 

1916 `batch_size` does not match the output size of the encoder passed 

1917 to the wrapper object at initialization time. 

1918 """ 

1919 if inputs is not None: 

1920 batch_size = tf.shape(inputs)[0] 

1921 dtype = inputs.dtype 

1922 with tf.name_scope( 

1923 type(self).__name__ + "ZeroState" 

1924 ): # pylint: disable=bad-continuation 

1925 if self._initial_cell_state is not None: 

1926 cell_state = self._initial_cell_state 

1927 else: 

1928 cell_state = self._cell.get_initial_state( 

1929 batch_size=batch_size, dtype=dtype 

1930 ) 

1931 error_message = ( 

1932 "When calling get_initial_state of AttentionWrapper %s: " % self.name 

1933 + "Non-matching batch sizes between the memory " 

1934 "(encoder output) and the requested batch size. Are you using " 

1935 "the BeamSearchDecoder? If so, make sure your encoder output " 

1936 "has been tiled to beam_width via " 

1937 "tfa.seq2seq.tile_batch, and the batch_size= argument " 

1938 "passed to get_initial_state is batch_size * beam_width." 

1939 ) 

1940 with tf.control_dependencies( 

1941 self._batch_size_checks(batch_size, error_message) 

1942 ): # pylint: disable=bad-continuation 

1943 cell_state = tf.nest.map_structure( 

1944 lambda s: tf.identity(s, name="checked_cell_state"), cell_state 

1945 ) 

1946 initial_alignments = [ 

1947 attention_mechanism.initial_alignments(batch_size, dtype) 

1948 for attention_mechanism in self._attention_mechanisms 

1949 ] 

1950 return AttentionWrapperState( 

1951 cell_state=cell_state, 

1952 attention=tf.zeros( 

1953 [batch_size, self._get_attention_layer_size()], dtype=dtype 

1954 ), 

1955 alignments=self._item_or_tuple(initial_alignments), 

1956 attention_state=self._item_or_tuple( 

1957 attention_mechanism.initial_state(batch_size, dtype) 

1958 for attention_mechanism in self._attention_mechanisms 

1959 ), 

1960 alignment_history=self._item_or_tuple( 

1961 tf.TensorArray( 

1962 dtype, size=0, dynamic_size=True, element_shape=alignment.shape 

1963 ) 

1964 if self._alignment_history 

1965 else () 

1966 for alignment in initial_alignments 

1967 ), 

1968 ) 

1969 

1970 def call(self, inputs, state, **kwargs): 

1971 """Perform a step of attention-wrapped RNN. 

1972 

1973 - Step 1: Mix the `inputs` and previous step's `attention` output via 

1974 `cell_input_fn`. 

1975 - Step 2: Call the wrapped `cell` with this input and its previous 

1976 state. 

1977 - Step 3: Score the cell's output with `attention_mechanism`. 

1978 - Step 4: Calculate the alignments by passing the score through the 

1979 `normalizer`. 

1980 - Step 5: Calculate the context vector as the inner product between the 

1981 alignments and the attention_mechanism's values (memory). 

1982 - Step 6: Calculate the attention output by concatenating the cell 

1983 output and context through the attention layer (a linear layer with 

1984 `attention_layer_size` outputs). 

1985 

1986 Args: 

1987 inputs: (Possibly nested tuple of) Tensor, the input at this time 

1988 step. 

1989 state: An instance of `tfa.seq2seq.AttentionWrapperState` containing 

1990 tensors from the previous time step. 

1991 **kwargs: Dict, other keyword arguments for the cell call method. 

1992 

1993 Returns: 

1994 A tuple `(attention_or_cell_output, next_state)`, where: 

1995 

1996 - `attention_or_cell_output` depending on `output_attention`. 

1997 - `next_state` is an instance of `tfa.seq2seq.AttentionWrapperState` 

1998 containing the state calculated at this time step. 

1999 

2000 Raises: 

2001 TypeError: If `state` is not an instance of `tfa.seq2seq.AttentionWrapperState`. 

2002 """ 

2003 if not isinstance(state, AttentionWrapperState): 

2004 try: 

2005 state = AttentionWrapperState(*state) 

2006 except TypeError: 

2007 raise TypeError( 

2008 "Expected state to be instance of AttentionWrapperState or " 

2009 "values that can construct AttentionWrapperState. " 

2010 "Received type %s instead." % type(state) 

2011 ) 

2012 

2013 # Step 1: Calculate the true inputs to the cell based on the 

2014 # previous attention value. 

2015 cell_inputs = self._cell_input_fn(inputs, state.attention) 

2016 cell_state = state.cell_state 

2017 cell_output, next_cell_state = self._cell(cell_inputs, cell_state, **kwargs) 

2018 next_cell_state = tf.nest.pack_sequence_as( 

2019 cell_state, tf.nest.flatten(next_cell_state) 

2020 ) 

2021 

2022 cell_batch_size = cell_output.shape[0] or tf.shape(cell_output)[0] 

2023 error_message = ( 

2024 "When applying AttentionWrapper %s: " % self.name 

2025 + "Non-matching batch sizes between the memory " 

2026 "(encoder output) and the query (decoder output). Are you using " 

2027 "the BeamSearchDecoder? You may need to tile your memory input " 

2028 "via the tfa.seq2seq.tile_batch function with argument " 

2029 "multiple=beam_width." 

2030 ) 

2031 with tf.control_dependencies( 

2032 self._batch_size_checks(cell_batch_size, error_message) 

2033 ): # pylint: disable=bad-continuation 

2034 cell_output = tf.identity(cell_output, name="checked_cell_output") 

2035 

2036 if self._is_multi: 

2037 previous_attention_state = state.attention_state 

2038 previous_alignment_history = state.alignment_history 

2039 else: 

2040 previous_attention_state = [state.attention_state] 

2041 previous_alignment_history = [state.alignment_history] 

2042 

2043 all_alignments = [] 

2044 all_attentions = [] 

2045 all_attention_states = [] 

2046 maybe_all_histories = [] 

2047 for i, attention_mechanism in enumerate(self._attention_mechanisms): 

2048 attention, alignments, next_attention_state = self._attention_fn( 

2049 attention_mechanism, 

2050 cell_output, 

2051 previous_attention_state[i], 

2052 self._attention_layers[i] if self._attention_layers else None, 

2053 ) 

2054 alignment_history = ( 

2055 previous_alignment_history[i].write( 

2056 previous_alignment_history[i].size(), alignments 

2057 ) 

2058 if self._alignment_history 

2059 else () 

2060 ) 

2061 

2062 all_attention_states.append(next_attention_state) 

2063 all_alignments.append(alignments) 

2064 all_attentions.append(attention) 

2065 maybe_all_histories.append(alignment_history) 

2066 

2067 attention = tf.concat(all_attentions, 1) 

2068 next_state = AttentionWrapperState( 

2069 cell_state=next_cell_state, 

2070 attention=attention, 

2071 attention_state=self._item_or_tuple(all_attention_states), 

2072 alignments=self._item_or_tuple(all_alignments), 

2073 alignment_history=self._item_or_tuple(maybe_all_histories), 

2074 ) 

2075 

2076 if self._output_attention: 

2077 return attention, next_state 

2078 else: 

2079 return cell_output, next_state