Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/seq2seq/beam_search_decoder.py: 16%
372 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A decoder that performs beam search."""
17import collections
18import numpy as np
20import tensorflow as tf
22from tensorflow_addons import options
23from tensorflow_addons.seq2seq import attention_wrapper
24from tensorflow_addons.seq2seq import decoder
25from tensorflow_addons.utils import keras_utils
26from tensorflow_addons.utils.resource_loader import LazySO
27from tensorflow_addons.utils.types import FloatTensorLike, TensorLike, Number
29from typeguard import typechecked
30from typing import Callable, Optional
32_beam_search_so = LazySO("custom_ops/seq2seq/_beam_search_ops.so")
35class BeamSearchDecoderState(
36 collections.namedtuple(
37 "BeamSearchDecoderState",
38 (
39 "cell_state",
40 "log_probs",
41 "finished",
42 "lengths",
43 "accumulated_attention_probs",
44 ),
45 )
46):
47 """State of a `tfa.seq2seq.BeamSearchDecoder`.
49 Attributes:
50 cell_state: The cell state returned at the previous time step.
51 log_probs: The accumulated log probabilities of each beam.
52 A `float32` `Tensor` of shape `[batch_size, beam_width]`.
53 finished: The finished status of each beam.
54 A `bool` `Tensor` of shape `[batch_size, beam_width]`.
55 lengths: The accumulated length of each beam.
56 An `int64` `Tensor` of shape `[batch_size, beam_width]`.
57 accumulated_attention_prob: Accumulation of the attention
58 probabilities (used to compute the coverage penalty)
59 """
61 pass
64class BeamSearchDecoderOutput(
65 collections.namedtuple(
66 "BeamSearchDecoderOutput", ("scores", "predicted_ids", "parent_ids")
67 )
68):
69 """Outputs of a `tfa.seq2seq.BeamSearchDecoder` step.
71 Attributes:
72 scores: The scores this step, which are the log
73 probabilities over the output vocabulary, possibly penalized by length
74 and attention coverage. When `tfa.seq2seq.BeamSearchDecoder` is created with
75 `output_all_scores=False` (default), this will be a `float32` `Tensor`
76 of shape `[batch_size, beam_width]` containing the top scores
77 corresponding to the predicted IDs. When `output_all_scores=True`,
78 this contains the scores for all token IDs and has shape
79 `[batch_size, beam_width, vocab_size]`.
80 predicted_ids: The token IDs predicted for this step.
81 A `int32` `Tensor` of shape `[batch_size, beam_width]`.
82 parent_ids: The indices of the parent beam of each beam.
83 A `int32` `Tensor` of shape `[batch_size, beam_width]`.
84 """
86 pass
89class FinalBeamSearchDecoderOutput(
90 collections.namedtuple(
91 "FinalBeamDecoderOutput", ["predicted_ids", "beam_search_decoder_output"]
92 )
93):
94 """Final outputs returned by the beam search after all decoding is finished.
96 Attributes:
97 predicted_ids: The final prediction. A tensor of shape
98 `[batch_size, T, beam_width]` (or `[T, batch_size, beam_width]` if
99 `output_time_major` is True). Beams are ordered from best to worst.
100 beam_search_decoder_output: An instance of `tfa.seq2seq.BeamSearchDecoderOutput` that
101 describes the state of the beam search.
102 """
104 pass
107def _tile_batch(t, multiplier):
108 """Core single-tensor implementation of tile_batch."""
109 t = tf.convert_to_tensor(t, name="t")
110 shape_t = tf.shape(t)
111 if t.shape.ndims is None or t.shape.ndims < 1:
112 raise ValueError("t must have statically known rank")
113 tiling = [1] * (t.shape.ndims + 1)
114 tiling[1] = multiplier
115 tiled_static_batch_size = (
116 t.shape[0] * multiplier if t.shape[0] is not None else None
117 )
118 tiled = tf.tile(tf.expand_dims(t, 1), tiling)
119 tiled = tf.reshape(tiled, tf.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
120 tiled.set_shape(tf.TensorShape([tiled_static_batch_size]).concatenate(t.shape[1:]))
121 return tiled
124def tile_batch(t: TensorLike, multiplier: int, name: Optional[str] = None) -> tf.Tensor:
125 """Tiles the batch dimension of a (possibly nested structure of) tensor(s).
127 For each tensor t in a (possibly nested structure) of tensors,
128 this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed
129 of minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a
130 shape `[batch_size * multiplier, s0, s1, ...]` composed of minibatch
131 entries `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is
132 repeated `multiplier` times.
134 Args:
135 t: `Tensor` shaped `[batch_size, ...]`.
136 multiplier: Python int.
137 name: Name scope for any created operations.
139 Returns:
140 A (possibly nested structure of) `Tensor` shaped
141 `[batch_size * multiplier, ...]`.
143 Raises:
144 ValueError: if tensor(s) `t` do not have a statically known rank or
145 the rank is < 1.
146 """
147 with tf.name_scope(name or "tile_batch"):
148 return tf.nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
151@tf.function(
152 input_signature=(
153 tf.TensorSpec([None, None, None], dtype=tf.int32),
154 tf.TensorSpec([None, None, None], dtype=tf.int32),
155 tf.TensorSpec([None], dtype=tf.int32),
156 tf.TensorSpec([], dtype=tf.int32),
157 )
158)
159def _gather_tree(step_ids, parent_ids, max_sequence_lengths, end_token):
160 input_shape = tf.shape(parent_ids)
161 max_time = input_shape[0]
162 beam_width = input_shape[2]
163 max_sequence_lengths = tf.math.minimum(max_sequence_lengths, max_time)
164 mask = tf.expand_dims(
165 tf.transpose(tf.sequence_mask(max_sequence_lengths, maxlen=max_time)), -1
166 )
168 # Mask out of range ids.
169 end_tokens = tf.fill(input_shape, end_token)
170 step_ids = tf.where(mask, x=step_ids, y=end_tokens)
171 parent_ids = tf.where(mask, x=parent_ids, y=tf.zeros_like(parent_ids))
172 assert_op = tf.debugging.Assert(
173 tf.math.reduce_all(
174 tf.math.logical_and(parent_ids >= 0, parent_ids < beam_width)
175 ),
176 ["All parent ids must be positive and less than beam_width"],
177 )
179 # Reverse all sequences as we need to gather from the end.
180 with tf.control_dependencies([assert_op]):
181 rev_step_ids = tf.reverse_sequence(
182 step_ids, max_sequence_lengths, seq_axis=0, batch_axis=1
183 )
184 rev_parent_ids = tf.reverse_sequence(
185 parent_ids, max_sequence_lengths, seq_axis=0, batch_axis=1
186 )
188 # Initialize output ids and parent based on last step.
189 output_ids = tf.TensorArray(step_ids.dtype, size=max_time, dynamic_size=False)
190 output_ids = output_ids.write(0, rev_step_ids[0])
191 parent = rev_parent_ids[0]
193 # For each step, gather ids based on beam origin.
194 for t in tf.range(1, max_time):
195 ids = tf.gather(rev_step_ids[t], parent, batch_dims=1)
196 parent = tf.gather(rev_parent_ids[t], parent, batch_dims=1)
197 output_ids = output_ids.write(t, ids)
199 # Reverse sequences to their original order.
200 output_ids = output_ids.stack()
201 output_ids = tf.reverse_sequence(
202 output_ids, max_sequence_lengths, seq_axis=0, batch_axis=1
203 )
205 # Ensure that there are only end_token after the first end_token.
206 in_bound_steps = tf.math.cumsum(tf.cast(output_ids == end_token, tf.int32)) == 0
207 output_ids = tf.where(in_bound_steps, x=output_ids, y=end_tokens)
208 return output_ids
211def gather_tree(
212 step_ids: TensorLike,
213 parent_ids: TensorLike,
214 max_sequence_lengths: TensorLike,
215 end_token: Number,
216) -> tf.Tensor:
217 """Calculates the full beams from the per-step ids and parent beam ids.
219 For a given beam, past the time step containing the first decoded
220 `end_token` all values are filled in with `end_token`.
222 Args:
223 step_ids: The predicted token IDs.
224 A `int32` `Tensor` of shape `[max_time, batch_size, beam_width]`.
225 parent_ids: The parent beam indices.
226 A `int32` `Tensor` of shape `[max_time, batch_size, beam_width]`.
227 max_sequence_lengths: The maximum sequence length of each batch.
228 A `int32` `Tensor` of shape `[batch_size]`.
229 end_token: The end token ID.
231 Returns:
232 The reordered token IDs based on `parent_ids`.
234 Raises:
235 InvalidArgumentError: if `parent_ids` contains an invalid index.
236 """
237 if not options.is_custom_kernel_disabled():
238 try:
239 return _beam_search_so.ops.addons_gather_tree(
240 step_ids, parent_ids, max_sequence_lengths, end_token
241 )
242 except tf.errors.NotFoundError:
243 options.warn_fallback("gather_tree")
245 step_ids = tf.convert_to_tensor(step_ids, dtype=tf.int32)
246 parent_ids = tf.convert_to_tensor(parent_ids, dtype=tf.int32)
247 max_sequence_lengths = tf.convert_to_tensor(max_sequence_lengths, dtype=tf.int32)
248 end_token = tf.convert_to_tensor(end_token, dtype=tf.int32)
249 return _gather_tree(step_ids, parent_ids, max_sequence_lengths, end_token)
252def gather_tree_from_array(
253 t: TensorLike, parent_ids: TensorLike, sequence_length: TensorLike
254) -> tf.Tensor:
255 """Calculates the full beams for a `TensorArray`.
257 Args:
258 t: A stacked `TensorArray` of size `max_time` that contains `Tensor`s of
259 shape `[batch_size, beam_width, s]` or `[batch_size * beam_width, s]`
260 where `s` is the depth shape.
261 parent_ids: The parent ids of shape `[max_time, batch_size, beam_width]`.
262 sequence_length: The sequence length of shape `[batch_size, beam_width]`.
264 Returns:
265 A `Tensor` which is a stacked `TensorArray` of the same size and type as
266 `t` and where beams are sorted in each `Tensor` according to
267 `parent_ids`.
268 """
269 max_time = parent_ids.shape[0] or tf.shape(parent_ids)[0]
270 batch_size = parent_ids.shape[1] or tf.shape(parent_ids)[1]
271 beam_width = parent_ids.shape[2] or tf.shape(parent_ids)[2]
273 # Generate beam ids that will be reordered by gather_tree.
274 beam_ids = tf.reshape(tf.range(beam_width), [1, 1, -1])
275 beam_ids = tf.tile(beam_ids, [max_time, batch_size, 1])
277 max_sequence_lengths = tf.cast(tf.reduce_max(sequence_length, axis=1), tf.int32)
278 sorted_beam_ids = gather_tree(
279 step_ids=beam_ids,
280 parent_ids=parent_ids,
281 max_sequence_lengths=max_sequence_lengths,
282 end_token=beam_width + 1,
283 )
285 # For out of range steps, simply copy the same beam.
286 in_bound_steps = tf.transpose(
287 tf.sequence_mask(sequence_length, maxlen=max_time), perm=[2, 0, 1]
288 )
289 sorted_beam_ids = tf.where(in_bound_steps, x=sorted_beam_ids, y=beam_ids)
291 # Gather from a tensor with collapsed additional dimensions.
292 final_shape = tf.shape(t)
293 gather_from = tf.reshape(t, [max_time, batch_size, beam_width, -1])
294 ordered = tf.gather(gather_from, sorted_beam_ids, axis=2, batch_dims=2)
295 ordered = tf.reshape(ordered, final_shape)
297 return ordered
300def _check_ndims(t):
301 if t.shape.ndims is None:
302 raise ValueError(
303 "Expected tensor (%s) to have known rank, but ndims == None." % t
304 )
307def _check_static_batch_beam_maybe(shape, batch_size, beam_width):
308 """Raises an exception if dimensions are known statically and can not be
309 reshaped to [batch_size, beam_size, -1]."""
310 reshaped_shape = tf.TensorShape([batch_size, beam_width, None])
311 assert len(shape.dims) > 0
312 if batch_size is None or shape[0] is None:
313 return True # not statically known => no check
314 if shape[0] == batch_size * beam_width:
315 return True # flattened, matching
316 has_second_dim = shape.ndims >= 2 and shape[1] is not None
317 if has_second_dim and shape[0] == batch_size and shape[1] == beam_width:
318 return True # non-flattened, matching
319 # Otherwise we could not find a match and warn:
320 tf.get_logger().warn(
321 "TensorArray reordering expects elements to be "
322 "reshapable to %s which is incompatible with the "
323 "current shape %s. Consider setting "
324 "reorder_tensor_arrays to False to disable TensorArray "
325 "reordering during the beam search." % (reshaped_shape, shape)
326 )
327 return False
330def _check_batch_beam(t, batch_size, beam_width):
331 """Returns an Assert operation checking that the elements of the stacked
332 TensorArray can be reshaped to [batch_size, beam_size, -1].
334 At this point, the TensorArray elements have a known rank of at
335 least 1.
336 """
337 error_message = (
338 "TensorArray reordering expects elements to be "
339 "reshapable to [batch_size, beam_size, -1] which is "
340 "incompatible with the dynamic shape of %s elements. "
341 "Consider setting reorder_tensor_arrays to False to disable "
342 "TensorArray reordering during the beam search."
343 % (t if tf.executing_eagerly() else t.name)
344 )
345 rank = t.shape.ndims
346 shape = tf.shape(t)
347 if rank == 2:
348 condition = tf.equal(shape[1], batch_size * beam_width)
349 else:
350 condition = tf.logical_or(
351 tf.equal(shape[1], batch_size * beam_width),
352 tf.logical_and(
353 tf.equal(shape[1], batch_size), tf.equal(shape[2], beam_width)
354 ),
355 )
356 return tf.Assert(condition, [error_message])
359def _as_shape(value):
360 """Converts the argument to a TensorShape if not already one."""
361 if not isinstance(value, tf.TensorShape):
362 if isinstance(value, tf.Tensor):
363 value = tf.get_static_value(value)
364 value = tf.TensorShape(value)
365 return value
368class BeamSearchDecoderMixin:
369 """BeamSearchDecoderMixin contains the common methods for
370 BeamSearchDecoder.
372 It is expected to be used a base class for concrete
373 BeamSearchDecoder. Since this is a mixin class, it is expected to be
374 used together with other class as base.
375 """
377 @typechecked
378 def __init__(
379 self,
380 cell: tf.keras.layers.Layer,
381 beam_width: int,
382 output_layer: Optional[tf.keras.layers.Layer] = None,
383 length_penalty_weight: FloatTensorLike = 0.0,
384 coverage_penalty_weight: FloatTensorLike = 0.0,
385 reorder_tensor_arrays: bool = True,
386 output_all_scores: bool = False,
387 **kwargs,
388 ):
389 """Initialize the BeamSearchDecoderMixin.
391 Args:
392 cell: A layer that implements the `tf.keras.layers.AbstractRNNCell`
393 interface.
394 beam_width: Python integer, the number of beams.
395 output_layer: (Optional) An instance of `tf.keras.layers.Layer`,
396 i.e., `tf.keras.layers.Dense`. Optional layer to apply to the RNN
397 output prior to storing the result or sampling.
398 length_penalty_weight: Float weight to penalize length. Disabled with
399 0.0.
400 coverage_penalty_weight: Float weight to penalize the coverage of
401 source sentence. Disabled with 0.0.
402 reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the
403 cell state will be reordered according to the beam search path. If
404 the `TensorArray` can be reordered, the stacked form will be
405 returned. Otherwise, the `TensorArray` will be returned as is. Set
406 this flag to `False` if the cell state contains `TensorArray`s that
407 are not amenable to reordering.
408 output_all_scores: If `True`, `BeamSearchDecoderOutput.scores` will
409 contain scores for all token IDs and be of shape
410 `[batch_size, beam_width, vocab_size]`. When `False` (default),
411 only the top score corresponding to the predicted token will be
412 output with shape `[batch_size, beam_width]`.
413 **kwargs: Dict, other keyword arguments for parent class.
414 """
415 keras_utils.assert_like_rnncell("cell", cell)
416 self._cell = cell
417 self._output_layer = output_layer
418 self._reorder_tensor_arrays = reorder_tensor_arrays
419 self._output_all_scores = output_all_scores
421 self._start_tokens = None
422 self._end_token = None
423 self._batch_size = None
424 self._beam_width = beam_width
425 self._length_penalty_weight = length_penalty_weight
426 self._coverage_penalty_weight = coverage_penalty_weight
427 super().__init__(**kwargs)
429 @property
430 def batch_size(self):
431 return self._batch_size
433 def _rnn_output_size(self):
434 """Get the output shape from the RNN layer."""
435 size = self._cell.output_size
436 if self._output_layer is None:
437 return size
438 else:
439 # To use layer's compute_output_shape, we need to convert the
440 # RNNCell's output_size entries into shapes with an unknown
441 # batch size. We then pass this through the layer's
442 # compute_output_shape and read off all but the first (batch)
443 # dimensions to get the output size of the rnn with the layer
444 # applied to the top.
445 output_shape_with_unknown_batch = tf.nest.map_structure(
446 lambda s: tf.TensorShape([None]).concatenate(s), size
447 )
448 layer_output_shape = self._output_layer.compute_output_shape(
449 output_shape_with_unknown_batch
450 )
451 return tf.nest.map_structure(lambda s: s[1:], layer_output_shape)
453 @property
454 def tracks_own_finished(self):
455 """The BeamSearchDecoder shuffles its beams and their finished state.
457 For this reason, it conflicts with the `dynamic_decode` function's
458 tracking of finished states. Setting this property to true avoids
459 early stopping of decoding due to mismanagement of the finished state
460 in `dynamic_decode`.
462 Returns:
463 `True`.
464 """
465 return True
467 @property
468 def output_size(self):
469 # Return the cell output and the id
470 score_size = (
471 tf.TensorShape([self._beam_width, self._rnn_output_size()[-1]])
472 if self._output_all_scores
473 else tf.TensorShape([self._beam_width])
474 )
475 return BeamSearchDecoderOutput(
476 scores=score_size,
477 predicted_ids=tf.TensorShape([self._beam_width]),
478 parent_ids=tf.TensorShape([self._beam_width]),
479 )
481 def finalize(self, outputs, final_state, sequence_lengths):
482 """Finalize and return the predicted_ids.
484 Args:
485 outputs: An instance of BeamSearchDecoderOutput.
486 final_state: An instance of BeamSearchDecoderState. Passed through to
487 the output.
488 sequence_lengths: An `int64` tensor shaped
489 `[batch_size, beam_width]`. The sequence lengths determined for
490 each beam during decode. **NOTE** These are ignored; the updated
491 sequence lengths are stored in `final_state.lengths`.
493 Returns:
494 outputs: An instance of `FinalBeamSearchDecoderOutput` where the
495 predicted_ids are the result of calling _gather_tree.
496 final_state: The same input instance of `BeamSearchDecoderState`.
497 """
498 del sequence_lengths
499 # Get max_sequence_length across all beams for each batch.
500 max_sequence_lengths = tf.cast(
501 tf.reduce_max(final_state.lengths, axis=1), tf.int32
502 )
503 predicted_ids = gather_tree(
504 outputs.predicted_ids,
505 outputs.parent_ids,
506 max_sequence_lengths=max_sequence_lengths,
507 end_token=self._end_token,
508 )
509 if self._reorder_tensor_arrays:
510 final_state = final_state._replace(
511 cell_state=tf.nest.map_structure(
512 lambda t: self._maybe_sort_array_beams(
513 t, outputs.parent_ids, final_state.lengths
514 ),
515 final_state.cell_state,
516 )
517 )
518 outputs = FinalBeamSearchDecoderOutput(
519 beam_search_decoder_output=outputs, predicted_ids=predicted_ids
520 )
521 return outputs, final_state
523 def _merge_batch_beams(self, t, s=None):
524 """Merges the tensor from a batch of beams into a batch by beams.
526 More exactly, t is a tensor of dimension [batch_size, beam_width, s].
527 We reshape this into [batch_size*beam_width, s]
529 Args:
530 t: Tensor of dimension [batch_size, beam_width, s]
531 s: (Possibly known) depth shape.
533 Returns:
534 A reshaped version of t with dimension [batch_size * beam_width, s].
535 """
536 s = _as_shape(s)
537 t_shape = tf.shape(t)
538 static_batch_size = tf.get_static_value(self._batch_size)
539 batch_size_beam_width = (
540 None if static_batch_size is None else static_batch_size * self._beam_width
541 )
542 reshaped_t = tf.reshape(
543 t, tf.concat(([self._batch_size * self._beam_width], t_shape[2:]), 0)
544 )
545 reshaped_t.set_shape(tf.TensorShape([batch_size_beam_width]).concatenate(s))
546 return reshaped_t
548 def _split_batch_beams(self, t, s=None):
549 """Splits the tensor from a batch by beams into a batch of beams.
551 More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We
552 reshape this into [batch_size, beam_width, s]
554 Args:
555 t: Tensor of dimension [batch_size*beam_width, s].
556 s: (Possibly known) depth shape.
558 Returns:
559 A reshaped version of t with dimension [batch_size, beam_width, s].
561 Raises:
562 ValueError: If, after reshaping, the new tensor is not shaped
563 `[batch_size, beam_width, s]` (assuming batch_size and beam_width
564 are known statically).
565 """
566 s = _as_shape(s)
567 t_shape = tf.shape(t)
568 reshaped_t = tf.reshape(
569 t, tf.concat(([self._batch_size, self._beam_width], t_shape[1:]), 0)
570 )
571 static_batch_size = tf.get_static_value(self._batch_size)
572 expected_reshaped_shape = tf.TensorShape(
573 [static_batch_size, self._beam_width]
574 ).concatenate(s)
575 if not reshaped_t.shape.is_compatible_with(expected_reshaped_shape):
576 raise ValueError(
577 "Unexpected behavior when reshaping between beam width "
578 "and batch size. The reshaped tensor has shape: %s. "
579 "We expected it to have shape "
580 "(batch_size, beam_width, depth) == %s. Perhaps you "
581 "forgot to call get_initial_state with "
582 "batch_size=encoder_batch_size * beam_width?"
583 % (reshaped_t.shape, expected_reshaped_shape)
584 )
585 reshaped_t.set_shape(expected_reshaped_shape)
586 return reshaped_t
588 def _maybe_split_batch_beams(self, t, s):
589 """Maybe splits the tensor from a batch by beams into a batch of beams.
591 We do this so that we can use nest and not run into problems with
592 shapes.
594 Args:
595 t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`.
596 s: `Tensor`, Python int, or `TensorShape`.
598 Returns:
599 If `t` is a matrix or higher order tensor, then the return value is
600 `t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is
601 returned unchanged.
603 Raises:
604 ValueError: If the rank of `t` is not statically known.
605 """
606 if isinstance(t, tf.TensorArray):
607 return t
608 _check_ndims(t)
609 if t.shape.ndims >= 1:
610 return self._split_batch_beams(t, s)
611 else:
612 return t
614 def _maybe_merge_batch_beams(self, t, s):
615 """Splits the tensor from a batch by beams into a batch of beams.
617 More exactly, `t` is a tensor of dimension
618 `[batch_size * beam_width] + s`, then we reshape it to
619 `[batch_size, beam_width] + s`.
621 Args:
622 t: `Tensor` of dimension `[batch_size * beam_width] + s`.
623 s: `Tensor`, Python int, or `TensorShape`.
625 Returns:
626 A reshaped version of t with shape `[batch_size, beam_width] + s`.
628 Raises:
629 ValueError: If the rank of `t` is not statically known.
630 """
631 if isinstance(t, tf.TensorArray):
632 return t
633 _check_ndims(t)
634 if t.shape.ndims >= 2:
635 return self._merge_batch_beams(t, s)
636 else:
637 return t
639 def _maybe_sort_array_beams(self, t, parent_ids, sequence_length):
640 """Maybe sorts beams within a `TensorArray`.
642 Args:
643 t: A `TensorArray` of size `max_time` that contains `Tensor`s of
644 shape `[batch_size, beam_width, s]` or
645 `[batch_size * beam_width, s]` where `s` is the depth shape.
646 parent_ids: The parent ids of shape
647 `[max_time, batch_size, beam_width]`.
648 sequence_length: The sequence length of shape
649 `[batch_size, beam_width]`.
651 Returns:
652 A `TensorArray` where beams are sorted in each `Tensor` or `t` itself
653 if it is not a `TensorArray` or does not meet shape requirements.
654 """
655 if not isinstance(t, tf.TensorArray):
656 return t
657 if t.element_shape.ndims is None or t.element_shape.ndims < 1:
658 tf.get_logger().warn(
659 "The TensorArray %s in the cell state is not amenable to "
660 "sorting based on the beam search result. For a "
661 "TensorArray to be sorted, its elements shape must be "
662 "defined and have at least a rank of 1, but saw shape: %s"
663 % (t.handle.name, t.element_shape)
664 )
665 return t
666 if not _check_static_batch_beam_maybe(
667 t.element_shape, tf.get_static_value(self._batch_size), self._beam_width
668 ):
669 return t
670 t = t.stack()
671 with tf.control_dependencies(
672 [_check_batch_beam(t, self._batch_size, self._beam_width)]
673 ):
674 return gather_tree_from_array(t, parent_ids, sequence_length)
676 def step(self, time, inputs, state, training=None, name=None):
677 """Perform a decoding step.
679 Args:
680 time: scalar `int32` tensor.
681 inputs: A (structure of) input tensors.
682 state: A (structure of) state tensors and TensorArrays.
683 training: Python boolean. Indicates whether the layer should
684 behave in training mode or in inference mode. Only relevant
685 when `dropout` or `recurrent_dropout` is used.
686 name: Name scope for any created operations.
688 Returns:
689 `(outputs, next_state, next_inputs, finished)`.
690 """
691 batch_size = self._batch_size
692 beam_width = self._beam_width
693 end_token = self._end_token
694 length_penalty_weight = self._length_penalty_weight
695 coverage_penalty_weight = self._coverage_penalty_weight
697 with tf.name_scope(name or "BeamSearchDecoderStep"):
698 cell_state = state.cell_state
699 inputs = tf.nest.map_structure(
700 lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs
701 )
702 cell_state = tf.nest.map_structure(
703 self._maybe_merge_batch_beams, cell_state, self._cell.state_size
704 )
705 cell_outputs, next_cell_state = self._cell(
706 inputs, cell_state, training=training
707 )
708 cell_outputs = tf.nest.map_structure(
709 lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs
710 )
711 next_cell_state = tf.nest.pack_sequence_as(
712 cell_state, tf.nest.flatten(next_cell_state)
713 )
714 next_cell_state = tf.nest.map_structure(
715 self._maybe_split_batch_beams, next_cell_state, self._cell.state_size
716 )
718 if self._output_layer is not None:
719 cell_outputs = self._output_layer(cell_outputs)
721 beam_search_output, beam_search_state = _beam_search_step(
722 time=time,
723 logits=cell_outputs,
724 next_cell_state=next_cell_state,
725 beam_state=state,
726 batch_size=batch_size,
727 beam_width=beam_width,
728 end_token=end_token,
729 length_penalty_weight=length_penalty_weight,
730 coverage_penalty_weight=coverage_penalty_weight,
731 output_all_scores=self._output_all_scores,
732 )
734 finished = beam_search_state.finished
735 sample_ids = beam_search_output.predicted_ids
736 next_inputs = tf.cond(
737 tf.reduce_all(finished),
738 lambda: self._start_inputs,
739 lambda: self._embedding_fn(sample_ids),
740 )
742 return (beam_search_output, beam_search_state, next_inputs, finished)
745class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.BaseDecoder):
746 # Note that the inheritance hierarchy is important here. The Mixin has to be
747 # the first parent class since we will use super().__init__(), and Mixin
748 # which is a object will properly invoke the __init__ method of other parent
749 # class.
750 """Beam search decoder.
752 **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
753 `tfa.seq2seq.AttentionWrapper`, then you must ensure that:
755 - The encoder output has been tiled to `beam_width` via
756 `tfa.seq2seq.tile_batch` (NOT `tf.tile`).
757 - The `batch_size` argument passed to the `get_initial_state` method of
758 this wrapper is equal to `true_batch_size * beam_width`.
759 - The initial state created with `get_initial_state` above contains a
760 `cell_state` value containing properly tiled final state from the
761 encoder.
763 An example:
765 ```
766 tiled_encoder_outputs = tfa.seq2seq.tile_batch(
767 encoder_outputs, multiplier=beam_width)
768 tiled_encoder_final_state = tfa.seq2seq.tile_batch(
769 encoder_final_state, multiplier=beam_width)
770 tiled_sequence_length = tfa.seq2seq.tile_batch(
771 sequence_length, multiplier=beam_width)
772 attention_mechanism = MyFavoriteAttentionMechanism(
773 num_units=attention_depth,
774 memory=tiled_inputs,
775 memory_sequence_length=tiled_sequence_length)
776 attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
777 decoder_initial_state = attention_cell.get_initial_state(
778 batch_size=true_batch_size * beam_width, dtype=dtype)
779 decoder_initial_state = decoder_initial_state.clone(
780 cell_state=tiled_encoder_final_state)
781 ```
783 Meanwhile, with `tfa.seq2seq.AttentionWrapper`, coverage penalty is suggested to use
784 when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
785 the decoding to cover all inputs.
786 """
788 @typechecked
789 def __init__(
790 self,
791 cell: tf.keras.layers.Layer,
792 beam_width: int,
793 embedding_fn: Optional[Callable] = None,
794 output_layer: Optional[tf.keras.layers.Layer] = None,
795 length_penalty_weight: FloatTensorLike = 0.0,
796 coverage_penalty_weight: FloatTensorLike = 0.0,
797 reorder_tensor_arrays: bool = True,
798 **kwargs,
799 ):
800 """Initialize the BeamSearchDecoder.
802 Args:
803 cell: A layer that implements the `tf.keras.layers.AbstractRNNCell`
804 interface.
805 beam_width: Python integer, the number of beams.
806 embedding_fn: A callable that takes a `int32` `Tensor` of token IDs
807 and returns embedding tensors. If set, the `embedding` argument in
808 the decoder call should be set to `None`.
809 output_layer: (Optional) An instance of `tf.keras.layers.Layer`,
810 i.e., `tf.keras.layers.Dense`. Optional layer to apply to the RNN
811 output prior to storing the result or sampling.
812 length_penalty_weight: Float weight to penalize length. Disabled with
813 0.0.
814 coverage_penalty_weight: Float weight to penalize the coverage of
815 source sentence. Disabled with 0.0.
816 reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the
817 cell state will be reordered according to the beam search path. If
818 the `TensorArray` can be reordered, the stacked form will be
819 returned. Otherwise, the `TensorArray` will be returned as is. Set
820 this flag to `False` if the cell state contains `TensorArray`s that
821 are not amenable to reordering.
822 **kwargs: Dict, other keyword arguments for initialization.
823 """
824 super().__init__(
825 cell,
826 beam_width,
827 output_layer=output_layer,
828 length_penalty_weight=length_penalty_weight,
829 coverage_penalty_weight=coverage_penalty_weight,
830 reorder_tensor_arrays=reorder_tensor_arrays,
831 **kwargs,
832 )
834 self._embedding_fn = embedding_fn
836 def initialize(self, embedding, start_tokens, end_token, initial_state):
837 """Initialize the decoder.
839 Args:
840 embedding: A `Tensor` (or `Variable`) to pass as the `params` argument
841 for `tf.nn.embedding_lookup`. This overrides `embedding_fn` set in
842 the constructor.
843 start_tokens: Start the decoding from these tokens.
844 A `int32` `Tensor` of shape `[batch_size]`.
845 end_token: The token that marks the end of decoding.
846 A `int32` scalar `Tensor`.
847 initial_state: The initial cell state as a (possibly nested) structure
848 of `Tensor` and `TensorArray`.
850 Returns:
851 `(finished, start_inputs, initial_state)`.
853 Raises:
854 ValueError: If `embedding` is `None` and `embedding_fn` was not set
855 in the constructor.
856 ValueError: If `start_tokens` is not a vector or `end_token` is not a
857 scalar.
858 """
859 if embedding is not None:
860 self._embedding_fn = lambda ids: tf.nn.embedding_lookup(embedding, ids)
861 elif self._embedding_fn is None:
862 raise ValueError(
863 "You should either pass an embedding variable when calling the "
864 "BeamSearchDecoder or set embedding_fn in the constructor."
865 )
867 self._start_tokens = tf.convert_to_tensor(
868 start_tokens, dtype=tf.int32, name="start_tokens"
869 )
870 if self._start_tokens.shape.ndims != 1:
871 raise ValueError("start_tokens must be a vector")
872 self._end_token = tf.convert_to_tensor(
873 end_token, dtype=tf.int32, name="end_token"
874 )
875 if self._end_token.shape.ndims != 0:
876 raise ValueError("end_token must be a scalar")
878 self._batch_size = tf.size(start_tokens)
879 self._initial_cell_state = tf.nest.map_structure(
880 self._maybe_split_batch_beams, initial_state, self._cell.state_size
881 )
882 self._start_tokens = tf.tile(
883 tf.expand_dims(self._start_tokens, 1), [1, self._beam_width]
884 )
885 self._start_inputs = self._embedding_fn(self._start_tokens)
887 self._finished = tf.one_hot(
888 tf.zeros([self._batch_size], dtype=tf.int32),
889 depth=self._beam_width,
890 on_value=False,
891 off_value=True,
892 dtype=tf.bool,
893 )
895 finished, start_inputs = self._finished, self._start_inputs
897 dtype = tf.nest.flatten(self._initial_cell_state)[0].dtype
898 log_probs = tf.one_hot( # shape(batch_sz, beam_sz)
899 tf.zeros([self._batch_size], dtype=tf.int32),
900 depth=self._beam_width,
901 on_value=tf.convert_to_tensor(0.0, dtype=dtype),
902 off_value=tf.convert_to_tensor(-np.Inf, dtype=dtype),
903 dtype=dtype,
904 )
905 init_attention_probs = get_attention_probs(
906 self._initial_cell_state, self._coverage_penalty_weight
907 )
908 if init_attention_probs is None:
909 init_attention_probs = ()
911 initial_state = BeamSearchDecoderState(
912 cell_state=self._initial_cell_state,
913 log_probs=log_probs,
914 finished=finished,
915 lengths=tf.zeros([self._batch_size, self._beam_width], dtype=tf.int64),
916 accumulated_attention_probs=init_attention_probs,
917 )
919 return (finished, start_inputs, initial_state)
921 @property
922 def output_dtype(self):
923 # Assume the dtype of the cell is the output_size structure
924 # containing the input_state's first component's dtype.
925 # Return that structure and int32 (the id)
926 dtype = tf.nest.flatten(self._initial_cell_state)[0].dtype
927 return BeamSearchDecoderOutput(
928 scores=tf.nest.map_structure(lambda _: dtype, self._rnn_output_size()),
929 predicted_ids=tf.int32,
930 parent_ids=tf.int32,
931 )
933 def call(
934 self, embedding, start_tokens, end_token, initial_state, training=None, **kwargs
935 ):
936 init_kwargs = kwargs
937 init_kwargs["start_tokens"] = start_tokens
938 init_kwargs["end_token"] = end_token
939 init_kwargs["initial_state"] = initial_state
940 return decoder.dynamic_decode(
941 self,
942 output_time_major=self.output_time_major,
943 impute_finished=self.impute_finished,
944 maximum_iterations=self.maximum_iterations,
945 parallel_iterations=self.parallel_iterations,
946 swap_memory=self.swap_memory,
947 training=training,
948 decoder_init_input=embedding,
949 decoder_init_kwargs=init_kwargs,
950 )
953def _beam_search_step(
954 time,
955 logits,
956 next_cell_state,
957 beam_state,
958 batch_size,
959 beam_width,
960 end_token,
961 length_penalty_weight,
962 coverage_penalty_weight,
963 output_all_scores,
964):
965 """Performs a single step of Beam Search Decoding.
967 Args:
968 time: Beam search time step, should start at 0. At time 0 we assume
969 that all beams are equal and consider only the first beam for
970 continuations.
971 logits: Logits at the current time step. A tensor of shape
972 `[batch_size, beam_width, vocab_size]`
973 next_cell_state: The next state from the cell, e.g. an instance of
974 AttentionWrapperState if the cell is attentional.
975 beam_state: Current state of the beam search.
976 An instance of `BeamSearchDecoderState`.
977 batch_size: The batch size for this input.
978 beam_width: Python int. The size of the beams.
979 end_token: The int32 end token.
980 length_penalty_weight: Float weight to penalize length. Disabled with
981 0.0.
982 coverage_penalty_weight: Float weight to penalize the coverage of source
983 sentence. Disabled with 0.0.
984 output_all_scores: Bool output scores for every token if True, else only
985 output the top scores.
987 Returns:
988 A new beam state.
989 """
990 static_batch_size = tf.get_static_value(batch_size)
992 # Calculate the current lengths of the predictions
993 prediction_lengths = beam_state.lengths
994 previously_finished = beam_state.finished
995 not_finished = tf.logical_not(previously_finished)
997 # Calculate the total log probs for the new hypotheses
998 # Final Shape: [batch_size, beam_width, vocab_size]
999 step_log_probs = tf.nn.log_softmax(logits)
1000 step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished)
1001 total_probs = tf.expand_dims(beam_state.log_probs, 2) + step_log_probs
1003 # Calculate the continuation lengths by adding to all continuing beams.
1004 vocab_size = logits.shape[-1] or tf.shape(logits)[-1]
1005 lengths_to_add = tf.one_hot(
1006 indices=tf.fill([batch_size, beam_width], end_token),
1007 depth=vocab_size,
1008 on_value=np.int64(0),
1009 off_value=np.int64(1),
1010 dtype=tf.int64,
1011 )
1012 add_mask = tf.cast(not_finished, tf.int64)
1013 lengths_to_add *= tf.expand_dims(add_mask, 2)
1014 new_prediction_lengths = lengths_to_add + tf.expand_dims(prediction_lengths, 2)
1016 # Calculate the accumulated attention probabilities if coverage penalty is
1017 # enabled.
1018 accumulated_attention_probs = None
1019 attention_probs = get_attention_probs(next_cell_state, coverage_penalty_weight)
1020 if attention_probs is not None:
1021 attention_probs *= tf.expand_dims(tf.cast(not_finished, tf.float32), 2)
1022 accumulated_attention_probs = (
1023 beam_state.accumulated_attention_probs + attention_probs
1024 )
1026 # Calculate the scores for each beam
1027 scores = _get_scores(
1028 log_probs=total_probs,
1029 sequence_lengths=new_prediction_lengths,
1030 length_penalty_weight=length_penalty_weight,
1031 coverage_penalty_weight=coverage_penalty_weight,
1032 finished=previously_finished,
1033 accumulated_attention_probs=accumulated_attention_probs,
1034 )
1036 time = tf.convert_to_tensor(time, name="time")
1037 # During the first time step we only consider the initial beam
1038 scores_flat = tf.reshape(scores, [batch_size, -1])
1040 # Pick the next beams according to the specified successors function
1041 next_beam_size = tf.convert_to_tensor(beam_width, dtype=tf.int32, name="beam_width")
1042 next_beam_scores, word_indices = tf.math.top_k(scores_flat, k=next_beam_size)
1044 next_beam_scores.set_shape([static_batch_size, beam_width])
1045 word_indices.set_shape([static_batch_size, beam_width])
1047 # Pick out the probs, beam_ids, and states according to the chosen
1048 # predictions
1049 next_beam_probs = _tensor_gather_helper(
1050 gather_indices=word_indices,
1051 gather_from=total_probs,
1052 batch_size=batch_size,
1053 range_size=beam_width * vocab_size,
1054 gather_shape=[-1],
1055 name="next_beam_probs",
1056 )
1057 # Note: just doing the following
1058 # tf.to_int32(word_indices % vocab_size,
1059 # name="next_beam_word_ids")
1060 # would be a lot cleaner but for reasons unclear, that hides the results of
1061 # the op which prevents capturing it with tfdbg debug ops.
1062 raw_next_word_ids = tf.math.floormod(
1063 word_indices, vocab_size, name="next_beam_word_ids"
1064 )
1065 next_word_ids = tf.cast(raw_next_word_ids, tf.int32)
1066 next_beam_ids = tf.cast(
1067 word_indices / vocab_size, tf.int32, name="next_beam_parent_ids"
1068 )
1070 # Append new ids to current predictions
1071 previously_finished = _tensor_gather_helper(
1072 gather_indices=next_beam_ids,
1073 gather_from=previously_finished,
1074 batch_size=batch_size,
1075 range_size=beam_width,
1076 gather_shape=[-1],
1077 )
1078 next_finished = tf.logical_or(
1079 previously_finished,
1080 tf.equal(next_word_ids, end_token),
1081 name="next_beam_finished",
1082 )
1084 # Calculate the length of the next predictions.
1085 # 1. Finished beams remain unchanged.
1086 # 2. Beams that are now finished (EOS predicted) have their length
1087 # increased by 1.
1088 # 3. Beams that are not yet finished have their length increased by 1.
1089 lengths_to_add = tf.cast(tf.logical_not(previously_finished), tf.int64)
1090 next_prediction_len = _tensor_gather_helper(
1091 gather_indices=next_beam_ids,
1092 gather_from=beam_state.lengths,
1093 batch_size=batch_size,
1094 range_size=beam_width,
1095 gather_shape=[-1],
1096 )
1097 next_prediction_len += lengths_to_add
1098 next_accumulated_attention_probs = ()
1099 if accumulated_attention_probs is not None:
1100 next_accumulated_attention_probs = _tensor_gather_helper(
1101 gather_indices=next_beam_ids,
1102 gather_from=accumulated_attention_probs,
1103 batch_size=batch_size,
1104 range_size=beam_width,
1105 gather_shape=[batch_size * beam_width, -1],
1106 name="next_accumulated_attention_probs",
1107 )
1109 # Pick out the cell_states according to the next_beam_ids. We use a
1110 # different gather_shape here because the cell_state tensors, i.e.
1111 # the tensors that would be gathered from, all have dimension
1112 # greater than two and we need to preserve those dimensions.
1113 next_cell_state = tf.nest.map_structure(
1114 lambda gather_from: _maybe_tensor_gather_helper(
1115 gather_indices=next_beam_ids,
1116 gather_from=gather_from,
1117 batch_size=batch_size,
1118 range_size=beam_width,
1119 gather_shape=[batch_size * beam_width, -1],
1120 ),
1121 next_cell_state,
1122 )
1124 next_state = BeamSearchDecoderState(
1125 cell_state=next_cell_state,
1126 log_probs=next_beam_probs,
1127 lengths=next_prediction_len,
1128 finished=next_finished,
1129 accumulated_attention_probs=next_accumulated_attention_probs,
1130 )
1132 output = BeamSearchDecoderOutput(
1133 scores=scores if output_all_scores else next_beam_scores,
1134 predicted_ids=next_word_ids,
1135 parent_ids=next_beam_ids,
1136 )
1138 return output, next_state
1141def get_attention_probs(next_cell_state, coverage_penalty_weight):
1142 """Get attention probabilities from the cell state.
1144 Args:
1145 next_cell_state: The next state from the cell, e.g. an instance of
1146 AttentionWrapperState if the cell is attentional.
1147 coverage_penalty_weight: Float weight to penalize the coverage of source
1148 sentence. Disabled with 0.0.
1150 Returns:
1151 The attention probabilities with shape
1152 `[batch_size, beam_width, max_time]` if coverage penalty is enabled.
1153 Otherwise, returns None.
1155 Raises:
1156 ValueError: If no cell is attentional but coverage penalty is enabled.
1157 """
1158 if coverage_penalty_weight == 0.0:
1159 return None
1161 # Attention probabilities of each attention layer. Each with shape
1162 # `[batch_size, beam_width, max_time]`.
1163 probs_per_attn_layer = []
1164 if isinstance(next_cell_state, attention_wrapper.AttentionWrapperState):
1165 probs_per_attn_layer = [attention_probs_from_attn_state(next_cell_state)]
1166 elif isinstance(next_cell_state, tuple):
1167 for state in next_cell_state:
1168 if isinstance(state, attention_wrapper.AttentionWrapperState):
1169 probs_per_attn_layer.append(attention_probs_from_attn_state(state))
1171 if not probs_per_attn_layer:
1172 raise ValueError(
1173 "coverage_penalty_weight must be 0.0 if no cell is attentional."
1174 )
1176 if len(probs_per_attn_layer) == 1:
1177 attention_probs = probs_per_attn_layer[0]
1178 else:
1179 # Calculate the average attention probabilities from all attention
1180 # layers.
1181 attention_probs = [tf.expand_dims(prob, -1) for prob in probs_per_attn_layer]
1182 attention_probs = tf.concat(attention_probs, -1)
1183 attention_probs = tf.reduce_mean(attention_probs, -1)
1185 return attention_probs
1188def _get_scores(
1189 log_probs,
1190 sequence_lengths,
1191 length_penalty_weight,
1192 coverage_penalty_weight,
1193 finished,
1194 accumulated_attention_probs,
1195):
1196 """Calculates scores for beam search hypotheses.
1198 Args:
1199 log_probs: The log probabilities with shape
1200 `[batch_size, beam_width, vocab_size]`.
1201 sequence_lengths: The array of sequence lengths.
1202 length_penalty_weight: Float weight to penalize length. Disabled with
1203 0.0.
1204 coverage_penalty_weight: Float weight to penalize the coverage of source
1205 sentence. Disabled with 0.0.
1206 finished: A boolean tensor of shape `[batch_size, beam_width]` that
1207 specifies which elements in the beam are finished already.
1208 accumulated_attention_probs: Accumulated attention probabilities up to
1209 the current time step, with shape `[batch_size, beam_width, max_time]`
1210 if coverage_penalty_weight is not 0.0.
1212 Returns:
1213 The scores normalized by the length_penalty and coverage_penalty.
1215 Raises:
1216 ValueError: accumulated_attention_probs is None when coverage penalty is
1217 enabled.
1218 """
1219 length_penalty_ = _length_penalty(
1220 sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight
1221 )
1222 length_penalty_ = tf.cast(length_penalty_, dtype=log_probs.dtype)
1223 scores = log_probs / length_penalty_
1225 coverage_penalty_weight = tf.convert_to_tensor(
1226 coverage_penalty_weight, name="coverage_penalty_weight"
1227 )
1228 if coverage_penalty_weight.shape.ndims != 0:
1229 raise ValueError(
1230 "coverage_penalty_weight should be a scalar, "
1231 "but saw shape: %s" % coverage_penalty_weight.shape
1232 )
1234 if tf.get_static_value(coverage_penalty_weight) == 0.0:
1235 return scores
1237 if accumulated_attention_probs is None:
1238 raise ValueError(
1239 "accumulated_attention_probs can be None only if coverage penalty "
1240 "is disabled."
1241 )
1243 # Add source sequence length mask before computing coverage penalty.
1244 accumulated_attention_probs = tf.where(
1245 tf.equal(accumulated_attention_probs, 0.0),
1246 tf.ones_like(accumulated_attention_probs),
1247 accumulated_attention_probs,
1248 )
1250 # coverage penalty =
1251 # sum over `max_time` {log(min(accumulated_attention_probs, 1.0))}
1252 coverage_penalty = tf.reduce_sum(
1253 tf.math.log(tf.minimum(accumulated_attention_probs, 1.0)), 2
1254 )
1255 # Apply coverage penalty to finished predictions.
1256 coverage_penalty *= tf.cast(finished, tf.float32)
1257 weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight
1258 # Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1]
1259 weighted_coverage_penalty = tf.expand_dims(weighted_coverage_penalty, 2)
1260 return scores + weighted_coverage_penalty
1263def attention_probs_from_attn_state(attention_state):
1264 """Calculates the average attention probabilities.
1266 Args:
1267 attention_state: An instance of `AttentionWrapperState`.
1269 Returns:
1270 The attention probabilities in the given AttentionWrapperState.
1271 If there're multiple attention mechanisms, return the average value from
1272 all attention mechanisms.
1273 """
1274 # Attention probabilities over time steps, with shape
1275 # `[batch_size, beam_width, max_time]`.
1276 attention_probs = attention_state.alignments
1277 if isinstance(attention_probs, tuple):
1278 attention_probs = [tf.expand_dims(prob, -1) for prob in attention_probs]
1279 attention_probs = tf.concat(attention_probs, -1)
1280 attention_probs = tf.reduce_mean(attention_probs, -1)
1281 return attention_probs
1284def _length_penalty(sequence_lengths, penalty_factor):
1285 """Calculates the length penalty. See https://arxiv.org/abs/1609.08144.
1287 Returns the length penalty tensor:
1288 ```
1289 [(5+sequence_lengths)/6]**penalty_factor
1290 ```
1291 where all operations are performed element-wise.
1293 Args:
1294 sequence_lengths: `Tensor`, the sequence lengths of each hypotheses.
1295 penalty_factor: A scalar that weights the length penalty.
1297 Returns:
1298 If the penalty is `0`, returns the scalar `1.0`. Otherwise returns
1299 the length penalty factor, a tensor with the same shape as
1300 `sequence_lengths`.
1301 """
1302 penalty_factor = tf.convert_to_tensor(penalty_factor, name="penalty_factor")
1303 penalty_factor.set_shape(()) # penalty should be a scalar.
1304 static_penalty = tf.get_static_value(penalty_factor)
1305 if static_penalty is not None and static_penalty == 0:
1306 return 1.0
1307 return tf.math.divide(
1308 (5.0 + tf.cast(sequence_lengths, tf.float32)) ** penalty_factor,
1309 (5.0 + 1.0) ** penalty_factor,
1310 )
1313def _mask_probs(probs, eos_token, finished):
1314 """Masks log probabilities.
1316 The result is that finished beams allocate all probability mass to eos and
1317 unfinished beams remain unchanged.
1319 Args:
1320 probs: Log probabilities of shape `[batch_size, beam_width, vocab_size]`
1321 eos_token: An int32 id corresponding to the EOS token to allocate
1322 probability to.
1323 finished: A boolean tensor of shape `[batch_size, beam_width]` that
1324 specifies which elements in the beam are finished already.
1326 Returns:
1327 A tensor of shape `[batch_size, beam_width, vocab_size]`, where
1328 unfinished beams stay unchanged and finished beams are replaced with a
1329 tensor with all probability on the EOS token.
1330 """
1331 vocab_size = tf.shape(probs)[2]
1332 # All finished examples are replaced with a vector that has all
1333 # probability on EOS
1334 finished_row = tf.one_hot(
1335 eos_token,
1336 vocab_size,
1337 dtype=probs.dtype,
1338 on_value=tf.convert_to_tensor(0.0, dtype=probs.dtype),
1339 off_value=probs.dtype.min,
1340 )
1341 finished_probs = tf.tile(
1342 tf.reshape(finished_row, [1, 1, -1]), tf.concat([tf.shape(finished), [1]], 0)
1343 )
1344 finished_mask = tf.tile(tf.expand_dims(finished, 2), [1, 1, vocab_size])
1346 return tf.where(finished_mask, finished_probs, probs)
1349def _maybe_tensor_gather_helper(
1350 gather_indices, gather_from, batch_size, range_size, gather_shape
1351):
1352 """Maybe applies _tensor_gather_helper.
1354 This applies _tensor_gather_helper when the gather_from dims is at least as
1355 big as the length of gather_shape. This is used in conjunction with nest so
1356 that we don't apply _tensor_gather_helper to inapplicable values like
1357 scalars.
1359 Args:
1360 gather_indices: The tensor indices that we use to gather.
1361 gather_from: The tensor that we are gathering from.
1362 batch_size: The batch size.
1363 range_size: The number of values in each range. Likely equal to
1364 beam_width.
1365 gather_shape: What we should reshape gather_from to in order to preserve
1366 the correct values. An example is when gather_from is the attention
1367 from an AttentionWrapperState with shape
1368 [batch_size, beam_width, attention_size]. There, we want to preserve
1369 the attention_size elements, so gather_shape is
1370 [batch_size * beam_width, -1]. Then, upon reshape, we still have the
1371 attention_size as desired.
1373 Returns:
1374 output: Gathered tensor of shape
1375 tf.shape(gather_from)[:1+len(gather_shape)] or the original tensor if
1376 its dimensions are too small.
1377 """
1378 if isinstance(gather_from, tf.TensorArray):
1379 return gather_from
1380 _check_ndims(gather_from)
1381 if gather_from.shape.ndims >= len(gather_shape):
1382 return _tensor_gather_helper(
1383 gather_indices=gather_indices,
1384 gather_from=gather_from,
1385 batch_size=batch_size,
1386 range_size=range_size,
1387 gather_shape=gather_shape,
1388 )
1389 else:
1390 return gather_from
1393def _tensor_gather_helper(
1394 gather_indices, gather_from, batch_size, range_size, gather_shape, name=None
1395):
1396 """Helper for gathering the right indices from the tensor.
1398 This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
1399 gathering from that according to the gather_indices, which are offset by
1400 the right amounts in order to preserve the batch order.
1402 Args:
1403 gather_indices: The tensor indices that we use to gather.
1404 gather_from: The tensor that we are gathering from.
1405 batch_size: The input batch size.
1406 range_size: The number of values in each range. Likely equal to
1407 beam_width.
1408 gather_shape: What we should reshape gather_from to in order to preserve
1409 the correct values. An example is when gather_from is the attention
1410 from an AttentionWrapperState with shape
1411 [batch_size, beam_width, attention_size]. There, we want to preserve
1412 the attention_size elements, so gather_shape is
1413 [batch_size * beam_width, -1]. Then, upon reshape, we still have the
1414 attention_size as desired.
1415 name: The tensor name for set of operations. By default this is
1416 'tensor_gather_helper'. The final output is named 'output'.
1418 Returns:
1419 output: Gathered tensor of shape
1420 tf.shape(gather_from)[:1+len(gather_shape)]
1421 """
1422 with tf.name_scope(name or "tensor_gather_helper"):
1423 range_ = tf.expand_dims(tf.range(batch_size) * range_size, 1)
1424 gather_indices = tf.reshape(gather_indices + range_, [-1])
1425 output = tf.gather(tf.reshape(gather_from, gather_shape), gather_indices)
1426 final_shape = tf.shape(gather_from)[: 1 + len(gather_shape)]
1427 static_batch_size = tf.get_static_value(batch_size)
1428 final_static_shape = tf.TensorShape([static_batch_size]).concatenate(
1429 gather_from.shape[1 : 1 + len(gather_shape)]
1430 )
1431 output = tf.reshape(output, final_shape, name="output")
1432 output.set_shape(final_static_shape)
1433 return output