Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/seq2seq/decoder.py: 24%
169 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base classes and functions for dynamic decoding."""
17import abc
19import tensorflow as tf
20from tensorflow_addons.utils.types import TensorLike
21from typeguard import typechecked
22from typing import Any, Optional, Tuple, Union
24# TODO: Find public API alternatives to these
25from tensorflow.python.ops import control_flow_util
28class Decoder(metaclass=abc.ABCMeta):
29 """An RNN Decoder abstract interface object.
31 Concepts used by this interface:
32 - `inputs`: (structure of) tensors and TensorArrays that is passed as input
33 to the RNN cell composing the decoder, at each time step.
34 - `state`: (structure of) tensors and TensorArrays that is passed to the
35 RNN cell instance as the state.
36 - `finished`: boolean tensor telling whether each sequence in the batch is
37 finished.
38 - `training`: boolean whether it should behave in training mode or in
39 inference mode.
40 - `outputs`: instance of `tfa.seq2seq.BasicDecoderOutput`. Result of the decoding, at
41 each time step.
42 """
44 @property
45 def batch_size(self):
46 """The batch size of input values."""
47 raise NotImplementedError
49 @property
50 def output_size(self):
51 """A (possibly nested tuple of...) integer[s] or `TensorShape`
52 object[s]."""
53 raise NotImplementedError
55 @property
56 def output_dtype(self):
57 """A (possibly nested tuple of...) dtype[s]."""
58 raise NotImplementedError
60 @abc.abstractmethod
61 def initialize(self, name=None):
62 """Called before any decoding iterations.
64 This methods must compute initial input values and initial state.
66 Args:
67 name: Name scope for any created operations.
69 Returns:
70 `(finished, initial_inputs, initial_state)`: initial values of
71 'finished' flags, inputs and state.
72 """
73 raise NotImplementedError
75 @abc.abstractmethod
76 def step(self, time, inputs, state, training=None, name=None):
77 """Called per step of decoding (but only once for dynamic decoding).
79 Args:
80 time: Scalar `int32` tensor. Current step number.
81 inputs: RNN cell input (possibly nested tuple of) tensor[s] for this
82 time step.
83 state: RNN cell state (possibly nested tuple of) tensor[s] from
84 previous time step.
85 training: Python boolean. Indicates whether the layer should behave
86 in training mode or in inference mode. Only relevant
87 when `dropout` or `recurrent_dropout` is used.
88 name: Name scope for any created operations.
90 Returns:
91 `(outputs, next_state, next_inputs, finished)`: `outputs` is an
92 object containing the decoder output, `next_state` is a (structure
93 of) state tensors and TensorArrays, `next_inputs` is the tensor that
94 should be used as input for the next step, `finished` is a boolean
95 tensor telling whether the sequence is complete, for each sequence in
96 the batch.
97 """
98 raise NotImplementedError
100 def finalize(self, outputs, final_state, sequence_lengths):
101 raise NotImplementedError
103 @property
104 def tracks_own_finished(self):
105 """Describes whether the Decoder keeps track of finished states.
107 Most decoders will emit a true/false `finished` value independently
108 at each time step. In this case, the `tfa.seq2seq.dynamic_decode` function keeps
109 track of which batch entries are already finished, and performs a
110 logical OR to insert new batches to the finished set.
112 Some decoders, however, shuffle batches / beams between time steps and
113 `tfa.seq2seq.dynamic_decode` will mix up the finished state across these entries
114 because it does not track the reshuffle across time steps. In this
115 case, it is up to the decoder to declare that it will keep track of its
116 own finished state by setting this property to `True`.
118 Returns:
119 Python bool.
120 """
121 return False
124class BaseDecoder(tf.keras.layers.Layer):
125 """An RNN Decoder that is based on a Keras layer.
127 Concepts used by this interface:
128 - `inputs`: (structure of) Tensors and TensorArrays that is passed as input
129 to the RNN cell composing the decoder, at each time step.
130 - `state`: (structure of) Tensors and TensorArrays that is passed to the
131 RNN cell instance as the state.
132 - `memory`: tensor that is usually the full output of the encoder, which
133 will be used for the attention wrapper for the RNN cell.
134 - `finished`: boolean tensor telling whether each sequence in the batch is
135 finished.
136 - `training`: boolean whether it should behave in training mode or in
137 inference mode.
138 - `outputs`: instance of `tfa.seq2seq.BasicDecoderOutput`. Result of the decoding, at
139 each time step.
140 """
142 @typechecked
143 def __init__(
144 self,
145 output_time_major: bool = False,
146 impute_finished: bool = False,
147 maximum_iterations: Optional[TensorLike] = None,
148 parallel_iterations: int = 32,
149 swap_memory: bool = False,
150 **kwargs,
151 ):
152 self.output_time_major = output_time_major
153 self.impute_finished = impute_finished
154 self.maximum_iterations = maximum_iterations
155 self.parallel_iterations = parallel_iterations
156 self.swap_memory = swap_memory
157 super().__init__(**kwargs)
159 def call(self, inputs, initial_state=None, training=None, **kwargs):
160 init_kwargs = kwargs
161 init_kwargs["initial_state"] = initial_state
162 return dynamic_decode(
163 self,
164 output_time_major=self.output_time_major,
165 impute_finished=self.impute_finished,
166 maximum_iterations=self.maximum_iterations,
167 parallel_iterations=self.parallel_iterations,
168 swap_memory=self.swap_memory,
169 training=training,
170 decoder_init_input=inputs,
171 decoder_init_kwargs=init_kwargs,
172 )
174 @property
175 def batch_size(self):
176 """The batch size of input values."""
177 raise NotImplementedError
179 @property
180 def output_size(self):
181 """A (possibly nested tuple of...) integer[s] or `TensorShape`
182 object[s]."""
183 raise NotImplementedError
185 @property
186 def output_dtype(self):
187 """A (possibly nested tuple of...) dtype[s]."""
188 raise NotImplementedError
190 def initialize(self, inputs, initial_state=None, **kwargs):
191 """Called before any decoding iterations.
193 This methods must compute initial input values and initial state.
195 Args:
196 inputs: (structure of) tensors that contains the input for the
197 decoder. In the normal case, it's a tensor with shape
198 [batch, timestep, embedding].
199 initial_state: (structure of) tensors that contains the initial state
200 for the RNN cell.
201 **kwargs: Other arguments that are passed in from layer.call()
202 method. It could contains item like input `sequence_length`, or
203 masking for input.
205 Returns:
206 `(finished, initial_inputs, initial_state)`: initial values of
207 'finished' flags, inputs and state.
208 """
209 raise NotImplementedError
211 def step(self, time, inputs, state, training):
212 """Called per step of decoding (but only once for dynamic decoding).
214 Args:
215 time: Scalar `int32` tensor. Current step number.
216 inputs: RNN cell input (possibly nested tuple of) tensor[s] for this
217 time step.
218 state: RNN cell state (possibly nested tuple of) tensor[s] from
219 previous time step.
220 training: Python boolean. Indicates whether the layer should
221 behave in training mode or in inference mode.
223 Returns:
224 `(outputs, next_state, next_inputs, finished)`: `outputs` is an
225 object containing the decoder output, `next_state` is a
226 (structure of) state tensors and TensorArrays, `next_inputs` is the
227 tensor that should be used as input for the next step, `finished` is
228 a boolean tensor telling whether the sequence is complete, for each
229 sequence in the batch.
230 """
231 raise NotImplementedError
233 def finalize(self, outputs, final_state, sequence_lengths):
234 raise NotImplementedError
236 @property
237 def tracks_own_finished(self):
238 """Describes whether the Decoder keeps track of finished states.
240 Most decoders will emit a true/false `finished` value independently
241 at each time step. In this case, the `tfa.seq2seq.dynamic_decode` function keeps
242 track of which batch entries are already finished, and performs a
243 logical OR to insert new batches to the finished set.
245 Some decoders, however, shuffle batches / beams between time steps and
246 `tfa.seq2seq.dynamic_decode` will mix up the finished state across these entries
247 because it does not track the reshuffle across time steps. In this
248 case, it is up to the decoder to declare that it will keep track of its
249 own finished state by setting this property to `True`.
251 Returns:
252 Python bool.
253 """
254 return False
256 # TODO(scottzhu): Add build/get_config/from_config and other layer methods.
259@typechecked
260def dynamic_decode(
261 decoder: Union[Decoder, BaseDecoder],
262 output_time_major: bool = False,
263 impute_finished: bool = False,
264 maximum_iterations: Optional[TensorLike] = None,
265 parallel_iterations: int = 32,
266 swap_memory: bool = False,
267 training: Optional[bool] = None,
268 scope: Optional[str] = None,
269 enable_tflite_convertible: bool = False,
270 **kwargs,
271) -> Tuple[Any, Any, Any]:
272 """Runs dynamic decoding with a decoder.
274 Calls `initialize()` once and `step()` repeatedly on the decoder object.
276 Args:
277 decoder: A `tfa.seq2seq.Decoder` or `tfa.seq2seq.BaseDecoder` instance.
278 output_time_major: Python boolean. Default: `False` (batch major). If
279 `True`, outputs are returned as time major tensors (this mode is
280 faster). Otherwise, outputs are returned as batch major tensors (this
281 adds extra time to the computation).
282 impute_finished: Python boolean. If `True`, then states for batch
283 entries which are marked as finished get copied through and the
284 corresponding outputs get zeroed out. This causes some slowdown at
285 each time step, but ensures that the final state and outputs have
286 the correct values and that backprop ignores time steps that were
287 marked as finished.
288 maximum_iterations: A strictly positive `int32` scalar, the maximum
289 allowed number of decoding steps. Default is `None` (decode until the
290 decoder is fully done).
291 parallel_iterations: Argument passed to `tf.while_loop`.
292 swap_memory: Argument passed to `tf.while_loop`.
293 training: Python boolean. Indicates whether the layer should behave
294 in training mode or in inference mode. Only relevant
295 when `dropout` or `recurrent_dropout` is used.
296 scope: Optional name scope to use.
297 enable_tflite_convertible: Python boolean. If `True`, then the variables
298 of `TensorArray` become of 1-D static shape. Also zero pads in the
299 output tensor will be discarded. Default: `False`.
300 **kwargs: dict, other keyword arguments for dynamic_decode. It might
301 contain arguments for `BaseDecoder` to initialize, which takes all
302 tensor inputs during call().
304 Returns:
305 `(final_outputs, final_state, final_sequence_lengths)`.
307 Raises:
308 ValueError: if `maximum_iterations` is provided but is not a scalar.
309 """
310 with tf.name_scope(scope or "decoder"):
311 is_xla = (
312 not tf.executing_eagerly()
313 and control_flow_util.GraphOrParentsInXlaContext(
314 tf.compat.v1.get_default_graph()
315 )
316 )
318 if maximum_iterations is not None:
319 maximum_iterations = tf.convert_to_tensor(
320 maximum_iterations, dtype=tf.int32, name="maximum_iterations"
321 )
322 if maximum_iterations.shape.ndims != 0:
323 raise ValueError("maximum_iterations must be a scalar")
324 tf.debugging.assert_greater(
325 maximum_iterations,
326 0,
327 message="maximum_iterations should be greater than 0",
328 )
329 elif is_xla:
330 raise ValueError("maximum_iterations is required for XLA compilation.")
332 if isinstance(decoder, Decoder):
333 initial_finished, initial_inputs, initial_state = decoder.initialize()
334 else:
335 # For BaseDecoder that takes tensor inputs during call.
336 decoder_init_input = kwargs.pop("decoder_init_input", None)
337 decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {})
338 initial_finished, initial_inputs, initial_state = decoder.initialize(
339 decoder_init_input, **decoder_init_kwargs
340 )
342 if enable_tflite_convertible:
343 # Assume the batch_size = 1 for inference.
344 # So we can change 2-D TensorArray into 1-D by reshaping it.
345 tf.debugging.assert_equal(
346 decoder.batch_size,
347 1,
348 message="TFLite conversion requires a batch size of 1",
349 )
350 zero_outputs = tf.nest.map_structure(
351 lambda shape, dtype: tf.reshape(
352 tf.zeros(_prepend_batch(decoder.batch_size, shape), dtype=dtype),
353 [-1],
354 ),
355 decoder.output_size,
356 decoder.output_dtype,
357 )
358 else:
359 zero_outputs = tf.nest.map_structure(
360 lambda shape, dtype: tf.zeros(
361 _prepend_batch(decoder.batch_size, shape), dtype=dtype
362 ),
363 decoder.output_size,
364 decoder.output_dtype,
365 )
367 if maximum_iterations is not None:
368 initial_finished = tf.logical_or(initial_finished, 0 >= maximum_iterations)
369 initial_sequence_lengths = tf.zeros_like(initial_finished, dtype=tf.int32)
370 initial_time = tf.constant(0, dtype=tf.int32)
372 def _shape(batch_size, from_shape):
373 if not isinstance(from_shape, tf.TensorShape) or from_shape.ndims == 0:
374 return None
375 else:
376 batch_size = tf.get_static_value(
377 tf.convert_to_tensor(batch_size, name="batch_size")
378 )
379 return tf.TensorShape([batch_size]).concatenate(from_shape)
381 dynamic_size = maximum_iterations is None or not is_xla
382 # The dynamic shape `TensorArray` is not allowed in TFLite yet.
383 dynamic_size = dynamic_size and (not enable_tflite_convertible)
385 def _create_ta(s, d):
386 if enable_tflite_convertible:
387 # TFLite requires 1D element_shape.
388 if isinstance(s, tf.TensorShape) and s.ndims == 0:
389 s = (1,)
390 element_shape = s
391 else:
392 element_shape = _shape(decoder.batch_size, s)
393 return tf.TensorArray(
394 dtype=d,
395 size=0 if dynamic_size else maximum_iterations,
396 dynamic_size=dynamic_size,
397 element_shape=element_shape,
398 )
400 initial_outputs_ta = tf.nest.map_structure(
401 _create_ta, decoder.output_size, decoder.output_dtype
402 )
404 def condition(
405 unused_time,
406 unused_outputs_ta,
407 unused_state,
408 unused_inputs,
409 finished,
410 unused_sequence_lengths,
411 ):
412 return tf.logical_not(tf.reduce_all(finished))
414 def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
415 """Internal while_loop body.
417 Args:
418 time: scalar int32 tensor.
419 outputs_ta: structure of TensorArray.
420 state: (structure of) state tensors and TensorArrays.
421 inputs: (structure of) input tensors.
422 finished: bool tensor (keeping track of what's finished).
423 sequence_lengths: int32 tensor (keeping track of time of finish).
425 Returns:
426 `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
427 next_sequence_lengths)`.
428 ```
429 """
430 (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(
431 time, inputs, state, training
432 )
433 decoder_state_sequence_lengths = False
434 if decoder.tracks_own_finished:
435 next_finished = decoder_finished
436 lengths = getattr(decoder_state, "lengths", None)
437 if lengths is not None:
438 # sequence lengths are provided by decoder_state.lengths;
439 # overwrite our sequence lengths.
440 decoder_state_sequence_lengths = True
441 sequence_lengths = tf.cast(lengths, tf.int32)
442 else:
443 next_finished = tf.logical_or(decoder_finished, finished)
445 if decoder_state_sequence_lengths:
446 # Just pass something through the loop; at the next iteration
447 # we'll pull the sequence lengths from the decoder_state again.
448 next_sequence_lengths = sequence_lengths
449 else:
450 next_sequence_lengths = tf.where(
451 tf.logical_not(finished),
452 tf.fill(tf.shape(sequence_lengths), time + 1),
453 sequence_lengths,
454 )
456 tf.nest.assert_same_structure(state, decoder_state)
457 tf.nest.assert_same_structure(outputs_ta, next_outputs)
458 tf.nest.assert_same_structure(inputs, next_inputs)
460 # Zero out output values past finish
461 if impute_finished:
463 def zero_out_finished(out, zero):
464 if finished.shape.rank < zero.shape.rank:
465 broadcast_finished = tf.broadcast_to(
466 tf.expand_dims(finished, axis=-1), zero.shape
467 )
468 return tf.where(broadcast_finished, zero, out)
469 else:
470 return tf.where(finished, zero, out)
472 emit = tf.nest.map_structure(
473 zero_out_finished, next_outputs, zero_outputs
474 )
475 else:
476 emit = next_outputs
478 # Copy through states past finish
479 def _maybe_copy_state(new, cur):
480 # TensorArrays and scalar states get passed through.
481 if isinstance(cur, tf.TensorArray):
482 pass_through = True
483 else:
484 new.set_shape(cur.shape)
485 pass_through = new.shape.ndims == 0
486 if not pass_through:
487 broadcast_finished = tf.broadcast_to(
488 tf.expand_dims(finished, axis=-1), new.shape
489 )
490 return tf.where(broadcast_finished, cur, new)
491 else:
492 return new
494 if impute_finished:
495 next_state = tf.nest.map_structure(
496 _maybe_copy_state, decoder_state, state
497 )
498 else:
499 next_state = decoder_state
501 if enable_tflite_convertible:
502 # Reshape to 1-D.
503 emit = tf.nest.map_structure(lambda x: tf.reshape(x, [-1]), emit)
505 outputs_ta = tf.nest.map_structure(
506 lambda ta, out: ta.write(time, out), outputs_ta, emit
507 )
508 return (
509 time + 1,
510 outputs_ta,
511 next_state,
512 next_inputs,
513 next_finished,
514 next_sequence_lengths,
515 )
517 res = tf.while_loop(
518 condition,
519 body,
520 loop_vars=(
521 initial_time,
522 initial_outputs_ta,
523 initial_state,
524 initial_inputs,
525 initial_finished,
526 initial_sequence_lengths,
527 ),
528 parallel_iterations=parallel_iterations,
529 maximum_iterations=maximum_iterations,
530 swap_memory=swap_memory,
531 )
533 final_outputs_ta = res[1]
534 final_state = res[2]
535 final_sequence_lengths = res[5]
537 final_outputs = tf.nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
539 try:
540 final_outputs, final_state = decoder.finalize(
541 final_outputs, final_state, final_sequence_lengths
542 )
543 except NotImplementedError:
544 pass
546 if not output_time_major:
547 if enable_tflite_convertible:
548 # Reshape the output to the original shape.
549 def _restore_batch(x):
550 return tf.expand_dims(x, [1])
552 final_outputs = tf.nest.map_structure(_restore_batch, final_outputs)
554 final_outputs = tf.nest.map_structure(_transpose_batch_time, final_outputs)
556 return final_outputs, final_state, final_sequence_lengths
559def _prepend_batch(batch_size, shape):
560 """Prepends the batch dimension to the shape.
562 If the batch_size value is known statically, this function returns a
563 TensorShape, otherwise a Tensor.
564 """
565 if isinstance(batch_size, tf.Tensor):
566 static_batch_size = tf.get_static_value(batch_size)
567 else:
568 static_batch_size = batch_size
569 if static_batch_size is None:
570 return tf.concat(([batch_size], shape), axis=0)
571 return [static_batch_size] + shape
574def _transpose_batch_time(tensor):
575 """Transposes the batch and time dimension of tensor if its rank is at
576 least 2."""
577 shape = tensor.shape
578 if shape.rank is not None and shape.rank < 2:
579 return tensor
580 perm = tf.concat(([1, 0], tf.range(2, tf.rank(tensor))), axis=0)
581 return tf.transpose(tensor, perm)