Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/batching.py: 43%

86 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Batching dataset transformations.""" 

16from tensorflow.python.data.ops import dataset_ops 

17from tensorflow.python.data.ops import structured_function 

18from tensorflow.python.data.util import convert 

19from tensorflow.python.data.util import nest 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.framework import sparse_tensor 

23from tensorflow.python.framework import tensor_shape 

24from tensorflow.python.framework import tensor_util 

25from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 

26from tensorflow.python.util import deprecation 

27from tensorflow.python.util.tf_export import tf_export 

28 

29 

30@tf_export("data.experimental.dense_to_ragged_batch") 

31@deprecation.deprecated(None, "Use `tf.data.Dataset.ragged_batch` instead.") 

32def dense_to_ragged_batch(batch_size, 

33 drop_remainder=False, 

34 row_splits_dtype=dtypes.int64): 

35 """A transformation that batches ragged elements into `tf.RaggedTensor`s. 

36 

37 This transformation combines multiple consecutive elements of the input 

38 dataset into a single element. 

39 

40 Like `tf.data.Dataset.batch`, the components of the resulting element will 

41 have an additional outer dimension, which will be `batch_size` (or 

42 `N % batch_size` for the last element if `batch_size` does not divide the 

43 number of input elements `N` evenly and `drop_remainder` is `False`). If 

44 your program depends on the batches having the same outer dimension, you 

45 should set the `drop_remainder` argument to `True` to prevent the smaller 

46 batch from being produced. 

47 

48 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have 

49 different shapes: 

50 

51 * If an input element is a `tf.Tensor` whose static `tf.TensorShape` is 

52 fully defined, then it is batched as normal. 

53 * If an input element is a `tf.Tensor` whose static `tf.TensorShape` contains 

54 one or more axes with unknown size (i.e., `shape[i]=None`), then the output 

55 will contain a `tf.RaggedTensor` that is ragged up to any of such 

56 dimensions. 

57 * If an input element is a `tf.RaggedTensor` or any other type, then it is 

58 batched as normal. 

59 

60 Example: 

61 

62 >>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6)) 

63 >>> dataset = dataset.map(lambda x: tf.range(x)) 

64 >>> dataset.element_spec.shape 

65 TensorShape([None]) 

66 >>> dataset = dataset.apply( 

67 ... tf.data.experimental.dense_to_ragged_batch(batch_size=2)) 

68 >>> for batch in dataset: 

69 ... print(batch) 

70 <tf.RaggedTensor [[], [0]]> 

71 <tf.RaggedTensor [[0, 1], [0, 1, 2]]> 

72 <tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]> 

73 

74 Args: 

75 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 

76 consecutive elements of this dataset to combine in a single batch. 

77 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 

78 whether the last batch should be dropped in the case it has fewer than 

79 `batch_size` elements; the default behavior is not to drop the smaller 

80 batch. 

81 row_splits_dtype: The dtype that should be used for the `row_splits` of any 

82 new ragged tensors. Existing `tf.RaggedTensor` elements do not have their 

83 row_splits dtype changed. 

84 

85 Returns: 

86 Dataset: A `Dataset`. 

87 """ 

88 def _apply_fn(dataset): 

89 return dataset.ragged_batch(batch_size, drop_remainder, row_splits_dtype) 

90 

91 return _apply_fn 

92 

93 

94@tf_export("data.experimental.dense_to_sparse_batch") 

95@deprecation.deprecated(None, "Use `tf.data.Dataset.sparse_batch` instead.") 

96def dense_to_sparse_batch(batch_size, row_shape): 

97 """A transformation that batches ragged elements into `tf.sparse.SparseTensor`s. 

98 

99 Like `Dataset.padded_batch()`, this transformation combines multiple 

100 consecutive elements of the dataset, which might have different 

101 shapes, into a single element. The resulting element has three 

102 components (`indices`, `values`, and `dense_shape`), which 

103 comprise a `tf.sparse.SparseTensor` that represents the same data. The 

104 `row_shape` represents the dense shape of each row in the 

105 resulting `tf.sparse.SparseTensor`, to which the effective batch size is 

106 prepended. For example: 

107 

108 ```python 

109 # NOTE: The following examples use `{ ... }` to represent the 

110 # contents of a dataset. 

111 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } 

112 

113 a.apply(tf.data.experimental.dense_to_sparse_batch( 

114 batch_size=2, row_shape=[6])) == 

115 { 

116 ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices 

117 ['a', 'b', 'c', 'a', 'b'], # values 

118 [2, 6]), # dense_shape 

119 ([[0, 0], [0, 1], [0, 2], [0, 3]], 

120 ['a', 'b', 'c', 'd'], 

121 [1, 6]) 

122 } 

123 ``` 

124 

125 Args: 

126 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 

127 consecutive elements of this dataset to combine in a single batch. 

128 row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like object 

129 representing the equivalent dense shape of a row in the resulting 

130 `tf.sparse.SparseTensor`. Each element of this dataset must have the same 

131 rank as `row_shape`, and must have size less than or equal to `row_shape` 

132 in each dimension. 

