Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/seq2seq/decoder.py: 24%

169 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base classes and functions for dynamic decoding.""" 

16 

17import abc 

18 

19import tensorflow as tf 

20from tensorflow_addons.utils.types import TensorLike 

21from typeguard import typechecked 

22from typing import Any, Optional, Tuple, Union 

23 

24# TODO: Find public API alternatives to these 

25from tensorflow.python.ops import control_flow_util 

26 

27 

28class Decoder(metaclass=abc.ABCMeta): 

29 """An RNN Decoder abstract interface object. 

30 

31 Concepts used by this interface: 

32 - `inputs`: (structure of) tensors and TensorArrays that is passed as input 

33 to the RNN cell composing the decoder, at each time step. 

34 - `state`: (structure of) tensors and TensorArrays that is passed to the 

35 RNN cell instance as the state. 

36 - `finished`: boolean tensor telling whether each sequence in the batch is 

37 finished. 

38 - `training`: boolean whether it should behave in training mode or in 

39 inference mode. 

40 - `outputs`: instance of `tfa.seq2seq.BasicDecoderOutput`. Result of the decoding, at 

41 each time step. 

42 """ 

43 

44 @property 

45 def batch_size(self): 

46 """The batch size of input values.""" 

47 raise NotImplementedError 

48 

49 @property 

50 def output_size(self): 

51 """A (possibly nested tuple of...) integer[s] or `TensorShape` 

52 object[s].""" 

53 raise NotImplementedError 

54 

55 @property 

56 def output_dtype(self): 

57 """A (possibly nested tuple of...) dtype[s].""" 

58 raise NotImplementedError 

59 

60 @abc.abstractmethod 

61 def initialize(self, name=None): 

62 """Called before any decoding iterations. 

63 

64 This methods must compute initial input values and initial state. 

65 

66 Args: 

67 name: Name scope for any created operations. 

68 

69 Returns: 

70 `(finished, initial_inputs, initial_state)`: initial values of 

71 'finished' flags, inputs and state. 

72 """ 

73 raise NotImplementedError 

74 

75 @abc.abstractmethod 

76 def step(self, time, inputs, state, training=None, name=None): 

77 """Called per step of decoding (but only once for dynamic decoding). 

78 

79 Args: 

80 time: Scalar `int32` tensor. Current step number. 

81 inputs: RNN cell input (possibly nested tuple of) tensor[s] for this 

82 time step. 

83 state: RNN cell state (possibly nested tuple of) tensor[s] from 

84 previous time step. 

85 training: Python boolean. Indicates whether the layer should behave 

86 in training mode or in inference mode. Only relevant 

87 when `dropout` or `recurrent_dropout` is used. 

88 name: Name scope for any created operations. 

89 

90 Returns: 

91 `(outputs, next_state, next_inputs, finished)`: `outputs` is an 

92 object containing the decoder output, `next_state` is a (structure 

93 of) state tensors and TensorArrays, `next_inputs` is the tensor that 

94 should be used as input for the next step, `finished` is a boolean 

95 tensor telling whether the sequence is complete, for each sequence in 

96 the batch. 

97 """ 

98 raise NotImplementedError 

99 

100 def finalize(self, outputs, final_state, sequence_lengths): 

101 raise NotImplementedError 

102 

103 @property 

104 def tracks_own_finished(self): 

105 """Describes whether the Decoder keeps track of finished states. 

106 

107 Most decoders will emit a true/false `finished` value independently 

108 at each time step. In this case, the `tfa.seq2seq.dynamic_decode` function keeps 

109 track of which batch entries are already finished, and performs a 

110 logical OR to insert new batches to the finished set. 

111 

112 Some decoders, however, shuffle batches / beams between time steps and 

113 `tfa.seq2seq.dynamic_decode` will mix up the finished state across these entries 

114 because it does not track the reshuffle across time steps. In this 

115 case, it is up to the decoder to declare that it will keep track of its 

116 own finished state by setting this property to `True`. 

117 

118 Returns: 

119 Python bool. 

120 """ 

121 return False 

122 

123 

124class BaseDecoder(tf.keras.layers.Layer): 

125 """An RNN Decoder that is based on a Keras layer. 

126 

127 Concepts used by this interface: 

128 - `inputs`: (structure of) Tensors and TensorArrays that is passed as input 

129 to the RNN cell composing the decoder, at each time step. 

130 - `state`: (structure of) Tensors and TensorArrays that is passed to the 

