Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/seq2seq/sampler.py: 27%

335 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Objects sampling from the decoder output distribution and producing the next input.""" 

16 

17import abc 

18 

19import tensorflow as tf 

20from tensorflow_addons.seq2seq import decoder 

21from tensorflow_addons.utils.types import Initializer, TensorLike 

22from typeguard import typechecked 

23from typing import Callable, Optional 

24from tensorflow_addons.utils import types 

25 

26_transpose_batch_time = decoder._transpose_batch_time 

27 

28 

29class Sampler(metaclass=abc.ABCMeta): 

30 """Interface for implementing sampling in seq2seq decoders. 

31 

32 Sampler classes implement the logic of sampling from the decoder output distribution 

33 and producing the inputs for the next decoding step. In most cases, they should not be 

34 used directly but passed to a `tfa.seq2seq.BasicDecoder` instance that will manage the 

35 sampling. 

36 

37 Here is an example using a training sampler directly to implement a custom decoding 

38 loop: 

39 

40 >>> batch_size = 4 

41 >>> max_time = 7 

42 >>> hidden_size = 16 

43 >>> 

44 >>> sampler = tfa.seq2seq.TrainingSampler() 

45 >>> cell = tf.keras.layers.LSTMCell(hidden_size) 

46 >>> 

47 >>> input_tensors = tf.random.uniform([batch_size, max_time, hidden_size]) 

48 >>> initial_finished, initial_inputs = sampler.initialize(input_tensors) 

49 >>> 

50 >>> cell_input = initial_inputs 

51 >>> cell_state = cell.get_initial_state(initial_inputs) 

52 >>> 

53 >>> for time_step in tf.range(max_time): 

54 ... cell_output, cell_state = cell(cell_input, cell_state) 

55 ... sample_ids = sampler.sample(time_step, cell_output, cell_state) 

56 ... finished, cell_input, cell_state = sampler.next_inputs( 

57 ... time_step, cell_output, cell_state, sample_ids) 

58 ... if tf.reduce_all(finished): 

59 ... break 

60 """ 

61 

62 @abc.abstractmethod 

63 def initialize(self, inputs, **kwargs): 

64 """initialize the sampler with the input tensors. 

65 

66 This method must be invoked exactly once before calling other 

67 methods of the Sampler. 

68 

69 Args: 

70 inputs: A (structure of) input tensors, it could be a nested tuple or 

71 a single tensor. 

72 **kwargs: Other kwargs for initialization. It could contain tensors 

73 like mask for inputs, or non tensor parameter. 

74 

75 Returns: 

76 `(initial_finished, initial_inputs)`. 

77 """ 

78 pass 

79 

80 @abc.abstractmethod 

81 def sample(self, time, outputs, state): 

82 """Returns `sample_ids`.""" 

83 pass 

84 

85 @abc.abstractmethod 

86 def next_inputs(self, time, outputs, state, sample_ids): 

87 """Returns `(finished, next_inputs, next_state)`.""" 

88 pass 

89 

90 @abc.abstractproperty 

91 def batch_size(self): 

92 """Batch size of tensor returned by `sample`. 

93 

94 Returns a scalar int32 tensor. The return value might not 

95 available before the invocation of initialize(), in this case, 

96 ValueError is raised. 

97 """ 

98 raise NotImplementedError("batch_size has not been implemented") 

99 

100 @abc.abstractproperty 

101 def sample_ids_shape(self): 

102 """Shape of tensor returned by `sample`, excluding the batch dimension. 

103 

104 Returns a `TensorShape`. The return value might not available 

105 before the invocation of initialize(). 

106 """ 

107 raise NotImplementedError("sample_ids_shape has not been implemented") 

108 

109 @abc.abstractproperty 

110 def sample_ids_dtype(self): 

111 """DType of tensor returned by `sample`. 

112 

113 Returns a DType. The return value might not available before the 

114 invocation of initialize(). 

115 """ 

116 raise NotImplementedError("sample_ids_dtype has not been implemented") 

117 

118 

119class CustomSampler(Sampler): 

120 """Base abstract class that allows the user to customize sampling.""" 

121 

122 @typechecked 

123 def __init__( 

124 self, 

125 initialize_fn: Initializer, 

126 sample_fn: Callable, 

127 next_inputs_fn: Callable, 

128 sample_ids_shape: Optional[TensorLike] = None, 

129 sample_ids_dtype: types.AcceptableDTypes = None, 

130 ): 

131 """Initializer. 

