Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/v1/input_lib.py: 36%

166 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Various classes representing distributed inputs.""" 

16 

17from tensorflow.python.data.experimental.ops import cardinality as cardinality_lib 

18from tensorflow.python.data.ops import dataset_ops 

19from tensorflow.python.data.ops import multi_device_iterator_ops 

20from tensorflow.python.data.ops import optional_ops 

21from tensorflow.python.distribute import input_lib 

22from tensorflow.python.eager import context 

23from tensorflow.python.framework import ops 

24from tensorflow.python.ops import control_flow_ops 

25from tensorflow.python.types import data as data_types 

26from tensorflow.python.util.deprecation import deprecated 

27 

28 

29class DistributedDatasetV1(input_lib.DistributedDataset): 

30 """Distributed dataset that supports prefetching to multiple devices.""" 

31 

32 def __init__(self, 

33 dataset, 

34 input_workers, 

35 strategy, 

36 num_replicas_in_sync=None, 

37 input_context=None, 

38 options=None): 

39 self._input_workers = input_workers 

40 super(DistributedDatasetV1, self).__init__( 

41 input_workers, 

42 strategy, 

43 dataset, 

44 num_replicas_in_sync=num_replicas_in_sync, 

45 input_context=input_context, 

46 options=options) 

47 

48 def make_one_shot_iterator(self): 

49 """Get a one time use iterator for DistributedDatasetV1. 

50 

51 Note: This API is deprecated. Please use `for ... in dataset:` to iterate 

52 over the dataset or `iter` to create an iterator. 

53 

54 Returns: 

55 A DistributedIteratorV1 instance. 

56 """ 

57 return self._make_one_shot_iterator() 

58 

59 def _make_one_shot_iterator(self): 

60 """Get an iterator for DistributedDatasetV1.""" 

61 # Graph mode with one shot iterator is disabled because we have to call 

62 # `initialize` on the iterator which is only required if we are using a 

63 # tf.distribute strategy. 

64 if not context.executing_eagerly(): 

65 raise ValueError("Cannot create a one shot iterator. Please use " 

66 "`make_initializable_iterator()` instead.") 

67 return self._get_iterator() 

68 

69 def make_initializable_iterator(self): 

70 """Get an initializable iterator for DistributedDatasetV1. 

71 

72 Note: This API is deprecated. Please use 

73 `tf.compat.v1.data.make_initializable_iterator(dataset)` to create an 

74 initializable iterator. 

75 

76 Returns: 

77 A DistributedIteratorV1 instance. 