131 RNN cell instance as the state. 

132 - `memory`: tensor that is usually the full output of the encoder, which 

133 will be used for the attention wrapper for the RNN cell. 

134 - `finished`: boolean tensor telling whether each sequence in the batch is 

135 finished. 

136 - `training`: boolean whether it should behave in training mode or in 

137 inference mode. 

138 - `outputs`: instance of `tfa.seq2seq.BasicDecoderOutput`. Result of the decoding, at 

139 each time step. 

140 """ 

141 

142 @typechecked 

143 def __init__( 

144 self, 

145 output_time_major: bool = False, 

146 impute_finished: bool = False, 

147 maximum_iterations: Optional[TensorLike] = None, 

148 parallel_iterations: int = 32, 

149 swap_memory: bool = False, 

150 **kwargs, 

151 ): 

152 self.output_time_major = output_time_major 

153 self.impute_finished = impute_finished 

154 self.maximum_iterations = maximum_iterations 

155 self.parallel_iterations = parallel_iterations 

156 self.swap_memory = swap_memory 

157 super().__init__(**kwargs) 

158 

159 def call(self, inputs, initial_state=None, training=None, **kwargs): 

160 init_kwargs = kwargs 

161 init_kwargs["initial_state"] = initial_state 

162 return dynamic_decode( 

163 self, 

164 output_time_major=self.output_time_major, 

165 impute_finished=self.impute_finished, 

166 maximum_iterations=self.maximum_iterations, 

167 parallel_iterations=self.parallel_iterations, 

168 swap_memory=self.swap_memory, 

169 training=training, 

170 decoder_init_input=inputs, 

171 decoder_init_kwargs=init_kwargs, 

172 ) 

173 

174 @property 

175 def batch_size(self): 

176 """The batch size of input values.""" 

177 raise NotImplementedError 

178 

179 @property 

180 def output_size(self): 

181 """A (possibly nested tuple of...) integer[s] or `TensorShape` 

182 object[s].""" 

183 raise NotImplementedError 

184 

185 @property 

186 def output_dtype(self): 

187 """A (possibly nested tuple of...) dtype[s].""" 

188 raise NotImplementedError 

189 

190 def initialize(self, inputs, initial_state=None, **kwargs): 

191 """Called before any decoding iterations. 

192 

193 This methods must compute initial input values and initial state. 

194 

195 Args: 

196 inputs: (structure of) tensors that contains the input for the 

197 decoder. In the normal case, it's a tensor with shape 

198 [batch, timestep, embedding]. 

199 initial_state: (structure of) tensors that contains the initial state 

200 for the RNN cell. 

201 **kwargs: Other arguments that are passed in from layer.call() 

202 method. It could contains item like input `sequence_length`, or 

203 masking for input. 

204 

205 Returns: 

206 `(finished, initial_inputs, initial_state)`: initial values of 

207 'finished' flags, inputs and state. 

208 """ 

209 raise NotImplementedError 

210 

211 def step(self, time, inputs, state, training): 

212 """Called per step of decoding (but only once for dynamic decoding). 

213 

214 Args: 

215 time: Scalar `int32` tensor. Current step number. 

216 inputs: RNN cell input (possibly nested tuple of) tensor[s] for this 

217 time step. 

218 state: RNN cell state (possibly nested tuple of) tensor[s] from 

219 previous time step. 

220 training: Python boolean. Indicates whether the layer should 

221 behave in training mode or in inference mode. 

222 

223 Returns: 

224 `(outputs, next_state, next_inputs, finished)`: `outputs` is an 

225 object containing the decoder output, `next_state` is a 

226 (structure of) state tensors and TensorArrays, `next_inputs` is the 

227 tensor that should be used as input for the next step, `finished` is 

228 a boolean tensor telling whether the sequence is complete, for each 

229 sequence in the batch. 

230 """ 

231 raise NotImplementedError 

232 

233 def finalize(self, outputs, final_state, sequence_lengths): 

234 raise NotImplementedError 

235 

236 @property 

237 def tracks_own_finished(self): 

238 """Describes whether the Decoder keeps track of finished states. 

239 

240 Most decoders will emit a true/false `finished` value independently 

241 at each time step. In this case, the `tfa.seq2seq.dynamic_decode` function keeps 

242 track of which batch entries are already finished, and performs a 

243 logical OR to insert new batches to the finished set. 

244 

245 Some decoders, however, shuffle batches / beams between time steps and 