132 

133 Args: 

134 initialize_fn: callable that returns `(finished, next_inputs)` for 

135 the first iteration. 

136 sample_fn: callable that takes `(time, outputs, state)` and emits 

137 tensor `sample_ids`. 

138 next_inputs_fn: callable that takes 

139 `(time, outputs, state, sample_ids)` and emits 

140 `(finished, next_inputs, next_state)`. 

141 sample_ids_shape: Either a list of integers, or a 1-D Tensor of type 

142 `int32`, the shape of each value in the `sample_ids` batch. 

143 Defaults to a scalar. 

144 sample_ids_dtype: The dtype of the `sample_ids` tensor. Defaults to 

145 int32. 

146 """ 

147 self._initialize_fn = initialize_fn 

148 self._sample_fn = sample_fn 

149 self._next_inputs_fn = next_inputs_fn 

150 self._batch_size = None 

151 self._sample_ids_shape = tf.TensorShape(sample_ids_shape or []) 

152 self._sample_ids_dtype = sample_ids_dtype or tf.int32 

153 

154 @property 

155 def batch_size(self): 

156 if self._batch_size is None: 

157 raise ValueError("batch_size accessed before initialize was called") 

158 return self._batch_size 

159 

160 @property 

161 def sample_ids_shape(self): 

162 return self._sample_ids_shape 

163 

164 @property 

165 def sample_ids_dtype(self): 

166 return self._sample_ids_dtype 

167 

168 def initialize(self, inputs, **kwargs): 

169 (finished, next_inputs) = self._initialize_fn(inputs, **kwargs) 

170 if self._batch_size is None: 

171 self._batch_size = tf.size(finished) 

172 return (finished, next_inputs) 

173 

174 def sample(self, time, outputs, state): 

175 return self._sample_fn(time=time, outputs=outputs, state=state) 

176 

177 def next_inputs(self, time, outputs, state, sample_ids): 

178 return self._next_inputs_fn( 

179 time=time, outputs=outputs, state=state, sample_ids=sample_ids 

180 ) 

181 

182 

183class TrainingSampler(Sampler): 

184 """A training sampler that simply reads its inputs. 

185 

186 Returned sample_ids are the argmax of the RNN output logits. 

187 """ 

188 

189 @typechecked 

190 def __init__(self, time_major: bool = False): 

191 """Initializer. 

192 

193 Args: 

194 time_major: Python bool. Whether the tensors in `inputs` are time 

195 major. If `False` (default), they are assumed to be batch major. 

196 

197 Raises: 

198 ValueError: if `sequence_length` is not a 1D tensor or `mask` is 

199 not a 2D boolean tensor. 

200 """ 

201 self.time_major = time_major 

202 self._batch_size = None 

203 

204 @property 

205 def batch_size(self): 

206 if self._batch_size is None: 

207 raise ValueError("batch_size accessed before initialize was called") 

208 return self._batch_size 

209 

210 @property 

211 def sample_ids_shape(self): 

212 return tf.TensorShape([]) 

213 

214 @property 

215 def sample_ids_dtype(self): 

216 return tf.int32 

217 

218 def initialize(self, inputs, sequence_length=None, mask=None): 

219 """Initialize the TrainSampler. 

220 

221 Args: 

222 inputs: A (structure of) input tensors. 

223 sequence_length: An int32 vector tensor. 

224 mask: A boolean 2D tensor. 

225 

226 Returns: 

227 (finished, next_inputs), a tuple of two items. The first item is a 

228 boolean vector to indicate whether the item in the batch has 

229 finished. The second item is the first slide of input data based on 

230 the timestep dimension (usually the second dim of the input). 

