Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/seq2seq/basic_decoder.py: 47%
47 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A basic decoder that may sample to generate the next input."""
17import collections
19import tensorflow as tf
21from tensorflow_addons.seq2seq import decoder
22from tensorflow_addons.seq2seq import sampler as sampler_py
23from tensorflow_addons.utils import keras_utils
25from typeguard import typechecked
26from typing import Optional
29class BasicDecoderOutput(
30 collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))
31):
32 """Outputs of a `tfa.seq2seq.BasicDecoder` step.
34 Attributes:
35 rnn_output: The output for this step. If the `output_layer` argument
36 of `tfa.seq2seq.BasicDecoder` was set, it is the output of this layer, otherwise it
37 is the output of the RNN cell.
38 sample_id: The token IDs sampled for this step, as returned by the
39 `sampler` instance passed to `tfa.seq2seq.BasicDecoder`.
40 """
42 pass
45class BasicDecoder(decoder.BaseDecoder):
46 """Basic sampling decoder for training and inference.
48 The `tfa.seq2seq.Sampler` instance passed as argument is responsible to sample from
49 the output distribution and produce the input for the next decoding step. The decoding
50 loop is implemented by the decoder in its `__call__` method.
52 Example using `tfa.seq2seq.TrainingSampler` for training:
54 >>> batch_size = 4
55 >>> max_time = 7
56 >>> hidden_size = 32
57 >>> embedding_size = 48
58 >>> input_vocab_size = 128
59 >>> output_vocab_size = 64
60 >>>
61 >>> embedding_layer = tf.keras.layers.Embedding(input_vocab_size, embedding_size)
62 >>> decoder_cell = tf.keras.layers.LSTMCell(hidden_size)
63 >>> sampler = tfa.seq2seq.TrainingSampler()
64 >>> output_layer = tf.keras.layers.Dense(output_vocab_size)
65 >>>
66 >>> decoder = tfa.seq2seq.BasicDecoder(decoder_cell, sampler, output_layer)
67 >>>
68 >>> input_ids = tf.random.uniform(
69 ... [batch_size, max_time], maxval=input_vocab_size, dtype=tf.int64)
70 >>> input_lengths = tf.fill([batch_size], max_time)
71 >>> input_tensors = embedding_layer(input_ids)
72 >>> initial_state = decoder_cell.get_initial_state(input_tensors)
73 >>>
74 >>> output, state, lengths = decoder(
75 ... input_tensors, sequence_length=input_lengths, initial_state=initial_state)
76 >>>
77 >>> logits = output.rnn_output
78 >>> logits.shape
79 TensorShape([4, 7, 64])
81 Example using `tfa.seq2seq.GreedyEmbeddingSampler` for inference:
83 >>> sampler = tfa.seq2seq.GreedyEmbeddingSampler(embedding_layer)
84 >>> decoder = tfa.seq2seq.BasicDecoder(
85 ... decoder_cell, sampler, output_layer, maximum_iterations=10)
86 >>>
87 >>> initial_state = decoder_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
88 >>> start_tokens = tf.fill([batch_size], 1)
89 >>> end_token = 2
90 >>>
91 >>> output, state, lengths = decoder(
92 ... None, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state)
93 >>>
94 >>> output.sample_id.shape
95 TensorShape([4, 10])
96 """
98 @typechecked
99 def __init__(
100 self,
101 cell: tf.keras.layers.Layer,
102 sampler: sampler_py.Sampler,
103 output_layer: Optional[tf.keras.layers.Layer] = None,
104 **kwargs,
105 ):
106 """Initialize BasicDecoder.
108 Args:
109 cell: A layer that implements the `tf.keras.layers.AbstractRNNCell`
110 interface.
111 sampler: A `tfa.seq2seq.Sampler` instance.
112 output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
113 `tf.keras.layers.Dense`. Optional layer to apply to the RNN output
114 prior to storing the result or sampling.
115 **kwargs: Other keyword arguments of `tfa.seq2seq.BaseDecoder`.
116 """
117 keras_utils.assert_like_rnncell("cell", cell)
118 self.cell = cell
119 self.sampler = sampler
120 self.output_layer = output_layer
121 super().__init__(**kwargs)
123 def initialize(self, inputs, initial_state=None, **kwargs):
124 """Initialize the decoder."""
125 # Assume the dtype of the cell is the output_size structure
126 # containing the input_state's first component's dtype.
127 self._cell_dtype = tf.nest.flatten(initial_state)[0].dtype
128 return self.sampler.initialize(inputs, **kwargs) + (initial_state,)
130 @property
131 def batch_size(self):
132 return self.sampler.batch_size
134 def _rnn_output_size(self):
135 size = tf.TensorShape(self.cell.output_size)
136 if self.output_layer is None:
137 return size
138 else:
139 # To use layer's compute_output_shape, we need to convert the
140 # RNNCell's output_size entries into shapes with an unknown
141 # batch size. We then pass this through the layer's
142 # compute_output_shape and read off all but the first (batch)
143 # dimensions to get the output size of the rnn with the layer
144 # applied to the top.
145 output_shape_with_unknown_batch = tf.nest.map_structure(
146 lambda s: tf.TensorShape([None]).concatenate(s), size
147 )
148 layer_output_shape = self.output_layer.compute_output_shape(
149 output_shape_with_unknown_batch
150 )
151 return tf.nest.map_structure(lambda s: s[1:], layer_output_shape)
153 @property
154 def output_size(self):
155 # Return the cell output and the id
156 return BasicDecoderOutput(
157 rnn_output=self._rnn_output_size(), sample_id=self.sampler.sample_ids_shape
158 )
160 @property
161 def output_dtype(self):
162 # Assume the dtype of the cell is the output_size structure
163 # containing the input_state's first component's dtype.
164 # Return that structure and the sample_ids_dtype from the helper.
165 dtype = self._cell_dtype
166 return BasicDecoderOutput(
167 tf.nest.map_structure(lambda _: dtype, self._rnn_output_size()),
168 self.sampler.sample_ids_dtype,
169 )
171 def step(self, time, inputs, state, training=None):
172 """Perform a decoding step.
174 Args:
175 time: scalar `int32` tensor.
176 inputs: A (structure of) input tensors.
177 state: A (structure of) state tensors and TensorArrays.
178 training: Python boolean.
180 Returns:
181 `(outputs, next_state, next_inputs, finished)`.
182 """
183 cell_outputs, cell_state = self.cell(inputs, state, training=training)
184 cell_state = tf.nest.pack_sequence_as(state, tf.nest.flatten(cell_state))
185 if self.output_layer is not None:
186 cell_outputs = self.output_layer(cell_outputs)
187 sample_ids = self.sampler.sample(
188 time=time, outputs=cell_outputs, state=cell_state
189 )
190 (finished, next_inputs, next_state) = self.sampler.next_inputs(
191 time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids
192 )
193 outputs = BasicDecoderOutput(cell_outputs, sample_ids)
194 return (outputs, next_state, next_inputs, finished)