Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/distribute.py: 30%

107 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Distribution Strategy-related dataset transformations.""" 

16 

17from tensorflow.python.data.ops import dataset_ops 

18from tensorflow.python.data.ops.options import ExternalStatePolicy 

19from tensorflow.python.data.util import nest 

20from tensorflow.python.framework import constant_op 

21from tensorflow.python.framework import dtypes 

22from tensorflow.python.framework import ops 

23from tensorflow.python.framework import tensor_shape 

24from tensorflow.python.framework import tensor_util 

25from tensorflow.python.ops import array_ops 

26from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 

27from tensorflow.python.types import data as data_types 

28from tensorflow.python.util.tf_export import tf_export 

29 

30SHARD_HINT = -1 

31tf_export("data.experimental.SHARD_HINT").export_constant( 

32 __name__, "SHARD_HINT") 

33 

34 

35class _AutoShardDataset(dataset_ops.UnaryDataset): 

36 """A `Dataset` that shards the `Dataset` automatically. 

37 

38 This dataset takes in an existing dataset and tries to automatically figure 

39 out how to shard the dataset in a multi-worker scenario using graph rewrites. 

40 

41 If the AutoShardPolicy is set to FILE, it walks up the dataset graph until 

42 it finds a reader dataset, then inserts a ShardDataset op before that node 

43 so that each worker only sees some files. 

44 

45 If the AutoShardPolicy is set to DATA, it inserts a ShardDataset op at the 

46 end of the input pipeline, before any terminal PrefetchDataset if there is 

47 one. Additionally, if there is a RebatchDatasetV2 in the input pipeline, it 

48 is written to legacy RebatchDataset for correctness reasons, since 

49 RebatchDatasetV2 is incompatible with data sharding. 

50 

51 If the AutoShardPolicy is set to AUTO, it tries to do file-based sharding. 

52 If it cannot find a reader dataset, it falls back to doing data-based 

53 sharding. 

54 

55 If the AutoShardPolicy is set to OFF, it does nothing. 

56 

57 Attributes: 

58 num_workers: Total number of workers to shard this dataset across. 

59 index: The current worker index (out of the total number of workers) this 

60 dataset is for. 

61 num_replicas: The total number of replicas across all workers. This is used 

62 only when sharding by data (either DATA or AUTO) in order to rewrite 

63 RebatchDatasetV2 to RebatchDataset. 

64 

65 Raises: 

66 NotFoundError: If we cannot find a suitable reader dataset to begin 

67 automatically sharding the dataset. 

68 """ 

69 

70 def __init__(self, input_dataset, num_workers, index, num_replicas=None): 

71 self._input_dataset = input_dataset 

72 

73 self._element_spec = input_dataset.element_spec 

74 variant_tensor = ged_ops.auto_shard_dataset( 

75 self._input_dataset._variant_tensor, # pylint: disable=protected-access 

76 num_workers=num_workers, 

77 index=index, 

78 auto_shard_policy=int( 

79 input_dataset.options().experimental_distribute.auto_shard_policy), 

80 num_replicas=num_replicas, 

81 **self._flat_structure) 

82 super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor) 

83 

84 @property 

85 def element_spec(self): 

86 return self._element_spec 

87 

88 

89def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None): # pylint: disable=invalid-name 

90 return dataset_ops.DatasetV1Adapter( 

91 _AutoShardDataset(input_dataset, num_workers, index, num_replicas)) 

92 

93 

94class _LegacyRebatchDataset(dataset_ops.UnaryDataset): 

95 """A `Dataset` that divides its input batches into `num_replicas` sub-batches. 

96 

97 For each batch in the input dataset, _LegacyRebatchDataset will produce 

98 `num_replicas` smaller batches whose sizes add up to the original batch size. 

99 

100 For example: 

101 

102 ```python 

103 ds = tf.data.Dataset.range(8) 

104 ds = ds.batch(4) 

105 ds = _LegacyRebatchDataset(ds, num_replicas=3) 

106 for elem in ds: 

107 print(elem) 

108 >> [0, 1], [2, 3], [], [4, 5], [6, 7], [] 

109 ``` 

