Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/seq2seq/beam_search_decoder.py: 16%

372 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A decoder that performs beam search.""" 

16 

17import collections 

18import numpy as np 

19 

20import tensorflow as tf 

21 

22from tensorflow_addons import options 

23from tensorflow_addons.seq2seq import attention_wrapper 

24from tensorflow_addons.seq2seq import decoder 

25from tensorflow_addons.utils import keras_utils 

26from tensorflow_addons.utils.resource_loader import LazySO 

27from tensorflow_addons.utils.types import FloatTensorLike, TensorLike, Number 

28 

29from typeguard import typechecked 

30from typing import Callable, Optional 

31 

32_beam_search_so = LazySO("custom_ops/seq2seq/_beam_search_ops.so") 

33 

34 

35class BeamSearchDecoderState( 

36 collections.namedtuple( 

37 "BeamSearchDecoderState", 

38 ( 

39 "cell_state", 

40 "log_probs", 

41 "finished", 

42 "lengths", 

43 "accumulated_attention_probs", 

44 ), 

45 ) 

46): 

47 """State of a `tfa.seq2seq.BeamSearchDecoder`. 

48 

49 Attributes: 

50 cell_state: The cell state returned at the previous time step. 

51 log_probs: The accumulated log probabilities of each beam. 

52 A `float32` `Tensor` of shape `[batch_size, beam_width]`. 

53 finished: The finished status of each beam. 

54 A `bool` `Tensor` of shape `[batch_size, beam_width]`. 

55 lengths: The accumulated length of each beam. 

56 An `int64` `Tensor` of shape `[batch_size, beam_width]`. 

57 accumulated_attention_prob: Accumulation of the attention 

58 probabilities (used to compute the coverage penalty) 

59 """ 

60 

61 pass 

62 

63 

64class BeamSearchDecoderOutput( 

65 collections.namedtuple( 

66 "BeamSearchDecoderOutput", ("scores", "predicted_ids", "parent_ids") 

67 ) 

68): 

69 """Outputs of a `tfa.seq2seq.BeamSearchDecoder` step. 

70 

71 Attributes: 

72 scores: The scores this step, which are the log 

73 probabilities over the output vocabulary, possibly penalized by length 

74 and attention coverage. When `tfa.seq2seq.BeamSearchDecoder` is created with 

75 `output_all_scores=False` (default), this will be a `float32` `Tensor` 

76 of shape `[batch_size, beam_width]` containing the top scores 

77 corresponding to the predicted IDs. When `output_all_scores=True`, 

78 this contains the scores for all token IDs and has shape 

79 `[batch_size, beam_width, vocab_size]`. 

80 predicted_ids: The token IDs predicted for this step. 

81 A `int32` `Tensor` of shape `[batch_size, beam_width]`. 

82 parent_ids: The indices of the parent beam of each beam. 

83 A `int32` `Tensor` of shape `[batch_size, beam_width]`. 

84 """ 

85 

86 pass 

87 

88 

89class FinalBeamSearchDecoderOutput( 

90 collections.namedtuple( 

91 "FinalBeamDecoderOutput", ["predicted_ids", "beam_search_decoder_output"] 

92 ) 

93): 

94 """Final outputs returned by the beam search after all decoding is finished. 

95 

96 Attributes: 

97 predicted_ids: The final prediction. A tensor of shape 

98 `[batch_size, T, beam_width]` (or `[T, batch_size, beam_width]` if 

99 `output_time_major` is True). Beams are ordered from best to worst. 

100 beam_search_decoder_output: An instance of `tfa.seq2seq.BeamSearchDecoderOutput` that 

101 describes the state of the beam search. 

102 """ 

103 

104 pass 

105 

106 

107def _tile_batch(t, multiplier): 

108 """Core single-tensor implementation of tile_batch.""" 

109 t = tf.convert_to_tensor(t, name="t") 

110 shape_t = tf.shape(t) 

111 if t.shape.ndims is None or t.shape.ndims < 1: 

112 raise ValueError("t must have statically known rank") 

113 tiling = [1] * (t.shape.ndims + 1) 

114 tiling[1] = multiplier 

115 tiled_static_batch_size = ( 

116 t.shape[0] * multiplier if t.shape[0] is not None else None 

117 ) 

118 tiled = tf.tile(tf.expand_dims(t, 1), tiling) 

119 tiled = tf.reshape(tiled, tf.concat(([shape_t[0] * multiplier], shape_t[1:]), 0)) 

120 tiled.set_shape(tf.TensorShape([tiled_static_batch_size]).concatenate(t.shape[1:])) 

121 return tiled 

122 

123 

124def tile_batch(t: TensorLike, multiplier: int, name: Optional[str] = None) -> tf.Tensor: 

125 """Tiles the batch dimension of a (possibly nested structure of) tensor(s). 

126 

127 For each tensor t in a (possibly nested structure) of tensors, 

128 this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed 

129 of minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a 

130 shape `[batch_size * multiplier, s0, s1, ...]` composed of minibatch 

