Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/batch_op.py: 34%
53 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The implementation of `tf.data.Dataset.batch`."""
17import warnings
19from tensorflow.python.data.ops import dataset_ops
20from tensorflow.python.data.ops import debug_mode
21from tensorflow.python.data.util import nest
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.ops import gen_dataset_ops
28def _batch(input_dataset,
29 batch_size,
30 drop_remainder=False,
31 num_parallel_calls=None,
32 deterministic=None,
33 name=None):
34 """See `Dataset.batch` for details."""
35 if num_parallel_calls is None or debug_mode.DEBUG_MODE:
36 if deterministic is not None and not debug_mode.DEBUG_MODE:
37 warnings.warn("The `deterministic` argument has no effect unless the "
38 "`num_parallel_calls` argument is specified.")
39 return _BatchDataset(input_dataset, batch_size, drop_remainder, name=name)
40 else:
41 return _ParallelBatchDataset(
42 input_dataset,
43 batch_size,
44 drop_remainder,
45 num_parallel_calls,
46 deterministic,
47 name=name)
50class _BatchDataset(dataset_ops.UnaryDataset):
51 """A `Dataset` that batches contiguous elements from its input."""
53 def __init__(self, input_dataset, batch_size, drop_remainder, name=None):
54 """See `Dataset.batch()` for details."""
55 self._input_dataset = input_dataset
56 self._batch_size = ops.convert_to_tensor(
57 batch_size, dtype=dtypes.int64, name="batch_size")
58 self._drop_remainder = ops.convert_to_tensor(
59 drop_remainder, dtype=dtypes.bool, name="drop_remainder")
61 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder)
62 # pylint: disable=protected-access
63 if constant_drop_remainder:
64 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
65 # or `False` (explicitly retaining the remainder).
66 # pylint: disable=g-long-lambda
67 constant_batch_size = tensor_util.constant_value(self._batch_size)
68 self._structure = nest.map_structure(
69 lambda component_spec: component_spec._batch(constant_batch_size),
70 input_dataset.element_spec)
71 else:
72 self._structure = nest.map_structure(
73 lambda component_spec: component_spec._batch(None),
74 input_dataset.element_spec)
76 self._name = name
77 variant_tensor = gen_dataset_ops.batch_dataset_v2(
78 input_dataset._variant_tensor,
79 batch_size=self._batch_size,
80 drop_remainder=self._drop_remainder,
81 **self._common_args)
82 super().__init__(input_dataset, variant_tensor)
84 @property
85 def element_spec(self):
86 return self._structure
89class _ParallelBatchDataset(dataset_ops.UnaryDataset):
90 """A `Dataset` that batches contiguous elements from its input in parallel."""
92 def __init__(self,
93 input_dataset,
94 batch_size,
95 drop_remainder,
96 num_parallel_calls,
97 deterministic,
98 name=None):
99 """See `Dataset.batch()` for details."""
100 self._input_dataset = input_dataset
101 self._batch_size = ops.convert_to_tensor(
102 batch_size, dtype=dtypes.int64, name="batch_size")
103 self._drop_remainder = ops.convert_to_tensor(
104 drop_remainder, dtype=dtypes.bool, name="drop_remainder")
105 self._num_parallel_calls = ops.convert_to_tensor(
106 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
107 if deterministic is None:
108 self._deterministic = "default"
109 elif deterministic:
110 self._deterministic = "true"
111 else:
112 self._deterministic = "false"
114 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder)
115 # pylint: disable=protected-access
116 if constant_drop_remainder:
117 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
118 # or `False` (explicitly retaining the remainder).
119 # pylint: disable=g-long-lambda
120 constant_batch_size = tensor_util.constant_value(self._batch_size)
121 self._structure = nest.map_structure(
122 lambda component_spec: component_spec._batch(constant_batch_size),
123 input_dataset.element_spec)
124 else:
125 self._structure = nest.map_structure(
126 lambda component_spec: component_spec._batch(None),
127 input_dataset.element_spec)
129 self._name = name
130 variant_tensor = gen_dataset_ops.parallel_batch_dataset(
131 input_dataset._variant_tensor,
132 batch_size=self._batch_size,
133 num_parallel_calls=self._num_parallel_calls,
134 drop_remainder=self._drop_remainder,
135 deterministic=self._deterministic,
136 **self._common_args)
138 super().__init__(input_dataset, variant_tensor)
140 @property
141 def element_spec(self):
142 return self._structure