Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/seq2seq/basic_decoder.py: 47%

47 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A basic decoder that may sample to generate the next input.""" 

16 

17import collections 

18 

19import tensorflow as tf 

20 

21from tensorflow_addons.seq2seq import decoder 

22from tensorflow_addons.seq2seq import sampler as sampler_py 

23from tensorflow_addons.utils import keras_utils 

24 

25from typeguard import typechecked 

26from typing import Optional 

27 

28 

29class BasicDecoderOutput( 

30 collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id")) 

31): 

32 """Outputs of a `tfa.seq2seq.BasicDecoder` step. 

33 

34 Attributes: 

35 rnn_output: The output for this step. If the `output_layer` argument 

36 of `tfa.seq2seq.BasicDecoder` was set, it is the output of this layer, otherwise it 

37 is the output of the RNN cell. 

38 sample_id: The token IDs sampled for this step, as returned by the 

39 `sampler` instance passed to `tfa.seq2seq.BasicDecoder`. 

40 """ 

41 

42 pass 

43 

44 

45class BasicDecoder(decoder.BaseDecoder): 

46 """Basic sampling decoder for training and inference. 

47 

48 The `tfa.seq2seq.Sampler` instance passed as argument is responsible to sample from 

49 the output distribution and produce the input for the next decoding step. The decoding 

50 loop is implemented by the decoder in its `__call__` method. 

51 

52 Example using `tfa.seq2seq.TrainingSampler` for training: 

53 

54 >>> batch_size = 4 

55 >>> max_time = 7 

56 >>> hidden_size = 32 

57 >>> embedding_size = 48 

58 >>> input_vocab_size = 128 

59 >>> output_vocab_size = 64 

60 >>> 

61 >>> embedding_layer = tf.keras.layers.Embedding(input_vocab_size, embedding_size) 

62 >>> decoder_cell = tf.keras.layers.LSTMCell(hidden_size) 

63 >>> sampler = tfa.seq2seq.TrainingSampler() 

64 >>> output_layer = tf.keras.layers.Dense(output_vocab_size) 

65 >>> 

66 >>> decoder = tfa.seq2seq.BasicDecoder(decoder_cell, sampler, output_layer) 

67 >>> 

68 >>> input_ids = tf.random.uniform( 

69 ... [batch_size, max_time], maxval=input_vocab_size, dtype=tf.int64) 

70 >>> input_lengths = tf.fill([batch_size], max_time) 

71 >>> input_tensors = embedding_layer(input_ids) 

72 >>> initial_state = decoder_cell.get_initial_state(input_tensors) 

73 >>> 

74 >>> output, state, lengths = decoder( 

75 ... input_tensors, sequence_length=input_lengths, initial_state=initial_state) 

76 >>> 

77 >>> logits = output.rnn_output 

78 >>> logits.shape 

79 TensorShape([4, 7, 64]) 

80 

81 Example using `tfa.seq2seq.GreedyEmbeddingSampler` for inference: 

82 

83 >>> sampler = tfa.seq2seq.GreedyEmbeddingSampler(embedding_layer) 

84 >>> decoder = tfa.seq2seq.BasicDecoder( 

85 ... decoder_cell, sampler, output_layer, maximum_iterations=10) 

86 >>> 

87 >>> initial_state = decoder_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32) 

88 >>> start_tokens = tf.fill([batch_size], 1) 

89 >>> end_token = 2 

90 >>> 

91 >>> output, state, lengths = decoder( 

92 ... None, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state) 

93 >>> 

94 >>> output.sample_id.shape 

95 TensorShape([4, 10]) 

96 """ 

97 

98 @typechecked 

99 def __init__( 

100 self, 

101 cell: tf.keras.layers.Layer, 

102 sampler: sampler_py.Sampler, 

103 output_layer: Optional[tf.keras.layers.Layer] = None, 

104 **kwargs, 

105 ): 

106 """Initialize BasicDecoder. 

107 

108 Args: 

109 cell: A layer that implements the `tf.keras.layers.AbstractRNNCell` 

110 interface. 

111 sampler: A `tfa.seq2seq.Sampler` instance. 

112 output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., 

113 `tf.keras.layers.Dense`. Optional layer to apply to the RNN output 

114 prior to storing the result or sampling. 

115 **kwargs: Other keyword arguments of `tfa.seq2seq.BaseDecoder`. 

116 """ 

117 keras_utils.assert_like_rnncell("cell", cell) 

118 self.cell = cell 

119 self.sampler = sampler 

120 self.output_layer = output_layer 

121 super().__init__(**kwargs) 

122 

123 def initialize(self, inputs, initial_state=None, **kwargs): 

124 """Initialize the decoder.""" 

125 # Assume the dtype of the cell is the output_size structure 

126 # containing the input_state's first component's dtype. 

127 self._cell_dtype = tf.nest.flatten(initial_state)[0].dtype 

128 return self.sampler.initialize(inputs, **kwargs) + (initial_state,) 

129 

130 @property 

131 def batch_size(self): 

132 return self.sampler.batch_size 

133 

134 def _rnn_output_size(self): 

135 size = tf.TensorShape(self.cell.output_size) 

136 if self.output_layer is None: 

137 return size 

138 else: 

139 # To use layer's compute_output_shape, we need to convert the 

140 # RNNCell's output_size entries into shapes with an unknown 

141 # batch size. We then pass this through the layer's 

142 # compute_output_shape and read off all but the first (batch) 

143 # dimensions to get the output size of the rnn with the layer 

144 # applied to the top. 

145 output_shape_with_unknown_batch = tf.nest.map_structure( 

146 lambda s: tf.TensorShape([None]).concatenate(s), size 

147 ) 

148 layer_output_shape = self.output_layer.compute_output_shape( 

149 output_shape_with_unknown_batch 

150 ) 

151 return tf.nest.map_structure(lambda s: s[1:], layer_output_shape) 

152 

153 @property 

154 def output_size(self): 

155 # Return the cell output and the id 

156 return BasicDecoderOutput( 

157 rnn_output=self._rnn_output_size(), sample_id=self.sampler.sample_ids_shape 

158 ) 

159 

160 @property 

161 def output_dtype(self): 

162 # Assume the dtype of the cell is the output_size structure 

163 # containing the input_state's first component's dtype. 

164 # Return that structure and the sample_ids_dtype from the helper. 

165 dtype = self._cell_dtype 

166 return BasicDecoderOutput( 

167 tf.nest.map_structure(lambda _: dtype, self._rnn_output_size()), 

168 self.sampler.sample_ids_dtype, 

169 ) 

170 

171 def step(self, time, inputs, state, training=None): 

172 """Perform a decoding step. 

173 

174 Args: 

175 time: scalar `int32` tensor. 

176 inputs: A (structure of) input tensors. 

177 state: A (structure of) state tensors and TensorArrays. 

178 training: Python boolean. 

179 

180 Returns: 

181 `(outputs, next_state, next_inputs, finished)`. 

182 """ 

183 cell_outputs, cell_state = self.cell(inputs, state, training=training) 

184 cell_state = tf.nest.pack_sequence_as(state, tf.nest.flatten(cell_state)) 

185 if self.output_layer is not None: 

186 cell_outputs = self.output_layer(cell_outputs) 

187 sample_ids = self.sampler.sample( 

188 time=time, outputs=cell_outputs, state=cell_state 

189 ) 

190 (finished, next_inputs, next_state) = self.sampler.next_inputs( 

191 time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids 

192 ) 

193 outputs = BasicDecoderOutput(cell_outputs, sample_ids) 

194 return (outputs, next_state, next_inputs, finished)