131 entries `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is 

132 repeated `multiplier` times. 

133 

134 Args: 

135 t: `Tensor` shaped `[batch_size, ...]`. 

136 multiplier: Python int. 

137 name: Name scope for any created operations. 

138 

139 Returns: 

140 A (possibly nested structure of) `Tensor` shaped 

141 `[batch_size * multiplier, ...]`. 

142 

143 Raises: 

144 ValueError: if tensor(s) `t` do not have a statically known rank or 

145 the rank is < 1. 

146 """ 

147 with tf.name_scope(name or "tile_batch"): 

148 return tf.nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t) 

149 

150 

151@tf.function( 

152 input_signature=( 

153 tf.TensorSpec([None, None, None], dtype=tf.int32), 

154 tf.TensorSpec([None, None, None], dtype=tf.int32), 

155 tf.TensorSpec([None], dtype=tf.int32), 

156 tf.TensorSpec([], dtype=tf.int32), 

157 ) 

158) 

159def _gather_tree(step_ids, parent_ids, max_sequence_lengths, end_token): 

160 input_shape = tf.shape(parent_ids) 

161 max_time = input_shape[0] 

162 beam_width = input_shape[2] 

163 max_sequence_lengths = tf.math.minimum(max_sequence_lengths, max_time) 

164 mask = tf.expand_dims( 

165 tf.transpose(tf.sequence_mask(max_sequence_lengths, maxlen=max_time)), -1 

166 ) 

167 

168 # Mask out of range ids. 

169 end_tokens = tf.fill(input_shape, end_token) 

170 step_ids = tf.where(mask, x=step_ids, y=end_tokens) 

171 parent_ids = tf.where(mask, x=parent_ids, y=tf.zeros_like(parent_ids)) 

172 assert_op = tf.debugging.Assert( 

173 tf.math.reduce_all( 

174 tf.math.logical_and(parent_ids >= 0, parent_ids < beam_width) 

175 ), 

176 ["All parent ids must be positive and less than beam_width"], 

177 ) 

178 

179 # Reverse all sequences as we need to gather from the end. 

180 with tf.control_dependencies([assert_op]): 

181 rev_step_ids = tf.reverse_sequence( 

182 step_ids, max_sequence_lengths, seq_axis=0, batch_axis=1 

183 ) 

184 rev_parent_ids = tf.reverse_sequence( 

185 parent_ids, max_sequence_lengths, seq_axis=0, batch_axis=1 

186 ) 

187 

188 # Initialize output ids and parent based on last step. 

189 output_ids = tf.TensorArray(step_ids.dtype, size=max_time, dynamic_size=False) 

190 output_ids = output_ids.write(0, rev_step_ids[0]) 

191 parent = rev_parent_ids[0] 

192 

193 # For each step, gather ids based on beam origin. 

194 for t in tf.range(1, max_time): 

195 ids = tf.gather(rev_step_ids[t], parent, batch_dims=1) 

196 parent = tf.gather(rev_parent_ids[t], parent, batch_dims=1) 

197 output_ids = output_ids.write(t, ids) 

198 

199 # Reverse sequences to their original order. 

200 output_ids = output_ids.stack() 

201 output_ids = tf.reverse_sequence( 

202 output_ids, max_sequence_lengths, seq_axis=0, batch_axis=1 

203 ) 

204 

205 # Ensure that there are only end_token after the first end_token. 

206 in_bound_steps = tf.math.cumsum(tf.cast(output_ids == end_token, tf.int32)) == 0 

207 output_ids = tf.where(in_bound_steps, x=output_ids, y=end_tokens) 

208 return output_ids 

209 

210 

211def gather_tree( 

212 step_ids: TensorLike, 

213 parent_ids: TensorLike, 

214 max_sequence_lengths: TensorLike, 

215 end_token: Number, 

216) -> tf.Tensor: 

217 """Calculates the full beams from the per-step ids and parent beam ids. 

218 

219 For a given beam, past the time step containing the first decoded 

220 `end_token` all values are filled in with `end_token`. 

221 

222 Args: 

223 step_ids: The predicted token IDs. 

224 A `int32` `Tensor` of shape `[max_time, batch_size, beam_width]`. 

225 parent_ids: The parent beam indices. 

226 A `int32` `Tensor` of shape `[max_time, batch_size, beam_width]`. 

227 max_sequence_lengths: The maximum sequence length of each batch. 

228 A `int32` `Tensor` of shape `[batch_size]`. 

229 end_token: The end token ID. 

230 

231 Returns: 

232 The reordered token IDs based on `parent_ids`. 

233 

234 Raises: 

235 InvalidArgumentError: if `parent_ids` contains an invalid index. 

236 """ 

237 if not options.is_custom_kernel_disabled(): 

238 try: 

239 return _beam_search_so.ops.addons_gather_tree( 

240 step_ids, parent_ids, max_sequence_lengths, end_token 

241 ) 

242 except tf.errors.NotFoundError: 

243 options.warn_fallback("gather_tree") 

244 

245 step_ids = tf.convert_to_tensor(step_ids, dtype=tf.int32) 

246 parent_ids = tf.convert_to_tensor(parent_ids, dtype=tf.int32) 

247 max_sequence_lengths = tf.convert_to_tensor(max_sequence_lengths, dtype=tf.int32) 

248 end_token = tf.convert_to_tensor(end_token, dtype=tf.int32) 

249 return _gather_tree(step_ids, parent_ids, max_sequence_lengths, end_token) 

250 

251 

252def gather_tree_from_array( 

253 t: TensorLike, parent_ids: TensorLike, sequence_length: TensorLike 

254) -> tf.Tensor: 

255 """Calculates the full beams for a `TensorArray`. 

256 

257 Args: 

258 t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of 

259 shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]` 

260 where `s` is the depth shape. 