78 """ 

79 return self._make_initializable_iterator() 

80 

81 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=unused-argument 

82 """Get an initializable iterator for DistributedDatasetV1.""" 

83 # Eager mode generates already initialized iterators. Hence we cannot create 

84 # an initializable iterator. 

85 if context.executing_eagerly(): 

86 raise ValueError("Cannot create initializable iterator in Eager mode. " 

87 "Please use `iter()` instead.") 

88 return self._get_iterator() 

89 

90 def _get_iterator(self): 

91 worker_iterators = _create_iterators_per_worker(self._cloned_datasets, 

92 self._input_workers, 

93 self._options) 

94 cardinality = input_lib._cardinality(self._cloned_datasets[0]) # pylint: disable=protected-access 

95 iterator = DistributedIteratorV1(self._input_workers, worker_iterators, 

96 self._strategy, cardinality, 

97 self._enable_get_next_as_optional) 

98 iterator._element_spec = self.element_spec # pylint: disable=protected-access 

99 

100 # When async eager is enabled, sometimes the iterator may not finish 

101 # initialization before passing to a multi device function, add a sync point 

102 # here to make sure all underlying iterators are initialized. 

103 if context.executing_eagerly(): 

104 context.async_wait() 

105 

106 return iterator 

107 

108 # pylint: disable=non-iterator-returned 

109 def __iter__(self): 

110 if (ops.executing_eagerly_outside_functions() or 

111 ops.get_default_graph().building_function): 

112 return self._get_iterator() 

113 

114 raise RuntimeError("__iter__() is only supported inside of tf.function " 

115 "or when eager execution is enabled.") 

116 

117 # pylint: enable=non-iterator-returned 

118 

119 

120class DistributedDatasetsFromFunctionV1( 

121 input_lib.DistributedDatasetsFromFunction): 

122 """Inputs created from dataset function.""" 

123 

124 def _make_initializable_iterator(self, shared_name=None): 

125 """Get an initializable iterator for DistributedDatasetsFromFunctionV1.""" 

126 del shared_name # Unused 

127 # Eager mode generates already initialized iterators. Hence we cannot create 

128 # an initializable iterator. 

129 if context.executing_eagerly(): 

130 raise ValueError("Cannot create initializable iterator in Eager mode. " 

131 "Please use `iter()` instead.") 

132 return self._get_iterator() 

133 

134 def _make_one_shot_iterator(self): 

135 """Get an iterator for iterating over DistributedDatasetsFromFunctionV1.""" 

136 # Graph mode with one shot iterator is disabled because we have to call 

137 # `initialize` on the iterator which is only required if we are using a 

138 # tf.distribute strategy. 

139 if not context.executing_eagerly(): 

140 raise ValueError("Cannot create a one shot iterator. Please use " 

141 "`make_initializable_iterator()` instead.") 

142 return self._get_iterator() 

143 

144 def _get_iterator(self): 

145 iterators = _create_iterators_per_worker(self._datasets, 

146 self._input_workers, self._options) 

147 cardinality = input_lib._cardinality(self._datasets[0]) # pylint: disable=protected-access 

148 iterator = DistributedIteratorV1(self._input_workers, iterators, 

149 self._strategy, cardinality, 

150 self._enable_get_next_as_optional) 

151 iterator._element_spec = self._element_spec # pylint: disable=protected-access 

152 

153 # When async eager is enabled, sometimes the iterator may not finish 

154 # initialization before passing to a multi device function, add a sync point 

155 # here to make sure all underlying iterators are initialized. 

156 if context.executing_eagerly(): 

157 context.async_wait() 

158 

159 return iterator 

160 

161 # pylint: disable=non-iterator-returned 

162 def __iter__(self): 

163 if (ops.executing_eagerly_outside_functions() or 

164 ops.get_default_graph().building_function): 

165 return self._get_iterator() 

166 

167 raise RuntimeError("__iter__() is only supported inside of tf.function " 

168 "or when eager execution is enabled.") 

169 

170 # pylint: enable=non-iterator-returned 

171 

172 

173class DistributedIteratorV1(input_lib.DistributedIteratorBase): 

174 """Input Iterator for a distributed dataset.""" 

175 

176 # We need a private initializer method for re-initializing multidevice 

177 # iterators when used with Keras training loops. If we don't reinitialize the 

178 # iterator we run into memory leak issues (b/123315763). 

179 @property 

180 def _initializer(self): 

181 init_ops = [] 

182 for it in self._iterators: 

183 init_ops.extend(it.initialize()) 

184 return control_flow_ops.group(init_ops) 

185 

186 @deprecated(None, "Use the iterator's `initializer` property instead.") 

187 def initialize(self): 

188 """Initialize underlying iterators. 

189 

190 Returns: 

191 A list of any initializer ops that should be run. 