231 """ 

232 self.inputs = tf.convert_to_tensor(inputs, name="inputs") 

233 if not self.time_major: 

234 inputs = tf.nest.map_structure(_transpose_batch_time, inputs) 

235 

236 self._batch_size = tf.shape(tf.nest.flatten(inputs)[0])[1] 

237 

238 self.input_tas = tf.nest.map_structure(_unstack_ta, inputs) 

239 if sequence_length is not None and mask is not None: 

240 raise ValueError( 

241 "sequence_length and mask can't be provided at the same time." 

242 ) 

243 if sequence_length is not None: 

244 self.sequence_length = tf.convert_to_tensor( 

245 sequence_length, name="sequence_length" 

246 ) 

247 if self.sequence_length.shape.ndims != 1: 

248 raise ValueError( 

249 "Expected sequence_length to be vector, but received " 

250 "shape: %s" % self.sequence_length.shape 

251 ) 

252 elif mask is not None: 

253 mask = tf.convert_to_tensor(mask) 

254 if mask.shape.ndims != 2: 

255 raise ValueError( 

256 "Expected mask to a 2D tensor, but received shape: %s" % mask 

257 ) 

258 if not mask.dtype.is_bool: 

259 raise ValueError( 

260 "Expected mask to be a boolean tensor, but received " 

261 "dtype: %s" % repr(mask.dtype) 

262 ) 

263 

264 axis = 1 if not self.time_major else 0 

265 with tf.control_dependencies( 

266 [_check_sequence_is_right_padded(mask, self.time_major)] 

267 ): 

268 self.sequence_length = tf.math.reduce_sum( 

269 tf.cast(mask, tf.int32), axis=axis, name="sequence_length" 

270 ) 

271 else: 

272 # As the input tensor has been converted to time major, 

273 # the maximum sequence length should be inferred from 

274 # the first dimension. 

275 max_seq_len = tf.shape(tf.nest.flatten(inputs)[0])[0] 

276 self.sequence_length = tf.fill( 

277 [self.batch_size], max_seq_len, name="sequence_length" 

278 ) 

279 

280 self.zero_inputs = tf.nest.map_structure( 

281 lambda inp: tf.zeros_like(inp[0, :]), inputs 

282 ) 

283 

284 finished = tf.equal(0, self.sequence_length) 

285 all_finished = tf.reduce_all(finished) 

286 next_inputs = tf.cond( 

287 all_finished, 

288 lambda: self.zero_inputs, 

289 lambda: tf.nest.map_structure(lambda inp: inp.read(0), self.input_tas), 

290 ) 

291 return (finished, next_inputs) 

292 

293 def sample(self, time, outputs, state): 

294 del state 

295 sample_ids = tf.cast(tf.argmax(outputs, axis=-1), tf.int32) 

296 return sample_ids 

297 

298 def next_inputs(self, time, outputs, state, sample_ids): 

299 del sample_ids 

300 next_time = time + 1 

301 finished = next_time >= self.sequence_length 

302 all_finished = tf.reduce_all(finished) 

303 

304 def read_from_ta(inp): 

305 return inp.read(next_time) 

306 

307 next_inputs = tf.cond( 

308 all_finished, 

309 lambda: self.zero_inputs, 

310 lambda: tf.nest.map_structure(read_from_ta, self.input_tas), 

311 ) 

312 return (finished, next_inputs, state) 

313 

314 

315class ScheduledEmbeddingTrainingSampler(TrainingSampler): 

316 """A training sampler that adds scheduled sampling. 

317 

318 Returns -1s for sample_ids where no sampling took place; valid 

319 sample id values elsewhere. 

320 """ 

321 

322 @typechecked 

323 def __init__( 

324 self, 

325 sampling_probability: TensorLike, 

326 embedding_fn: Optional[Callable] = None, 

327 time_major: bool = False, 

328 seed: Optional[int] = None, 

329 scheduling_seed: Optional[TensorLike] = None, 

330 ): 

331 """Initializer. 

332 

333 Args: 

334 sampling_probability: A `float32` 0-D or 1-D tensor: the probability 

335 of sampling categorically from the output ids instead of reading 

336 directly from the inputs. 

337 embedding_fn: A callable that takes a vector tensor of `ids` 

338 (argmax ids). 

339 time_major: Python bool. Whether the tensors in `inputs` are time 

340 major. If `False` (default), they are assumed to be batch major. 

341 seed: The sampling seed. 

342 scheduling_seed: The schedule decision rule sampling seed. 

343 

344 Raises: 

345 ValueError: if `sampling_probability` is not a scalar or vector. 