261 parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`. 

262 sequence_length: The sequence length of shape `[batch_size, beam_width]`. 

263 

264 Returns: 

265 A `Tensor` which is a stacked `TensorArray` of the same size and type as 

266 `t` and where beams are sorted in each `Tensor` according to 

267 `parent_ids`. 

268 """ 

269 max_time = parent_ids.shape[0] or tf.shape(parent_ids)[0] 

270 batch_size = parent_ids.shape[1] or tf.shape(parent_ids)[1] 

271 beam_width = parent_ids.shape[2] or tf.shape(parent_ids)[2] 

272 

273 # Generate beam ids that will be reordered by gather_tree. 

274 beam_ids = tf.reshape(tf.range(beam_width), [1, 1, -1]) 

275 beam_ids = tf.tile(beam_ids, [max_time, batch_size, 1]) 

276 

277 max_sequence_lengths = tf.cast(tf.reduce_max(sequence_length, axis=1), tf.int32) 

278 sorted_beam_ids = gather_tree( 

279 step_ids=beam_ids, 

280 parent_ids=parent_ids, 

281 max_sequence_lengths=max_sequence_lengths, 

282 end_token=beam_width + 1, 

283 ) 

284 

285 # For out of range steps, simply copy the same beam. 

286 in_bound_steps = tf.transpose( 

287 tf.sequence_mask(sequence_length, maxlen=max_time), perm=[2, 0, 1] 

288 ) 

289 sorted_beam_ids = tf.where(in_bound_steps, x=sorted_beam_ids, y=beam_ids) 

290 

291 # Gather from a tensor with collapsed additional dimensions. 

292 final_shape = tf.shape(t) 

293 gather_from = tf.reshape(t, [max_time, batch_size, beam_width, -1]) 

294 ordered = tf.gather(gather_from, sorted_beam_ids, axis=2, batch_dims=2) 

295 ordered = tf.reshape(ordered, final_shape) 

296 

297 return ordered 

298 

299 

300def _check_ndims(t): 

301 if t.shape.ndims is None: 

302 raise ValueError( 

303 "Expected tensor (%s) to have known rank, but ndims == None." % t 

304 ) 

305 

306 

307def _check_static_batch_beam_maybe(shape, batch_size, beam_width): 

308 """Raises an exception if dimensions are known statically and can not be 

309 reshaped to [batch_size, beam_size, -1].""" 

310 reshaped_shape = tf.TensorShape([batch_size, beam_width, None]) 

311 assert len(shape.dims) > 0 

312 if batch_size is None or shape[0] is None: 

313 return True # not statically known => no check 

314 if shape[0] == batch_size * beam_width: 

315 return True # flattened, matching 

316 has_second_dim = shape.ndims >= 2 and shape[1] is not None 

317 if has_second_dim and shape[0] == batch_size and shape[1] == beam_width: 

318 return True # non-flattened, matching 

319 # Otherwise we could not find a match and warn: 

320 tf.get_logger().warn( 

321 "TensorArray reordering expects elements to be " 

322 "reshapable to %s which is incompatible with the " 

323 "current shape %s. Consider setting " 

324 "reorder_tensor_arrays to False to disable TensorArray " 

325 "reordering during the beam search." % (reshaped_shape, shape) 

326 ) 

327 return False 

328 

329 

330def _check_batch_beam(t, batch_size, beam_width): 

331 """Returns an Assert operation checking that the elements of the stacked 

332 TensorArray can be reshaped to [batch_size, beam_size, -1]. 

333 

334 At this point, the TensorArray elements have a known rank of at 

335 least 1. 

336 """ 

337 error_message = ( 

338 "TensorArray reordering expects elements to be " 

339 "reshapable to [batch_size, beam_size, -1] which is " 

340 "incompatible with the dynamic shape of %s elements. " 

341 "Consider setting reorder_tensor_arrays to False to disable " 

342 "TensorArray reordering during the beam search." 

343 % (t if tf.executing_eagerly() else t.name) 

344 ) 

345 rank = t.shape.ndims 

346 shape = tf.shape(t) 

347 if rank == 2: 

348 condition = tf.equal(shape[1], batch_size * beam_width) 

349 else: 

350 condition = tf.logical_or( 

351 tf.equal(shape[1], batch_size * beam_width), 

352 tf.logical_and( 

353 tf.equal(shape[1], batch_size), tf.equal(shape[2], beam_width) 

354 ), 

355 ) 

356 return tf.Assert(condition, [error_message]) 

357 

358 

359def _as_shape(value): 

360 """Converts the argument to a TensorShape if not already one.""" 

361 if not isinstance(value, tf.TensorShape): 

362 if isinstance(value, tf.Tensor): 

363 value = tf.get_static_value(value) 

364 value = tf.TensorShape(value) 

365 return value 

366 

367 

368class BeamSearchDecoderMixin: 

369 """BeamSearchDecoderMixin contains the common methods for 

370 BeamSearchDecoder. 

371 

372 It is expected to be used a base class for concrete 

373 BeamSearchDecoder. Since this is a mixin class, it is expected to be 

374 used together with other class as base. 

375 """ 

376 

377 @typechecked 

378 def __init__( 

379 self, 

380 cell: tf.keras.layers.Layer, 

381 beam_width: int, 

382 output_layer: Optional[tf.keras.layers.Layer] = None, 

383 length_penalty_weight: FloatTensorLike = 0.0, 

384 coverage_penalty_weight: FloatTensorLike = 0.0, 

385 reorder_tensor_arrays: bool = True, 

386 output_all_scores: bool = False, 

387 **kwargs, 

388 ): 

389 """Initialize the BeamSearchDecoderMixin. 

390 

391 Args: 

392 cell: A layer that implements the `tf.keras.layers.AbstractRNNCell` 

393 interface. 

394 beam_width: Python integer, the number of beams. 

395 output_layer: (Optional) An instance of `tf.keras.layers.Layer`, 

396 i.e., `tf.keras.layers.Dense`. Optional layer to apply to the RNN 

397 output prior to storing the result or sampling. 

398 length_penalty_weight: Float weight to penalize length. Disabled with 

399 0.0. 

400 coverage_penalty_weight: Float weight to penalize the coverage of 

401 source sentence. Disabled with 0.0. 

402 reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the 

403 cell state will be reordered according to the beam search path. If 

404 the `TensorArray` can be reordered, the stacked form will be 

405 returned. Otherwise, the `TensorArray` will be returned as is. Set 

406 this flag to `False` if the cell state contains `TensorArray`s that 

407 are not amenable to reordering. 

408 output_all_scores: If `True`, `BeamSearchDecoderOutput.scores` will 

409 contain scores for all token IDs and be of shape 

410 `[batch_size, beam_width, vocab_size]`. When `False` (default), 

411 only the top score corresponding to the predicted token will be 

412 output with shape `[batch_size, beam_width]`. 

413 **kwargs: Dict, other keyword arguments for parent class. 

414 """ 

415 keras_utils.assert_like_rnncell("cell", cell) 

416 self._cell = cell 

417 self._output_layer = output_layer 

418 self._reorder_tensor_arrays = reorder_tensor_arrays 

419 self._output_all_scores = output_all_scores 

420 

421 self._start_tokens = None 

422 self._end_token = None 

423 self._batch_size = None 

424 self._beam_width = beam_width 

425 self._length_penalty_weight = length_penalty_weight 

426 self._coverage_penalty_weight = coverage_penalty_weight 

427 super().__init__(**kwargs) 

428 

429 @property 

430 def batch_size(self): 

431 return self._batch_size 

432 

433 def _rnn_output_size(self): 

434 """Get the output shape from the RNN layer.""" 

435 size = self._cell.output_size 

436 if self._output_layer is None: 

437 return size 

438 else: 

439 # To use layer's compute_output_shape, we need to convert the 

440 # RNNCell's output_size entries into shapes with an unknown 

441 # batch size. We then pass this through the layer's 

442 # compute_output_shape and read off all but the first (batch) 

443 # dimensions to get the output size of the rnn with the layer 

444 # applied to the top. 

445 output_shape_with_unknown_batch = tf.nest.map_structure( 

446 lambda s: tf.TensorShape([None]).concatenate(s), size 

447 ) 

448 layer_output_shape = self._output_layer.compute_output_shape( 

449 output_shape_with_unknown_batch 

450 ) 

451 return tf.nest.map_structure(lambda s: s[1:], layer_output_shape) 

452 

453 @property 

454 def tracks_own_finished(self): 

455 """The BeamSearchDecoder shuffles its beams and their finished state. 

456 

457 For this reason, it conflicts with the `dynamic_decode` function's 

458 tracking of finished states. Setting this property to true avoids 

459 early stopping of decoding due to mismanagement of the finished state 

460 in `dynamic_decode`. 

461 

462 Returns: 

463 `True`. 

464 """ 

465 return True 

466 

467 @property 

468 def output_size(self): 

469 # Return the cell output and the id 

470 score_size = ( 

471 tf.TensorShape([self._beam_width, self._rnn_output_size()[-1]]) 

472 if self._output_all_scores 

473 else tf.TensorShape([self._beam_width]) 

474 ) 

475 return BeamSearchDecoderOutput( 

476 scores=score_size, 

477 predicted_ids=tf.TensorShape([self._beam_width]), 

478 parent_ids=tf.TensorShape([self._beam_width]), 

479 ) 

480 

481 def finalize(self, outputs, final_state, sequence_lengths): 

482 """Finalize and return the predicted_ids. 

483 

484 Args: 

485 outputs: An instance of BeamSearchDecoderOutput. 

486 final_state: An instance of BeamSearchDecoderState. Passed through to 

487 the output. 

488 sequence_lengths: An `int64` tensor shaped 

489 `[batch_size, beam_width]`. The sequence lengths determined for 

490 each beam during decode. **NOTE** These are ignored; the updated 

491 sequence lengths are stored in `final_state.lengths`. 

492 

493 Returns: 

494 outputs: An instance of `FinalBeamSearchDecoderOutput` where the 

495 predicted_ids are the result of calling _gather_tree. 

496 final_state: The same input instance of `BeamSearchDecoderState`. 

497 """ 

498 del sequence_lengths 

499 # Get max_sequence_length across all beams for each batch. 

500 max_sequence_lengths = tf.cast( 

501 tf.reduce_max(final_state.lengths, axis=1), tf.int32 

502 ) 

503 predicted_ids = gather_tree( 

504 outputs.predicted_ids, 

505 outputs.parent_ids, 

506 max_sequence_lengths=max_sequence_lengths, 

507 end_token=self._end_token, 

508 ) 

509 if self._reorder_tensor_arrays: 

510 final_state = final_state._replace( 

511 cell_state=tf.nest.map_structure( 

512 lambda t: self._maybe_sort_array_beams( 

513 t, outputs.parent_ids, final_state.lengths 

514 ), 

515 final_state.cell_state, 

516 ) 

517 ) 

518 outputs = FinalBeamSearchDecoderOutput( 

519 beam_search_decoder_output=outputs, predicted_ids=predicted_ids 

520 ) 

521 return outputs, final_state 

522 

523 def _merge_batch_beams(self, t, s=None): 

524 """Merges the tensor from a batch of beams into a batch by beams. 

525 

526 More exactly, t is a tensor of dimension [batch_size, beam_width, s]. 

527 We reshape this into [batch_size*beam_width, s] 

528 

529 Args: 

530 t: Tensor of dimension [batch_size, beam_width, s] 

531 s: (Possibly known) depth shape. 

532 

533 Returns: 

534 A reshaped version of t with dimension [batch_size * beam_width, s]. 

535 """ 

536 s = _as_shape(s) 

537 t_shape = tf.shape(t) 

538 static_batch_size = tf.get_static_value(self._batch_size) 

539 batch_size_beam_width = ( 

540 None if static_batch_size is None else static_batch_size * self._beam_width 

541 ) 

542 reshaped_t = tf.reshape( 

543 t, tf.concat(([self._batch_size * self._beam_width], t_shape[2:]), 0) 

544 ) 

545 reshaped_t.set_shape(tf.TensorShape([batch_size_beam_width]).concatenate(s)) 

546 return reshaped_t 

547 

548 def _split_batch_beams(self, t, s=None): 

549 """Splits the tensor from a batch by beams into a batch of beams. 

550 

551 More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We 

552 reshape this into [batch_size, beam_width, s] 

553 

554 Args: 

555 t: Tensor of dimension [batch_size*beam_width, s]. 

556 s: (Possibly known) depth shape. 

557 

558 Returns: 

559 A reshaped version of t with dimension [batch_size, beam_width, s]. 

560 

561 Raises: 

562 ValueError: If, after reshaping, the new tensor is not shaped 

563 `[batch_size, beam_width, s]` (assuming batch_size and beam_width 

564 are known statically). 

565 """ 

566 s = _as_shape(s) 

567 t_shape = tf.shape(t) 

568 reshaped_t = tf.reshape( 

569 t, tf.concat(([self._batch_size, self._beam_width], t_shape[1:]), 0) 

570 ) 

571 static_batch_size = tf.get_static_value(self._batch_size) 

572 expected_reshaped_shape = tf.TensorShape( 

573 [static_batch_size, self._beam_width] 

574 ).concatenate(s) 

575 if not reshaped_t.shape.is_compatible_with(expected_reshaped_shape): 

576 raise ValueError( 

577 "Unexpected behavior when reshaping between beam width " 

578 "and batch size. The reshaped tensor has shape: %s. " 

579 "We expected it to have shape " 

580 "(batch_size, beam_width, depth) == %s. Perhaps you " 

581 "forgot to call get_initial_state with " 

582 "batch_size=encoder_batch_size * beam_width?" 

583 % (reshaped_t.shape, expected_reshaped_shape) 

584 ) 

585 reshaped_t.set_shape(expected_reshaped_shape) 

586 return reshaped_t 

587 

588 def _maybe_split_batch_beams(self, t, s): 

589 """Maybe splits the tensor from a batch by beams into a batch of beams. 

590 

591 We do this so that we can use nest and not run into problems with 

592 shapes. 

593 

594 Args: 

595 t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`. 