133 

134 Returns: 

135 A `Dataset` transformation function, which can be passed to 

136 `tf.data.Dataset.apply`. 

137 """ 

138 

139 def _apply_fn(dataset): 

140 return dataset.sparse_batch(batch_size, row_shape) 

141 

142 return _apply_fn 

143 

144 

145@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch()") 

146@tf_export(v1=["data.experimental.map_and_batch_with_legacy_function"]) 

147def map_and_batch_with_legacy_function(map_func, 

148 batch_size, 

149 num_parallel_batches=None, 

150 drop_remainder=False, 

151 num_parallel_calls=None): 

152 """Fused implementation of `map` and `batch`. 

153 

154 NOTE: This is an escape hatch for existing uses of `map_and_batch` that do not 

155 work with V2 functions. New uses are strongly discouraged and existing uses 

156 should migrate to `map_and_batch` as this method will not be removed in V2. 

157 

158 Args: 

159 map_func: A function mapping a nested structure of tensors to another 

160 nested structure of tensors. 

161 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 

162 consecutive elements of this dataset to combine in a single batch. 

163 num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, 

164 representing the number of batches to create in parallel. On one hand, 

165 higher values can help mitigate the effect of stragglers. On the other 

166 hand, higher values can increase contention if CPU is scarce. 

167 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 

168 whether the last batch should be dropped in case its size is smaller than 

169 desired; the default behavior is not to drop the smaller batch. 

170 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 

171 representing the number of elements to process in parallel. If not 

172 specified, `batch_size * num_parallel_batches` elements will be processed 

173 in parallel. If the value `tf.data.AUTOTUNE` is used, then 

174 the number of parallel calls is set dynamically based on available CPU. 

175 

176 Returns: 

177 A `Dataset` transformation function, which can be passed to 

178 `tf.data.Dataset.apply`. 

179 

180 Raises: 

181 ValueError: If both `num_parallel_batches` and `num_parallel_calls` are 

182 specified. 

183 """ 

184 

185 if num_parallel_batches is None and num_parallel_calls is None: 

186 num_parallel_calls = batch_size 

187 elif num_parallel_batches is not None and num_parallel_calls is None: 

188 num_parallel_calls = batch_size * num_parallel_batches 

189 elif num_parallel_batches is not None and num_parallel_calls is not None: 

190 raise ValueError( 

191 "`map_and_batch_with_legacy_function` allows only one of " 

192 "`num_parallel_batches` and " 

193 "`num_parallel_calls` to be set, but " 

194 f"`num_parallel_batches` was set to {num_parallel_batches} " 

195 f"and `num_parallel_calls` as set to {num_parallel_calls}.") 

196 

197 def _apply_fn(dataset): 

198 return _MapAndBatchDataset(dataset, map_func, batch_size, 

199 num_parallel_calls, drop_remainder, 

200 use_legacy_function=True) 

201 

202 return _apply_fn 

203 

204 

205@deprecation.deprecated( 

206 None, 

207 "Use `tf.data.Dataset.map(map_func, num_parallel_calls)` followed by " 

208 "`tf.data.Dataset.batch(batch_size, drop_remainder)`. Static tf.data " 

209 "optimizations will take care of using the fused implementation.") 

210@tf_export("data.experimental.map_and_batch") 

211def map_and_batch(map_func, 

212 batch_size, 

213 num_parallel_batches=None, 

214 drop_remainder=False, 

215 num_parallel_calls=None): 

216 """Fused implementation of `map` and `batch`. 

217 

218 Maps `map_func` across `batch_size` consecutive elements of this dataset 

219 and then combines them into a batch. Functionally, it is equivalent to `map` 

220 followed by `batch`. This API is temporary and deprecated since input pipeline 

221 optimization now fuses consecutive `map` and `batch` operations automatically. 

222 

223 Args: 

224 map_func: A function mapping a nested structure of tensors to another 

225 nested structure of tensors. 

226 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 

227 consecutive elements of this dataset to combine in a single batch. 

228 num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, 

229 representing the number of batches to create in parallel. On one hand, 

230 higher values can help mitigate the effect of stragglers. On the other 

231 hand, higher values can increase contention if CPU is scarce. 

232 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 

233 whether the last batch should be dropped in case its size is smaller than 

234 desired; the default behavior is not to drop the smaller batch. 

235 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 

236 representing the number of elements to process in parallel. If not 

237 specified, `batch_size * num_parallel_batches` elements will be processed 

238 in parallel. If the value `tf.data.AUTOTUNE` is used, then 

239 the number of parallel calls is set dynamically based on available CPU. 

240 

241 Returns: 

242 A `Dataset` transformation function, which can be passed to 

243 `tf.data.Dataset.apply`. 

244 

245 Raises: 

246 ValueError: If both `num_parallel_batches` and `num_parallel_calls` are 

247 specified. 

