Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_distributed_v1.py: 12%

317 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Part of the Keras training engine related to distributed training.""" 

16# pylint: disable=protected-access 

17 

18import numpy as np 

19from tensorflow.python.distribute import distribute_lib 

20from tensorflow.python.distribute import input_lib 

21from tensorflow.python.distribute import reduce_util as ds_reduce_util 

22from tensorflow.python.eager import context 

23from tensorflow.python.framework import constant_op 

24from tensorflow.python.framework import errors 

25from tensorflow.python.framework import ops 

26from tensorflow.python.keras import backend 

27from tensorflow.python.keras import callbacks as cbks 

28from tensorflow.python.keras.distribute import distribute_coordinator_utils as dc 

29from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils 

30from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util 

31from tensorflow.python.keras.engine import training_arrays_v1 

32from tensorflow.python.keras.engine import training_utils_v1 

33from tensorflow.python.keras.utils.generic_utils import Progbar 

34from tensorflow.python.keras.utils.mode_keys import ModeKeys 

35from tensorflow.python.ops import array_ops 

36from tensorflow.python.ops import control_flow_ops 

37from tensorflow.python.platform import tf_logging as logging 

38 

39 

40def _per_replica_execution_function(model, mode): 

41 exec_func = model._make_execution_function(mode) 

42 return (exec_func.inputs, exec_func.outputs, exec_func.updates_op, 

43 exec_func.session_kwargs) 

44 

45 

46def _build_model(strategy, model, mode, inputs, targets=None): 

47 if model._compile_distribution: 

48 dist_utils.clone_model_on_replicas( 

49 model, strategy, mode, inputs=inputs, targets=targets) 

50 else: 

51 dist_utils._build_distributed_network(model, strategy, mode, inputs, 

52 targets) 

53 

54 

55def _make_train_step_fn(model, mode, strategy, output_labels): 

56 """Create step fn. 

57 

58 Args: 

59 model: a Keras Model instance. 

60 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 

61 strategy: a `tf.distribute.Strategy` instance. 

62 output_labels: the output labels for the step function. 

63 

64 Returns: 

65 A step function to run by `tf.distribute.Strategy`. 

66 """ 

67 

68 def _step_fn(ctx, inputs): 

69 """A step fn that returns update ops.""" 

70 if isinstance(inputs, (tuple, list)) and len(inputs) == 2: 

71 inputs, targets = inputs 

72 else: 

73 targets = None 

74 

75 # When input feature is a dictionary of tensors, dictionary is flattended 

76 # to an array and passed as a model input. This results in input mismatch 

77 # when model input layer names are not sorted in alphabetical order as 

78 # `nest.flatten()`sorts dictionary elements by keys. As so, transform input 

79 # tensors into an array and order it along `model._feed_input_names`. 

80 if isinstance(inputs, dict): 

81 inputs = [inputs[input_name] for input_name in model._feed_input_names] 

82 

83 _build_model(strategy, model, mode, inputs, targets) 

84 

85 (grouped_inputs, grouped_outputs, grouped_updates, 

86 grouped_session_args) = strategy.extended.call_for_each_replica( 

87 _per_replica_execution_function, 

88 args=(dist_utils.get_distributed_model(model, mode), mode)) 

89 (all_inputs, all_outputs, all_updates, 

90 all_session_args) = dist_utils.unwrap_values(strategy, grouped_inputs, 

91 grouped_outputs, 

92 grouped_updates, 

93 grouped_session_args) 

94 combined_fn = backend.function( 

95 all_inputs, 

96 all_outputs, 

97 updates=all_updates, 

98 name='distributed_' + str(mode) + '_function', 

99 **all_session_args) 

100 

101 for label, output in zip(output_labels, combined_fn.outputs): 

102 if label == 'loss': 

103 reduce_op = ds_reduce_util.ReduceOp.SUM 

104 else: 

105 # We reduce all other metrics using mean for now. This is temporary 

106 # workaround until new metrics are in place. 

107 reduce_op = ds_reduce_util.ReduceOp.MEAN 

108 ctx.set_last_step_output(label, output, reduce_op) 

109 

110 # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: 

111 # feed_dict, session kwargs, run options, run_metadata for now. These should 

112 # be handled appropriately 

113 return combined_fn.updates_op 

114 

115 return _step_fn 

116 

117 

118def experimental_tpu_fit_loop(model, 

119 dataset, 

120 epochs=100, 

121 verbose=1, 

122 callbacks=None, 

123 initial_epoch=0, 

124 steps_per_epoch=None, 

125 val_dataset=None, 

126 validation_steps=None, 

127 validation_freq=1): 

128 """Fit loop for training with TPU tf.distribute.Strategy. 

