Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training_distributed_v1.py: 10%

310 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Part of the Keras training engine related to distributed training.""" 

16 

17import numpy as np 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src import backend 

21from keras.src import callbacks as cbks 

22from keras.src.distribute import distribute_coordinator_utils as dc 

23from keras.src.distribute import distributed_training_utils_v1 as dist_utils 

24from keras.src.engine import partial_batch_padding_handler as padding_util 

25from keras.src.engine import training_arrays_v1 

26from keras.src.engine import training_utils_v1 

27from keras.src.utils.generic_utils import Progbar 

28from keras.src.utils.mode_keys import ModeKeys 

29 

30# isort: off 

31from tensorflow.python.distribute import input_lib 

32from tensorflow.python.platform import tf_logging as logging 

33 

34 

35def _per_replica_execution_function(model, mode): 

36 exec_func = model._make_execution_function(mode) 

37 return ( 

38 exec_func.inputs, 

39 exec_func.outputs, 

40 exec_func.updates_op, 

41 exec_func.session_kwargs, 

42 ) 

43 

44 

45def _build_model(strategy, model, mode, inputs, targets=None): 

46 if model._compile_distribution: 

47 dist_utils.clone_model_on_replicas( 

48 model, strategy, mode, inputs=inputs, targets=targets 

49 ) 

50 else: 

51 dist_utils._build_distributed_network( 

52 model, strategy, mode, inputs, targets 

53 ) 

54 

55 

56def _make_train_step_fn(model, mode, strategy, output_labels): 

57 """Create step fn. 

58 

59 Args: 

60 model: a Keras Model instance. 

61 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 

62 strategy: a `tf.distribute.Strategy` instance. 

63 output_labels: the output labels for the step function. 

64 

65 Returns: 

66 A step function to run by `tf.distribute.Strategy`. 

67 """ 

68 

69 def _step_fn(ctx, inputs): 

70 """A step fn that returns update ops.""" 

71 if isinstance(inputs, (tuple, list)) and len(inputs) == 2: 

72 inputs, targets = inputs 

73 else: 

74 targets = None 

75 

76 # When input feature is a dictionary of tensors, dictionary is 

77 # flattended to an array and passed as a model input. This results in 

78 # input mismatch when model input layer names are not sorted in 

79 # alphabetical order as `nest.flatten()`sorts dictionary elements by 

80 # keys. As so, transform input tensors into an array and order it along 

81 # `model._feed_input_names`. 

82 if isinstance(inputs, dict): 

83 inputs = [ 

84 inputs[input_name] for input_name in model._feed_input_names 

85 ] 

86 

87 _build_model(strategy, model, mode, inputs, targets) 

88 

89 ( 

90 grouped_inputs, 

91 grouped_outputs, 

92 grouped_updates, 

93 grouped_session_args, 

94 ) = strategy.extended.call_for_each_replica( 

95 _per_replica_execution_function, 

96 args=(dist_utils.get_distributed_model(model, mode), mode), 

97 ) 

98 ( 

99 all_inputs, 

100 all_outputs, 

101 all_updates, 

102 all_session_args, 

103 ) = dist_utils.unwrap_values( 

104 strategy, 

105 grouped_inputs, 

106 grouped_outputs, 

107 grouped_updates, 

108 grouped_session_args, 

109 ) 

110 combined_fn = backend.function( 

111 all_inputs, 

112 all_outputs, 

113 updates=all_updates, 

114 name="distributed_" + str(mode) + "_function", 

115 **all_session_args 

116 ) 

117 

118 for label, output in zip(output_labels, combined_fn.outputs): 

119 if label == "loss": 

120 reduce_op = tf.distribute.ReduceOp.SUM 

121 else: 

122 # We reduce all other metrics using mean for now. This is 

123 # temporary workaround until new metrics are in place. 

124 reduce_op = tf.distribute.ReduceOp.MEAN 

125 ctx.set_last_step_output(label, output, reduce_op) 

126 

127 # TODO(priyag, sourabhbajaj): Ignoring these things from the 

128 # combined_fn: feed_dict, session kwargs, run options, run_metadata for 

129 # now. These should be handled appropriately 

130 return combined_fn.updates_op 

131 

132 return _step_fn 

133 

134 

135def experimental_tpu_fit_loop( 

136 model, 

137 dataset, 

138 epochs=100, 

139 verbose=1, 

140 callbacks=None, 

141 initial_epoch=0, 

142 steps_per_epoch=None, 

143 val_dataset=None, 

144 validation_steps=None, 

145 validation_freq=1, 

146): 

147 """Fit loop for training with TPU tf.distribute.Strategy. 