192 """ 

193 return self._initializer 

194 

195 @property 

196 def initializer(self): 

197 """Returns a list of ops that initialize the iterator.""" 

198 return self.initialize() 

199 

200 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 

201 @property 

202 def output_classes(self): 

203 return self._iterators[0].output_classes 

204 

205 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 

206 @property 

207 def output_shapes(self): 

208 return self._iterators[0].output_shapes 

209 

210 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 

211 @property 

212 def output_types(self): 

213 return self._iterators[0].output_types 

214 

215 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 

216 def get_iterator(self, worker): 

217 for i, w in enumerate(self._input_workers.worker_devices): 

218 if worker == w: 

219 return self._iterators[i] 

220 return None 

221 

222 @property 

223 def element_spec(self): 

224 """The type specification of an element of this iterator.""" 

225 return self._element_spec 

226 

227 

228class DatasetIterator(DistributedIteratorV1): 

229 """Iterator created from input dataset.""" 

230 

231 def __init__(self, 

232 dataset, 

233 input_workers, 

234 strategy, 

235 num_replicas_in_sync=None, 

236 input_context=None): 

237 """Make an iterator for the dataset on given devices. 

238 

239 If `num_replicas_in_sync` is not None, we split each batch of the dataset 

240 into `num_replicas_in_sync` smaller batches, to be distributed among that 

241 worker's replicas, so that the batch size for a global step (across all 

242 workers and replicas) is as expected. 

243 

244 Args: 

245 dataset: `tf.data.Dataset` that will be used as the input source. 

246 input_workers: an `InputWorkers` object. 

247 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 

248 handle last partial batch. 

249 num_replicas_in_sync: Optional integer. If this is not None, the value is 

250 used to decide how to rebatch datasets into smaller batches so that the 

251 total batch size for each step (across all workers and replicas) adds up 

252 to `dataset`'s batch size. 

253 input_context: `InputContext` for sharding. Only pass this in for between 

254 graph multi-worker cases where there is only one `input_worker`. In 

255 these cases, we will shard based on the `input_pipeline_id` and 

256 `num_input_pipelines` in the `InputContext`. 

257 """ 

258 dist_dataset = DistributedDatasetV1( 

259 dataset, 

260 input_workers, 

261 strategy, 

262 num_replicas_in_sync=num_replicas_in_sync, 

263 input_context=input_context) 

264 # pylint: disable=protected-access 

265 worker_iterators = _create_iterators_per_worker( 

266 dist_dataset._cloned_datasets, input_workers) 

267 super(DatasetIterator, 

268 self).__init__(input_workers, worker_iterators, strategy, 

269 dist_dataset.cardinality, 

270 dist_dataset._enable_get_next_as_optional) 

271 self._element_spec = dist_dataset.element_spec 

272 # pylint: enable=protected-access 

273 

274 

275class InputFunctionIterator(DistributedIteratorV1): 

276 """Iterator created from input function.""" 

277 

278 def __init__(self, input_fn, input_workers, input_contexts, strategy): 

279 """Make an iterator for input provided via an input function. 

280 

281 Currently implements PER_WORKER mode, in which the `input_fn` is called 

282 once on each worker. 

283 

284 TODO(priyag): Add other replication modes. 

285 

286 Args: 

287 input_fn: Input function that returns a `tf.data.Dataset` object. 

288 input_workers: an `InputWorkers` object. 

289 input_contexts: A list of `InputContext` instances to be passed to call(s) 

290 to `input_fn`. Length and order should match worker order in 

291 `worker_device_pairs`. 

292 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 

293 handle last partial batch. 