246 `tfa.seq2seq.dynamic_decode` will mix up the finished state across these entries 

247 because it does not track the reshuffle across time steps. In this 

248 case, it is up to the decoder to declare that it will keep track of its 

249 own finished state by setting this property to `True`. 

250 

251 Returns: 

252 Python bool. 

253 """ 

254 return False 

255 

256 # TODO(scottzhu): Add build/get_config/from_config and other layer methods. 

257 

258 

259@typechecked 

260def dynamic_decode( 

261 decoder: Union[Decoder, BaseDecoder], 

262 output_time_major: bool = False, 

263 impute_finished: bool = False, 

264 maximum_iterations: Optional[TensorLike] = None, 

265 parallel_iterations: int = 32, 

266 swap_memory: bool = False, 

267 training: Optional[bool] = None, 

268 scope: Optional[str] = None, 

269 enable_tflite_convertible: bool = False, 

270 **kwargs, 

271) -> Tuple[Any, Any, Any]: 

272 """Runs dynamic decoding with a decoder. 

273 

274 Calls `initialize()` once and `step()` repeatedly on the decoder object. 

275 

276 Args: 

277 decoder: A `tfa.seq2seq.Decoder` or `tfa.seq2seq.BaseDecoder` instance. 

278 output_time_major: Python boolean. Default: `False` (batch major). If 

279 `True`, outputs are returned as time major tensors (this mode is 

280 faster). Otherwise, outputs are returned as batch major tensors (this 

281 adds extra time to the computation). 

282 impute_finished: Python boolean. If `True`, then states for batch 

283 entries which are marked as finished get copied through and the 

284 corresponding outputs get zeroed out. This causes some slowdown at 

285 each time step, but ensures that the final state and outputs have 

286 the correct values and that backprop ignores time steps that were 

287 marked as finished. 

288 maximum_iterations: A strictly positive `int32` scalar, the maximum 

289 allowed number of decoding steps. Default is `None` (decode until the 

290 decoder is fully done). 

291 parallel_iterations: Argument passed to `tf.while_loop`. 

292 swap_memory: Argument passed to `tf.while_loop`. 

293 training: Python boolean. Indicates whether the layer should behave 

294 in training mode or in inference mode. Only relevant 

295 when `dropout` or `recurrent_dropout` is used. 

296 scope: Optional name scope to use. 

297 enable_tflite_convertible: Python boolean. If `True`, then the variables 

298 of `TensorArray` become of 1-D static shape. Also zero pads in the 

299 output tensor will be discarded. Default: `False`. 

300 **kwargs: dict, other keyword arguments for dynamic_decode. It might 

301 contain arguments for `BaseDecoder` to initialize, which takes all 

302 tensor inputs during call(). 

303 

304 Returns: 

305 `(final_outputs, final_state, final_sequence_lengths)`. 

306 

307 Raises: 

308 ValueError: if `maximum_iterations` is provided but is not a scalar. 