346 """ 

347 self.embedding_fn = embedding_fn 

348 if isinstance(sampling_probability, tf.Variable): 

349 self.sampling_probability = sampling_probability 

350 else: 

351 self.sampling_probability = tf.convert_to_tensor( 

352 sampling_probability, name="sampling_probability" 

353 ) 

354 if self.sampling_probability.shape.ndims not in (0, 1): 

355 raise ValueError( 

356 "sampling_probability must be either a scalar or a vector. " 

357 "saw shape: %s" % (self.sampling_probability.shape) 

358 ) 

359 self.seed = seed 

360 self.scheduling_seed = scheduling_seed 

361 super().__init__(time_major=time_major) 

362 

363 def initialize(self, inputs, sequence_length=None, mask=None, embedding=None): 

364 if self.embedding_fn is None: 

365 if embedding is None: 

366 raise ValueError( 

367 "embedding is required as a keyword argument for " 

368 "ScheduledEmbeddingTrainingSampler" 

369 ) 

370 self.embedding_fn = lambda ids: tf.nn.embedding_lookup(embedding, ids) 

371 return super().initialize(inputs, sequence_length=sequence_length, mask=mask) 

372 

373 def sample(self, time, outputs, state): 

374 del state 

375 # Return -1s where we did not sample, and sample_ids elsewhere 

376 select_sample = bernoulli_sample( 

377 probs=self.sampling_probability, 

378 dtype=tf.bool, 

379 sample_shape=self.batch_size, 

380 seed=self.scheduling_seed, 

381 ) 

382 return tf.where( 

383 select_sample, 

384 categorical_sample(logits=outputs, seed=self.seed), 

385 tf.fill([self.batch_size], -1), 

386 ) 

387 

388 def next_inputs(self, time, outputs, state, sample_ids): 

389 (finished, base_next_inputs, state) = super().next_inputs( 

390 time=time, outputs=outputs, state=state, sample_ids=sample_ids 

391 ) 

392 

393 def maybe_sample(): 

394 """Perform scheduled sampling.""" 

395 where_sampling = tf.cast(tf.where(sample_ids > -1), tf.int32) 

396 where_not_sampling = tf.cast(tf.where(sample_ids <= -1), tf.int32) 

397 sample_ids_sampling = tf.gather_nd(sample_ids, where_sampling) 

398 inputs_not_sampling = tf.gather_nd(base_next_inputs, where_not_sampling) 

399 sampled_next_inputs = self.embedding_fn(sample_ids_sampling) 

400 sampled_next_inputs = tf.cast( 

401 sampled_next_inputs, inputs_not_sampling.dtype 

402 ) 

403 base_shape = tf.shape(base_next_inputs) 

404 return tf.scatter_nd( 

405 indices=where_sampling, updates=sampled_next_inputs, shape=base_shape 

406 ) + tf.scatter_nd( 

407 indices=where_not_sampling, 

408 updates=inputs_not_sampling, 

409 shape=base_shape, 

410 ) 

411 

412 all_finished = tf.reduce_all(finished) 

413 next_inputs = tf.cond(all_finished, lambda: base_next_inputs, maybe_sample) 

414 return (finished, next_inputs, state) 

415 

416 

417class ScheduledOutputTrainingSampler(TrainingSampler): 

418 """A training sampler that adds scheduled sampling directly to outputs. 

419 

420 Returns False for sample_ids where no sampling took place; True 

421 elsewhere. 

422 """ 

423 

424 @typechecked 

425 def __init__( 

426 self, 

427 sampling_probability: TensorLike, 

428 time_major: bool = False, 

429 seed: Optional[int] = None, 

430 next_inputs_fn: Optional[Callable] = None, 

431 ): 

432 """Initializer. 

433 

434 Args: 

435 sampling_probability: A `float32` scalar tensor: the probability of 

436 sampling from the outputs instead of reading directly from the 

437 inputs. 

438 time_major: Python bool. Whether the tensors in `inputs` are time 

439 major. If `False` (default), they are assumed to be batch major. 

440 seed: The sampling seed. 

441 next_inputs_fn: (Optional) callable to apply to the RNN outputs to 

442 create the next input when sampling. If `None` (default), the RNN 

443 outputs will be used as the next inputs. 

444 

445 Raises: 

446 ValueError: if `sampling_probability` is not a scalar or vector. 