129 

130 Args: 

131 model: Keras Model instance. 

132 dataset: Dataset that returns inputs and targets 

133 epochs: Number of times to iterate over the data 

134 verbose: Integer, Verbosity mode, 0, 1 or 2 

135 callbacks: List of callbacks to be called during training 

136 initial_epoch: Epoch at which to start training 

137 (useful for resuming a previous training run) 

138 steps_per_epoch: Total number of steps (batches of samples) 

139 before declaring one epoch finished and starting the 

140 next epoch. Ignored with the default value of `None`. 

141 val_dataset: Dataset for validation data. 

142 validation_steps: Number of steps to run validation for 

143 (only if doing validation from data tensors). 

144 Ignored with the default value of `None`. 

145 validation_freq: Only relevant if validation data is provided. Integer or 

146 `collections.abc.Container` instance (e.g. list, tuple, etc.). If an 

147 integer, specifies how many training epochs to run before a new 

148 validation run is performed, e.g. `validation_freq=2` runs 

149 validation every 2 epochs. If a Container, specifies the epochs on 

150 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 

151 validation at the end of the 1st, 2nd, and 10th epochs. 

152 

153 Returns: 

154 Returns `None`. 

155 

156 Raises: 

157 ValueError: in case of invalid arguments. 

158 """ 

159 mode = ModeKeys.TRAIN 

160 

161 current_strategy = model._distribution_strategy 

162 iteration_value = min(steps_per_epoch, 

163 current_strategy.extended.steps_per_run) 

164 steps_per_run = backend.variable( 

165 value=iteration_value, 

166 dtype='int32', 

167 name='steps_per_run') 

168 

169 # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops. 

170 iterator = dist_utils.get_iterator(dataset, current_strategy) 

171 

172 scope = dist_utils.distributed_scope( 

173 strategy=current_strategy, learning_phase=1) 

174 scope.__enter__() 

175 

176 out_labels = model.metrics_names or [] 

177 

178 step_fn = _make_train_step_fn(model, ModeKeys.TRAIN, current_strategy, 

179 out_labels) 

180 

181 # Add initial dummy values for loss and other metric tensors. 

182 initial_loop_values = {} 

183 initial_loop_values['loss'] = constant_op.constant(1e7) 

184 for m in model._get_training_eval_metrics(): 

185 tensor = m.result() 

186 initial_loop_values[m.name] = array_ops.zeros(tensor.shape, tensor.dtype) 

187 

188 ctx = current_strategy.extended.experimental_run_steps_on_iterator( 

189 step_fn, iterator, iterations=steps_per_run, 

190 initial_loop_values=initial_loop_values) 

191 train_op = ctx.run_op 

192 output_tensors = ctx.last_step_outputs 

193 

194 do_validation = bool(validation_steps) 

195 

196 if model._compile_distribution: 

197 dist_utils._copy_weights_to_distributed_model(model, mode) 

198 

199 callbacks = cbks.configure_callbacks( 

200 callbacks, 

201 model, 

202 do_validation=do_validation, 

203 epochs=epochs, 

204 steps_per_epoch=steps_per_epoch, 

205 verbose=verbose, 

206 count_mode='steps', 

207 mode=mode) 

208 

209 # Calculate the steps each time on the device. 

210 steps_to_run = ([current_strategy.extended.steps_per_run] * 

211 (steps_per_epoch // 

212 current_strategy.extended.steps_per_run)) 

213 if steps_per_epoch % current_strategy.extended.steps_per_run: 

214 steps_to_run.append( 

215 steps_per_epoch % current_strategy.extended.steps_per_run) 

216 target_steps = len(steps_to_run) 

217 

218 callbacks._call_begin_hook(mode) 

219 

220 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode) 

221 

222 for epoch in range(initial_epoch, epochs): 

223 dist_utils._reset_metrics(model) 

224 callbacks.on_epoch_begin(epoch) 

225 epoch_logs = {} 

226 step_index = 0 

227 prev_step_count = None 

228 current_step = 0 

229 while current_step < target_steps: 

230 step_count = steps_to_run[current_step] 

231 batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count} 

232 callbacks._call_batch_hook(mode, 'begin', step_index, batch_logs) 

233 if prev_step_count is None or step_count != prev_step_count: 

234 backend.get_session().run(steps_per_run.assign(step_count)) 

235 prev_step_count = step_count 

236 try: 

237 _, outputs = backend.batch_get_value([train_op, output_tensors]) 

238 except errors.OutOfRangeError: 

239 logging.warning('Your dataset iterator ran out of data; ' 

240 'interrupting training. Make sure that your dataset ' 

241 'can generate at least `steps_per_epoch * epochs` ' 

242 'batches (in this case, %d batches).' % 

243 steps_per_epoch * epochs) 

244 break 

245 

246 batch_logs.update(outputs) 

247 callbacks._call_batch_hook(mode, 'end', step_index, batch_logs) 

248 step_index = step_index + step_count 

249 current_step += 1 

250 

251 if callbacks.model.stop_training: 

252 break 

253 

254 if (do_validation and 

255 training_utils_v1.should_run_validation(validation_freq, epoch)): 

256 logging.info('Running validation at fit epoch: %s', epoch) 

257 

258 if model._compile_distribution: 

259 # Since we create a new clone from the original model we need to copy 

260 # the weights back to the original model before we can run validation. 

261 dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN) 

262 

263 val_outs = experimental_tpu_test_loop( # pylint: disable=undefined-variable 

264 model, 

265 val_dataset, 

266 steps=validation_steps, 

267 verbose=verbose, 

268 callbacks=callbacks) 

269 if not isinstance(val_outs, list): 

270 val_outs = [val_outs] 

271 # Same labels assumed. 

272 for label, val_out in zip(out_labels, val_outs): 

273 epoch_logs['val_' + label] = val_out 

274 

275 callbacks.on_epoch_end(epoch, epoch_logs) 

276 if callbacks.model.stop_training: 

277 break 

278 model._successful_loop_finish = True 

279 callbacks._call_end_hook(mode) 

280 

281 if model._compile_distribution: 

282 # Copy the weights back from the replicated model to the original model. 

283 dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN) 

284 scope.__exit__(None, None, None) 

285 return model.history 

286 

287 

288def experimental_tpu_test_loop(model, 

289 dataset, 

290 verbose=0, 

291 steps=None, 

292 callbacks=None): 

293 """Test loop for evaluating with TPU tf.distribute.Strategy. 