110 """ 

111 

112 def __init__(self, input_dataset, num_replicas): 

113 """Creates a _LegacyRebatchDataset. 

114 

115 Args: 

116 input_dataset: `Dataset` to rebatch. 

117 num_replicas: A `tf.int64` scalar, representing the number of sub-batches 

118 to split each batch from `input_dataset` into. 

119 """ 

120 

121 def recalculate_batch_size(type_spec): 

122 """Recalculates the output_shape after dividing it by num_replicas.""" 

123 output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access 

124 if not isinstance(output_shape, tensor_shape.TensorShape): 

125 return None 

126 

127 # If the output shape is unknown, we set the batch dimension to unknown. 

128 if output_shape.rank is None: 

129 return None 

130 

131 if len(output_shape) < 1: 

132 raise ValueError( 

133 "Invalid `input_dataset`. Expected a dataset whose elements " 

134 "have rank >= 1 but found a dataset whose elements are scalars. " 

135 "Fix the issue by adding the `batch` transformation to the " 

136 "dataset.") 

137 output_dims = [d.value for d in output_shape.dims] 

138 

139 if output_dims[0] is not None and output_dims[0] % num_replicas == 0: 

140 return output_dims[0] // num_replicas 

141 

142 # Set the batch dimension to unknown. If the global batch size does not 

143 # divide num_replicas evenly, the minibatches may have different sizes. 

144 return None 

145 

146 def rebatch(type_spec): 

147 # pylint: disable=protected-access 

148 batch_size = recalculate_batch_size(type_spec) 

149 return type_spec._unbatch()._batch(batch_size) 

150 # pylint: enable=protected-access 

151 

152 self._element_spec = nest.map_structure( 

153 rebatch, dataset_ops.get_structure(input_dataset)) 

154 

155 # auto_shard rewrite assumes that there's normalize_to_dense before 

156 # rebatch_dataset. 

157 # LINT.IfChange 

158 input_dataset = dataset_ops.normalize_to_dense(input_dataset) 

159 variant_tensor = ged_ops.rebatch_dataset( 

160 input_dataset._variant_tensor, # pylint: disable=protected-access 

161 num_replicas=num_replicas, 

162 **self._flat_structure) 

163 # LINT.ThenChange(//tensorflow/core/grappler/optimizers/data/auto_shard.cc) 

164 super(_LegacyRebatchDataset, self).__init__(input_dataset, variant_tensor) 

165 

166 @property 

167 def element_spec(self): 

168 return self._element_spec 

169 

170 

171class _RemoteDataset(dataset_ops.DatasetSource): 

172 """Creates a dataset on a given `device` given a graph def.""" 

173 

174 def __init__(self, graph_def, device, element_spec): 

175 self._elem_spec = element_spec 

176 with ops.device(device): 

177 variant_tensor = ged_ops.dataset_from_graph(graph_def) 

178 super(_RemoteDataset, self).__init__(variant_tensor) 

179 

180 @property 

181 def element_spec(self): 

182 return self._elem_spec 

183 

184 

185def replicate(dataset, devices): 

186 """A transformation that replicates `dataset` onto a list of devices. 

187 

188 Args: 

189 dataset: A `tf.data.Dataset` object. 

190 devices: A list of devices to replicate the dataset on. 

191 

192 Returns: 

193 A dictionary mapping device name to a dataset on that device. 

194 """ 

195 if not isinstance(dataset, data_types.DatasetV2): 

196 raise TypeError( 

197 f"Invalid `dataset`. Expected a `tf.data.Dataset` object but " 

198 f"got {type(dataset)}.") 

199 

200 # pylint: disable=protected-access 

201 dataset_device = dataset._variant_tensor.device 

202 

203 datasets = {} 

204 if len(devices) == 1 and devices[0] == dataset_device: 

205 datasets[devices[0]] = dataset 

206 return datasets 

207 

208 with ops.colocate_with(dataset._variant_tensor): 

209 dataset = dataset._apply_debug_options() 

210 graph_def = dataset._as_serialized_graph( 

211 strip_device_assignment=True, 

212 external_state_policy=ExternalStatePolicy.WARN) 

213 for device in devices: 

214 ds = _RemoteDataset(graph_def, device, dataset.element_spec) 

215 datasets[device] = ds 

216 return datasets 

217 

218 

219def batch_sizes_for_worker(global_batch_size, num_workers, 

220 num_replicas_per_worker, worker_index): 

221 """Determines how to rebatch a dataset for the given worker. 