447 """ 

448 if isinstance(sampling_probability, tf.Variable): 

449 self.sampling_probability = sampling_probability 

450 else: 

451 self.sampling_probability = tf.convert_to_tensor( 

452 sampling_probability, name="sampling_probability" 

453 ) 

454 if self.sampling_probability.shape.ndims not in (0, 1): 

455 raise ValueError( 

456 "sampling_probability must be either a scalar or a vector. " 

457 "saw shape: %s" % (self.sampling_probability.shape) 

458 ) 

459 

460 self.seed = seed 

461 self.next_inputs_fn = next_inputs_fn 

462 

463 super().__init__(time_major=time_major) 

464 

465 def initialize( 

466 self, inputs, sequence_length=None, mask=None, auxiliary_inputs=None 

467 ): 

468 if auxiliary_inputs is None: 

469 maybe_concatenated_inputs = inputs 

470 else: 

471 inputs = tf.convert_to_tensor(inputs) 

472 auxiliary_inputs = tf.convert_to_tensor(auxiliary_inputs) 

473 maybe_concatenated_inputs = tf.nest.map_structure( 

474 lambda x, y: tf.concat((x, y), -1), inputs, auxiliary_inputs 

475 ) 

476 if not self.time_major: 

477 auxiliary_inputs = tf.nest.map_structure( 

478 _transpose_batch_time, auxiliary_inputs 

479 ) 

480 if auxiliary_inputs is not None: 

481 self._auxiliary_input_tas = tf.nest.map_structure( 

482 _unstack_ta, auxiliary_inputs 

483 ) 

484 else: 

485 self._auxiliary_input_tas = None 

486 

487 return super().initialize( 

488 maybe_concatenated_inputs, sequence_length=sequence_length, mask=mask 

489 ) 

490 

491 def sample(self, time, outputs, state): 

492 del state 

493 return bernoulli_sample( 

494 probs=self.sampling_probability, 

495 sample_shape=self.batch_size, 

496 seed=self.seed, 

497 ) 

498 

499 def next_inputs(self, time, outputs, state, sample_ids): 

500 (finished, base_next_inputs, state) = super().next_inputs( 

501 time=time, outputs=outputs, state=state, sample_ids=sample_ids 

502 ) 

503 sample_ids = tf.cast(sample_ids, tf.bool) 

504 

505 def maybe_sample(): 

506 """Perform scheduled sampling.""" 

507 

508 def maybe_concatenate_auxiliary_inputs(outputs_, indices=None): 

509 """Concatenate outputs with auxiliary inputs, if they exist.""" 

510 if self._auxiliary_input_tas is None: 

511 return outputs_ 

512 

513 next_time = time + 1 

514 auxiliary_inputs = tf.nest.map_structure( 

515 lambda ta: ta.read(next_time), self._auxiliary_input_tas 

516 ) 

517 if indices is not None: 

518 auxiliary_inputs = tf.gather_nd(auxiliary_inputs, indices) 

519 return tf.nest.map_structure( 

520 lambda x, y: tf.concat((x, y), -1), outputs_, auxiliary_inputs 

521 ) 

522 

523 if self.next_inputs_fn is None: 

524 return tf.where( 

525 tf.broadcast_to( 

526 tf.expand_dims(sample_ids, axis=-1), base_next_inputs.shape 

527 ), 

528 maybe_concatenate_auxiliary_inputs(outputs), 

529 base_next_inputs, 

530 ) 

531 

532 where_sampling = tf.cast(tf.where(sample_ids), tf.int32) 

533 where_not_sampling = tf.cast(tf.where(tf.logical_not(sample_ids)), tf.int32) 

534 outputs_sampling = tf.gather_nd(outputs, where_sampling) 

535 inputs_not_sampling = tf.gather_nd(base_next_inputs, where_not_sampling) 

536 sampled_next_inputs = maybe_concatenate_auxiliary_inputs( 

537 self.next_inputs_fn(outputs_sampling), where_sampling 

538 ) 

539 

540 base_shape = tf.shape(base_next_inputs) 

541 return tf.scatter_nd( 

542 indices=where_sampling, updates=sampled_next_inputs, shape=base_shape 

543 ) + tf.scatter_nd( 

544 indices=where_not_sampling, 

545 updates=inputs_not_sampling, 

546 shape=base_shape, 

547 ) 

548 

549 all_finished = tf.reduce_all(finished) 

550 no_samples = tf.logical_not(tf.reduce_any(sample_ids)) 

551 next_inputs = tf.cond( 

552 tf.logical_or(all_finished, no_samples), 

553 lambda: base_next_inputs, 

554 maybe_sample, 

555 ) 

556 return (finished, next_inputs, state) 

557 

558 

559class GreedyEmbeddingSampler(Sampler): 

560 """A inference sampler that takes the maximum from the output distribution. 