248 """ 

249 

250 if num_parallel_batches is None and num_parallel_calls is None: 

251 num_parallel_calls = batch_size 

252 elif num_parallel_batches is not None and num_parallel_calls is None: 

253 num_parallel_calls = batch_size * num_parallel_batches 

254 elif num_parallel_batches is not None and num_parallel_calls is not None: 

255 raise ValueError( 

256 "`map_and_batch` allows only one of `num_parallel_batches` and " 

257 "`num_parallel_calls` to be set, but " 

258 f"`num_parallel_batches` was set to {num_parallel_batches} " 

259 f"and `num_parallel_calls` as set to {num_parallel_calls}.") 

260 

261 def _apply_fn(dataset): 

262 return _MapAndBatchDataset(dataset, map_func, batch_size, 

263 num_parallel_calls, drop_remainder) 

264 

265 return _apply_fn 

266 

267 

268@deprecation.deprecated(None, "Use `tf.data.Dataset.unbatch()`.") 

269@tf_export("data.experimental.unbatch") 

270def unbatch(): 

271 """Splits elements of a dataset into multiple elements on the batch dimension. 

272 

273 For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, 

274 where `B` may vary for each input element, then for each element in the 

275 dataset, the unbatched dataset will contain `B` consecutive elements 

276 of shape `[a0, a1, ...]`. 

277 

278 ```python 

279 # NOTE: The following example uses `{ ... }` to represent the contents 

280 # of a dataset. 

281 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } 

282 

283 a.unbatch() == { 

284 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'} 

285 ``` 

286 

287 Returns: 

288 A `Dataset` transformation function, which can be passed to 

289 `tf.data.Dataset.apply`. 

290 """ 

291 

292 def _apply_fn(dataset): 

293 return dataset.unbatch() 

294 

295 return _apply_fn 

296 

297 

298class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset): 

299 """A `Dataset` that batches ragged dense elements into `tf.sparse.SparseTensor`s.""" 

300 

301 def __init__(self, input_dataset, batch_size, row_shape): 

302 """See `Dataset.dense_to_sparse_batch()` for more details.""" 

303 if not isinstance( 

304 dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType): 

305 raise TypeError("`dense_to_sparse_batch` requires an input dataset whose " 

306 "elements have a single component, but the given dataset " 

307 "has the following component types: " 

308 f"{dataset_ops.get_legacy_output_types(input_dataset)}.") 

309 self._input_dataset = input_dataset 

310 self._batch_size = batch_size 

311 self._row_shape = row_shape 

312 self._element_spec = sparse_tensor.SparseTensorSpec( 

313 tensor_shape.TensorShape([None]).concatenate(self._row_shape), 

314 dataset_ops.get_legacy_output_types(input_dataset)) 

315 

316 variant_tensor = ged_ops.dense_to_sparse_batch_dataset( 

317 self._input_dataset._variant_tensor, # pylint: disable=protected-access 

318 self._batch_size, 

319 row_shape=convert.partial_shape_to_tensor(self._row_shape), 

320 **self._flat_structure) 

321 super(_DenseToSparseBatchDataset, self).__init__(input_dataset, 

322 variant_tensor) 

323 

324 @property 

325 def element_spec(self): 

326 return self._element_spec 

327 

328 

329class _MapAndBatchDataset(dataset_ops.UnaryDataset): 

330 """A `Dataset` that maps a function over a batch of elements.""" 

331 

332 def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, 

333 drop_remainder, use_legacy_function=False): 

334 self._input_dataset = input_dataset 

335 

336 self._map_func = structured_function.StructuredFunctionWrapper( 

337 map_func, 

338 "tf.data.experimental.map_and_batch()", 

339 dataset=input_dataset, 

340 use_legacy_function=use_legacy_function) 

341 self._batch_size_t = ops.convert_to_tensor( 

342 batch_size, dtype=dtypes.int64, name="batch_size") 

343 self._num_parallel_calls_t = ops.convert_to_tensor( 

344 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 

345 self._drop_remainder_t = ops.convert_to_tensor( 

346 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 

347 

348 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t) 

349 # pylint: disable=protected-access 

350 if constant_drop_remainder: 

351 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 

352 # or `False` (explicitly retaining the remainder). 

353 # pylint: disable=g-long-lambda 

354 self._element_spec = nest.map_structure( 

355 lambda component_spec: component_spec._batch( 

356 tensor_util.constant_value(self._batch_size_t)), 

357 self._map_func.output_structure) 

358 else: 

359 self._element_spec = nest.map_structure( 

360 lambda component_spec: component_spec._batch(None), 

361 self._map_func.output_structure) 

362 # pylint: enable=protected-access 

363 variant_tensor = ged_ops.map_and_batch_dataset( 

364 self._input_dataset._variant_tensor, # pylint: disable=protected-access 

365 self._map_func.function.captured_inputs, 

366 f=self._map_func.function, 

367 batch_size=self._batch_size_t, 

368 num_parallel_calls=self._num_parallel_calls_t, 

369 drop_remainder=self._drop_remainder_t, 

370 preserve_cardinality=True, 

371 **self._flat_structure) 

372 super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor) 

373 

374 def _functions(self): 

375 return [self._map_func] 

376 

377 @property 

378 def element_spec(self): 

379 return self._element_spec