596 s: `Tensor`, Python int, or `TensorShape`. 

597 

598 Returns: 

599 If `t` is a matrix or higher order tensor, then the return value is 

600 `t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is 

601 returned unchanged. 

602 

603 Raises: 

604 ValueError: If the rank of `t` is not statically known. 

605 """ 

606 if isinstance(t, tf.TensorArray): 

607 return t 

608 _check_ndims(t) 

609 if t.shape.ndims >= 1: 

610 return self._split_batch_beams(t, s) 

611 else: 

612 return t 

613 

614 def _maybe_merge_batch_beams(self, t, s): 

615 """Splits the tensor from a batch by beams into a batch of beams. 

616 

617 More exactly, `t` is a tensor of dimension 

618 `[batch_size * beam_width] + s`, then we reshape it to 

619 `[batch_size, beam_width] + s`. 

620 

621 Args: 

622 t: `Tensor` of dimension `[batch_size * beam_width] + s`. 

623 s: `Tensor`, Python int, or `TensorShape`. 

624 

625 Returns: 

626 A reshaped version of t with shape `[batch_size, beam_width] + s`. 

627 

628 Raises: 

629 ValueError: If the rank of `t` is not statically known. 

630 """ 

631 if isinstance(t, tf.TensorArray): 

632 return t 

633 _check_ndims(t) 

634 if t.shape.ndims >= 2: 

635 return self._merge_batch_beams(t, s) 

636 else: 

637 return t 

638 

639 def _maybe_sort_array_beams(self, t, parent_ids, sequence_length): 

640 """Maybe sorts beams within a `TensorArray`. 

641 

642 Args: 

643 t: A `TensorArray` of size `max_time` that contains `Tensor`s of 

644 shape `[batch_size, beam_width, s]` or 

645 `[batch_size * beam_width, s]` where `s` is the depth shape. 

646 parent_ids: The parent ids of shape 

647 `[max_time, batch_size, beam_width]`. 

648 sequence_length: The sequence length of shape 

649 `[batch_size, beam_width]`. 

650 

651 Returns: 

652 A `TensorArray` where beams are sorted in each `Tensor` or `t` itself 

653 if it is not a `TensorArray` or does not meet shape requirements. 

654 """ 

655 if not isinstance(t, tf.TensorArray): 

656 return t 

657 if t.element_shape.ndims is None or t.element_shape.ndims < 1: 

658 tf.get_logger().warn( 

659 "The TensorArray %s in the cell state is not amenable to " 

660 "sorting based on the beam search result. For a " 

661 "TensorArray to be sorted, its elements shape must be " 

662 "defined and have at least a rank of 1, but saw shape: %s" 

663 % (t.handle.name, t.element_shape) 

664 ) 

665 return t 

666 if not _check_static_batch_beam_maybe( 

667 t.element_shape, tf.get_static_value(self._batch_size), self._beam_width 

668 ): 

669 return t 

670 t = t.stack() 

671 with tf.control_dependencies( 

672 [_check_batch_beam(t, self._batch_size, self._beam_width)] 

673 ): 

674 return gather_tree_from_array(t, parent_ids, sequence_length) 

675 

676 def step(self, time, inputs, state, training=None, name=None): 

677 """Perform a decoding step. 

678 

679 Args: 

680 time: scalar `int32` tensor. 

681 inputs: A (structure of) input tensors. 

682 state: A (structure of) state tensors and TensorArrays. 

683 training: Python boolean. Indicates whether the layer should 

684 behave in training mode or in inference mode. Only relevant 

685 when `dropout` or `recurrent_dropout` is used. 

686 name: Name scope for any created operations. 

687 

688 Returns: 

689 `(outputs, next_state, next_inputs, finished)`. 

690 """ 

691 batch_size = self._batch_size 

692 beam_width = self._beam_width 

693 end_token = self._end_token 

694 length_penalty_weight = self._length_penalty_weight 

695 coverage_penalty_weight = self._coverage_penalty_weight 

696 

697 with tf.name_scope(name or "BeamSearchDecoderStep"): 

698 cell_state = state.cell_state 

699 inputs = tf.nest.map_structure( 

700 lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs 

701 ) 

702 cell_state = tf.nest.map_structure( 

703 self._maybe_merge_batch_beams, cell_state, self._cell.state_size 

704 ) 

705 cell_outputs, next_cell_state = self._cell( 

706 inputs, cell_state, training=training 

707 ) 

708 cell_outputs = tf.nest.map_structure( 

709 lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs 

710 ) 

711 next_cell_state = tf.nest.pack_sequence_as( 

712 cell_state, tf.nest.flatten(next_cell_state) 

713 ) 

714 next_cell_state = tf.nest.map_structure( 

715 self._maybe_split_batch_beams, next_cell_state, self._cell.state_size 

716 ) 

717 

718 if self._output_layer is not None: 

719 cell_outputs = self._output_layer(cell_outputs) 

720 

721 beam_search_output, beam_search_state = _beam_search_step( 

722 time=time, 

723 logits=cell_outputs, 

724 next_cell_state=next_cell_state, 

725 beam_state=state, 

726 batch_size=batch_size, 

727 beam_width=beam_width, 

728 end_token=end_token, 

729 length_penalty_weight=length_penalty_weight, 

730 coverage_penalty_weight=coverage_penalty_weight, 

731 output_all_scores=self._output_all_scores, 

732 ) 

733 

734 finished = beam_search_state.finished 

735 sample_ids = beam_search_output.predicted_ids 

736 next_inputs = tf.cond( 

737 tf.reduce_all(finished), 

738 lambda: self._start_inputs, 

739 lambda: self._embedding_fn(sample_ids), 

740 ) 

741 

742 return (beam_search_output, beam_search_state, next_inputs, finished) 

743 

744 

745class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.BaseDecoder): 

746 # Note that the inheritance hierarchy is important here. The Mixin has to be 

747 # the first parent class since we will use super().__init__(), and Mixin 

748 # which is a object will properly invoke the __init__ method of other parent 

749 # class. 

750 """Beam search decoder. 

751 

752 **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in 

753 `tfa.seq2seq.AttentionWrapper`, then you must ensure that: 

754 

755 - The encoder output has been tiled to `beam_width` via 

756 `tfa.seq2seq.tile_batch` (NOT `tf.tile`). 

757 - The `batch_size` argument passed to the `get_initial_state` method of 

758 this wrapper is equal to `true_batch_size * beam_width`. 

759 - The initial state created with `get_initial_state` above contains a 

760 `cell_state` value containing properly tiled final state from the 

761 encoder. 

762 

763 An example: 

764 

765 ``` 