222 

223 Given the global batch size, number of workers, number of replicas per worker, 

224 and worker index, returns the correct batch sizes for rebatching a dataset 

225 on worker `worker_index` of `num_workers`, such that each global step (across 

226 all workers and replicas) will consume global_batch_size elements. The 

227 returned value should be passed as the `batch_sizes` input parameter to 

228 `tf.data.experimental.rebatch()`. The returned batch sizes meet the following 

229 constraints: 

230 

231 Let G = global_batch_size, W = num_workers, R = num_replicas_per_worker 

232 (A) for any worker, len(batch_sizes) = W * R 

233 (B) for any worker, sum(batch_sizes) == G 

234 (C) for any global step (i.e. R iterations on each worker), the sum of batches 

235 consumed by replicas across all workers is G. 

236 (D) any two batch sizes of any two replicas differs by at most one. 

237 

238 For example, suppose we have G = 7, W = 2, R = 2, and suppose we have two 

239 files which each contain 7 elements: 

240 

241 ```python 

242 # WORKER 0 

243 batch_sizes_0 = batch_sizes_for_worker(global_batch_size=global_batch_size, 

244 num_workers=2, 

245 num_replicas_per_worker=2, 

246 worker_index=0) 

247 print(batch_sizes_0) 

248 >> [2, 2, 2, 1] 

249 

250 dataset_0 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"]) 

251 dataset_0 = dataset_0.shard(num_shards, index=0) 

252 dataset_0 = dataset_0.batch(7) 

253 dataset_0 = dataset_0.apply(tf.data.experimental.rebatch(batch_sizes_0)) 

254 for elem in dataset_0: 

255 print(elem) 

256 >> [[A0, A1], [A2, A3], [A4, A5], [A6]] 

257 

258 # WORKER 1 

259 batch_sizes_1 = batch_sizes_for_worker(global_batch_size=global_batch_size, 

260 num_workers=2, 

261 num_replicas_per_worker=2, 

262 worker_index=1) 

263 print(batch_sizes_1) 

264 >> [2, 1, 2, 2] 

265 

266 dataset_1 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"]) 

267 dataset_1 = dataset_1.shard(num_shards, index=1) 

268 dataset_1 = dataset_1.batch(7) 

269 dataset_1 = dataset_1.apply(tf.data.experimental.rebatch(batch_sizes_1)) 

270 for elem in dataset_1: 

271 print(elem) 

272 >> [[B0, B1], [B2], [B3, B4], [B5, B6]] 

273 ``` 

274 

275 The above example will produce the following elements: 

276 

277 Step 1: 

278 Worker 0 Replica 0: [A0, A1] 

279 Worker 0 Replica 1: [A2, A3] 

280 Worker 1 Replica 0: [B0, B1] 

281 Worker 1 Replica 1: [B2] 

282 Total batch size = 7 

283 

284 Step 2: 

285 Worker 0 Replica 0: [A4, A5] 

286 Worker 0 Replica 1: [A6] 

287 Worker 1 Replica 0: [B3, B4] 

288 Worker 1 Replica 1: [B5, B6] 

289 Total batch size = 7 

290 

291 Args: 

292 global_batch_size: A `tf.int64` scalar, representing the global batch size. 

293 num_workers: An integer representing the number of workers the dataset will 

294 be distributed across. 

295 num_replicas_per_worker: An integer representing the number of replicas per 

296 worker. All workers are assumed to have the same number of replicas. 

297 worker_index: An integer index of the worker to be rebatched. 

298 

299 Returns: 

300 A `tf.int64` vector, representing the batch sizes to rebatch the dataset 

301 into. 

