Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/batching.py: 43%
86 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Batching dataset transformations."""
16from tensorflow.python.data.ops import dataset_ops
17from tensorflow.python.data.ops import structured_function
18from tensorflow.python.data.util import convert
19from tensorflow.python.data.util import nest
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import sparse_tensor
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
26from tensorflow.python.util import deprecation
27from tensorflow.python.util.tf_export import tf_export
30@tf_export("data.experimental.dense_to_ragged_batch")
31@deprecation.deprecated(None, "Use `tf.data.Dataset.ragged_batch` instead.")
32def dense_to_ragged_batch(batch_size,
33 drop_remainder=False,
34 row_splits_dtype=dtypes.int64):
35 """A transformation that batches ragged elements into `tf.RaggedTensor`s.
37 This transformation combines multiple consecutive elements of the input
38 dataset into a single element.
40 Like `tf.data.Dataset.batch`, the components of the resulting element will
41 have an additional outer dimension, which will be `batch_size` (or
42 `N % batch_size` for the last element if `batch_size` does not divide the
43 number of input elements `N` evenly and `drop_remainder` is `False`). If
44 your program depends on the batches having the same outer dimension, you
45 should set the `drop_remainder` argument to `True` to prevent the smaller
46 batch from being produced.
48 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
49 different shapes:
51 * If an input element is a `tf.Tensor` whose static `tf.TensorShape` is
52 fully defined, then it is batched as normal.
53 * If an input element is a `tf.Tensor` whose static `tf.TensorShape` contains
54 one or more axes with unknown size (i.e., `shape[i]=None`), then the output
55 will contain a `tf.RaggedTensor` that is ragged up to any of such
56 dimensions.
57 * If an input element is a `tf.RaggedTensor` or any other type, then it is
58 batched as normal.
60 Example:
62 >>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6))
63 >>> dataset = dataset.map(lambda x: tf.range(x))
64 >>> dataset.element_spec.shape
65 TensorShape([None])
66 >>> dataset = dataset.apply(
67 ... tf.data.experimental.dense_to_ragged_batch(batch_size=2))
68 >>> for batch in dataset:
69 ... print(batch)
70 <tf.RaggedTensor [[], [0]]>
71 <tf.RaggedTensor [[0, 1], [0, 1, 2]]>
72 <tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]>
74 Args:
75 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
76 consecutive elements of this dataset to combine in a single batch.
77 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
78 whether the last batch should be dropped in the case it has fewer than
79 `batch_size` elements; the default behavior is not to drop the smaller
80 batch.
81 row_splits_dtype: The dtype that should be used for the `row_splits` of any
82 new ragged tensors. Existing `tf.RaggedTensor` elements do not have their
83 row_splits dtype changed.
85 Returns:
86 Dataset: A `Dataset`.
87 """
88 def _apply_fn(dataset):
89 return dataset.ragged_batch(batch_size, drop_remainder, row_splits_dtype)
91 return _apply_fn
94@tf_export("data.experimental.dense_to_sparse_batch")
95@deprecation.deprecated(None, "Use `tf.data.Dataset.sparse_batch` instead.")
96def dense_to_sparse_batch(batch_size, row_shape):
97 """A transformation that batches ragged elements into `tf.sparse.SparseTensor`s.
99 Like `Dataset.padded_batch()`, this transformation combines multiple
100 consecutive elements of the dataset, which might have different
101 shapes, into a single element. The resulting element has three
102 components (`indices`, `values`, and `dense_shape`), which
103 comprise a `tf.sparse.SparseTensor` that represents the same data. The
104 `row_shape` represents the dense shape of each row in the
105 resulting `tf.sparse.SparseTensor`, to which the effective batch size is
106 prepended. For example:
108 ```python
109 # NOTE: The following examples use `{ ... }` to represent the
110 # contents of a dataset.
111 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
113 a.apply(tf.data.experimental.dense_to_sparse_batch(
114 batch_size=2, row_shape=[6])) ==
115 {
116 ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices
117 ['a', 'b', 'c', 'a', 'b'], # values
118 [2, 6]), # dense_shape
119 ([[0, 0], [0, 1], [0, 2], [0, 3]],
120 ['a', 'b', 'c', 'd'],
121 [1, 6])
122 }
123 ```
125 Args:
126 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
127 consecutive elements of this dataset to combine in a single batch.
128 row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like object
129 representing the equivalent dense shape of a row in the resulting
130 `tf.sparse.SparseTensor`. Each element of this dataset must have the same
131 rank as `row_shape`, and must have size less than or equal to `row_shape`
132 in each dimension.
134 Returns:
135 A `Dataset` transformation function, which can be passed to
136 `tf.data.Dataset.apply`.
137 """
139 def _apply_fn(dataset):
140 return dataset.sparse_batch(batch_size, row_shape)
142 return _apply_fn
145@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch()")
146@tf_export(v1=["data.experimental.map_and_batch_with_legacy_function"])
147def map_and_batch_with_legacy_function(map_func,
148 batch_size,
149 num_parallel_batches=None,
150 drop_remainder=False,
151 num_parallel_calls=None):
152 """Fused implementation of `map` and `batch`.
154 NOTE: This is an escape hatch for existing uses of `map_and_batch` that do not
155 work with V2 functions. New uses are strongly discouraged and existing uses
156 should migrate to `map_and_batch` as this method will not be removed in V2.
158 Args:
159 map_func: A function mapping a nested structure of tensors to another
160 nested structure of tensors.
161 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
162 consecutive elements of this dataset to combine in a single batch.
163 num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
164 representing the number of batches to create in parallel. On one hand,
165 higher values can help mitigate the effect of stragglers. On the other
166 hand, higher values can increase contention if CPU is scarce.
167 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
168 whether the last batch should be dropped in case its size is smaller than
169 desired; the default behavior is not to drop the smaller batch.
170 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
171 representing the number of elements to process in parallel. If not
172 specified, `batch_size * num_parallel_batches` elements will be processed
173 in parallel. If the value `tf.data.AUTOTUNE` is used, then
174 the number of parallel calls is set dynamically based on available CPU.
176 Returns:
177 A `Dataset` transformation function, which can be passed to
178 `tf.data.Dataset.apply`.
180 Raises:
181 ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
182 specified.
183 """
185 if num_parallel_batches is None and num_parallel_calls is None:
186 num_parallel_calls = batch_size
187 elif num_parallel_batches is not None and num_parallel_calls is None:
188 num_parallel_calls = batch_size * num_parallel_batches
189 elif num_parallel_batches is not None and num_parallel_calls is not None:
190 raise ValueError(
191 "`map_and_batch_with_legacy_function` allows only one of "
192 "`num_parallel_batches` and "
193 "`num_parallel_calls` to be set, but "
194 f"`num_parallel_batches` was set to {num_parallel_batches} "
195 f"and `num_parallel_calls` as set to {num_parallel_calls}.")
197 def _apply_fn(dataset):
198 return _MapAndBatchDataset(dataset, map_func, batch_size,
199 num_parallel_calls, drop_remainder,
200 use_legacy_function=True)
202 return _apply_fn
205@deprecation.deprecated(
206 None,
207 "Use `tf.data.Dataset.map(map_func, num_parallel_calls)` followed by "
208 "`tf.data.Dataset.batch(batch_size, drop_remainder)`. Static tf.data "
209 "optimizations will take care of using the fused implementation.")
210@tf_export("data.experimental.map_and_batch")
211def map_and_batch(map_func,
212 batch_size,
213 num_parallel_batches=None,
214 drop_remainder=False,
215 num_parallel_calls=None):
216 """Fused implementation of `map` and `batch`.
218 Maps `map_func` across `batch_size` consecutive elements of this dataset
219 and then combines them into a batch. Functionally, it is equivalent to `map`
220 followed by `batch`. This API is temporary and deprecated since input pipeline
221 optimization now fuses consecutive `map` and `batch` operations automatically.
223 Args:
224 map_func: A function mapping a nested structure of tensors to another
225 nested structure of tensors.
226 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
227 consecutive elements of this dataset to combine in a single batch.
228 num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
229 representing the number of batches to create in parallel. On one hand,
230 higher values can help mitigate the effect of stragglers. On the other
231 hand, higher values can increase contention if CPU is scarce.
232 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
233 whether the last batch should be dropped in case its size is smaller than
234 desired; the default behavior is not to drop the smaller batch.
235 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
236 representing the number of elements to process in parallel. If not
237 specified, `batch_size * num_parallel_batches` elements will be processed
238 in parallel. If the value `tf.data.AUTOTUNE` is used, then
239 the number of parallel calls is set dynamically based on available CPU.
241 Returns:
242 A `Dataset` transformation function, which can be passed to
243 `tf.data.Dataset.apply`.
245 Raises:
246 ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
247 specified.
248 """
250 if num_parallel_batches is None and num_parallel_calls is None:
251 num_parallel_calls = batch_size
252 elif num_parallel_batches is not None and num_parallel_calls is None:
253 num_parallel_calls = batch_size * num_parallel_batches
254 elif num_parallel_batches is not None and num_parallel_calls is not None:
255 raise ValueError(
256 "`map_and_batch` allows only one of `num_parallel_batches` and "
257 "`num_parallel_calls` to be set, but "
258 f"`num_parallel_batches` was set to {num_parallel_batches} "
259 f"and `num_parallel_calls` as set to {num_parallel_calls}.")
261 def _apply_fn(dataset):
262 return _MapAndBatchDataset(dataset, map_func, batch_size,
263 num_parallel_calls, drop_remainder)
265 return _apply_fn
268@deprecation.deprecated(None, "Use `tf.data.Dataset.unbatch()`.")
269@tf_export("data.experimental.unbatch")
270def unbatch():
271 """Splits elements of a dataset into multiple elements on the batch dimension.
273 For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
274 where `B` may vary for each input element, then for each element in the
275 dataset, the unbatched dataset will contain `B` consecutive elements
276 of shape `[a0, a1, ...]`.
278 ```python
279 # NOTE: The following example uses `{ ... }` to represent the contents
280 # of a dataset.
281 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
283 a.unbatch() == {
284 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
285 ```
287 Returns:
288 A `Dataset` transformation function, which can be passed to
289 `tf.data.Dataset.apply`.
290 """
292 def _apply_fn(dataset):
293 return dataset.unbatch()
295 return _apply_fn
298class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
299 """A `Dataset` that batches ragged dense elements into `tf.sparse.SparseTensor`s."""
301 def __init__(self, input_dataset, batch_size, row_shape):
302 """See `Dataset.dense_to_sparse_batch()` for more details."""
303 if not isinstance(
304 dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType):
305 raise TypeError("`dense_to_sparse_batch` requires an input dataset whose "
306 "elements have a single component, but the given dataset "
307 "has the following component types: "
308 f"{dataset_ops.get_legacy_output_types(input_dataset)}.")
309 self._input_dataset = input_dataset
310 self._batch_size = batch_size
311 self._row_shape = row_shape
312 self._element_spec = sparse_tensor.SparseTensorSpec(
313 tensor_shape.TensorShape([None]).concatenate(self._row_shape),
314 dataset_ops.get_legacy_output_types(input_dataset))
316 variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
317 self._input_dataset._variant_tensor, # pylint: disable=protected-access
318 self._batch_size,
319 row_shape=convert.partial_shape_to_tensor(self._row_shape),
320 **self._flat_structure)
321 super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
322 variant_tensor)
324 @property
325 def element_spec(self):
326 return self._element_spec
329class _MapAndBatchDataset(dataset_ops.UnaryDataset):
330 """A `Dataset` that maps a function over a batch of elements."""
332 def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
333 drop_remainder, use_legacy_function=False):
334 self._input_dataset = input_dataset
336 self._map_func = structured_function.StructuredFunctionWrapper(
337 map_func,
338 "tf.data.experimental.map_and_batch()",
339 dataset=input_dataset,
340 use_legacy_function=use_legacy_function)
341 self._batch_size_t = ops.convert_to_tensor(
342 batch_size, dtype=dtypes.int64, name="batch_size")
343 self._num_parallel_calls_t = ops.convert_to_tensor(
344 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
345 self._drop_remainder_t = ops.convert_to_tensor(
346 drop_remainder, dtype=dtypes.bool, name="drop_remainder")
348 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t)
349 # pylint: disable=protected-access
350 if constant_drop_remainder:
351 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
352 # or `False` (explicitly retaining the remainder).
353 # pylint: disable=g-long-lambda
354 self._element_spec = nest.map_structure(
355 lambda component_spec: component_spec._batch(
356 tensor_util.constant_value(self._batch_size_t)),
357 self._map_func.output_structure)
358 else:
359 self._element_spec = nest.map_structure(
360 lambda component_spec: component_spec._batch(None),
361 self._map_func.output_structure)
362 # pylint: enable=protected-access
363 variant_tensor = ged_ops.map_and_batch_dataset(
364 self._input_dataset._variant_tensor, # pylint: disable=protected-access
365 self._map_func.function.captured_inputs,
366 f=self._map_func.function,
367 batch_size=self._batch_size_t,
368 num_parallel_calls=self._num_parallel_calls_t,
369 drop_remainder=self._drop_remainder_t,
370 preserve_cardinality=True,
371 **self._flat_structure)
372 super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
374 def _functions(self):
375 return [self._map_func]
377 @property
378 def element_spec(self):
379 return self._element_spec