766 tiled_encoder_outputs = tfa.seq2seq.tile_batch( 

767 encoder_outputs, multiplier=beam_width) 

768 tiled_encoder_final_state = tfa.seq2seq.tile_batch( 

769 encoder_final_state, multiplier=beam_width) 

770 tiled_sequence_length = tfa.seq2seq.tile_batch( 

771 sequence_length, multiplier=beam_width) 

772 attention_mechanism = MyFavoriteAttentionMechanism( 

773 num_units=attention_depth, 

774 memory=tiled_inputs, 

775 memory_sequence_length=tiled_sequence_length) 

776 attention_cell = AttentionWrapper(cell, attention_mechanism, ...) 

777 decoder_initial_state = attention_cell.get_initial_state( 

778 batch_size=true_batch_size * beam_width, dtype=dtype) 

779 decoder_initial_state = decoder_initial_state.clone( 

780 cell_state=tiled_encoder_final_state) 

781 ``` 

782 

783 Meanwhile, with `tfa.seq2seq.AttentionWrapper`, coverage penalty is suggested to use 

784 when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages 

785 the decoding to cover all inputs. 

786 """ 

787 

788 @typechecked 

789 def __init__( 

790 self, 

791 cell: tf.keras.layers.Layer, 

792 beam_width: int, 

793 embedding_fn: Optional[Callable] = None, 

794 output_layer: Optional[tf.keras.layers.Layer] = None, 

795 length_penalty_weight: FloatTensorLike = 0.0, 

796 coverage_penalty_weight: FloatTensorLike = 0.0, 

797 reorder_tensor_arrays: bool = True, 

798 **kwargs, 

799 ): 

800 """Initialize the BeamSearchDecoder. 

801 

802 Args: 

803 cell: A layer that implements the `tf.keras.layers.AbstractRNNCell` 

804 interface. 

805 beam_width: Python integer, the number of beams. 

806 embedding_fn: A callable that takes a `int32` `Tensor` of token IDs 

807 and returns embedding tensors. If set, the `embedding` argument in 

808 the decoder call should be set to `None`. 

809 output_layer: (Optional) An instance of `tf.keras.layers.Layer`, 

810 i.e., `tf.keras.layers.Dense`. Optional layer to apply to the RNN 

811 output prior to storing the result or sampling. 

812 length_penalty_weight: Float weight to penalize length. Disabled with 

813 0.0. 

814 coverage_penalty_weight: Float weight to penalize the coverage of 

815 source sentence. Disabled with 0.0. 

816 reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the 

817 cell state will be reordered according to the beam search path. If 

818 the `TensorArray` can be reordered, the stacked form will be 

819 returned. Otherwise, the `TensorArray` will be returned as is. Set 

820 this flag to `False` if the cell state contains `TensorArray`s that 

821 are not amenable to reordering. 

822 **kwargs: Dict, other keyword arguments for initialization. 

823 """ 

824 super().__init__( 

825 cell, 

826 beam_width, 

827 output_layer=output_layer, 

828 length_penalty_weight=length_penalty_weight, 

829 coverage_penalty_weight=coverage_penalty_weight, 

830 reorder_tensor_arrays=reorder_tensor_arrays, 

831 **kwargs, 

832 ) 

833 

834 self._embedding_fn = embedding_fn 

835 

836 def initialize(self, embedding, start_tokens, end_token, initial_state): 

837 """Initialize the decoder. 

838 

839 Args: 

840 embedding: A `Tensor` (or `Variable`) to pass as the `params` argument 

841 for `tf.nn.embedding_lookup`. This overrides `embedding_fn` set in 

842 the constructor. 

843 start_tokens: Start the decoding from these tokens. 

844 A `int32` `Tensor` of shape `[batch_size]`. 

845 end_token: The token that marks the end of decoding. 

846 A `int32` scalar `Tensor`. 

847 initial_state: The initial cell state as a (possibly nested) structure 

848 of `Tensor` and `TensorArray`. 

849 

850 Returns: 

851 `(finished, start_inputs, initial_state)`. 

852 

853 Raises: 

854 ValueError: If `embedding` is `None` and `embedding_fn` was not set 

855 in the constructor. 

856 ValueError: If `start_tokens` is not a vector or `end_token` is not a 

857 scalar. 

858 """ 

859 if embedding is not None: 

860 self._embedding_fn = lambda ids: tf.nn.embedding_lookup(embedding, ids) 

861 elif self._embedding_fn is None: 

862 raise ValueError( 

863 "You should either pass an embedding variable when calling the " 

864 "BeamSearchDecoder or set embedding_fn in the constructor." 

865 ) 

866 

867 self._start_tokens = tf.convert_to_tensor( 

868 start_tokens, dtype=tf.int32, name="start_tokens" 

869 ) 

870 if self._start_tokens.shape.ndims != 1: 

871 raise ValueError("start_tokens must be a vector") 

872 self._end_token = tf.convert_to_tensor( 

873 end_token, dtype=tf.int32, name="end_token" 

874 ) 

875 if self._end_token.shape.ndims != 0: 

876 raise ValueError("end_token must be a scalar") 

877 

878 self._batch_size = tf.size(start_tokens) 

879 self._initial_cell_state = tf.nest.map_structure( 

880 self._maybe_split_batch_beams, initial_state, self._cell.state_size 

881 ) 

882 self._start_tokens = tf.tile( 

883 tf.expand_dims(self._start_tokens, 1), [1, self._beam_width] 

884 ) 

885 self._start_inputs = self._embedding_fn(self._start_tokens) 

886 

887 self._finished = tf.one_hot( 

888 tf.zeros([self._batch_size], dtype=tf.int32), 

889 depth=self._beam_width, 

890 on_value=False, 

891 off_value=True, 

892 dtype=tf.bool, 

893 ) 

894 

895 finished, start_inputs = self._finished, self._start_inputs 

896 

897 dtype = tf.nest.flatten(self._initial_cell_state)[0].dtype 

898 log_probs = tf.one_hot( # shape(batch_sz, beam_sz) 

899 tf.zeros([self._batch_size], dtype=tf.int32), 

900 depth=self._beam_width, 

901 on_value=tf.convert_to_tensor(0.0, dtype=dtype), 

902 off_value=tf.convert_to_tensor(-np.Inf, dtype=dtype), 

903 dtype=dtype, 

904 ) 

905 init_attention_probs = get_attention_probs( 

906 self._initial_cell_state, self._coverage_penalty_weight 

907 ) 

908 if init_attention_probs is None: 

909 init_attention_probs = () 

910 

911 initial_state = BeamSearchDecoderState( 

912 cell_state=self._initial_cell_state, 

913 log_probs=log_probs, 

914 finished=finished, 

915 lengths=tf.zeros([self._batch_size, self._beam_width], dtype=tf.int64), 

916 accumulated_attention_probs=init_attention_probs, 

917 ) 

918 

919 return (finished, start_inputs, initial_state) 

920 

921 @property 

922 def output_dtype(self): 

923 # Assume the dtype of the cell is the output_size structure 

924 # containing the input_state's first component's dtype. 

925 # Return that structure and int32 (the id) 

926 dtype = tf.nest.flatten(self._initial_cell_state)[0].dtype 

927 return BeamSearchDecoderOutput( 

928 scores=tf.nest.map_structure(lambda _: dtype, self._rnn_output_size()), 

929 predicted_ids=tf.int32, 

930 parent_ids=tf.int32, 

931 ) 

932 

933 def call( 

934 self, embedding, start_tokens, end_token, initial_state, training=None, **kwargs 

935 ): 

936 init_kwargs = kwargs 

937 init_kwargs["start_tokens"] = start_tokens 

938 init_kwargs["end_token"] = end_token 

939 init_kwargs["initial_state"] = initial_state 

940 return decoder.dynamic_decode( 

941 self, 

942 output_time_major=self.output_time_major, 

943 impute_finished=self.impute_finished, 

944 maximum_iterations=self.maximum_iterations, 

945 parallel_iterations=self.parallel_iterations, 

946 swap_memory=self.swap_memory, 

947 training=training, 

948 decoder_init_input=embedding, 

949 decoder_init_kwargs=init_kwargs, 

950 ) 

951 

952 

953def _beam_search_step( 

954 time, 

955 logits, 

956 next_cell_state, 

957 beam_state, 

958 batch_size, 

959 beam_width, 

960 end_token, 

961 length_penalty_weight, 

962 coverage_penalty_weight, 

963 output_all_scores, 

964): 

965 """Performs a single step of Beam Search Decoding. 

966 

967 Args: 

968 time: Beam search time step, should start at 0. At time 0 we assume 

969 that all beams are equal and consider only the first beam for 

970 continuations. 

971 logits: Logits at the current time step. A tensor of shape 

972 `[batch_size, beam_width, vocab_size]` 

973 next_cell_state: The next state from the cell, e.g. an instance of 

974 AttentionWrapperState if the cell is attentional. 

975 beam_state: Current state of the beam search. 

976 An instance of `BeamSearchDecoderState`. 

977 batch_size: The batch size for this input. 

978 beam_width: Python int. The size of the beams. 

979 end_token: The int32 end token. 

980 length_penalty_weight: Float weight to penalize length. Disabled with 

981 0.0. 

982 coverage_penalty_weight: Float weight to penalize the coverage of source 

983 sentence. Disabled with 0.0. 

984 output_all_scores: Bool output scores for every token if True, else only 

985 output the top scores. 

986 

987 Returns: 

988 A new beam state. 

989 """ 

990 static_batch_size = tf.get_static_value(batch_size) 

991 

992 # Calculate the current lengths of the predictions 

993 prediction_lengths = beam_state.lengths 

994 previously_finished = beam_state.finished 

995 not_finished = tf.logical_not(previously_finished) 

996 

997 # Calculate the total log probs for the new hypotheses 

998 # Final Shape: [batch_size, beam_width, vocab_size] 

999 step_log_probs = tf.nn.log_softmax(logits) 

1000 step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished) 

1001 total_probs = tf.expand_dims(beam_state.log_probs, 2) + step_log_probs 

1002 

1003 # Calculate the continuation lengths by adding to all continuing beams. 

1004 vocab_size = logits.shape[-1] or tf.shape(logits)[-1] 

1005 lengths_to_add = tf.one_hot( 

1006 indices=tf.fill([batch_size, beam_width], end_token), 

1007 depth=vocab_size, 

1008 on_value=np.int64(0), 

1009 off_value=np.int64(1), 

1010 dtype=tf.int64, 

1011 ) 

1012 add_mask = tf.cast(not_finished, tf.int64) 

1013 lengths_to_add *= tf.expand_dims(add_mask, 2) 

1014 new_prediction_lengths = lengths_to_add + tf.expand_dims(prediction_lengths, 2) 

1015 

1016 # Calculate the accumulated attention probabilities if coverage penalty is 

1017 # enabled. 

1018 accumulated_attention_probs = None 

1019 attention_probs = get_attention_probs(next_cell_state, coverage_penalty_weight) 

1020 if attention_probs is not None: 

1021 attention_probs *= tf.expand_dims(tf.cast(not_finished, tf.float32), 2) 

1022 accumulated_attention_probs = ( 

1023 beam_state.accumulated_attention_probs + attention_probs 

1024 ) 

1025 

1026 # Calculate the scores for each beam 

1027 scores = _get_scores( 

1028 log_probs=total_probs, 

1029 sequence_lengths=new_prediction_lengths, 

1030 length_penalty_weight=length_penalty_weight, 

1031 coverage_penalty_weight=coverage_penalty_weight, 

1032 finished=previously_finished, 

1033 accumulated_attention_probs=accumulated_attention_probs, 

1034 ) 

1035 

1036 time = tf.convert_to_tensor(time, name="time") 

1037 # During the first time step we only consider the initial beam 

1038 scores_flat = tf.reshape(scores, [batch_size, -1]) 

1039 

1040 # Pick the next beams according to the specified successors function 

1041 next_beam_size = tf.convert_to_tensor(beam_width, dtype=tf.int32, name="beam_width") 

1042 next_beam_scores, word_indices = tf.math.top_k(scores_flat, k=next_beam_size) 

1043 

1044 next_beam_scores.set_shape([static_batch_size, beam_width]) 

1045 word_indices.set_shape([static_batch_size, beam_width]) 

1046 

1047 # Pick out the probs, beam_ids, and states according to the chosen 

1048 # predictions 

1049 next_beam_probs = _tensor_gather_helper( 

1050 gather_indices=word_indices, 

1051 gather_from=total_probs, 

1052 batch_size=batch_size, 

1053 range_size=beam_width * vocab_size, 

1054 gather_shape=[-1], 

1055 name="next_beam_probs", 

1056 ) 

1057 # Note: just doing the following 

1058 # tf.to_int32(word_indices % vocab_size, 

1059 # name="next_beam_word_ids") 

1060 # would be a lot cleaner but for reasons unclear, that hides the results of 

1061 # the op which prevents capturing it with tfdbg debug ops. 

1062 raw_next_word_ids = tf.math.floormod( 

1063 word_indices, vocab_size, name="next_beam_word_ids" 

1064 ) 

1065 next_word_ids = tf.cast(raw_next_word_ids, tf.int32) 

1066 next_beam_ids = tf.cast( 

1067 word_indices / vocab_size, tf.int32, name="next_beam_parent_ids" 

1068 ) 

1069 

1070 # Append new ids to current predictions 

1071 previously_finished = _tensor_gather_helper( 

1072 gather_indices=next_beam_ids, 

1073 gather_from=previously_finished, 

1074 batch_size=batch_size, 

1075 range_size=beam_width, 

1076 gather_shape=[-1], 

1077 ) 

1078 next_finished = tf.logical_or( 

1079 previously_finished, 

1080 tf.equal(next_word_ids, end_token), 

1081 name="next_beam_finished", 

1082 ) 

1083 

1084 # Calculate the length of the next predictions. 

1085 # 1. Finished beams remain unchanged. 

1086 # 2. Beams that are now finished (EOS predicted) have their length 

1087 # increased by 1. 

1088 # 3. Beams that are not yet finished have their length increased by 1. 

1089 lengths_to_add = tf.cast(tf.logical_not(previously_finished), tf.int64) 

1090 next_prediction_len = _tensor_gather_helper( 

1091 gather_indices=next_beam_ids, 

1092 gather_from=beam_state.lengths, 

1093 batch_size=batch_size, 

1094 range_size=beam_width, 

1095 gather_shape=[-1], 

1096 ) 

1097 next_prediction_len += lengths_to_add 

1098 next_accumulated_attention_probs = () 

1099 if accumulated_attention_probs is not None: 

1100 next_accumulated_attention_probs = _tensor_gather_helper( 

1101 gather_indices=next_beam_ids, 

1102 gather_from=accumulated_attention_probs, 

1103 batch_size=batch_size, 

1104 range_size=beam_width, 

1105 gather_shape=[batch_size * beam_width, -1], 

1106 name="next_accumulated_attention_probs", 

1107 ) 

1108 

1109 # Pick out the cell_states according to the next_beam_ids. We use a 

1110 # different gather_shape here because the cell_state tensors, i.e. 

1111 # the tensors that would be gathered from, all have dimension 

1112 # greater than two and we need to preserve those dimensions. 

1113 next_cell_state = tf.nest.map_structure( 

1114 lambda gather_from: _maybe_tensor_gather_helper( 

1115 gather_indices=next_beam_ids, 

1116 gather_from=gather_from, 

1117 batch_size=batch_size, 

1118 range_size=beam_width, 

1119 gather_shape=[batch_size * beam_width, -1], 

1120 ), 

1121 next_cell_state, 

1122 ) 

1123 

1124 next_state = BeamSearchDecoderState( 

1125 cell_state=next_cell_state, 

1126 log_probs=next_beam_probs, 

1127 lengths=next_prediction_len, 

1128 finished=next_finished, 

1129 accumulated_attention_probs=next_accumulated_attention_probs, 

1130 ) 

1131 

1132 output = BeamSearchDecoderOutput( 

1133 scores=scores if output_all_scores else next_beam_scores, 

1134 predicted_ids=next_word_ids, 

1135 parent_ids=next_beam_ids, 

1136 ) 

1137 

1138 return output, next_state 

1139 

1140 

1141def get_attention_probs(next_cell_state, coverage_penalty_weight): 

1142 """Get attention probabilities from the cell state. 

1143 

1144 Args: 

1145 next_cell_state: The next state from the cell, e.g. an instance of 

1146 AttentionWrapperState if the cell is attentional. 

1147 coverage_penalty_weight: Float weight to penalize the coverage of source 

1148 sentence. Disabled with 0.0. 

1149 

1150 Returns: 

1151 The attention probabilities with shape 

1152 `[batch_size, beam_width, max_time]` if coverage penalty is enabled. 

1153 Otherwise, returns None. 

1154 

1155 Raises: 

1156 ValueError: If no cell is attentional but coverage penalty is enabled. 

1157 """ 

1158 if coverage_penalty_weight == 0.0: 

1159 return None 

1160 

1161 # Attention probabilities of each attention layer. Each with shape 

1162 # `[batch_size, beam_width, max_time]`. 

1163 probs_per_attn_layer = [] 

1164 if isinstance(next_cell_state, attention_wrapper.AttentionWrapperState): 

1165 probs_per_attn_layer = [attention_probs_from_attn_state(next_cell_state)] 

1166 elif isinstance(next_cell_state, tuple): 

1167 for state in next_cell_state: 

1168 if isinstance(state, attention_wrapper.AttentionWrapperState): 

1169 probs_per_attn_layer.append(attention_probs_from_attn_state(state)) 

1170 

1171 if not probs_per_attn_layer: 

1172 raise ValueError( 

1173 "coverage_penalty_weight must be 0.0 if no cell is attentional." 

1174 ) 

1175 

1176 if len(probs_per_attn_layer) == 1: 

1177 attention_probs = probs_per_attn_layer[0] 

1178 else: 

1179 # Calculate the average attention probabilities from all attention 

1180 # layers. 

1181 attention_probs = [tf.expand_dims(prob, -1) for prob in probs_per_attn_layer] 

1182 attention_probs = tf.concat(attention_probs, -1) 

1183 attention_probs = tf.reduce_mean(attention_probs, -1) 

1184 

1185 return attention_probs 

1186 

1187 

1188def _get_scores( 

1189 log_probs, 

1190 sequence_lengths, 

1191 length_penalty_weight, 

1192 coverage_penalty_weight, 

1193 finished, 

1194 accumulated_attention_probs, 

1195): 

1196 """Calculates scores for beam search hypotheses. 

1197 

1198 Args: 

1199 log_probs: The log probabilities with shape 

1200 `[batch_size, beam_width, vocab_size]`. 

1201 sequence_lengths: The array of sequence lengths. 

1202 length_penalty_weight: Float weight to penalize length. Disabled with 

1203 0.0. 

1204 coverage_penalty_weight: Float weight to penalize the coverage of source 

1205 sentence. Disabled with 0.0. 

1206 finished: A boolean tensor of shape `[batch_size, beam_width]` that 

1207 specifies which elements in the beam are finished already. 

1208 accumulated_attention_probs: Accumulated attention probabilities up to 

1209 the current time step, with shape `[batch_size, beam_width, max_time]` 

1210 if coverage_penalty_weight is not 0.0. 

1211 

1212 Returns: 

1213 The scores normalized by the length_penalty and coverage_penalty. 

1214 

1215 Raises: 

1216 ValueError: accumulated_attention_probs is None when coverage penalty is 

1217 enabled. 

1218 """ 

1219 length_penalty_ = _length_penalty( 

1220 sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight 

1221 ) 

1222 length_penalty_ = tf.cast(length_penalty_, dtype=log_probs.dtype) 

1223 scores = log_probs / length_penalty_ 

1224 

1225 coverage_penalty_weight = tf.convert_to_tensor( 

1226 coverage_penalty_weight, name="coverage_penalty_weight" 

1227 ) 

1228 if coverage_penalty_weight.shape.ndims != 0: 

1229 raise ValueError( 

1230 "coverage_penalty_weight should be a scalar, " 

1231 "but saw shape: %s" % coverage_penalty_weight.shape 

1232 ) 

1233 

1234 if tf.get_static_value(coverage_penalty_weight) == 0.0: 

1235 return scores 

1236 

1237 if accumulated_attention_probs is None: 

1238 raise ValueError( 

1239 "accumulated_attention_probs can be None only if coverage penalty " 

1240 "is disabled." 

1241 ) 

1242 

1243 # Add source sequence length mask before computing coverage penalty. 

1244 accumulated_attention_probs = tf.where( 

1245 tf.equal(accumulated_attention_probs, 0.0), 

1246 tf.ones_like(accumulated_attention_probs), 

1247 accumulated_attention_probs, 

1248 ) 

1249 

1250 # coverage penalty = 

1251 # sum over `max_time` {log(min(accumulated_attention_probs, 1.0))} 

1252 coverage_penalty = tf.reduce_sum( 

1253 tf.math.log(tf.minimum(accumulated_attention_probs, 1.0)), 2 

1254 ) 

1255 # Apply coverage penalty to finished predictions. 

1256 coverage_penalty *= tf.cast(finished, tf.float32) 

1257 weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight 

1258 # Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1] 

1259 weighted_coverage_penalty = tf.expand_dims(weighted_coverage_penalty, 2) 

1260 return scores + weighted_coverage_penalty 

1261 

1262 

1263def attention_probs_from_attn_state(attention_state): 

1264 """Calculates the average attention probabilities. 

1265 

1266 Args: 

1267 attention_state: An instance of `AttentionWrapperState`. 

1268 

1269 Returns: 

1270 The attention probabilities in the given AttentionWrapperState. 

1271 If there're multiple attention mechanisms, return the average value from 

1272 all attention mechanisms. 

1273 """ 

1274 # Attention probabilities over time steps, with shape 

1275 # `[batch_size, beam_width, max_time]`. 

1276 attention_probs = attention_state.alignments 

1277 if isinstance(attention_probs, tuple): 

1278 attention_probs = [tf.expand_dims(prob, -1) for prob in attention_probs] 

1279 attention_probs = tf.concat(attention_probs, -1) 

1280 attention_probs = tf.reduce_mean(attention_probs, -1) 

1281 return attention_probs 

1282 

1283 

1284def _length_penalty(sequence_lengths, penalty_factor): 

1285 """Calculates the length penalty. See https://arxiv.org/abs/1609.08144. 

1286 

1287 Returns the length penalty tensor: 

1288 ``` 