148 

149 Args: 

150 model: Keras Model instance. 

151 dataset: Dataset that returns inputs and targets 

152 epochs: Number of times to iterate over the data 

153 verbose: Integer, Verbosity mode, 0, 1 or 2 

154 callbacks: List of callbacks to be called during training 

155 initial_epoch: Epoch at which to start training 

156 (useful for resuming a previous training run) 

157 steps_per_epoch: Total number of steps (batches of samples) 

158 before declaring one epoch finished and starting the 

159 next epoch. Ignored with the default value of `None`. 

160 val_dataset: Dataset for validation data. 

161 validation_steps: Number of steps to run validation for 

162 (only if doing validation from data tensors). 

163 Ignored with the default value of `None`. 

164 validation_freq: Only relevant if validation data is provided. Integer 

165 or `collections.abc.Container` instance (e.g. list, tuple, etc.). If 

166 an integer, specifies how many training epochs to run before a new 

167 validation run is performed, e.g. `validation_freq=2` runs 

168 validation every 2 epochs. If a Container, specifies the epochs on 

169 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 

170 validation at the end of the 1st, 2nd, and 10th epochs. 

171 

172 Returns: 

173 Returns `None`. 

174 

175 Raises: 

176 ValueError: in case of invalid arguments. 

177 """ 

178 mode = ModeKeys.TRAIN 

179 

180 current_strategy = model._distribution_strategy 

181 iteration_value = min( 

182 steps_per_epoch, current_strategy.extended.steps_per_run 

183 ) 

184 steps_per_run = backend.variable( 

185 value=iteration_value, dtype="int32", name="steps_per_run" 

186 ) 

187 

188 # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops. 

189 iterator = dist_utils.get_iterator(dataset, current_strategy) 

190 

191 scope = dist_utils.distributed_scope( 

192 strategy=current_strategy, learning_phase=1 

193 ) 

194 scope.__enter__() 

195 

196 out_labels = model.metrics_names or [] 

197 

198 step_fn = _make_train_step_fn( 

199 model, ModeKeys.TRAIN, current_strategy, out_labels 

200 ) 

201 

202 # Add initial dummy values for loss and other metric tensors. 

203 initial_loop_values = {} 

204 initial_loop_values["loss"] = tf.constant(1e7) 

205 for m in model._get_training_eval_metrics(): 

206 tensor = m.result() 

207 initial_loop_values[m.name] = tf.zeros(tensor.shape, tensor.dtype) 

208 

209 ctx = current_strategy.extended.experimental_run_steps_on_iterator( 

210 step_fn, 

211 iterator, 

212 iterations=steps_per_run, 

213 initial_loop_values=initial_loop_values, 

214 ) 

215 train_op = ctx.run_op 

216 output_tensors = ctx.last_step_outputs 

217 

218 do_validation = bool(validation_steps) 

219 

220 if model._compile_distribution: 

221 dist_utils._copy_weights_to_distributed_model(model, mode) 

222 

223 callbacks = cbks.configure_callbacks( 

224 callbacks, 

225 model, 

226 do_validation=do_validation, 

227 epochs=epochs, 

228 steps_per_epoch=steps_per_epoch, 

229 verbose=verbose, 

230 count_mode="steps", 

231 mode=mode, 

232 ) 

233 

234 # Calculate the steps each time on the device. 

235 steps_to_run = [current_strategy.extended.steps_per_run] * ( 

236 steps_per_epoch // current_strategy.extended.steps_per_run 

237 ) 

238 if steps_per_epoch % current_strategy.extended.steps_per_run: 

239 steps_to_run.append( 

240 steps_per_epoch % current_strategy.extended.steps_per_run 

241 ) 

242 target_steps = len(steps_to_run) 

243 

244 callbacks._call_begin_hook(mode) 

245 

246 initial_epoch = model._maybe_load_initial_epoch_from_ckpt( 

247 initial_epoch, mode 

248 ) 

249 

250 for epoch in range(initial_epoch, epochs): 

251 dist_utils._reset_metrics(model) 

252 callbacks.on_epoch_begin(epoch) 

253 epoch_logs = {} 

254 step_index = 0 

255 prev_step_count = None 

256 current_step = 0 

257 while current_step < target_steps: 

258 step_count = steps_to_run[current_step] 

259 batch_logs = { 

260 "batch": step_index, 

261 "size": 1, 

262 "num_steps": step_count, 

263 } 

264 callbacks._call_batch_hook(mode, "begin", step_index, batch_logs) 

265 if prev_step_count is None or step_count != prev_step_count: 

266 backend.get_session().run(steps_per_run.assign(step_count)) 

267 prev_step_count = step_count 

268 try: 

269 _, outputs = backend.batch_get_value([train_op, output_tensors]) 

270 except tf.errors.OutOfRangeError: 

271 logging.warning( 

272 "Your dataset iterator ran out of data; " 

273 "interrupting training. Make sure that your dataset " 

274 "can generate at least `steps_per_epoch * epochs` " 

275 "batches (in this case, %d batches)." 

276 % steps_per_epoch 

277 * epochs 

278 ) 

279 break 

280 

281 batch_logs.update(outputs) 

282 callbacks._call_batch_hook(mode, "end", step_index, batch_logs) 

283 step_index = step_index + step_count 

284 current_step += 1 

285 

286 if callbacks.model.stop_training: 

287 break 

288 

289 if do_validation and training_utils_v1.should_run_validation( 

290 validation_freq, epoch 

291 ): 

292 logging.info("Running validation at fit epoch: %s", epoch) 

293 

294 if model._compile_distribution: 

295 # Since we create a new clone from the original model we need to 

296 # copy the weights back to the original model before we can run 

297 # validation. 

298 dist_utils._copy_weights_to_original_model( 

299 model, ModeKeys.TRAIN 

300 ) 

301 

302 val_outs = experimental_tpu_test_loop( 

303 model, 

304 val_dataset, 

305 steps=validation_steps, 

306 verbose=verbose, 

307 callbacks=callbacks, 

308 ) 

309 if not isinstance(val_outs, list): 

310 val_outs = [val_outs] 

311 # Same labels assumed. 

312 for label, val_out in zip(out_labels, val_outs): 

313 epoch_logs["val_" + label] = val_out 

314 

315 callbacks.on_epoch_end(epoch, epoch_logs) 

316 if callbacks.model.stop_training: 

317 break 

318 model._successful_loop_finish = True 

319 callbacks._call_end_hook(mode) 

320 

321 if model._compile_distribution: 

322 # Copy the weights back from the replicated model to the original model. 

323 dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN) 

324 scope.__exit__(None, None, None) 

325 return model.history 

326 

327 

328def experimental_tpu_test_loop( 

329 model, dataset, verbose=0, steps=None, callbacks=None 

330): 

331 """Test loop for evaluating with TPU tf.distribute.Strategy. 