294 

295 Args: 

296 model: Keras Model instance. 

297 dataset: Dataset for input data. 

298 verbose: Integer, Verbosity mode 0 or 1. 

299 steps: Total number of steps (batches of samples) 

300 before declaring predictions finished. 

301 Ignored with the default value of `None`. 

302 callbacks: List of callbacks to be called during training 

303 

304 Returns: 

305 Scalar loss (if the model has a single output and no metrics) 

306 or list of scalars (if the model has multiple outputs 

307 and/or metrics). The attribute `model.metrics_names` will give you 

308 the display labels for the outputs. 

309 """ 

310 mode = ModeKeys.TEST 

311 current_strategy = model._distribution_strategy 

312 iterator = dist_utils.get_iterator(dataset, current_strategy) 

313 

314 scope = dist_utils.distributed_scope( 

315 strategy=current_strategy, learning_phase=0) 

316 scope.__enter__() 

317 

318 out_labels = model.metrics_names 

319 

320 def _test_step_fn(inputs): 

321 """A fn that returns output of single test step.""" 

322 if isinstance(inputs, (tuple, list)) and len(inputs) == 2: 

323 inputs, targets = inputs 

324 else: 

325 targets = None 

326 

327 (distribute_lib.get_replica_context().merge_call( 

328 _build_model, args=(model, mode, inputs, targets))) 

329 

330 (_, outputs, updates, _) = _per_replica_execution_function( 

331 dist_utils.get_distributed_model(model, mode), mode) 

332 with ops.control_dependencies([updates]): 

333 return [array_ops.identity(out) for out in outputs] 

334 

335 test_input_data = iterator.get_next() 

336 per_replica_outputs = current_strategy.run( 

337 _test_step_fn, args=(test_input_data,)) 

338 output_tensors = {} 

339 for label, output in zip(out_labels, per_replica_outputs): 

340 if label == 'loss': 

341 reduce_op = ds_reduce_util.ReduceOp.SUM 

342 else: 

343 # We reduce all other metrics using mean for now. This is temporary 

344 # workaround until new metrics are in place. 

345 reduce_op = ds_reduce_util.ReduceOp.MEAN 

346 output_tensors[label] = current_strategy.reduce(reduce_op, output, 

347 axis=None) 

348 test_op = control_flow_ops.group(list(output_tensors.values())) 

349 

350 if verbose >= 1: 

351 progbar = Progbar(target=steps) 

352 

353 if model._compile_distribution: 

354 dist_utils._copy_weights_to_distributed_model(model, mode) 

355 

356 dist_utils._reset_metrics(model) 

357 

358 callbacks = cbks.configure_callbacks( 

359 callbacks, 

360 model, 

361 do_validation=False, 

362 epochs=1, 

363 steps_per_epoch=steps, 

364 verbose=verbose, 

365 count_mode='steps', 

366 mode=ModeKeys.TEST) 

367 callbacks._call_begin_hook(mode) 

368 

369 outs = [0.] * len(model.metrics_names) 

370 if steps is not None: 

371 target_steps = steps 

372 else: 

373 raise ValueError('Number of steps could not be inferred from the data, ' 

374 'please pass the steps argument.') 

375 

376 current_step = 0 

377 while current_step < target_steps: 

378 batch_logs = {'batch': current_step, 'size': 1} 

379 callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs) 

380 try: 

381 _, batch_outs = backend.batch_get_value([test_op, output_tensors]) 

382 except errors.OutOfRangeError: 

383 warning_msg = ( 

384 'Make sure that your dataset can generate at least ' 

385 '`steps` batches (in this case, {} batches).'.format(steps)) 

386 

387 logging.warning('Your dataset iterator ran out of data; ' 

388 'interrupting evaluation. ' + warning_msg) 

389 target_steps = current_step 

390 break 

391 for i, label in enumerate(model.metrics_names): 

392 if i == 0: 

393 # Loss is stateless metrics. 

394 outs[i] += batch_outs[label] 

395 else: 

396 # For all stateful metrics, the aggregation is handled by mirrored vars. 

397 outs[i] = batch_outs[label] 

398 

399 batch_logs = cbks.make_logs(model, batch_logs, outs, mode) 

400 callbacks._call_batch_hook(mode, 'end', current_step, batch_logs) 

401 if verbose == 1: 

402 progbar.update(current_step + 1) 

403 current_step += 1 

404 

405 if verbose >= 1: 

406 # Progress bar finishes at the end. 

407 progbar.update(target_steps) 

408 callbacks._call_end_hook(mode) 

409 

410 scope.__exit__(None, None, None) 

411 if len(outs) >= 0: 

412 outs[0] /= (target_steps) 

413 

414 if len(outs) == 1: 

415 return outs[0] 

416 return outs 

417 

418 

419def experimental_tpu_predict_loop(model, 

420 dataset, 

421 verbose=0, 

422 steps=None, 

423 callbacks=None): 

424 """Predict loop for predicting with TPU tf.distribute.Strategy. 