309 """ 

310 with tf.name_scope(scope or "decoder"): 

311 is_xla = ( 

312 not tf.executing_eagerly() 

313 and control_flow_util.GraphOrParentsInXlaContext( 

314 tf.compat.v1.get_default_graph() 

315 ) 

316 ) 

317 

318 if maximum_iterations is not None: 

319 maximum_iterations = tf.convert_to_tensor( 

320 maximum_iterations, dtype=tf.int32, name="maximum_iterations" 

321 ) 

322 if maximum_iterations.shape.ndims != 0: 

323 raise ValueError("maximum_iterations must be a scalar") 

324 tf.debugging.assert_greater( 

325 maximum_iterations, 

326 0, 

327 message="maximum_iterations should be greater than 0", 

328 ) 

329 elif is_xla: 

330 raise ValueError("maximum_iterations is required for XLA compilation.") 

331 

332 if isinstance(decoder, Decoder): 

333 initial_finished, initial_inputs, initial_state = decoder.initialize() 

334 else: 

335 # For BaseDecoder that takes tensor inputs during call. 

336 decoder_init_input = kwargs.pop("decoder_init_input", None) 

337 decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {}) 

338 initial_finished, initial_inputs, initial_state = decoder.initialize( 

339 decoder_init_input, **decoder_init_kwargs 

340 ) 

341 

342 if enable_tflite_convertible: 

343 # Assume the batch_size = 1 for inference. 

344 # So we can change 2-D TensorArray into 1-D by reshaping it. 

345 tf.debugging.assert_equal( 

346 decoder.batch_size, 

347 1, 

348 message="TFLite conversion requires a batch size of 1", 

349 ) 

350 zero_outputs = tf.nest.map_structure( 

351 lambda shape, dtype: tf.reshape( 

352 tf.zeros(_prepend_batch(decoder.batch_size, shape), dtype=dtype), 

353 [-1], 

354 ), 

355 decoder.output_size, 

356 decoder.output_dtype, 

357 ) 

358 else: 

359 zero_outputs = tf.nest.map_structure( 

360 lambda shape, dtype: tf.zeros( 

361 _prepend_batch(decoder.batch_size, shape), dtype=dtype 

362 ), 

363 decoder.output_size, 

364 decoder.output_dtype, 

365 ) 

366 

367 if maximum_iterations is not None: 

368 initial_finished = tf.logical_or(initial_finished, 0 >= maximum_iterations) 

369 initial_sequence_lengths = tf.zeros_like(initial_finished, dtype=tf.int32) 

370 initial_time = tf.constant(0, dtype=tf.int32) 

371 

372 def _shape(batch_size, from_shape): 

373 if not isinstance(from_shape, tf.TensorShape) or from_shape.ndims == 0: 

374 return None 

375 else: 

376 batch_size = tf.get_static_value( 

377 tf.convert_to_tensor(batch_size, name="batch_size") 

378 ) 

379 return tf.TensorShape([batch_size]).concatenate(from_shape) 

380 

381 dynamic_size = maximum_iterations is None or not is_xla 

382 # The dynamic shape `TensorArray` is not allowed in TFLite yet. 

383 dynamic_size = dynamic_size and (not enable_tflite_convertible) 

384 

385 def _create_ta(s, d): 

386 if enable_tflite_convertible: 

387 # TFLite requires 1D element_shape. 

388 if isinstance(s, tf.TensorShape) and s.ndims == 0: 

389 s = (1,) 

390 element_shape = s 

391 else: 

392 element_shape = _shape(decoder.batch_size, s) 

393 return tf.TensorArray( 

394 dtype=d, 

395 size=0 if dynamic_size else maximum_iterations, 

396 dynamic_size=dynamic_size, 

397 element_shape=element_shape, 

398 ) 

399 

400 initial_outputs_ta = tf.nest.map_structure( 

401 _create_ta, decoder.output_size, decoder.output_dtype 

402 ) 

403 

404 def condition( 

405 unused_time, 

406 unused_outputs_ta, 

407 unused_state, 

408 unused_inputs, 

409 finished, 

410 unused_sequence_lengths, 

411 ): 

412 return tf.logical_not(tf.reduce_all(finished)) 

413 

414 def body(time, outputs_ta, state, inputs, finished, sequence_lengths): 

415 """Internal while_loop body. 

416 

417 Args: 

418 time: scalar int32 tensor. 

419 outputs_ta: structure of TensorArray. 

420 state: (structure of) state tensors and TensorArrays. 

421 inputs: (structure of) input tensors. 