332 

333 Args: 

334 model: Keras Model instance. 

335 dataset: Dataset for input data. 

336 verbose: Integer, Verbosity mode 0 or 1. 

337 steps: Total number of steps (batches of samples) 

338 before declaring predictions finished. 

339 Ignored with the default value of `None`. 

340 callbacks: List of callbacks to be called during training 

341 

342 Returns: 

343 Scalar loss (if the model has a single output and no metrics) 

344 or list of scalars (if the model has multiple outputs 

345 and/or metrics). The attribute `model.metrics_names` will give you 

346 the display labels for the outputs. 

347 """ 

348 mode = ModeKeys.TEST 

349 current_strategy = model._distribution_strategy 

350 iterator = dist_utils.get_iterator(dataset, current_strategy) 

351 

352 scope = dist_utils.distributed_scope( 

353 strategy=current_strategy, learning_phase=0 

354 ) 

355 scope.__enter__() 

356 

357 out_labels = model.metrics_names 

358 

359 def _test_step_fn(inputs): 

360 """A fn that returns output of single test step.""" 

361 if isinstance(inputs, (tuple, list)) and len(inputs) == 2: 

362 inputs, targets = inputs 

363 else: 

364 targets = None 

365 

366 ( 

367 tf.distribute.get_replica_context().merge_call( 

368 _build_model, args=(model, mode, inputs, targets) 

369 ) 

370 ) 

371 

372 (_, outputs, updates, _) = _per_replica_execution_function( 

373 dist_utils.get_distributed_model(model, mode), mode 

374 ) 

375 with tf.control_dependencies([updates]): 

376 return [tf.identity(out) for out in outputs] 

377 

378 test_input_data = iterator.get_next() 

379 per_replica_outputs = current_strategy.run( 

380 _test_step_fn, args=(test_input_data,) 

381 ) 

382 output_tensors = {} 

383 for label, output in zip(out_labels, per_replica_outputs): 

384 if label == "loss": 

385 reduce_op = tf.distribute.ReduceOp.SUM 

386 else: 

387 # We reduce all other metrics using mean for now. This is temporary 

388 # workaround until new metrics are in place. 

389 reduce_op = tf.distribute.ReduceOp.MEAN 

390 output_tensors[label] = current_strategy.reduce( 

391 reduce_op, output, axis=None 

392 ) 

393 test_op = tf.group(list(output_tensors.values())) 

394 

395 if verbose >= 1: 

396 progbar = Progbar(target=steps) 

397 

398 if model._compile_distribution: 

399 dist_utils._copy_weights_to_distributed_model(model, mode) 

400 

401 dist_utils._reset_metrics(model) 

402 

403 callbacks = cbks.configure_callbacks( 

404 callbacks, 

405 model, 

406 do_validation=False, 

407 epochs=1, 

408 steps_per_epoch=steps, 

409 verbose=verbose, 

410 count_mode="steps", 

411 mode=ModeKeys.TEST, 

412 ) 

413 callbacks._call_begin_hook(mode) 

414 

415 outs = [0.0] * len(model.metrics_names) 

416 if steps is not None: 

417 target_steps = steps 

418 else: 

419 raise ValueError( 

420 "Number of steps could not be inferred from the data, " 

421 "please pass the steps argument." 

422 ) 

423 

424 current_step = 0 

425 while current_step < target_steps: 

426 batch_logs = {"batch": current_step, "size": 1} 

427 callbacks._call_batch_hook(mode, "begin", current_step, batch_logs) 

428 try: 

429 _, batch_outs = backend.batch_get_value([test_op, output_tensors]) 

430 except tf.errors.OutOfRangeError: 

431 warning_msg = ( 

432 "Make sure that your dataset can generate at least " 

433 "`steps` batches (in this case, {} batches).".format(steps) 

434 ) 

435 

436 logging.warning( 

437 "Your dataset iterator ran out of data; " 

438 "interrupting evaluation. " + warning_msg 

439 ) 

440 target_steps = current_step 

441 break 

442 for i, label in enumerate(model.metrics_names): 

443 if i == 0: 

444 # Loss is stateless metrics. 

445 outs[i] += batch_outs[label] 

446 else: 

447 # For all stateful metrics, the aggregation is handled by 

448 # mirrored vars. 

449 outs[i] = batch_outs[label] 

450 

451 batch_logs = callbacks.make_logs(model, batch_logs, outs, mode) 

452 callbacks._call_batch_hook(mode, "end", current_step, batch_logs) 

453 if verbose == 1: 

454 progbar.update(current_step + 1) 

455 current_step += 1 

456 

457 if verbose >= 1: 

458 # Progress bar finishes at the end. 

459 progbar.update(target_steps) 

460 callbacks._call_end_hook(mode) 

461 

462 scope.__exit__(None, None, None) 

463 if len(outs) > 0: 

464 outs[0] /= target_steps 

465 

466 if len(outs) == 1: 

467 return outs[0] 

468 return outs 

469 

470 

471def experimental_tpu_predict_loop( 

472 model, dataset, verbose=0, steps=None, callbacks=None 

473): 

474 """Predict loop for predicting with TPU tf.distribute.Strategy. 