425 

426 Args: 

427 model: Keras Model instance. 

428 dataset: Dataset for input data. 

429 verbose: Integer, Verbosity mode 0 or 1. 

430 steps: Total number of steps (batches of samples) 

431 before declaring `_predict_loop` finished. 

432 Ignored with the default value of `None`. 

433 callbacks: List of callbacks to be called during training 

434 

435 Returns: 

436 Array of predictions (if the model has a single output) 

437 or list of arrays of predictions 

438 (if the model has multiple outputs). 

439 """ 

440 mode = ModeKeys.PREDICT 

441 dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset) 

442 padding_handler = None 

443 if not dataset_fully_shaped: 

444 # TODO(hongjunchoi): Investigate whether operations from 

445 # PartialBatchPaddingHandler are unnecessarily pruned out 

446 # during graph optimization. 

447 padding_handler = padding_util.PartialBatchPaddingHandler( 

448 model._feed_output_shapes) 

449 batch_size, _, prefetch_buffer = input_lib._get_dataset_attributes(dataset) 

450 padding_handler.padded_batch_size = batch_size 

451 padding_handler.padding_mask = dataset.reduce(padding_handler.padding_mask, 

452 padding_handler.update_mask) 

453 

454 dataset = dataset.map(padding_handler.pad_batch) 

455 dataset = dataset.unbatch() 

456 # Upon this point, it is guaranteed that the dataset does not 

457 # have partial batches. Thus, we set `drop_remainder=True` to 

458 # get static shape information about the elements in the dataset. 

459 dataset = dataset.batch(batch_size, drop_remainder=True) 

460 

461 if prefetch_buffer is not None: 

462 dataset = dataset.prefetch(prefetch_buffer) 

463 

464 current_strategy = model._distribution_strategy 

465 iterator = dist_utils.get_iterator(dataset, current_strategy) 

466 

467 scope = dist_utils.distributed_scope( 

468 strategy=current_strategy, learning_phase=0) 

469 scope.__enter__() 

470 

471 def _predict_step_fn(inputs): 

472 """A fn that returns output of single prediction step.""" 

473 

474 (distribute_lib.get_replica_context().merge_call( 

475 _build_model, args=(model, mode, inputs))) 

476 

477 (_, outputs, updates, _) = _per_replica_execution_function( 

478 dist_utils.get_distributed_model(model, mode), mode) 

479 

480 with ops.control_dependencies([updates]): 

481 return [array_ops.identity(out) for out in outputs] 

482 

483 # TODO(hongjunchoi): When numpy array is passed as an input to `predict()` 

484 # use numpy arrays directly to avoid cumulating unnecessary input pipeline 

485 # ops. 

486 predict_input_data = iterator.get_next() 

487 per_replica_outputs = current_strategy.run( 

488 _predict_step_fn, args=(predict_input_data,)) 

489 output_tensors = dist_utils.flatten_per_replica_values( 

490 current_strategy, per_replica_outputs) 

491 

492 if verbose >= 1: 

493 progbar = Progbar(target=steps) 

494 

495 if model._compile_distribution: 

496 dist_utils._copy_weights_to_distributed_model(model, mode) 

497 

498 dist_utils._reset_metrics(model) 

499 

500 callbacks = cbks.configure_callbacks( 

501 callbacks, 

502 model, 

503 do_validation=False, 

504 epochs=1, 

505 steps_per_epoch=steps, 

506 verbose=verbose, 

507 count_mode='steps', 

508 mode=mode) 

509 callbacks._call_begin_hook(mode) 

510 

511 # Since we do not know how many samples we will see, we cannot pre-allocate 

512 # the returned Numpy arrays. Instead, we store one array per batch seen 

513 # and concatenate them upon returning. 

514 num_model_outputs = len(model.output_names) 

515 unconcatenated_outs = [[] for _ in range(num_model_outputs)] 

516 if steps is not None: 

517 target_steps = steps 

518 else: 

519 raise ValueError('Number of steps could not be inferred from the data, ' 

520 'please pass the steps argument.') 

521 

522 current_step = 0 

523 while current_step < target_steps: 

524 batch_logs = {'batch': current_step, 'size': 1} 

525 callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs) 

526 try: 

527 predict_ops = control_flow_ops.group(output_tensors) 

528 _, batch_outs = backend.batch_get_value([predict_ops, output_tensors]) 

529 

530 except errors.OutOfRangeError: 

531 warning_msg = ( 

532 'Make sure that your dataset can generate at least ' 

533 '`steps` batches (in this case, {} batches).'.format(steps)) 

534 

535 logging.warning('Your dataset iterator ran out of data; ' 

536 'interrupting evaluation. ' + warning_msg) 

537 break 

538 

539 # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy. 

540 for i in range(num_model_outputs): 

541 output_start_index = i * current_strategy.num_replicas_in_sync 

542 output_end_index = ( 

543 output_start_index + current_strategy.num_replicas_in_sync) 

544 single_model_output = batch_outs[output_start_index:output_end_index] 

545 unconcatenated_outs[i].extend(single_model_output) 

546 

547 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 

548 callbacks._call_batch_hook(mode, 'end', current_step, batch_logs) 

549 if verbose == 1: 

550 progbar.update(current_step + 1) 

551 current_step += 1 

552 

553 if verbose >= 1: 

554 # Progress bar finishes at the end. 

555 progbar.update(current_step) 

556 

557 callbacks._call_end_hook(mode) 

558 

559 scope.__exit__(None, None, None) 

560 

561 if len(unconcatenated_outs) == 1: 

562 prediction_result = np.concatenate(unconcatenated_outs[0], axis=0) 

563 else: 

564 prediction_result = [ 

565 np.concatenate(out, axis=0) for out in unconcatenated_outs 

566 ] 

567 

568 if padding_handler: 

569 prediction_result = padding_handler.apply_mask(prediction_result) 

570 

571 return prediction_result 

572 

573 

574class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop): 

575 """Training loop for distribution strategy with single worker.""" 

576 

577 def fit(self, 

578 model, 

579 x=None, 

580 y=None, 

581 batch_size=None, 

582 epochs=1, 

583 verbose=1, 

584 callbacks=None, 

585 validation_split=0., 

586 validation_data=None, 

587 shuffle=True, 

588 class_weight=None, 

589 sample_weight=None, 

590 initial_epoch=0, 

591 steps_per_epoch=None, 

592 validation_steps=None, 

593 validation_freq=1, 

594 **kwargs): 

595 """Fit loop for Distribution Strategies.""" 

596 dist_utils.validate_callbacks(input_callbacks=callbacks, 

597 optimizer=model.optimizer) 

598 dist_utils.validate_inputs(x, y) 

599 

600 batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size( 

601 model._distribution_strategy, 

602 x, 

603 batch_size, 

604 steps_per_epoch, 

605 ModeKeys.TRAIN, 

606 validation_split=validation_split) 

607 batch_size = model._validate_or_infer_batch_size( 

608 batch_size, steps_per_epoch, x) 

609 dataset = model._distribution_standardize_user_data( 

610 x, y, 

611 sample_weight=sample_weight, 

612 class_weight=class_weight, 

613 batch_size=batch_size, 

614 validation_split=validation_split, 

615 shuffle=shuffle, 

616 epochs=epochs) 

617 if not dist_utils.is_distributing_by_cloning(model): 

618 with model._distribution_strategy.scope(): 

619 (dataset, _, _) = model._standardize_user_data( 

620 dataset, 

621 sample_weight=sample_weight, 

622 class_weight=class_weight, 

623 batch_size=batch_size, 

624 validation_split=validation_split, 

625 shuffle=shuffle) 

626 

627 val_dataset = None 

628 if validation_data: 

629 val_x, val_y, val_sample_weights = ( 

630 training_utils_v1.unpack_validation_data(validation_data)) 

631 dist_utils.validate_inputs(val_x, val_y) 

632 _, validation_steps = dist_utils.process_batch_and_step_size( 

633 model._distribution_strategy, val_x, batch_size, validation_steps, 

634 ModeKeys.TEST) 

635 

636 val_dataset = model._distribution_standardize_user_data( 

637 val_x, val_y, 

638 sample_weight=val_sample_weights, 

639 class_weight=None, 

640 batch_size=batch_size, 

641 validation_split=validation_split, 

642 shuffle=shuffle, 

643 allow_partial_batch=True) 

644 elif validation_split: 

645 raise ValueError('validation_split argument is not supported with ' 

646 'distribution strategies.') 

647 

648 if backend.is_tpu_strategy(model._distribution_strategy): 

649 steps_per_epoch = training_utils_v1.infer_steps_for_dataset( 

650 model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch') 

651 if steps_per_epoch is None: 

652 raise ValueError('Number of steps could not be inferred from the data, ' 

653 'please pass the steps_per_epoch argument.') 

654 

655 if not context.executing_eagerly(): 

656 # Run TPU training in a custom loop in graph mode. 

657 return experimental_tpu_fit_loop( 

658 model, 

659 dataset, 

660 epochs=epochs, 

661 verbose=verbose, 

662 callbacks=callbacks, 

663 val_dataset=val_dataset, 

664 initial_epoch=initial_epoch, 

665 steps_per_epoch=steps_per_epoch, 

666 validation_steps=validation_steps, 

667 validation_freq=validation_freq) 

668 

669 return training_arrays_v1.fit_loop( 

670 model, 

671 dataset, 

672 batch_size=batch_size, 

673 epochs=epochs, 

674 verbose=verbose, 

675 callbacks=callbacks, 

676 val_inputs=val_dataset, 

677 shuffle=shuffle, 

678 initial_epoch=initial_epoch, 

679 steps_per_epoch=steps_per_epoch, 

680 validation_steps=validation_steps, 

681 validation_freq=validation_freq, 

682 steps_name='steps_per_epoch') 

683 

684 def evaluate(self, 

685 model, 

686 x=None, 

687 y=None, 

688 batch_size=None, 

689 verbose=1, 

690 sample_weight=None, 

691 steps=None, 

692 callbacks=None, 

693 **kwargs): 

694 """Evaluate loop for Distribution Strategies.""" 

695 dist_utils.validate_inputs(x, y) 

696 batch_size, steps = dist_utils.process_batch_and_step_size( 

697 model._distribution_strategy, x, batch_size, steps, ModeKeys.TEST) 

698 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 

699 dataset = model._distribution_standardize_user_data( 

700 x, y, 

701 sample_weight=sample_weight, 

702 batch_size=batch_size, 

703 allow_partial_batch=True) 

704 

705 if backend.is_tpu_strategy(model._distribution_strategy): 

706 steps = training_utils_v1.infer_steps_for_dataset( 

707 model, dataset, steps, steps_name='steps') 

708 if steps is None: 

709 raise ValueError('Number of steps could not be inferred from the data, ' 

710 'please pass the steps argument.') 

711 

712 if not context.executing_eagerly(): 

713 # Run TPU evaluation in a custom loop in graph mode. 

714 return experimental_tpu_test_loop( 

715 model, dataset, verbose=verbose, steps=steps, callbacks=callbacks) 

716 

717 return training_arrays_v1.test_loop( 

718 model, 

719 inputs=dataset, 

720 batch_size=batch_size, 

721 verbose=verbose, 

722 steps=steps, 

723 callbacks=callbacks) 

724 

725 def predict(self, 

726 model, 

727 x, 

728 batch_size=None, 

729 verbose=0, 

730 steps=None, 

731 callbacks=None, 

732 **kwargs): 

733 """Predict loop for Distribution Strategies.""" 

734 dist_utils.validate_inputs(x=x, y=None) 

735 batch_size, steps = dist_utils.process_batch_and_step_size( 

736 model._distribution_strategy, x, batch_size, steps, ModeKeys.PREDICT) 

737 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 

738 dataset = model._distribution_standardize_user_data( 

739 x, 

740 batch_size=batch_size, 

741 allow_partial_batch=True) 

742 if backend.is_tpu_strategy(model._distribution_strategy): 

743 steps = training_utils_v1.infer_steps_for_dataset( 

744 model, dataset, steps, steps_name='steps') 

745 if steps is None: 

746 raise ValueError('Number of steps could not be inferred from the data, ' 

747 'please pass the steps argument.') 

748 if not context.executing_eagerly(): 

749 return experimental_tpu_predict_loop( 

750 model, dataset, verbose=verbose, steps=steps, callbacks=callbacks) 

751 return training_arrays_v1.predict_loop( 

752 model, 

753 dataset, 

754 batch_size=batch_size, 

755 verbose=verbose, 

756 steps=steps, 

757 callbacks=callbacks) 

758 

759 

760def _train_with_multi_worker(method): 

761 """Decorator that handles multi worker training with distribution strategy.""" 

762 

763 def wrapper(model, **kwargs): 

764 def _worker_fn(_): 

765 callbacks = kwargs.pop('callbacks', None) 

766 filtered_callbacks = dist_utils.filter_distributed_callbacks( 

767 callbacks, model) 

768 kwargs['callbacks'] = filtered_callbacks 

769 return method(model, **kwargs) 

770 

771 return dc.run_distribute_coordinator( 

772 _worker_fn, 

773 model._distribution_strategy) 

774 

775 return wrapper 

776 

777 

778class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop): 

779 """Training loop for distribution strategy with multiple worker.""" 

780 

781 def __init__(self, single_worker_loop): 

782 self._single_worker_loop = single_worker_loop 

783 

784 def fit(self, *args, **kwargs): 

785 return _train_with_multi_worker(self._single_worker_loop.fit)( 

786 *args, **kwargs) 

787 

788 def evaluate(self, *args, **kwargs): 

789 return _train_with_multi_worker(self._single_worker_loop.evaluate)( 

790 *args, **kwargs) 

791 

792 def predict(self, *args, **kwargs): 

793 # Currently predict is still using the single worker implementation. 

794 return self._single_worker_loop.predict(*args, **kwargs)