1289 [(5+sequence_lengths)/6]**penalty_factor 

1290 ``` 

1291 where all operations are performed element-wise. 

1292 

1293 Args: 

1294 sequence_lengths: `Tensor`, the sequence lengths of each hypotheses. 

1295 penalty_factor: A scalar that weights the length penalty. 

1296 

1297 Returns: 

1298 If the penalty is `0`, returns the scalar `1.0`. Otherwise returns 

1299 the length penalty factor, a tensor with the same shape as 

1300 `sequence_lengths`. 

1301 """ 

1302 penalty_factor = tf.convert_to_tensor(penalty_factor, name="penalty_factor") 

1303 penalty_factor.set_shape(()) # penalty should be a scalar. 

1304 static_penalty = tf.get_static_value(penalty_factor) 

1305 if static_penalty is not None and static_penalty == 0: 

1306 return 1.0 

1307 return tf.math.divide( 

1308 (5.0 + tf.cast(sequence_lengths, tf.float32)) ** penalty_factor, 

1309 (5.0 + 1.0) ** penalty_factor, 

1310 ) 

1311 

1312 

1313def _mask_probs(probs, eos_token, finished): 

1314 """Masks log probabilities. 

1315 

1316 The result is that finished beams allocate all probability mass to eos and 

1317 unfinished beams remain unchanged. 

1318 

1319 Args: 

1320 probs: Log probabilities of shape `[batch_size, beam_width, vocab_size]` 

1321 eos_token: An int32 id corresponding to the EOS token to allocate 

1322 probability to. 

1323 finished: A boolean tensor of shape `[batch_size, beam_width]` that 

1324 specifies which elements in the beam are finished already. 

1325 

1326 Returns: 

1327 A tensor of shape `[batch_size, beam_width, vocab_size]`, where 

1328 unfinished beams stay unchanged and finished beams are replaced with a 

1329 tensor with all probability on the EOS token. 

1330 """ 

1331 vocab_size = tf.shape(probs)[2] 

1332 # All finished examples are replaced with a vector that has all 

1333 # probability on EOS 

1334 finished_row = tf.one_hot( 

1335 eos_token, 

1336 vocab_size, 

1337 dtype=probs.dtype, 

1338 on_value=tf.convert_to_tensor(0.0, dtype=probs.dtype), 

1339 off_value=probs.dtype.min, 

1340 ) 

1341 finished_probs = tf.tile( 

1342 tf.reshape(finished_row, [1, 1, -1]), tf.concat([tf.shape(finished), [1]], 0) 

1343 ) 

1344 finished_mask = tf.tile(tf.expand_dims(finished, 2), [1, 1, vocab_size]) 

1345 

1346 return tf.where(finished_mask, finished_probs, probs) 

1347 

1348 

1349def _maybe_tensor_gather_helper( 

1350 gather_indices, gather_from, batch_size, range_size, gather_shape 

1351): 

1352 """Maybe applies _tensor_gather_helper. 

1353 

1354 This applies _tensor_gather_helper when the gather_from dims is at least as 

1355 big as the length of gather_shape. This is used in conjunction with nest so 

1356 that we don't apply _tensor_gather_helper to inapplicable values like 

1357 scalars. 

1358 

1359 Args: 

1360 gather_indices: The tensor indices that we use to gather. 

1361 gather_from: The tensor that we are gathering from. 

1362 batch_size: The batch size. 

1363 range_size: The number of values in each range. Likely equal to 

1364 beam_width. 

1365 gather_shape: What we should reshape gather_from to in order to preserve 

1366 the correct values. An example is when gather_from is the attention 

1367 from an AttentionWrapperState with shape 

1368 [batch_size, beam_width, attention_size]. There, we want to preserve 

1369 the attention_size elements, so gather_shape is 

1370 [batch_size * beam_width, -1]. Then, upon reshape, we still have the 

1371 attention_size as desired. 

1372 

1373 Returns: 

1374 output: Gathered tensor of shape 

1375 tf.shape(gather_from)[:1+len(gather_shape)] or the original tensor if 

1376 its dimensions are too small. 

1377 """ 

1378 if isinstance(gather_from, tf.TensorArray): 

1379 return gather_from 

1380 _check_ndims(gather_from) 

1381 if gather_from.shape.ndims >= len(gather_shape): 

1382 return _tensor_gather_helper( 

1383 gather_indices=gather_indices, 

1384 gather_from=gather_from, 

1385 batch_size=batch_size, 

1386 range_size=range_size, 

1387 gather_shape=gather_shape, 

1388 ) 

1389 else: 

1390 return gather_from 

1391 

1392 

1393def _tensor_gather_helper( 

1394 gather_indices, gather_from, batch_size, range_size, gather_shape, name=None 

1395): 

1396 """Helper for gathering the right indices from the tensor. 

1397 

1398 This works by reshaping gather_from to gather_shape (e.g. [-1]) and then 

1399 gathering from that according to the gather_indices, which are offset by 

1400 the right amounts in order to preserve the batch order. 

1401 

1402 Args: 

1403 gather_indices: The tensor indices that we use to gather. 

1404 gather_from: The tensor that we are gathering from. 

1405 batch_size: The input batch size. 

1406 range_size: The number of values in each range. Likely equal to 

1407 beam_width. 

1408 gather_shape: What we should reshape gather_from to in order to preserve 

1409 the correct values. An example is when gather_from is the attention 

1410 from an AttentionWrapperState with shape 

1411 [batch_size, beam_width, attention_size]. There, we want to preserve 

1412 the attention_size elements, so gather_shape is 

1413 [batch_size * beam_width, -1]. Then, upon reshape, we still have the 

1414 attention_size as desired. 

1415 name: The tensor name for set of operations. By default this is 

1416 'tensor_gather_helper'. The final output is named 'output'. 

1417 

1418 Returns: 

1419 output: Gathered tensor of shape 

1420 tf.shape(gather_from)[:1+len(gather_shape)] 

1421 """ 

1422 with tf.name_scope(name or "tensor_gather_helper"): 

1423 range_ = tf.expand_dims(tf.range(batch_size) * range_size, 1) 

1424 gather_indices = tf.reshape(gather_indices + range_, [-1]) 

1425 output = tf.gather(tf.reshape(gather_from, gather_shape), gather_indices) 

1426 final_shape = tf.shape(gather_from)[: 1 + len(gather_shape)] 

1427 static_batch_size = tf.get_static_value(batch_size) 

1428 final_static_shape = tf.TensorShape([static_batch_size]).concatenate( 

1429 gather_from.shape[1 : 1 + len(gather_shape)] 

1430 ) 

1431 output = tf.reshape(output, final_shape, name="output") 

1432 output.set_shape(final_static_shape) 

1433 return output