294 """ 

295 assert isinstance(input_workers, input_lib.InputWorkers) 

296 if input_workers.num_workers != len(input_contexts): 

297 raise ValueError("Number of input workers (%d) is not same as number of " 

298 "input_contexts (%d)" % 

299 (input_workers.num_workers, len(input_contexts))) 

300 

301 iterators = [] 

302 for i, ctx in enumerate(input_contexts): 

303 worker = input_workers.worker_devices[i] 

304 with ops.device(worker): 

305 result = input_fn(ctx) 

306 devices = input_workers.compute_devices_for_worker(i) 

307 if isinstance(result, data_types.DatasetV2): 

308 iterator = _SingleWorkerDatasetIterator(result, worker, devices) 

309 elif callable(result): 

310 iterator = _SingleWorkerCallableIterator(result, worker, devices) 

311 else: 

312 raise ValueError( 

313 "input_fn must return a tf.data.Dataset or a callable.") 

314 iterators.append(iterator) 

315 

316 super(InputFunctionIterator, self).__init__( 

317 input_workers, 

318 iterators, 

319 strategy, 

320 cardinality=cardinality_lib.UNKNOWN, 

321 enable_get_next_as_optional=False) 

322 self._enable_get_next_as_optional = False 

323 

324 

325class _SingleWorkerDatasetIterator(input_lib._SingleWorkerDatasetIteratorBase): # pylint: disable=protected-access 

326 """Iterator for a single DistributedDatasetV1 instance.""" 

327 

328 def _make_iterator(self): 

329 """Make appropriate iterator on the dataset.""" 

330 with ops.device(self._worker): 

331 if self._options is not None: 

332 self._iterator = multi_device_iterator_ops.MultiDeviceIterator( 

333 self._dataset, 

334 self._devices, 

335 max_buffer_size=self._options.experimental_per_replica_buffer_size, 

336 prefetch_buffer_size=self._options 

337 .experimental_per_replica_buffer_size) 

338 else: 

339 self._iterator = multi_device_iterator_ops.MultiDeviceIterator( 

340 self._dataset, 

341 self._devices, 

342 ) 

343 

344 def initialize(self): 

345 """Initialize underlying iterator. 

346 

347 In eager execution, this simply recreates the underlying iterator. 

348 In graph execution, it returns the initializer ops for the underlying 

349 iterator. 

350 

351 Returns: 

352 A list of any initializer ops that should be run. 

353 """ 

354 if ops.executing_eagerly_outside_functions(): 

355 self._iterator._eager_reset() # pylint: disable=protected-access 

356 return [] 

357 else: 

358 return [self._iterator.initializer] 

359 

360 @property 

361 def output_classes(self): 

362 return dataset_ops.get_legacy_output_classes(self._iterator) 

363 

364 @property 

365 def output_shapes(self): 

366 return dataset_ops.get_legacy_output_shapes(self._iterator) 

367 

368 @property 

369 def output_types(self): 

370 return dataset_ops.get_legacy_output_types(self._iterator) 

371 

372 

373class _SingleWorkerCallableIterator(object): 

374 """Iterator for a single tensor-returning callable.""" 

375 

376 def __init__(self, fn, worker, devices): 

377 self._fn = fn 

378 self._worker = worker 

379 self._devices = devices 

380 

381 def get_next(self, device, name=None): 

382 """Get next element for the given device from the callable.""" 

383 del device, name 

384 with ops.device(self._worker): 

385 return self._fn() 

386 

387 def get_next_as_list(self, name=None): 

388 """Get next element from the callable.""" 

389 del name 

390 with ops.device(self._worker): 

391 data_list = [self._fn() for _ in self._devices] 

392 return data_list 

393 

394 def get_next_as_optional_list(self): 

395 with ops.device(self._worker): 

396 data_list = [ 

397 optional_ops.Optional.from_value(self._fn()) for _ in self._devices 

398 ] 

399 return data_list 

400 

401 def initialize(self): 

402 # TODO(petebu) Should this throw an exception instead? 

403 return [] 

404 

405 

406def _create_iterators_per_worker(worker_datasets, input_workers, options=None): 

407 """Create a multidevice iterator on each of the workers.""" 

408 assert isinstance(input_workers, input_lib.InputWorkers) 

409 assert len(worker_datasets) == len(input_workers.worker_devices) 

410 iterators = [] 

411 for i, worker in enumerate(input_workers.worker_devices): 

412 with ops.device(worker): 

413 worker_devices = input_workers.compute_devices_for_worker(i) 

414 iterator = _SingleWorkerDatasetIterator( 

415 worker_datasets[i], # pylint: disable=protected-access 

416 worker, 

417 worker_devices, 

418 options) 

419 iterators.append(iterator) 

420 return iterators