Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/seq2seq/sampler.py: 27%
335 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Objects sampling from the decoder output distribution and producing the next input."""
17import abc
19import tensorflow as tf
20from tensorflow_addons.seq2seq import decoder
21from tensorflow_addons.utils.types import Initializer, TensorLike
22from typeguard import typechecked
23from typing import Callable, Optional
24from tensorflow_addons.utils import types
26_transpose_batch_time = decoder._transpose_batch_time
29class Sampler(metaclass=abc.ABCMeta):
30 """Interface for implementing sampling in seq2seq decoders.
32 Sampler classes implement the logic of sampling from the decoder output distribution
33 and producing the inputs for the next decoding step. In most cases, they should not be
34 used directly but passed to a `tfa.seq2seq.BasicDecoder` instance that will manage the
35 sampling.
37 Here is an example using a training sampler directly to implement a custom decoding
38 loop:
40 >>> batch_size = 4
41 >>> max_time = 7
42 >>> hidden_size = 16
43 >>>
44 >>> sampler = tfa.seq2seq.TrainingSampler()
45 >>> cell = tf.keras.layers.LSTMCell(hidden_size)
46 >>>
47 >>> input_tensors = tf.random.uniform([batch_size, max_time, hidden_size])
48 >>> initial_finished, initial_inputs = sampler.initialize(input_tensors)
49 >>>
50 >>> cell_input = initial_inputs
51 >>> cell_state = cell.get_initial_state(initial_inputs)
52 >>>
53 >>> for time_step in tf.range(max_time):
54 ... cell_output, cell_state = cell(cell_input, cell_state)
55 ... sample_ids = sampler.sample(time_step, cell_output, cell_state)
56 ... finished, cell_input, cell_state = sampler.next_inputs(
57 ... time_step, cell_output, cell_state, sample_ids)
58 ... if tf.reduce_all(finished):
59 ... break
60 """
62 @abc.abstractmethod
63 def initialize(self, inputs, **kwargs):
64 """initialize the sampler with the input tensors.
66 This method must be invoked exactly once before calling other
67 methods of the Sampler.
69 Args:
70 inputs: A (structure of) input tensors, it could be a nested tuple or
71 a single tensor.
72 **kwargs: Other kwargs for initialization. It could contain tensors
73 like mask for inputs, or non tensor parameter.
75 Returns:
76 `(initial_finished, initial_inputs)`.
77 """
78 pass
80 @abc.abstractmethod
81 def sample(self, time, outputs, state):
82 """Returns `sample_ids`."""
83 pass
85 @abc.abstractmethod
86 def next_inputs(self, time, outputs, state, sample_ids):
87 """Returns `(finished, next_inputs, next_state)`."""
88 pass
90 @abc.abstractproperty
91 def batch_size(self):
92 """Batch size of tensor returned by `sample`.
94 Returns a scalar int32 tensor. The return value might not
95 available before the invocation of initialize(), in this case,
96 ValueError is raised.
97 """
98 raise NotImplementedError("batch_size has not been implemented")
100 @abc.abstractproperty
101 def sample_ids_shape(self):
102 """Shape of tensor returned by `sample`, excluding the batch dimension.
104 Returns a `TensorShape`. The return value might not available
105 before the invocation of initialize().
106 """
107 raise NotImplementedError("sample_ids_shape has not been implemented")
109 @abc.abstractproperty
110 def sample_ids_dtype(self):
111 """DType of tensor returned by `sample`.
113 Returns a DType. The return value might not available before the
114 invocation of initialize().
115 """
116 raise NotImplementedError("sample_ids_dtype has not been implemented")
119class CustomSampler(Sampler):
120 """Base abstract class that allows the user to customize sampling."""
122 @typechecked
123 def __init__(
124 self,
125 initialize_fn: Initializer,
126 sample_fn: Callable,
127 next_inputs_fn: Callable,
128 sample_ids_shape: Optional[TensorLike] = None,
129 sample_ids_dtype: types.AcceptableDTypes = None,
130 ):
131 """Initializer.
133 Args:
134 initialize_fn: callable that returns `(finished, next_inputs)` for
135 the first iteration.
136 sample_fn: callable that takes `(time, outputs, state)` and emits
137 tensor `sample_ids`.
138 next_inputs_fn: callable that takes
139 `(time, outputs, state, sample_ids)` and emits
140 `(finished, next_inputs, next_state)`.
141 sample_ids_shape: Either a list of integers, or a 1-D Tensor of type
142 `int32`, the shape of each value in the `sample_ids` batch.
143 Defaults to a scalar.
144 sample_ids_dtype: The dtype of the `sample_ids` tensor. Defaults to
145 int32.
146 """
147 self._initialize_fn = initialize_fn
148 self._sample_fn = sample_fn
149 self._next_inputs_fn = next_inputs_fn
150 self._batch_size = None
151 self._sample_ids_shape = tf.TensorShape(sample_ids_shape or [])
152 self._sample_ids_dtype = sample_ids_dtype or tf.int32
154 @property
155 def batch_size(self):
156 if self._batch_size is None:
157 raise ValueError("batch_size accessed before initialize was called")
158 return self._batch_size
160 @property
161 def sample_ids_shape(self):
162 return self._sample_ids_shape
164 @property
165 def sample_ids_dtype(self):
166 return self._sample_ids_dtype
168 def initialize(self, inputs, **kwargs):
169 (finished, next_inputs) = self._initialize_fn(inputs, **kwargs)
170 if self._batch_size is None:
171 self._batch_size = tf.size(finished)
172 return (finished, next_inputs)
174 def sample(self, time, outputs, state):
175 return self._sample_fn(time=time, outputs=outputs, state=state)
177 def next_inputs(self, time, outputs, state, sample_ids):
178 return self._next_inputs_fn(
179 time=time, outputs=outputs, state=state, sample_ids=sample_ids
180 )
183class TrainingSampler(Sampler):
184 """A training sampler that simply reads its inputs.
186 Returned sample_ids are the argmax of the RNN output logits.
187 """
189 @typechecked
190 def __init__(self, time_major: bool = False):
191 """Initializer.
193 Args:
194 time_major: Python bool. Whether the tensors in `inputs` are time
195 major. If `False` (default), they are assumed to be batch major.
197 Raises:
198 ValueError: if `sequence_length` is not a 1D tensor or `mask` is
199 not a 2D boolean tensor.
200 """
201 self.time_major = time_major
202 self._batch_size = None
204 @property
205 def batch_size(self):
206 if self._batch_size is None:
207 raise ValueError("batch_size accessed before initialize was called")
208 return self._batch_size
210 @property
211 def sample_ids_shape(self):
212 return tf.TensorShape([])
214 @property
215 def sample_ids_dtype(self):
216 return tf.int32
218 def initialize(self, inputs, sequence_length=None, mask=None):
219 """Initialize the TrainSampler.
221 Args:
222 inputs: A (structure of) input tensors.
223 sequence_length: An int32 vector tensor.
224 mask: A boolean 2D tensor.
226 Returns:
227 (finished, next_inputs), a tuple of two items. The first item is a
228 boolean vector to indicate whether the item in the batch has
229 finished. The second item is the first slide of input data based on
230 the timestep dimension (usually the second dim of the input).
231 """
232 self.inputs = tf.convert_to_tensor(inputs, name="inputs")
233 if not self.time_major:
234 inputs = tf.nest.map_structure(_transpose_batch_time, inputs)
236 self._batch_size = tf.shape(tf.nest.flatten(inputs)[0])[1]
238 self.input_tas = tf.nest.map_structure(_unstack_ta, inputs)
239 if sequence_length is not None and mask is not None:
240 raise ValueError(
241 "sequence_length and mask can't be provided at the same time."
242 )
243 if sequence_length is not None:
244 self.sequence_length = tf.convert_to_tensor(
245 sequence_length, name="sequence_length"
246 )
247 if self.sequence_length.shape.ndims != 1:
248 raise ValueError(
249 "Expected sequence_length to be vector, but received "
250 "shape: %s" % self.sequence_length.shape
251 )
252 elif mask is not None:
253 mask = tf.convert_to_tensor(mask)
254 if mask.shape.ndims != 2:
255 raise ValueError(
256 "Expected mask to a 2D tensor, but received shape: %s" % mask
257 )
258 if not mask.dtype.is_bool:
259 raise ValueError(
260 "Expected mask to be a boolean tensor, but received "
261 "dtype: %s" % repr(mask.dtype)
262 )
264 axis = 1 if not self.time_major else 0
265 with tf.control_dependencies(
266 [_check_sequence_is_right_padded(mask, self.time_major)]
267 ):
268 self.sequence_length = tf.math.reduce_sum(
269 tf.cast(mask, tf.int32), axis=axis, name="sequence_length"
270 )
271 else:
272 # As the input tensor has been converted to time major,
273 # the maximum sequence length should be inferred from
274 # the first dimension.
275 max_seq_len = tf.shape(tf.nest.flatten(inputs)[0])[0]
276 self.sequence_length = tf.fill(
277 [self.batch_size], max_seq_len, name="sequence_length"
278 )
280 self.zero_inputs = tf.nest.map_structure(
281 lambda inp: tf.zeros_like(inp[0, :]), inputs
282 )
284 finished = tf.equal(0, self.sequence_length)
285 all_finished = tf.reduce_all(finished)
286 next_inputs = tf.cond(
287 all_finished,
288 lambda: self.zero_inputs,
289 lambda: tf.nest.map_structure(lambda inp: inp.read(0), self.input_tas),
290 )
291 return (finished, next_inputs)
293 def sample(self, time, outputs, state):
294 del state
295 sample_ids = tf.cast(tf.argmax(outputs, axis=-1), tf.int32)
296 return sample_ids
298 def next_inputs(self, time, outputs, state, sample_ids):
299 del sample_ids
300 next_time = time + 1
301 finished = next_time >= self.sequence_length
302 all_finished = tf.reduce_all(finished)
304 def read_from_ta(inp):
305 return inp.read(next_time)
307 next_inputs = tf.cond(
308 all_finished,
309 lambda: self.zero_inputs,
310 lambda: tf.nest.map_structure(read_from_ta, self.input_tas),
311 )
312 return (finished, next_inputs, state)
315class ScheduledEmbeddingTrainingSampler(TrainingSampler):
316 """A training sampler that adds scheduled sampling.
318 Returns -1s for sample_ids where no sampling took place; valid
319 sample id values elsewhere.
320 """
322 @typechecked
323 def __init__(
324 self,
325 sampling_probability: TensorLike,
326 embedding_fn: Optional[Callable] = None,
327 time_major: bool = False,
328 seed: Optional[int] = None,
329 scheduling_seed: Optional[TensorLike] = None,
330 ):
331 """Initializer.
333 Args:
334 sampling_probability: A `float32` 0-D or 1-D tensor: the probability
335 of sampling categorically from the output ids instead of reading
336 directly from the inputs.
337 embedding_fn: A callable that takes a vector tensor of `ids`
338 (argmax ids).
339 time_major: Python bool. Whether the tensors in `inputs` are time
340 major. If `False` (default), they are assumed to be batch major.
341 seed: The sampling seed.
342 scheduling_seed: The schedule decision rule sampling seed.
344 Raises:
345 ValueError: if `sampling_probability` is not a scalar or vector.
346 """
347 self.embedding_fn = embedding_fn
348 if isinstance(sampling_probability, tf.Variable):
349 self.sampling_probability = sampling_probability
350 else:
351 self.sampling_probability = tf.convert_to_tensor(
352 sampling_probability, name="sampling_probability"
353 )
354 if self.sampling_probability.shape.ndims not in (0, 1):
355 raise ValueError(
356 "sampling_probability must be either a scalar or a vector. "
357 "saw shape: %s" % (self.sampling_probability.shape)
358 )
359 self.seed = seed
360 self.scheduling_seed = scheduling_seed
361 super().__init__(time_major=time_major)
363 def initialize(self, inputs, sequence_length=None, mask=None, embedding=None):
364 if self.embedding_fn is None:
365 if embedding is None:
366 raise ValueError(
367 "embedding is required as a keyword argument for "
368 "ScheduledEmbeddingTrainingSampler"
369 )
370 self.embedding_fn = lambda ids: tf.nn.embedding_lookup(embedding, ids)
371 return super().initialize(inputs, sequence_length=sequence_length, mask=mask)
373 def sample(self, time, outputs, state):
374 del state
375 # Return -1s where we did not sample, and sample_ids elsewhere
376 select_sample = bernoulli_sample(
377 probs=self.sampling_probability,
378 dtype=tf.bool,
379 sample_shape=self.batch_size,
380 seed=self.scheduling_seed,
381 )
382 return tf.where(
383 select_sample,
384 categorical_sample(logits=outputs, seed=self.seed),
385 tf.fill([self.batch_size], -1),
386 )
388 def next_inputs(self, time, outputs, state, sample_ids):
389 (finished, base_next_inputs, state) = super().next_inputs(
390 time=time, outputs=outputs, state=state, sample_ids=sample_ids
391 )
393 def maybe_sample():
394 """Perform scheduled sampling."""
395 where_sampling = tf.cast(tf.where(sample_ids > -1), tf.int32)
396 where_not_sampling = tf.cast(tf.where(sample_ids <= -1), tf.int32)
397 sample_ids_sampling = tf.gather_nd(sample_ids, where_sampling)
398 inputs_not_sampling = tf.gather_nd(base_next_inputs, where_not_sampling)
399 sampled_next_inputs = self.embedding_fn(sample_ids_sampling)
400 sampled_next_inputs = tf.cast(
401 sampled_next_inputs, inputs_not_sampling.dtype
402 )
403 base_shape = tf.shape(base_next_inputs)
404 return tf.scatter_nd(
405 indices=where_sampling, updates=sampled_next_inputs, shape=base_shape
406 ) + tf.scatter_nd(
407 indices=where_not_sampling,
408 updates=inputs_not_sampling,
409 shape=base_shape,
410 )
412 all_finished = tf.reduce_all(finished)
413 next_inputs = tf.cond(all_finished, lambda: base_next_inputs, maybe_sample)
414 return (finished, next_inputs, state)
417class ScheduledOutputTrainingSampler(TrainingSampler):
418 """A training sampler that adds scheduled sampling directly to outputs.
420 Returns False for sample_ids where no sampling took place; True
421 elsewhere.
422 """
424 @typechecked
425 def __init__(
426 self,
427 sampling_probability: TensorLike,
428 time_major: bool = False,
429 seed: Optional[int] = None,
430 next_inputs_fn: Optional[Callable] = None,
431 ):
432 """Initializer.
434 Args:
435 sampling_probability: A `float32` scalar tensor: the probability of
436 sampling from the outputs instead of reading directly from the
437 inputs.
438 time_major: Python bool. Whether the tensors in `inputs` are time
439 major. If `False` (default), they are assumed to be batch major.
440 seed: The sampling seed.
441 next_inputs_fn: (Optional) callable to apply to the RNN outputs to
442 create the next input when sampling. If `None` (default), the RNN
443 outputs will be used as the next inputs.
445 Raises:
446 ValueError: if `sampling_probability` is not a scalar or vector.
447 """
448 if isinstance(sampling_probability, tf.Variable):
449 self.sampling_probability = sampling_probability
450 else:
451 self.sampling_probability = tf.convert_to_tensor(
452 sampling_probability, name="sampling_probability"
453 )
454 if self.sampling_probability.shape.ndims not in (0, 1):
455 raise ValueError(
456 "sampling_probability must be either a scalar or a vector. "
457 "saw shape: %s" % (self.sampling_probability.shape)
458 )
460 self.seed = seed
461 self.next_inputs_fn = next_inputs_fn
463 super().__init__(time_major=time_major)
465 def initialize(
466 self, inputs, sequence_length=None, mask=None, auxiliary_inputs=None
467 ):
468 if auxiliary_inputs is None:
469 maybe_concatenated_inputs = inputs
470 else:
471 inputs = tf.convert_to_tensor(inputs)
472 auxiliary_inputs = tf.convert_to_tensor(auxiliary_inputs)
473 maybe_concatenated_inputs = tf.nest.map_structure(
474 lambda x, y: tf.concat((x, y), -1), inputs, auxiliary_inputs
475 )
476 if not self.time_major:
477 auxiliary_inputs = tf.nest.map_structure(
478 _transpose_batch_time, auxiliary_inputs
479 )
480 if auxiliary_inputs is not None:
481 self._auxiliary_input_tas = tf.nest.map_structure(
482 _unstack_ta, auxiliary_inputs
483 )
484 else:
485 self._auxiliary_input_tas = None
487 return super().initialize(
488 maybe_concatenated_inputs, sequence_length=sequence_length, mask=mask
489 )
491 def sample(self, time, outputs, state):
492 del state
493 return bernoulli_sample(
494 probs=self.sampling_probability,
495 sample_shape=self.batch_size,
496 seed=self.seed,
497 )
499 def next_inputs(self, time, outputs, state, sample_ids):
500 (finished, base_next_inputs, state) = super().next_inputs(
501 time=time, outputs=outputs, state=state, sample_ids=sample_ids
502 )
503 sample_ids = tf.cast(sample_ids, tf.bool)
505 def maybe_sample():
506 """Perform scheduled sampling."""
508 def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
509 """Concatenate outputs with auxiliary inputs, if they exist."""
510 if self._auxiliary_input_tas is None:
511 return outputs_
513 next_time = time + 1
514 auxiliary_inputs = tf.nest.map_structure(
515 lambda ta: ta.read(next_time), self._auxiliary_input_tas
516 )
517 if indices is not None:
518 auxiliary_inputs = tf.gather_nd(auxiliary_inputs, indices)
519 return tf.nest.map_structure(
520 lambda x, y: tf.concat((x, y), -1), outputs_, auxiliary_inputs
521 )
523 if self.next_inputs_fn is None:
524 return tf.where(
525 tf.broadcast_to(
526 tf.expand_dims(sample_ids, axis=-1), base_next_inputs.shape
527 ),
528 maybe_concatenate_auxiliary_inputs(outputs),
529 base_next_inputs,
530 )
532 where_sampling = tf.cast(tf.where(sample_ids), tf.int32)
533 where_not_sampling = tf.cast(tf.where(tf.logical_not(sample_ids)), tf.int32)
534 outputs_sampling = tf.gather_nd(outputs, where_sampling)
535 inputs_not_sampling = tf.gather_nd(base_next_inputs, where_not_sampling)
536 sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
537 self.next_inputs_fn(outputs_sampling), where_sampling
538 )
540 base_shape = tf.shape(base_next_inputs)
541 return tf.scatter_nd(
542 indices=where_sampling, updates=sampled_next_inputs, shape=base_shape
543 ) + tf.scatter_nd(
544 indices=where_not_sampling,
545 updates=inputs_not_sampling,
546 shape=base_shape,
547 )
549 all_finished = tf.reduce_all(finished)
550 no_samples = tf.logical_not(tf.reduce_any(sample_ids))
551 next_inputs = tf.cond(
552 tf.logical_or(all_finished, no_samples),
553 lambda: base_next_inputs,
554 maybe_sample,
555 )
556 return (finished, next_inputs, state)
559class GreedyEmbeddingSampler(Sampler):
560 """A inference sampler that takes the maximum from the output distribution.
562 Uses the argmax of the output (treated as logits) and passes the
563 result through an embedding layer to get the next input.
564 """
566 @typechecked
567 def __init__(self, embedding_fn: Optional[Callable] = None):
568 """Initializer.
570 Args:
571 embedding_fn: A optional callable that takes a vector tensor of `ids`
572 (argmax ids). The returned tensor will be passed to the decoder
573 input. Default to use `tf.nn.embedding_lookup`.
574 """
575 self.embedding_fn = embedding_fn
576 self._batch_size = None
578 @property
579 def batch_size(self):
580 if self._batch_size is None:
581 raise ValueError("batch_size accessed before initialize was called")
582 return self._batch_size
584 @property
585 def sample_ids_shape(self):
586 return tf.TensorShape([])
588 @property
589 def sample_ids_dtype(self):
590 return tf.int32
592 def initialize(self, embedding, start_tokens=None, end_token=None):
593 """Initialize the GreedyEmbeddingSampler.
595 Args:
596 embedding: tensor that contains embedding states matrix. It will be
597 used to generate generate outputs with `start_tokens` and `end_token`.
598 The embedding will be ignored if the `embedding_fn` has been provided
599 at __init__().
600 start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
601 end_token: `int32` scalar, the token that marks end of decoding.
603 Returns:
604 Tuple of two items: `(finished, self.start_inputs)`.
605 Raises:
606 ValueError: if `start_tokens` is not a 1D tensor or `end_token` is
607 not a scalar.
608 """
609 if self.embedding_fn is None:
610 self.embedding_fn = lambda ids: tf.nn.embedding_lookup(embedding, ids)
612 self.start_tokens = tf.convert_to_tensor(
613 start_tokens, dtype=tf.int32, name="start_tokens"
614 )
615 self.end_token = tf.convert_to_tensor(
616 end_token, dtype=tf.int32, name="end_token"
617 )
618 if self.start_tokens.shape.ndims != 1:
619 raise ValueError("start_tokens must be a vector")
620 self._batch_size = tf.size(start_tokens)
621 if self.end_token.shape.ndims != 0:
622 raise ValueError("end_token must be a scalar")
623 self.start_inputs = self.embedding_fn(self.start_tokens)
625 finished = tf.tile([False], [self._batch_size])
626 return (finished, self.start_inputs)
628 def sample(self, time, outputs, state):
629 """sample for GreedyEmbeddingHelper."""
630 del time, state # unused by sample_fn
631 # Outputs are logits, use argmax to get the most probable id
632 if not isinstance(outputs, tf.Tensor):
633 raise TypeError(
634 "Expected outputs to be a single Tensor, got: %s" % type(outputs)
635 )
636 sample_ids = tf.argmax(outputs, axis=-1, output_type=tf.int32)
637 return sample_ids
639 def next_inputs(self, time, outputs, state, sample_ids):
640 """next_inputs_fn for GreedyEmbeddingHelper."""
641 del time, outputs # unused by next_inputs_fn
642 finished = tf.equal(sample_ids, self.end_token)
643 all_finished = tf.reduce_all(finished)
644 next_inputs = tf.cond(
645 all_finished,
646 # If we're finished, the next_inputs value doesn't matter
647 lambda: self.start_inputs,
648 lambda: self.embedding_fn(sample_ids),
649 )
650 return (finished, next_inputs, state)
653class SampleEmbeddingSampler(GreedyEmbeddingSampler):
654 """An inference sampler that randomly samples from the output distribution.
656 Uses sampling (from a distribution) instead of argmax and passes the
657 result through an embedding layer to get the next input.
658 """
660 @typechecked
661 def __init__(
662 self,
663 embedding_fn: Optional[Callable] = None,
664 softmax_temperature: Optional[TensorLike] = None,
665 seed: Optional[TensorLike] = None,
666 ):
667 """Initializer.
669 Args:
670 embedding_fn: (Optional) A callable that takes a vector tensor of
671 `ids` (argmax ids). The returned tensor will be passed to the
672 decoder input.
673 softmax_temperature: (Optional) `float32` scalar, value to divide the
674 logits by before computing the softmax. Larger values (above 1.0)
675 result in more random samples, while smaller values push the
676 sampling distribution towards the argmax. Must be strictly greater
677 than 0. Defaults to 1.0.
678 seed: (Optional) The sampling seed.
680 Raises:
681 ValueError: if `start_tokens` is not a 1D tensor or `end_token` is
682 not a scalar.
683 """
684 super().__init__(embedding_fn)
685 self.softmax_temperature = softmax_temperature
686 self.seed = seed
688 def sample(self, time, outputs, state):
689 """sample for SampleEmbeddingHelper."""
690 del time, state # unused by sample_fn
691 # Outputs are logits, we sample instead of argmax (greedy).
692 if not isinstance(outputs, tf.Tensor):
693 raise TypeError(
694 "Expected outputs to be a single Tensor, got: %s" % type(outputs)
695 )
696 if self.softmax_temperature is None:
697 logits = outputs
698 else:
699 logits = outputs / self.softmax_temperature
701 return categorical_sample(logits=logits, seed=self.seed)
704class InferenceSampler(Sampler):
705 """An inference sampler that uses a custom sampling function."""
707 @typechecked
708 def __init__(
709 self,
710 sample_fn: Callable,
711 sample_shape: TensorLike,
712 sample_dtype: types.AcceptableDTypes,
713 end_fn: Callable,
714 next_inputs_fn: Optional[Callable] = None,
715 ):
716 """Initializer.
718 Args:
719 sample_fn: A callable that takes `outputs` and emits tensor
720 `sample_ids`.
721 sample_shape: Either a list of integers, or a 1-D Tensor of type
722 `int32`, the shape of the each sample in the batch returned by
723 `sample_fn`.
724 sample_dtype: the dtype of the sample returned by `sample_fn`.
725 end_fn: A callable that takes `sample_ids` and emits a `bool` vector
726 shaped `[batch_size]` indicating whether each sample is an end
727 token.
728 next_inputs_fn: (Optional) A callable that takes `sample_ids` and
729 returns the next batch of inputs. If not provided, `sample_ids` is
730 used as the next batch of inputs.
731 """
732 self.sample_fn = sample_fn
733 self.sample_shape = tf.TensorShape(sample_shape)
734 self.sample_dtype = sample_dtype
735 self.end_fn = end_fn
736 self.next_inputs_fn = next_inputs_fn
737 self._batch_size = None
739 @property
740 def batch_size(self):
741 if self._batch_size is None:
742 raise ValueError("batch_size accessed before initialize was called")
743 return self._batch_size
745 @property
746 def sample_ids_shape(self):
747 return self.sample_shape
749 @property
750 def sample_ids_dtype(self):
751 return self.sample_dtype
753 def initialize(self, start_inputs):
754 self.start_inputs = tf.convert_to_tensor(start_inputs, name="start_inputs")
755 self._batch_size = tf.shape(start_inputs)[0]
756 finished = tf.tile([False], [self._batch_size])
757 return (finished, self.start_inputs)
759 def sample(self, time, outputs, state):
760 del time, state # unused by sample
761 return self.sample_fn(outputs)
763 def next_inputs(self, time, outputs, state, sample_ids):
764 del time, outputs # unused by next_inputs
765 if self.next_inputs_fn is None:
766 next_inputs = sample_ids
767 else:
768 next_inputs = self.next_inputs_fn(sample_ids)
769 finished = self.end_fn(sample_ids)
770 return (finished, next_inputs, state)
773# The following sample functions (_call_sampler, bernoulli_sample,
774# categorical_sample) mimic TensorFlow Probability distribution semantics.
775def _call_sampler(sample_n_fn, sample_shape, name=None):
776 """Reshapes vector of samples."""
777 with tf.name_scope(name or "call_sampler"):
778 sample_shape = tf.convert_to_tensor(
779 sample_shape, dtype=tf.int32, name="sample_shape"
780 )
781 # Ensure sample_shape is a vector (vs just a scalar).
782 pad = tf.cast(tf.equal(tf.rank(sample_shape), 0), tf.int32)
783 sample_shape = tf.reshape(
784 sample_shape,
785 tf.pad(tf.shape(sample_shape), paddings=[[pad, 0]], constant_values=1),
786 )
787 samples = sample_n_fn(tf.reduce_prod(sample_shape))
788 batch_event_shape = tf.shape(samples)[1:]
789 final_shape = tf.concat([sample_shape, batch_event_shape], 0)
790 return tf.reshape(samples, final_shape)
793def bernoulli_sample(
794 probs=None, logits=None, dtype=tf.int32, sample_shape=(), seed=None
795):
796 """Samples from Bernoulli distribution."""
797 if probs is None:
798 probs = tf.sigmoid(logits, name="probs")
799 else:
800 probs = tf.convert_to_tensor(probs, name="probs")
801 batch_shape_tensor = tf.shape(probs)
803 def _sample_n(n):
804 """Sample vector of Bernoullis."""
805 new_shape = tf.concat([[n], batch_shape_tensor], 0)
806 uniform = tf.random.uniform(new_shape, seed=seed, dtype=probs.dtype)
807 return tf.cast(tf.less(uniform, probs), dtype)
809 return _call_sampler(_sample_n, sample_shape)
812def categorical_sample(logits, dtype=tf.int32, sample_shape=(), seed=None):
813 """Samples from categorical distribution."""
814 logits = tf.convert_to_tensor(logits, name="logits")
815 event_size = tf.shape(logits)[-1]
816 batch_shape_tensor = tf.shape(logits)[:-1]
818 def _sample_n(n):
819 """Sample vector of categoricals."""
820 if logits.shape.ndims == 2:
821 logits_2d = logits
822 else:
823 logits_2d = tf.reshape(logits, [-1, event_size])
824 sample_dtype = tf.int64 if logits.dtype.size > 4 else tf.int32
825 draws = tf.random.categorical(logits_2d, n, dtype=sample_dtype, seed=seed)
826 draws = tf.reshape(tf.transpose(draws), tf.concat([[n], batch_shape_tensor], 0))
827 return tf.cast(draws, dtype)
829 return _call_sampler(_sample_n, sample_shape)
832def _unstack_ta(inp):
833 return tf.TensorArray(
834 dtype=inp.dtype, size=tf.shape(inp)[0], element_shape=inp.shape[1:]
835 ).unstack(inp)
838def _check_sequence_is_right_padded(mask, time_major):
839 """Returns an Assert operation checking that if the mask tensor is right
840 padded."""
841 if time_major:
842 mask = tf.transpose(mask)
843 sequence_length = tf.math.reduce_sum(tf.cast(mask, tf.int32), axis=1)
844 max_seq_length = tf.shape(mask)[1]
845 right_padded_mask = tf.sequence_mask(
846 sequence_length, maxlen=max_seq_length, dtype=tf.bool
847 )
848 all_equal = tf.math.equal(mask, right_padded_mask)
850 condition = tf.math.reduce_all(all_equal)
851 error_message = "The input sequence should be right padded."
853 return tf.Assert(condition, [error_message])