561 

562 Uses the argmax of the output (treated as logits) and passes the 

563 result through an embedding layer to get the next input. 

564 """ 

565 

566 @typechecked 

567 def __init__(self, embedding_fn: Optional[Callable] = None): 

568 """Initializer. 

569 

570 Args: 

571 embedding_fn: A optional callable that takes a vector tensor of `ids` 

572 (argmax ids). The returned tensor will be passed to the decoder 

573 input. Default to use `tf.nn.embedding_lookup`. 

574 """ 

575 self.embedding_fn = embedding_fn 

576 self._batch_size = None 

577 

578 @property 

579 def batch_size(self): 

580 if self._batch_size is None: 

581 raise ValueError("batch_size accessed before initialize was called") 

582 return self._batch_size 

583 

584 @property 

585 def sample_ids_shape(self): 

586 return tf.TensorShape([]) 

587 

588 @property 

589 def sample_ids_dtype(self): 

590 return tf.int32 

591 

592 def initialize(self, embedding, start_tokens=None, end_token=None): 

593 """Initialize the GreedyEmbeddingSampler. 

594 

595 Args: 

596 embedding: tensor that contains embedding states matrix. It will be 

597 used to generate generate outputs with `start_tokens` and `end_token`. 

598 The embedding will be ignored if the `embedding_fn` has been provided 

599 at __init__(). 

600 start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. 

601 end_token: `int32` scalar, the token that marks end of decoding. 

602 

603 Returns: 

604 Tuple of two items: `(finished, self.start_inputs)`. 

605 Raises: 

606 ValueError: if `start_tokens` is not a 1D tensor or `end_token` is 

607 not a scalar. 

608 """ 

609 if self.embedding_fn is None: 

610 self.embedding_fn = lambda ids: tf.nn.embedding_lookup(embedding, ids) 

611 

612 self.start_tokens = tf.convert_to_tensor( 

613 start_tokens, dtype=tf.int32, name="start_tokens" 

614 ) 

615 self.end_token = tf.convert_to_tensor( 

616 end_token, dtype=tf.int32, name="end_token" 

617 ) 

618 if self.start_tokens.shape.ndims != 1: 

619 raise ValueError("start_tokens must be a vector") 

620 self._batch_size = tf.size(start_tokens) 

621 if self.end_token.shape.ndims != 0: 

622 raise ValueError("end_token must be a scalar") 

623 self.start_inputs = self.embedding_fn(self.start_tokens) 

624 

625 finished = tf.tile([False], [self._batch_size]) 

626 return (finished, self.start_inputs) 

627 

628 def sample(self, time, outputs, state): 

629 """sample for GreedyEmbeddingHelper.""" 

630 del time, state # unused by sample_fn 

631 # Outputs are logits, use argmax to get the most probable id 

632 if not isinstance(outputs, tf.Tensor): 

633 raise TypeError( 

634 "Expected outputs to be a single Tensor, got: %s" % type(outputs) 

635 ) 

636 sample_ids = tf.argmax(outputs, axis=-1, output_type=tf.int32) 

637 return sample_ids 

638 

639 def next_inputs(self, time, outputs, state, sample_ids): 

640 """next_inputs_fn for GreedyEmbeddingHelper.""" 

641 del time, outputs # unused by next_inputs_fn 

642 finished = tf.equal(sample_ids, self.end_token) 

643 all_finished = tf.reduce_all(finished) 

644 next_inputs = tf.cond( 

645 all_finished, 

646 # If we're finished, the next_inputs value doesn't matter 

647 lambda: self.start_inputs, 

648 lambda: self.embedding_fn(sample_ids), 

649 ) 

650 return (finished, next_inputs, state) 

651 

652 

653class SampleEmbeddingSampler(GreedyEmbeddingSampler): 

654 """An inference sampler that randomly samples from the output distribution. 