422 finished: bool tensor (keeping track of what's finished). 

423 sequence_lengths: int32 tensor (keeping track of time of finish). 

424 

425 Returns: 

426 `(time + 1, outputs_ta, next_state, next_inputs, next_finished, 

427 next_sequence_lengths)`. 

428 ``` 

429 """ 

430 (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step( 

431 time, inputs, state, training 

432 ) 

433 decoder_state_sequence_lengths = False 

434 if decoder.tracks_own_finished: 

435 next_finished = decoder_finished 

436 lengths = getattr(decoder_state, "lengths", None) 

437 if lengths is not None: 

438 # sequence lengths are provided by decoder_state.lengths; 

439 # overwrite our sequence lengths. 

440 decoder_state_sequence_lengths = True 

441 sequence_lengths = tf.cast(lengths, tf.int32) 

442 else: 

443 next_finished = tf.logical_or(decoder_finished, finished) 

444 

445 if decoder_state_sequence_lengths: 

446 # Just pass something through the loop; at the next iteration 

447 # we'll pull the sequence lengths from the decoder_state again. 

448 next_sequence_lengths = sequence_lengths 

449 else: 

450 next_sequence_lengths = tf.where( 

451 tf.logical_not(finished), 

452 tf.fill(tf.shape(sequence_lengths), time + 1), 

453 sequence_lengths, 

454 ) 

455 

456 tf.nest.assert_same_structure(state, decoder_state) 

457 tf.nest.assert_same_structure(outputs_ta, next_outputs) 

458 tf.nest.assert_same_structure(inputs, next_inputs) 

459 

460 # Zero out output values past finish 

461 if impute_finished: 

462 

463 def zero_out_finished(out, zero): 

464 if finished.shape.rank < zero.shape.rank: 

465 broadcast_finished = tf.broadcast_to( 

466 tf.expand_dims(finished, axis=-1), zero.shape 

467 ) 

468 return tf.where(broadcast_finished, zero, out) 

469 else: 

470 return tf.where(finished, zero, out) 

471 

472 emit = tf.nest.map_structure( 

473 zero_out_finished, next_outputs, zero_outputs 

474 ) 

475 else: 

476 emit = next_outputs 

477 

478 # Copy through states past finish 

479 def _maybe_copy_state(new, cur): 

480 # TensorArrays and scalar states get passed through. 

481 if isinstance(cur, tf.TensorArray): 

482 pass_through = True 

483 else: 

484 new.set_shape(cur.shape) 

485 pass_through = new.shape.ndims == 0 

486 if not pass_through: 

487 broadcast_finished = tf.broadcast_to( 

488 tf.expand_dims(finished, axis=-1), new.shape 

489 ) 

490 return tf.where(broadcast_finished, cur, new) 

491 else: 

492 return new 

493 

494 if impute_finished: 

495 next_state = tf.nest.map_structure( 

496 _maybe_copy_state, decoder_state, state 

497 ) 

498 else: 

499 next_state = decoder_state 

500 

501 if enable_tflite_convertible: 

502 # Reshape to 1-D. 

503 emit = tf.nest.map_structure(lambda x: tf.reshape(x, [-1]), emit) 

504 

505 outputs_ta = tf.nest.map_structure( 

506 lambda ta, out: ta.write(time, out), outputs_ta, emit 

507 ) 

508 return ( 

509 time + 1, 

510 outputs_ta, 

511 next_state, 

512 next_inputs, 

513 next_finished, 

514 next_sequence_lengths, 

515 ) 

516 

517 res = tf.while_loop( 

518 condition, 

519 body, 

520 loop_vars=( 

521 initial_time, 

522 initial_outputs_ta, 

523 initial_state, 

524 initial_inputs, 

525 initial_finished, 

526 initial_sequence_lengths, 

527 ), 

528 parallel_iterations=parallel_iterations, 

529 maximum_iterations=maximum_iterations, 

530 swap_memory=swap_memory, 

531 ) 

532 

533 final_outputs_ta = res[1] 

534 final_state = res[2] 

535 final_sequence_lengths = res[5] 

536 

537 final_outputs = tf.nest.map_structure(lambda ta: ta.stack(), final_outputs_ta) 

538 

539 try: 

540 final_outputs, final_state = decoder.finalize( 

541 final_outputs, final_state, final_sequence_lengths 

542 ) 

543 except NotImplementedError: 

544 pass 

545 

546 if not output_time_major: 

547 if enable_tflite_convertible: 

548 # Reshape the output to the original shape. 

549 def _restore_batch(x): 

550 return tf.expand_dims(x, [1]) 

551 

552 final_outputs = tf.nest.map_structure(_restore_batch, final_outputs) 

553 

554 final_outputs = tf.nest.map_structure(_transpose_batch_time, final_outputs) 

555 

556 return final_outputs, final_state, final_sequence_lengths 

557 

558 

559def _prepend_batch(batch_size, shape): 

560 """Prepends the batch dimension to the shape. 

561 

562 If the batch_size value is known statically, this function returns a 

563 TensorShape, otherwise a Tensor. 

564 """ 

565 if isinstance(batch_size, tf.Tensor): 

566 static_batch_size = tf.get_static_value(batch_size) 

567 else: 

568 static_batch_size = batch_size 

569 if static_batch_size is None: 

570 return tf.concat(([batch_size], shape), axis=0) 

571 return [static_batch_size] + shape 

572 

573 

574def _transpose_batch_time(tensor): 

575 """Transposes the batch and time dimension of tensor if its rank is at 

576 least 2.""" 

577 shape = tensor.shape 

578 if shape.rank is not None and shape.rank < 2: 

579 return tensor 

580 perm = tf.concat(([1, 0], tf.range(2, tf.rank(tensor))), axis=0) 

581 return tf.transpose(tensor, perm)