302 """ 

303 # Constraint (A) 

304 num_subbatches = num_workers * num_replicas_per_worker 

305 

306 offset = worker_index * num_replicas_per_worker 

307 

308 const_value = tensor_util.constant_value(global_batch_size) 

309 if const_value is not None: 

310 # Use the constant global batch size for further calculations 

311 global_batch_size = const_value 

312 

313 # Let N = W * R. Constraint (B) and (D) jointly mean that the iterations 

314 # should have batch size either floor(B/N) or ceil(B/N). Namely, of the N 

315 # subbatches a batch is split into, B - N * floor(B/N) of them will have size 

316 # ceil(B/N), and the rest will have size floor(B/N). 

317 floor = global_batch_size // num_subbatches 

318 num_ceil = global_batch_size - (num_subbatches * floor) 

319 

320 # For worker 0, we assign the first num_ceil subbatches to have size 

321 # ceil(B/N), and the remainder to have size floor(B/N). The other workers will 

322 # each be offset by R * worker_index in order to meet constraint (C). 

323 if const_value is not None: 

324 # If the global batch size is a known constant value, we return a constant 

325 # tensor directly instead of manipulating it with TF ops. This allows for 

326 # better downstream shape inference. 

327 worker_0 = [floor + 1] * num_ceil + [floor] * (num_subbatches - num_ceil) 

328 return ops.convert_to_tensor( 

329 worker_0[offset:] + worker_0[:offset], 

330 dtype=dtypes.int64, 

331 name="batch_sizes") 

332 

333 worker_0 = array_ops.ones(num_subbatches, dtype=dtypes.int64) 

334 worker_0 = floor * worker_0 + array_ops.concat([ 

335 array_ops.ones(num_ceil, dtype=dtypes.int64), 

336 array_ops.zeros(num_subbatches - num_ceil, dtype=dtypes.int64) 

337 ], 

338 axis=0) 

339 

340 return array_ops.concat([worker_0[offset:], worker_0[:offset]], axis=0) 

341 

342 

343def compute_batch_size(dataset): 

344 """An operation that returns the batch size of the dataset. 

345 

346 This op tries to infer the batch size statically by walking up the dataset 

347 tree from the final dataset node and returning the batch size of the first 

348 batching dataset (such as from .batch() and .padded_batch()) that it 

349 encounters. This differs from using the `element_spec` of a dataset in that it 

350 does not account for partial batches. 

351 

352 This operation may fail if it encounters contradictory batch sizes (for 

353 example, if the dataset is created by zipping together two datasets with 

354 different batch sizes), if there are no explicit batching transformations, or 

355 if there are operations downstream from the batching transformation that may 

356 modify its batch size. In these cases, it returns a -1. 

357 

358 Args: 

359 dataset: A `tf.data.Dataset` object. 

360 

361 Returns: 

362 A `tf.int64` Tensor representing the batch size of the dataset sans partial 

363 batches. If this cannot be inferred statically, the value of this tensor 

364 will be -1. 

365 """ 

366 

367 def get_static_batch_dim(type_spec): 

368 try: 

369 output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access 

370 except NotImplementedError: 

371 return None 

372 if not isinstance(output_shape, tensor_shape.TensorShape): 

373 return None 

374 if output_shape.rank is None: 

375 return None 

376 return output_shape.dims[0].value 

377 

378 batch_dims = [ 

379 get_static_batch_dim(type_spec) 

380 for type_spec in nest.flatten(dataset_ops.get_structure(dataset)) 

381 ] 

382 

383 if all(d is not None for d in batch_dims): 

384 

385 if all(d == batch_dims[0] for d in batch_dims): 

386 # If all batch dimensions are known and equal, return that directly. 

387 batch_dim = batch_dims[0] 

388 else: 

389 # If all batch dimensions are known but not all equal, return -1. 

390 batch_dim = -1 

391 

392 return constant_op.constant( 

393 batch_dim, dtype=dtypes.int64, name="static_batch_size") 

394 

395 # If any batch dimensions are unknown, use compute_batch_size op. 

396 return ged_ops.compute_batch_size(dataset._variant_tensor) # pylint: disable=protected-access 

397 

398 

399_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__