655 

656 Uses sampling (from a distribution) instead of argmax and passes the 

657 result through an embedding layer to get the next input. 

658 """ 

659 

660 @typechecked 

661 def __init__( 

662 self, 

663 embedding_fn: Optional[Callable] = None, 

664 softmax_temperature: Optional[TensorLike] = None, 

665 seed: Optional[TensorLike] = None, 

666 ): 

667 """Initializer. 

668 

669 Args: 

670 embedding_fn: (Optional) A callable that takes a vector tensor of 

671 `ids` (argmax ids). The returned tensor will be passed to the 

672 decoder input. 

673 softmax_temperature: (Optional) `float32` scalar, value to divide the 

674 logits by before computing the softmax. Larger values (above 1.0) 

675 result in more random samples, while smaller values push the 

676 sampling distribution towards the argmax. Must be strictly greater 

677 than 0. Defaults to 1.0. 

678 seed: (Optional) The sampling seed. 

679 

680 Raises: 

681 ValueError: if `start_tokens` is not a 1D tensor or `end_token` is 

682 not a scalar. 

683 """ 

684 super().__init__(embedding_fn) 

685 self.softmax_temperature = softmax_temperature 

686 self.seed = seed 

687 

688 def sample(self, time, outputs, state): 

689 """sample for SampleEmbeddingHelper.""" 

690 del time, state # unused by sample_fn 

691 # Outputs are logits, we sample instead of argmax (greedy). 

692 if not isinstance(outputs, tf.Tensor): 

693 raise TypeError( 

694 "Expected outputs to be a single Tensor, got: %s" % type(outputs) 

695 ) 

696 if self.softmax_temperature is None: 

697 logits = outputs 

698 else: 

699 logits = outputs / self.softmax_temperature 

700 

701 return categorical_sample(logits=logits, seed=self.seed) 

702 

703 

704class InferenceSampler(Sampler): 

705 """An inference sampler that uses a custom sampling function.""" 

706 

707 @typechecked 

708 def __init__( 

709 self, 

710 sample_fn: Callable, 

711 sample_shape: TensorLike, 

712 sample_dtype: types.AcceptableDTypes, 

713 end_fn: Callable, 

714 next_inputs_fn: Optional[Callable] = None, 

715 ): 

716 """Initializer. 

717 

718 Args: 

719 sample_fn: A callable that takes `outputs` and emits tensor 

720 `sample_ids`. 

721 sample_shape: Either a list of integers, or a 1-D Tensor of type 

722 `int32`, the shape of the each sample in the batch returned by 

723 `sample_fn`. 

724 sample_dtype: the dtype of the sample returned by `sample_fn`. 

725 end_fn: A callable that takes `sample_ids` and emits a `bool` vector 

726 shaped `[batch_size]` indicating whether each sample is an end 

727 token. 

728 next_inputs_fn: (Optional) A callable that takes `sample_ids` and 

729 returns the next batch of inputs. If not provided, `sample_ids` is 

730 used as the next batch of inputs. 