475 

476 Args: 

477 model: Keras Model instance. 

478 dataset: Dataset for input data. 

479 verbose: Integer, Verbosity mode 0 or 1. 

480 steps: Total number of steps (batches of samples) 

481 before declaring `_predict_loop` finished. 

482 Ignored with the default value of `None`. 

483 callbacks: List of callbacks to be called during training 

484 

485 Returns: 

486 Array of predictions (if the model has a single output) 

487 or list of arrays of predictions 

488 (if the model has multiple outputs). 

489 """ 

490 mode = ModeKeys.PREDICT 

491 dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset) 

492 padding_handler = None 

493 if not dataset_fully_shaped: 

494 # TODO(hongjunchoi): Investigate whether operations from 

495 # PartialBatchPaddingHandler are unnecessarily pruned out 

496 # during graph optimization. 

497 padding_handler = padding_util.PartialBatchPaddingHandler( 

498 model._feed_output_shapes 

499 ) 

500 batch_size, _, prefetch_buffer = input_lib._get_dataset_attributes( 

501 dataset 

502 ) 

503 padding_handler.padded_batch_size = batch_size 

504 padding_handler.padding_mask = dataset.reduce( 

505 padding_handler.padding_mask, padding_handler.update_mask 

506 ) 

507 

508 dataset = dataset.map(padding_handler.pad_batch) 

509 dataset = dataset.unbatch() 

510 # Upon this point, it is guaranteed that the dataset does not 

511 # have partial batches. Thus, we set `drop_remainder=True` to 

512 # get static shape information about the elements in the dataset. 

513 dataset = dataset.batch(batch_size, drop_remainder=True) 

514 

515 if prefetch_buffer is not None: 

516 dataset = dataset.prefetch(prefetch_buffer) 

517 

518 current_strategy = model._distribution_strategy 

519 iterator = dist_utils.get_iterator(dataset, current_strategy) 

520 

521 scope = dist_utils.distributed_scope( 

522 strategy=current_strategy, learning_phase=0 

523 ) 

524 scope.__enter__() 

525 

526 def _predict_step_fn(inputs): 

527 """A fn that returns output of single prediction step.""" 

528 

529 ( 

530 tf.distribute.get_replica_context().merge_call( 

531 _build_model, args=(model, mode, inputs) 

532 ) 

533 ) 

534 

535 (_, outputs, updates, _) = _per_replica_execution_function( 

536 dist_utils.get_distributed_model(model, mode), mode 

537 ) 

538 

539 with tf.control_dependencies([updates]): 

540 return [tf.identity(out) for out in outputs] 

541 

542 # TODO(hongjunchoi): When numpy array is passed as an input to `predict()` 

543 # use numpy arrays directly to avoid cumulating unnecessary input pipeline 

544 # ops. 

545 predict_input_data = iterator.get_next() 

546 per_replica_outputs = current_strategy.run( 

547 _predict_step_fn, args=(predict_input_data,) 

548 ) 

549 output_tensors = dist_utils.flatten_per_replica_values( 

550 current_strategy, per_replica_outputs 

551 ) 

552 

553 if verbose >= 1: 

554 progbar = Progbar(target=steps) 

555 

556 if model._compile_distribution: 

557 dist_utils._copy_weights_to_distributed_model(model, mode) 

558 

559 dist_utils._reset_metrics(model) 

560 

561 callbacks = cbks.configure_callbacks( 

562 callbacks, 

563 model, 

564 do_validation=False, 

565 epochs=1, 

566 steps_per_epoch=steps, 

567 verbose=verbose, 

568 count_mode="steps", 

569 mode=mode, 

570 ) 

571 callbacks._call_begin_hook(mode) 

572 

573 # Since we do not know how many samples we will see, we cannot pre-allocate 

574 # the returned Numpy arrays. Instead, we store one array per batch seen 

575 # and concatenate them upon returning. 

576 num_model_outputs = len(model.output_names) 

577 unconcatenated_outs = [[] for _ in range(num_model_outputs)] 

578 if steps is not None: 

579 target_steps = steps 

580 else: 

581 raise ValueError( 

582 "Number of steps could not be inferred from the data, " 

583 "please pass the steps argument." 

584 ) 

585 

586 current_step = 0 

587 while current_step < target_steps: 

588 batch_logs = {"batch": current_step, "size": 1} 

589 callbacks._call_batch_hook(mode, "begin", current_step, batch_logs) 

590 try: 

591 predict_ops = tf.group(output_tensors) 

592 _, batch_outs = backend.batch_get_value( 

593 [predict_ops, output_tensors] 

594 ) 

595 

596 except tf.errors.OutOfRangeError: 

597 warning_msg = ( 

598 "Make sure that your dataset can generate at least " 

599 "`steps` batches (in this case, {} batches).".format(steps) 

600 ) 

601 

602 logging.warning( 

603 "Your dataset iterator ran out of data; " 

604 "interrupting evaluation. " + warning_msg 

605 ) 

606 break 

607 

608 # TODO(priyag): maybe need to unwrap the outputs first for 

609 # MirroredStrategy. 

610 for i in range(num_model_outputs): 

611 output_start_index = i * current_strategy.num_replicas_in_sync 

612 output_end_index = ( 

613 output_start_index + current_strategy.num_replicas_in_sync 

614 ) 

615 single_model_output = batch_outs[ 

616 output_start_index:output_end_index 

617 ] 

618 unconcatenated_outs[i].extend(single_model_output) 

619 

620 batch_logs = callbacks.make_logs(model, batch_logs, batch_outs, mode) 

621 callbacks._call_batch_hook(mode, "end", current_step, batch_logs) 

622 if verbose == 1: 

623 progbar.update(current_step + 1) 

624 current_step += 1 

625 

626 if verbose >= 1: 

627 # Progress bar finishes at the end. 

628 progbar.update(current_step) 

629 

630 callbacks._call_end_hook(mode) 

631 

632 scope.__exit__(None, None, None) 

633 

634 if len(unconcatenated_outs) == 1: 

635 prediction_result = np.concatenate(unconcatenated_outs[0], axis=0) 

636 else: 

637 prediction_result = [ 

638 np.concatenate(out, axis=0) for out in unconcatenated_outs 

639 ] 

640 

641 if padding_handler: 

642 prediction_result = padding_handler.apply_mask(prediction_result) 

643 

644 return prediction_result 

645 

646 

647class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop): 

648 """Training loop for distribution strategy with single worker.""" 

649 

650 def fit( 

651 self, 

652 model, 

653 x=None, 

654 y=None, 

655 batch_size=None, 

656 epochs=1, 

657 verbose=1, 

658 callbacks=None, 

659 validation_split=0.0, 

660 validation_data=None, 

661 shuffle=True, 

662 class_weight=None, 

663 sample_weight=None, 

664 initial_epoch=0, 

665 steps_per_epoch=None, 

666 validation_steps=None, 

667 validation_freq=1, 

668 **kwargs 

669 ): 

670 """Fit loop for Distribution Strategies.""" 

671 dist_utils.validate_callbacks( 

672 input_callbacks=callbacks, optimizer=model.optimizer 

673 ) 

674 dist_utils.validate_inputs(x, y) 

675 

676 batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size( 

677 model._distribution_strategy, 

678 x, 

679 batch_size, 

680 steps_per_epoch, 

681 ModeKeys.TRAIN, 

682 validation_split=validation_split, 

683 ) 

684 batch_size = model._validate_or_infer_batch_size( 

685 batch_size, steps_per_epoch, x 

686 ) 

687 dataset = model._distribution_standardize_user_data( 

688 x, 

689 y, 

690 sample_weight=sample_weight, 

691 class_weight=class_weight, 

692 batch_size=batch_size, 

693 validation_split=validation_split, 

694 shuffle=shuffle, 

695 epochs=epochs, 

696 ) 

697 if not dist_utils.is_distributing_by_cloning(model): 

698 with model._distribution_strategy.scope(): 

699 (dataset, _, _) = model._standardize_user_data( 

700 dataset, 

701 sample_weight=sample_weight, 

702 class_weight=class_weight, 

703 batch_size=batch_size, 

704 validation_split=validation_split, 

705 shuffle=shuffle, 

706 ) 

707 

708 val_dataset = None 

709 if validation_data: 

710 ( 

711 val_x, 

712 val_y, 

713 val_sample_weights, 

714 ) = training_utils_v1.unpack_validation_data(validation_data) 

715 dist_utils.validate_inputs(val_x, val_y) 

716 _, validation_steps = dist_utils.process_batch_and_step_size( 

717 model._distribution_strategy, 

718 val_x, 

719 batch_size, 

720 validation_steps, 

721 ModeKeys.TEST, 

722 ) 

723 

724 val_dataset = model._distribution_standardize_user_data( 

725 val_x, 

726 val_y, 

727 sample_weight=val_sample_weights, 

728 class_weight=None, 

729 batch_size=batch_size, 

730 validation_split=validation_split, 

731 shuffle=shuffle, 

732 allow_partial_batch=True, 

733 ) 

734 elif validation_split: 

735 raise ValueError( 

736 "validation_split argument is not supported with " 

737 "distribution strategies." 

738 ) 

739 

740 if backend.is_tpu_strategy(model._distribution_strategy): 

741 steps_per_epoch = training_utils_v1.infer_steps_for_dataset( 

742 model, 

743 dataset, 

744 steps_per_epoch, 

745 epochs, 

746 steps_name="steps_per_epoch", 

747 ) 

748 if steps_per_epoch is None: 

749 raise ValueError( 

750 "Number of steps could not be inferred from the data, " 

751 "please pass the steps_per_epoch argument." 

752 ) 

753 

754 if not tf.executing_eagerly(): 

755 # Run TPU training in a custom loop in graph mode. 

756 return experimental_tpu_fit_loop( 

757 model, 

758 dataset, 

759 epochs=epochs, 

760 verbose=verbose, 

761 callbacks=callbacks, 

762 val_dataset=val_dataset, 

763 initial_epoch=initial_epoch, 

764 steps_per_epoch=steps_per_epoch, 

765 validation_steps=validation_steps, 

766 validation_freq=validation_freq, 

767 ) 

768 

769 return training_arrays_v1.fit_loop( 

770 model, 

771 dataset, 

772 batch_size=batch_size, 

773 epochs=epochs, 

774 verbose=verbose, 

775 callbacks=callbacks, 

776 val_inputs=val_dataset, 

777 shuffle=shuffle, 

778 initial_epoch=initial_epoch, 

779 steps_per_epoch=steps_per_epoch, 

780 validation_steps=validation_steps, 

781 validation_freq=validation_freq, 

782 steps_name="steps_per_epoch", 

783 ) 

784 

785 def evaluate( 

786 self, 

787 model, 

788 x=None, 

789 y=None, 

790 batch_size=None, 

791 verbose=1, 

792 sample_weight=None, 

793 steps=None, 

794 callbacks=None, 

795 **kwargs 

796 ): 

797 """Evaluate loop for Distribution Strategies.""" 

798 dist_utils.validate_inputs(x, y) 

799 batch_size, steps = dist_utils.process_batch_and_step_size( 

800 model._distribution_strategy, x, batch_size, steps, ModeKeys.TEST 

801 ) 

802 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 

803 dataset = model._distribution_standardize_user_data( 

804 x, 

805 y, 

806 sample_weight=sample_weight, 

807 batch_size=batch_size, 

808 allow_partial_batch=True, 

809 ) 

810 

811 if backend.is_tpu_strategy(model._distribution_strategy): 

812 steps = training_utils_v1.infer_steps_for_dataset( 

813 model, dataset, steps, steps_name="steps" 

814 ) 

815 if steps is None: 

816 raise ValueError( 

817 "Number of steps could not be inferred from the data, " 

818 "please pass the steps argument." 

819 ) 

820 

821 if not tf.executing_eagerly(): 

822 # Run TPU evaluation in a custom loop in graph mode. 

823 return experimental_tpu_test_loop( 

824 model, 

825 dataset, 

826 verbose=verbose, 

827 steps=steps, 

828 callbacks=callbacks, 

829 ) 

830 

831 return training_arrays_v1.test_loop( 

832 model, 

833 inputs=dataset, 

834 batch_size=batch_size, 

835 verbose=verbose, 

836 steps=steps, 

837 callbacks=callbacks, 

838 ) 

839 

840 def predict( 

841 self, 

842 model, 

843 x, 

844 batch_size=None, 

845 verbose=0, 

846 steps=None, 

847 callbacks=None, 

848 **kwargs 

849 ): 

850 """Predict loop for Distribution Strategies.""" 

851 dist_utils.validate_inputs(x=x, y=None) 

852 batch_size, steps = dist_utils.process_batch_and_step_size( 

853 model._distribution_strategy, x, batch_size, steps, ModeKeys.PREDICT 

854 ) 

855 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 

856 dataset = model._distribution_standardize_user_data( 

857 x, batch_size=batch_size, allow_partial_batch=True 

858 ) 

859 if backend.is_tpu_strategy(model._distribution_strategy): 

860 steps = training_utils_v1.infer_steps_for_dataset( 

861 model, dataset, steps, steps_name="steps" 

862 ) 

863 if steps is None: 

864 raise ValueError( 

865 "Number of steps could not be inferred from the data, " 

866 "please pass the steps argument." 

867 ) 

868 if not tf.executing_eagerly(): 

869 return experimental_tpu_predict_loop( 

870 model, 

871 dataset, 

872 verbose=verbose, 

873 steps=steps, 

874 callbacks=callbacks, 

875 ) 

876 return training_arrays_v1.predict_loop( 

877 model, 

878 dataset, 

879 batch_size=batch_size, 

880 verbose=verbose, 

881 steps=steps, 

882 callbacks=callbacks, 

883 ) 

884 

885 

886def _train_with_multi_worker(method): 

887 """Decorator handles multi worker training with distribution strategy.""" 

888 

889 def wrapper(model, **kwargs): 

890 def _worker_fn(_): 

891 callbacks = kwargs.pop("callbacks", None) 

892 filtered_callbacks = dist_utils.filter_distributed_callbacks( 

893 callbacks, model 

894 ) 

895 kwargs["callbacks"] = filtered_callbacks 

896 return method(model, **kwargs) 

897 

898 return dc.run_distribute_coordinator( 

899 _worker_fn, model._distribution_strategy 

900 ) 

901 

902 return wrapper 

903 

904 

905class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop): 

906 """Training loop for distribution strategy with multiple worker.""" 

907 

908 def __init__(self, single_worker_loop): 

909 self._single_worker_loop = single_worker_loop 

910 

911 def fit(self, *args, **kwargs): 

912 return _train_with_multi_worker(self._single_worker_loop.fit)( 

913 *args, **kwargs 

914 ) 

915 

916 def evaluate(self, *args, **kwargs): 

917 return _train_with_multi_worker(self._single_worker_loop.evaluate)( 

918 *args, **kwargs 

919 ) 

920 

921 def predict(self, *args, **kwargs): 

922 # Currently predict is still using the single worker implementation. 

923 return self._single_worker_loop.predict(*args, **kwargs) 

924