731 """ 

732 self.sample_fn = sample_fn 

733 self.sample_shape = tf.TensorShape(sample_shape) 

734 self.sample_dtype = sample_dtype 

735 self.end_fn = end_fn 

736 self.next_inputs_fn = next_inputs_fn 

737 self._batch_size = None 

738 

739 @property 

740 def batch_size(self): 

741 if self._batch_size is None: 

742 raise ValueError("batch_size accessed before initialize was called") 

743 return self._batch_size 

744 

745 @property 

746 def sample_ids_shape(self): 

747 return self.sample_shape 

748 

749 @property 

750 def sample_ids_dtype(self): 

751 return self.sample_dtype 

752 

753 def initialize(self, start_inputs): 

754 self.start_inputs = tf.convert_to_tensor(start_inputs, name="start_inputs") 

755 self._batch_size = tf.shape(start_inputs)[0] 

756 finished = tf.tile([False], [self._batch_size]) 

757 return (finished, self.start_inputs) 

758 

759 def sample(self, time, outputs, state): 

760 del time, state # unused by sample 

761 return self.sample_fn(outputs) 

762 

763 def next_inputs(self, time, outputs, state, sample_ids): 

764 del time, outputs # unused by next_inputs 

765 if self.next_inputs_fn is None: 

766 next_inputs = sample_ids 

767 else: 

768 next_inputs = self.next_inputs_fn(sample_ids) 

769 finished = self.end_fn(sample_ids) 

770 return (finished, next_inputs, state) 

771 

772 

773# The following sample functions (_call_sampler, bernoulli_sample, 

774# categorical_sample) mimic TensorFlow Probability distribution semantics. 

775def _call_sampler(sample_n_fn, sample_shape, name=None): 

776 """Reshapes vector of samples.""" 

777 with tf.name_scope(name or "call_sampler"): 

778 sample_shape = tf.convert_to_tensor( 

779 sample_shape, dtype=tf.int32, name="sample_shape" 

780 ) 

781 # Ensure sample_shape is a vector (vs just a scalar). 

782 pad = tf.cast(tf.equal(tf.rank(sample_shape), 0), tf.int32) 

783 sample_shape = tf.reshape( 

784 sample_shape, 

785 tf.pad(tf.shape(sample_shape), paddings=[[pad, 0]], constant_values=1), 

786 ) 

787 samples = sample_n_fn(tf.reduce_prod(sample_shape)) 

788 batch_event_shape = tf.shape(samples)[1:] 

789 final_shape = tf.concat([sample_shape, batch_event_shape], 0) 

790 return tf.reshape(samples, final_shape) 

791 

792 

793def bernoulli_sample( 

794 probs=None, logits=None, dtype=tf.int32, sample_shape=(), seed=None 

795): 

796 """Samples from Bernoulli distribution.""" 

797 if probs is None: 

798 probs = tf.sigmoid(logits, name="probs") 

799 else: 

800 probs = tf.convert_to_tensor(probs, name="probs") 

801 batch_shape_tensor = tf.shape(probs) 

802 

803 def _sample_n(n): 

804 """Sample vector of Bernoullis.""" 

805 new_shape = tf.concat([[n], batch_shape_tensor], 0) 

806 uniform = tf.random.uniform(new_shape, seed=seed, dtype=probs.dtype) 

807 return tf.cast(tf.less(uniform, probs), dtype) 

808 

809 return _call_sampler(_sample_n, sample_shape) 

810 

811 

812def categorical_sample(logits, dtype=tf.int32, sample_shape=(), seed=None): 

813 """Samples from categorical distribution.""" 

814 logits = tf.convert_to_tensor(logits, name="logits") 

815 event_size = tf.shape(logits)[-1] 

816 batch_shape_tensor = tf.shape(logits)[:-1] 

817 

818 def _sample_n(n): 

819 """Sample vector of categoricals.""" 

820 if logits.shape.ndims == 2: 

821 logits_2d = logits 

822 else: 

823 logits_2d = tf.reshape(logits, [-1, event_size]) 

824 sample_dtype = tf.int64 if logits.dtype.size > 4 else tf.int32 

825 draws = tf.random.categorical(logits_2d, n, dtype=sample_dtype, seed=seed) 

826 draws = tf.reshape(tf.transpose(draws), tf.concat([[n], batch_shape_tensor], 0)) 

827 return tf.cast(draws, dtype) 

828 

829 return _call_sampler(_sample_n, sample_shape) 

830 

831 

832def _unstack_ta(inp): 

833 return tf.TensorArray( 

834 dtype=inp.dtype, size=tf.shape(inp)[0], element_shape=inp.shape[1:] 

835 ).unstack(inp) 

836 

837 

838def _check_sequence_is_right_padded(mask, time_major): 

839 """Returns an Assert operation checking that if the mask tensor is right 

840 padded.""" 

841 if time_major: 

842 mask = tf.transpose(mask) 

843 sequence_length = tf.math.reduce_sum(tf.cast(mask, tf.int32), axis=1) 

844 max_seq_length = tf.shape(mask)[1] 

845 right_padded_mask = tf.sequence_mask( 

846 sequence_length, maxlen=max_seq_length, dtype=tf.bool 

847 ) 

848 all_equal = tf.math.equal(mask, right_padded_mask) 

849 

850 condition = tf.math.reduce_all(all_equal) 

851 error_message = "The input sequence should be right padded." 

852 

853 return tf.Assert(condition, [error_message])