Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training_arrays_v1.py: 13%

252 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Part of the Keras training engine related to plain array data.""" 

16 

17import functools 

18 

19import numpy as np 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src import backend 

23from keras.src import callbacks as cbks 

24from keras.src.distribute import distributed_training_utils_v1 

25from keras.src.engine import training_utils_v1 

26from keras.src.utils import io_utils 

27from keras.src.utils.generic_utils import make_batches 

28from keras.src.utils.generic_utils import slice_arrays 

29from keras.src.utils.mode_keys import ModeKeys 

30 

31# isort: off 

32from tensorflow.python.platform import tf_logging as logging 

33 

34 

35try: 

36 from scipy.sparse import issparse 

37except ImportError: 

38 issparse = None 

39 

40 

41def model_iteration( 

42 model, 

43 inputs, 

44 targets=None, 

45 sample_weights=None, 

46 batch_size=None, 

47 epochs=1, 

48 verbose=1, 

49 callbacks=None, 

50 val_inputs=None, 

51 val_targets=None, 

52 val_sample_weights=None, 

53 shuffle=True, 

54 initial_epoch=0, 

55 steps_per_epoch=None, 

56 validation_steps=None, 

57 validation_freq=1, 

58 mode=ModeKeys.TRAIN, 

59 validation_in_fit=False, 

60 prepared_feed_values_from_dataset=False, 

61 steps_name="steps", 

62 **kwargs, 

63): 

64 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT. 

65 

66 Args: 

67 model: Keras Model instance. 

68 inputs: Either a list or dictionary of arrays, or a dataset instance. 

69 targets: List/dictionary of input arrays. 

70 sample_weights: Optional list of sample weight arrays. 

71 batch_size: Integer batch size or None if unknown. 

72 epochs: Number of times to iterate over the data 

73 verbose: 0, 1, or 2. Verbosity mode. 

74 0 = silent, 1 = progress bar, 2 = one line per epoch. 

75 Note that the progress bar is not particularly useful when 

76 logged to a file, so verbose=2 is recommended when not running 

77 interactively (eg, in a production environment). 

78 callbacks: List of callbacks to be called during training 

79 val_inputs: Either a list or dictionary of arrays, or a dataset 

80 instance. 

81 val_targets: List/dictionary of target arrays. 

82 val_sample_weights: Optional list of sample weight arrays. 

83 shuffle: Whether to shuffle the data at the beginning of each epoch 

84 concatenation of list the display names of the outputs of `f` and the 

85 list of display names of the outputs of `f_val`. 

86 initial_epoch: Epoch at which to start training (useful for resuming a 

87 previous training run) 

88 steps_per_epoch: Total number of steps (batches of samples) before 

89 declaring one epoch finished and starting the next epoch. Ignored with 

90 the default value of `None`. 

91 validation_steps: Number of steps to run validation for (only if doing 

92 validation from data tensors). Ignored with the default value of 

93 `None`. 

94 validation_freq: Only relevant if validation data is provided. Integer 

95 or `collections.abc.Container` instance (e.g. list, tuple, etc.). If 

96 an integer, specifies how many training epochs to run before a new 

97 validation run is performed, e.g. `validation_freq=2` runs validation 

98 every 2 epochs. If a Container, specifies the epochs on which to run 

99 validation, e.g. `validation_freq=[1, 2, 10]` runs validation at the 

100 end of the 1st, 2nd, and 10th epochs. 

101 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 

102 validation_in_fit: if true, then this method is invoked from within 

103 training iteration (for validation). In the case where `val_inputs` is 

104 a dataset, this flag indicates that its iterator and feed values are 

105 already created so should properly reuse resources. 

106 prepared_feed_values_from_dataset: if True, `inputs` is a list of feed 

107 tensors returned from `_prepare_feed_values` call on the validation 

108 dataset, so do not call it again on `inputs`. Should only be used for 

109 inline validation (i.e., only if `validation_in_fit` is also True). 

110 steps_name: The string name of the steps argument, either `steps`, 

111 `validation_steps`, or `steps_per_epoch`. Only used for error message 

112 formatting. 

113 **kwargs: Additional arguments for backwards compatibility. 

114 

115 Returns: 

116 - In TRAIN mode: `History` object. 

117 - In TEST mode: Evaluation metrics. 

118 - In PREDICT mode: Outputs of the Model called on inputs. 

119 

120 Raises: 

121 ValueError: in case of invalid arguments. 

122 """ 

123 # Backwards compatibility. 

124 if "steps" in kwargs: 

125 steps_per_epoch = kwargs.pop("steps") 

126 if kwargs: 

127 raise TypeError(f"Unknown arguments: {kwargs}") 

128 

129 # In case we were passed a dataset, we extract symbolic tensors from it. 

130 reset_dataset_after_each_epoch = False 

131 input_iterator = None 

132 is_dataset = isinstance( 

133 inputs, (tf.compat.v1.data.Dataset, tf.data.Dataset) 

134 ) 

135 # TODO(fchollet): consider moving `steps_per_epoch` inference to 

136 # _standardize_user_data and set reset_dataset_after_each_epoch as an 

137 # attribute on the dataset instance. 

138 if is_dataset: 

139 if steps_per_epoch is None: 

140 reset_dataset_after_each_epoch = True 

141 steps_per_epoch = training_utils_v1.infer_steps_for_dataset( 

142 model, 

143 inputs, 

144 steps_per_epoch, 

145 epochs=epochs, 

146 steps_name=steps_name, 

147 ) 

148 input_iterator = _get_iterator(inputs, model._distribution_strategy) 

149 

150 # Enter tf.distribute.Strategy scope. 

151 if model._distribution_strategy: 

152 scope = distributed_training_utils_v1.distributed_scope( 

153 strategy=model._distribution_strategy, 

154 learning_phase=(1 if mode == ModeKeys.TRAIN else 0), 

155 ) 

156 scope.__enter__() 

157 

158 use_steps = is_dataset or steps_per_epoch is not None 

159 do_validation = val_inputs is not None 

160 

161 # Prepare input data. 

162 inputs = input_iterator or inputs 

163 if validation_in_fit and prepared_feed_values_from_dataset: 

164 # When invoking validation in training loop, avoid creating iterator and 

165 # list of feed values for the same validation dataset multiple times 

166 # (which essentially would call `iterator.get_next()` that slows down 

167 # execution and leads to OOM errors eventually. 

168 ins = inputs 

169 else: 

170 ins = _prepare_feed_values(model, inputs, targets, sample_weights, mode) 

171 # `ins` is a function when a distribute strategy is used in Eager mode. 

172 # In that case `is_dataset` is True. The code branches that have 

173 # requirements about the type of `ins` do not trigger in the distributed 

174 # case. 

175 

176 if not is_dataset: 

177 num_samples_or_steps = _get_num_samples_or_steps( 

178 ins, batch_size, steps_per_epoch 

179 ) 

180 else: 

181 num_samples_or_steps = steps_per_epoch 

182 

183 # Update sample_weight_mode of the model if sample_weights is specified by 

184 # the user. We need to call this function after we have a handle on the 

185 # inputs (both numpy arrays and datasets) in order to determine if the user 

186 # has specified sample_weights. 

187 _update_sample_weight_mode(model, mode, ins) 

188 

189 # Get step function and loop type. As part of building the execution 

190 # function we recompile the metrics based on the updated 

191 # sample_weight_mode value. 

192 f = _make_execution_function(model, mode) 

193 

194 # Prepare validation data. Hold references to the iterator and the input 

195 # list to properly reinitialize and reuse in multiple validation passes. 

196 val_iterator = None 

197 if isinstance(val_inputs, (tf.compat.v1.data.Dataset, tf.data.Dataset)): 

198 if validation_steps is None: 

199 # Because we pass an iterator feed instead of a Dataset to the eval 

200 # model_iteration() call, it will not trigger the dataset-input path 

201 # that determines the number of steps required. To avoid this issue, 

202 # set validation_steps here if validation_steps is None. 

203 validation_steps = training_utils_v1.infer_steps_for_dataset( 

204 model, 

205 val_inputs, 

206 validation_steps, 

207 epochs=epochs, 

208 steps_name="validation_steps", 

209 ) 

210 val_iterator = _get_iterator(val_inputs, model._distribution_strategy) 

211 val_inputs = _prepare_feed_values( 

212 model, val_iterator, val_targets, val_sample_weights, ModeKeys.TEST 

213 ) 

214 # Get num steps for printing. 

215 val_samples_or_steps = validation_steps 

216 else: 

217 # Get num samples for printing. 

218 val_samples_or_steps = ( 

219 val_inputs and tf.nest.flatten(val_inputs)[0].shape[0] or None 

220 ) 

221 

222 if mode == ModeKeys.TRAIN and verbose: 

223 _print_train_info( 

224 num_samples_or_steps, val_samples_or_steps, is_dataset 

225 ) 

226 

227 # Configure callbacks. 

228 count_mode = "steps" if use_steps else "samples" 

229 callbacks = cbks.configure_callbacks( 

230 callbacks, 

231 model, 

232 do_validation=do_validation, 

233 batch_size=batch_size, 

234 epochs=epochs, 

235 steps_per_epoch=steps_per_epoch, 

236 samples=num_samples_or_steps, 

237 count_mode=count_mode, 

238 verbose=verbose, 

239 mode=mode, 

240 ) 

241 

242 # Find beforehand arrays that need sparse-to-dense conversion. 

243 if issparse is not None and not use_steps: 

244 indices_for_conversion_to_dense = [] 

245 feed = _get_model_feed(model, mode) 

246 for i, (input_data, feed_tensor) in enumerate(zip(ins, feed)): 

247 if issparse(input_data) and not backend.is_sparse(feed_tensor): 

248 indices_for_conversion_to_dense.append(i) 

249 

250 # Select aggregation method. 

251 if mode == ModeKeys.PREDICT: 

252 aggregator = training_utils_v1.OutputsAggregator( 

253 use_steps, 

254 num_samples=None if steps_per_epoch else num_samples_or_steps, 

255 steps=steps_per_epoch, 

256 ) 

257 else: 

258 aggregator = training_utils_v1.MetricsAggregator( 

259 use_steps, 

260 num_samples=None if steps_per_epoch else num_samples_or_steps, 

261 steps=steps_per_epoch, 

262 ) 

263 

264 if model._compile_distribution: 

265 distributed_training_utils_v1._copy_weights_to_distributed_model( 

266 model, mode 

267 ) 

268 

269 callbacks.model.stop_training = False 

270 callbacks._call_begin_hook(mode) 

271 

272 initial_epoch = model._maybe_load_initial_epoch_from_ckpt( 

273 initial_epoch, mode 

274 ) 

275 

276 for epoch in range(initial_epoch, epochs): 

277 if callbacks.model.stop_training: 

278 break 

279 

280 # Setup work for each epoch 

281 epoch_logs = {} 

282 if mode != ModeKeys.PREDICT: 

283 # Collecting and resetting metrics has non-zero cost and will 

284 # needlessly slow down model.predict. 

285 model.reset_metrics() 

286 if mode == ModeKeys.TRAIN: 

287 callbacks.on_epoch_begin(epoch, epoch_logs) 

288 

289 if use_steps: 

290 # Step-wise loop. 

291 if steps_per_epoch is None: 

292 # Loop over dataset until `OutOfRangeError` is raised. 

293 target_steps = np.inf 

294 else: 

295 # Loop over dataset for the specified number of steps. 

296 target_steps = steps_per_epoch 

297 

298 step = 0 

299 while step < target_steps: 

300 batch_logs = {"batch": step, "size": 1} 

301 callbacks._call_batch_hook(mode, "begin", step, batch_logs) 

302 

303 # Get outputs. 

304 try: 

305 # `ins` can be callable in tf.distribute.Strategy + eager 

306 # case. 

307 if not callable(ins) or ( 

308 model._distribution_strategy 

309 and not distributed_training_utils_v1.is_distributing_by_cloning( # noqa: E501 

310 model 

311 ) 

312 ): 

313 actual_inputs = ins 

314 else: 

315 actual_inputs = ins() 

316 batch_outs = f(actual_inputs) 

317 except tf.errors.OutOfRangeError: 

318 if is_dataset: 

319 # The dataset passed by the user ran out of batches. 

320 # Now we know the cardinality of the dataset. If 

321 # steps_per_epoch was specified, then running out of 

322 # data is unexpected, so we stop training and inform the 

323 # user. 

324 if steps_per_epoch: 

325 callbacks.model.stop_training = True 

326 logging.warning( 

327 "Your dataset ran out of data; interrupting " 

328 "training. Make sure that your dataset can " 

329 "generate at least `%s * epochs` batches (in " 

330 "this case, %d batches). You may need to use " 

331 "the repeat() function when building your " 

332 "dataset." 

333 % (steps_name, steps_per_epoch * epochs) 

334 ) 

335 elif step > 0: 

336 steps_per_epoch = step 

337 aggregator.steps = steps_per_epoch 

338 else: 

339 # We ran out of batches while the user passed an 

340 # iterator (legacy). 

341 callbacks.model.stop_training = True 

342 logging.warning( 

343 "Your dataset iterator ran out of data; " 

344 "interrupting training. Make sure that your " 

345 "iterator can generate at least `%s * epochs` " 

346 "batches (in this case, %d batches). You may need " 

347 "to use the repeat() function when building your " 

348 "dataset." % (steps_name, steps_per_epoch * epochs) 

349 ) 

350 break 

351 

352 if not isinstance(batch_outs, list): 

353 batch_outs = [batch_outs] 

354 

355 if model._distribution_strategy: 

356 batch_outs = distributed_training_utils_v1._per_replica_aggregate_batch( # noqa: E501 

357 model._distribution_strategy, batch_outs, model, mode 

358 ) 

359 

360 # Aggregate results. 

361 if step == 0: 

362 aggregator.create(batch_outs) 

363 aggregator.aggregate(batch_outs) 

364 

365 # Callbacks batch end. 

366 batch_logs = callbacks.make_logs( 

367 model, batch_logs, batch_outs, mode 

368 ) 

369 callbacks._call_batch_hook(mode, "end", step, batch_logs) 

370 step += 1 

371 

372 if callbacks.model.stop_training: 

373 break 

374 else: 

375 # Sample-wise loop. 

376 index_array = np.arange(num_samples_or_steps) 

377 if shuffle == "batch": 

378 index_array = training_utils_v1.batch_shuffle( 

379 index_array, batch_size 

380 ) 

381 elif shuffle: 

382 np.random.shuffle(index_array) 

383 batches = make_batches(num_samples_or_steps, batch_size) 

384 for batch_index, (batch_start, batch_end) in enumerate(batches): 

385 batch_ids = index_array[batch_start:batch_end] 

386 # Slice into a batch. 

387 if len(batches) == 1: 

388 # If we only have one batch, do not slice. This takes care 

389 # of composite tensors in non-Dataset modes; we currently 

390 # don't support slicing them. 

391 # TODO(b/133517906): Add slicing support. 

392 ins_batch = ins 

393 else: 

394 try: 

395 if ins and isinstance(ins[-1], int): 

396 # Do not slice the training phase flag. 

397 ins_batch = slice_arrays(ins[:-1], batch_ids) + [ 

398 ins[-1] 

399 ] 

400 else: 

401 ins_batch = slice_arrays(ins, batch_ids) 

402 except TypeError: 

403 raise TypeError( 

404 "TypeError while preparing batch. " 

405 "If using HDF5 input data, " 

406 'pass shuffle="batch".' 

407 ) 

408 

409 # Sparse to dense conversion. 

410 if issparse is not None: 

411 for i in indices_for_conversion_to_dense: 

412 ins_batch[i] = ins_batch[i].toarray() 

413 

414 # Callbacks batch_begin. 

415 batch_logs = {"batch": batch_index, "size": len(batch_ids)} 

416 callbacks._call_batch_hook( 

417 mode, "begin", batch_index, batch_logs 

418 ) 

419 

420 # Get outputs. 

421 batch_outs = f(ins_batch) 

422 if not isinstance(batch_outs, list): 

423 batch_outs = [batch_outs] 

424 

425 # Aggregate results. 

426 if batch_index == 0: 

427 aggregator.create(batch_outs) 

428 aggregator.aggregate(batch_outs, batch_start, batch_end) 

429 

430 # Callbacks batch end. 

431 batch_logs = callbacks.make_logs( 

432 model, batch_logs, batch_outs, mode 

433 ) 

434 callbacks._call_batch_hook(mode, "end", batch_index, batch_logs) 

435 

436 if callbacks.model.stop_training: 

437 break 

438 

439 aggregator.finalize() 

440 results = aggregator.results 

441 epoch_logs = callbacks.make_logs(model, epoch_logs, results, mode) 

442 if len(results) == 1: 

443 results = results[0] 

444 

445 # Run the test loop every `validation_freq` epochs during training. 

446 if ( 

447 do_validation 

448 and training_utils_v1.should_run_validation(validation_freq, epoch) 

449 and not callbacks.model.stop_training 

450 ): 

451 

452 if model._compile_distribution: 

453 # Since we create a new clone from the original model we need to 

454 # copy the weights back to the original model before we can run 

455 # validation. 

456 distributed_training_utils_v1._copy_weights_to_original_model( 

457 model, ModeKeys.TRAIN 

458 ) 

459 

460 val_results = model_iteration( 

461 model, 

462 val_inputs, 

463 targets=val_targets, 

464 sample_weights=val_sample_weights, 

465 batch_size=batch_size, 

466 steps_per_epoch=validation_steps, 

467 callbacks=callbacks, 

468 verbose=0, 

469 mode=ModeKeys.TEST, 

470 validation_in_fit=True, 

471 prepared_feed_values_from_dataset=(val_iterator is not None), 

472 steps_name="validation_steps", 

473 ) 

474 if not isinstance(val_results, list): 

475 val_results = [val_results] 

476 epoch_logs = callbacks.make_logs( 

477 model, epoch_logs, val_results, mode, prefix="val_" 

478 ) 

479 if val_iterator and epoch < epochs - 1: 

480 _reinitialize_iterator( 

481 val_iterator, model._distribution_strategy 

482 ) 

483 

484 if mode == ModeKeys.TRAIN: 

485 # Epochs only apply to `fit`. 

486 callbacks.on_epoch_end(epoch, epoch_logs) 

487 

488 # Reinitialize dataset iterator for the next epoch. 

489 if reset_dataset_after_each_epoch and epoch < epochs - 1: 

490 _reinitialize_iterator(input_iterator, model._distribution_strategy) 

491 

492 model._successful_loop_finish = True 

493 callbacks._call_end_hook(mode) 

494 

495 if model._distribution_strategy: 

496 if model._compile_distribution: 

497 # TODO(priyag, psv): Copy back metrics to the original model as 

498 # well? 

499 distributed_training_utils_v1._copy_weights_to_original_model( 

500 model, mode 

501 ) 

502 scope.__exit__(None, None, None) 

503 

504 if mode == ModeKeys.TRAIN: 

505 return model.history 

506 return results 

507 

508 

509def _get_model_feed(model, mode): 

510 if mode == ModeKeys.PREDICT: 

511 feed = model._feed_inputs 

512 else: 

513 feed = ( 

514 model._feed_inputs 

515 + model._feed_targets 

516 + model._feed_sample_weights 

517 ) 

518 return feed 

519 

520 

521def _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset): 

522 increment = "steps" if is_dataset else "samples" 

523 msg = f"Train on {num_samples_or_steps} {increment}" 

524 if val_samples_or_steps: 

525 msg += f", validate on {val_samples_or_steps} {increment}" 

526 io_utils.print_msg(msg) 

527 

528 

529def _get_num_samples_or_steps(ins, batch_size, steps_per_epoch): 

530 """Returns total number of samples when training in batch mode or steps.""" 

531 if steps_per_epoch: 

532 return steps_per_epoch 

533 return training_utils_v1.check_num_samples( 

534 ins, batch_size, steps_per_epoch, "steps_per_epoch" 

535 ) 

536 

537 

538def _prepare_feed_values(model, inputs, targets, sample_weights, mode): 

539 """Prepare feed values to the model execution function. 

540 

541 Args: 

542 model: Model to prepare feed values for. 

543 inputs: List or dict of model inputs. 

544 targets: Optional list of model targets. 

545 sample_weights: Optional list of sample weight arrays. 

546 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 

547 

548 Returns: 

549 Feed values for the model in the given mode. 

550 """ 

551 if model._distribution_strategy: 

552 if isinstance(inputs, (tf.compat.v1.data.Dataset, tf.data.Dataset)): 

553 inputs = distributed_training_utils_v1.get_iterator( 

554 inputs, model._distribution_strategy 

555 ) 

556 

557 def get_distributed_inputs(): 

558 return distributed_training_utils_v1._prepare_feed_values( 

559 model, inputs, targets, sample_weights, mode 

560 ) 

561 

562 # In the eager case, we want to call the input method per step, so 

563 # return a lambda from here that can be called. Note that this is 

564 # applicable only in Distribution Strategy case as it follows the same 

565 # code path for both eager and graph modes. 

566 # TODO(priyag,omalleyt): Either we should move the training DS with 

567 # IteratorBase to use training_generator code path, or figure out how to 

568 # set a symbolic Iterator out of a Dataset when in eager mode. 

569 if tf.executing_eagerly(): 

570 return get_distributed_inputs 

571 else: 

572 return get_distributed_inputs() 

573 

574 if isinstance( 

575 inputs, 

576 ( 

577 tf.compat.v1.data.Dataset, 

578 tf.data.Dataset, 

579 tf.compat.v1.data.Iterator, 

580 ), 

581 ): 

582 inputs, targets, sample_weights = model._standardize_user_data( 

583 inputs, extract_tensors_from_dataset=True 

584 ) 

585 

586 inputs = training_utils_v1.ModelInputs(inputs).as_list() 

587 targets = list(targets or []) 

588 sample_weights = list(sample_weights or []) 

589 ins = inputs + targets + sample_weights 

590 if mode == ModeKeys.TRAIN and not isinstance( 

591 backend.symbolic_learning_phase(), int 

592 ): 

593 ins += [True] # Add learning phase value. 

594 return ins 

595 

596 

597def _get_iterator(inputs, distribution_strategy=None): 

598 if distribution_strategy: 

599 return distributed_training_utils_v1.get_iterator( 

600 inputs, distribution_strategy 

601 ) 

602 return training_utils_v1.get_iterator(inputs) 

603 

604 

605def _reinitialize_iterator(iterator, distribution_strategy=None): 

606 if distribution_strategy: 

607 distributed_training_utils_v1.initialize_iterator( 

608 iterator, distribution_strategy 

609 ) 

610 else: 

611 training_utils_v1.initialize_iterator(iterator) 

612 

613 

614def _make_execution_function(model, mode): 

615 """Makes function to run one step of model execution.""" 

616 if model._distribution_strategy: 

617 return distributed_training_utils_v1._make_execution_function( 

618 model, mode 

619 ) 

620 return model._make_execution_function(mode) 

621 

622 

623def _update_sample_weight_mode(model, mode, inputs): 

624 """Updates the sample_weight_mode of a given model.""" 

625 # Add a quick return to prevent us from calling model._feed_targets that 

626 # accesses certain model properties that may not be set in the `PREDICT` 

627 # mode. 

628 if mode == ModeKeys.PREDICT: 

629 return 

630 

631 sample_weights = None 

632 # `inputs` is the model's inputs + targets + sample_weights + 

633 # learning phase placeholder if specified. To update the sample_weight_mode 

634 # we need to determine if the user has passed sample weights as part of the 

635 # input. 

636 if not callable(inputs): 

637 sample_weights = inputs[ 

638 len(model._feed_inputs) + len(model._feed_targets) : 

639 ] 

640 has_learning_phase_pl = mode == ModeKeys.TRAIN and not isinstance( 

641 backend.symbolic_learning_phase(), int 

642 ) 

643 if has_learning_phase_pl: 

644 sample_weights = sample_weights[:-1] 

645 model._update_sample_weight_modes(sample_weights=sample_weights) 

646 

647 # Call the DistributionStrategy specific function to update the 

648 # sample_weight_mode on the model. 

649 if model._distribution_strategy: 

650 distributed_training_utils_v1._update_sample_weight_modes( 

651 model, mode, sample_weights 

652 ) 

653 

654 

655# For backwards compatibility for internal users of these loops. 

656fit_loop = functools.partial(model_iteration, mode=ModeKeys.TRAIN) 

657test_loop = functools.partial( 

658 model_iteration, mode=ModeKeys.TEST, shuffle=False 

659) 

660predict_loop = functools.partial( 

661 model_iteration, mode=ModeKeys.PREDICT, shuffle=False 

662) 

663 

664 

665class ArrayLikeTrainingLoop(training_utils_v1.TrainingLoop): 

666 """TrainingLoop that handle inputs like array. 

667 

668 This is the default handler for most of the input data types, includes 

669 symbolic tensors or Numpy array-like, Datasets and iterators in graph mode 

670 (since they generate symbolic tensors). This Function is used to handle 

671 model with `run_eagerly` = False. 

672 """ 

673 

674 def fit( 

675 self, 

676 model, 

677 x=None, 

678 y=None, 

679 batch_size=None, 

680 epochs=1, 

681 verbose=1, 

682 callbacks=None, 

683 validation_split=0.0, 

684 validation_data=None, 

685 shuffle=True, 

686 class_weight=None, 

687 sample_weight=None, 

688 initial_epoch=0, 

689 steps_per_epoch=None, 

690 validation_steps=None, 

691 validation_freq=1, 

692 **kwargs, 

693 ): 

694 batch_size = model._validate_or_infer_batch_size( 

695 batch_size, steps_per_epoch, x 

696 ) 

697 

698 x, y, sample_weights = model._standardize_user_data( 

699 x, 

700 y, 

701 sample_weight=sample_weight, 

702 class_weight=class_weight, 

703 batch_size=batch_size, 

704 check_steps=True, 

705 steps_name="steps_per_epoch", 

706 steps=steps_per_epoch, 

707 validation_split=validation_split, 

708 shuffle=shuffle, 

709 ) 

710 

711 if validation_data: 

712 val_x, val_y, val_sample_weights = model._prepare_validation_data( 

713 validation_data, batch_size, validation_steps 

714 ) 

715 elif validation_split and 0.0 < validation_split < 1.0: 

716 ( 

717 x, 

718 y, 

719 sample_weights, 

720 val_x, 

721 val_y, 

722 val_sample_weights, 

723 ) = training_utils_v1.split_training_and_validation_data( 

724 x, y, sample_weights, validation_split 

725 ) 

726 else: 

727 if validation_steps: 

728 raise ValueError( 

729 "`validation_steps` should not be specified if " 

730 "`validation_data` is None." 

731 ) 

732 val_x, val_y, val_sample_weights = None, None, None 

733 

734 return fit_loop( 

735 model, 

736 inputs=x, 

737 targets=y, 

738 sample_weights=sample_weights, 

739 batch_size=batch_size, 

740 epochs=epochs, 

741 verbose=verbose, 

742 callbacks=callbacks, 

743 val_inputs=val_x, 

744 val_targets=val_y, 

745 val_sample_weights=val_sample_weights, 

746 shuffle=shuffle, 

747 initial_epoch=initial_epoch, 

748 steps_per_epoch=steps_per_epoch, 

749 validation_steps=validation_steps, 

750 validation_freq=validation_freq, 

751 steps_name="steps_per_epoch", 

752 ) 

753 

754 def evaluate( 

755 self, 

756 model, 

757 x=None, 

758 y=None, 

759 batch_size=None, 

760 verbose=1, 

761 sample_weight=None, 

762 steps=None, 

763 callbacks=None, 

764 **kwargs, 

765 ): 

766 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 

767 x, y, sample_weights = model._standardize_user_data( 

768 x, 

769 y, 

770 sample_weight=sample_weight, 

771 batch_size=batch_size, 

772 check_steps=True, 

773 steps_name="steps", 

774 steps=steps, 

775 ) 

776 return test_loop( 

777 model, 

778 inputs=x, 

779 targets=y, 

780 sample_weights=sample_weights, 

781 batch_size=batch_size, 

782 verbose=verbose, 

783 steps=steps, 

784 callbacks=callbacks, 

785 ) 

786 

787 def predict( 

788 self, 

789 model, 

790 x, 

791 batch_size=None, 

792 verbose=0, 

793 steps=None, 

794 callbacks=None, 

795 **kwargs, 

796 ): 

797 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 

798 x, _, _ = model._standardize_user_data( 

799 x, check_steps=True, steps_name="steps", steps=steps 

800 ) 

801 return predict_loop( 

802 model, 

803 x, 

804 batch_size=batch_size, 

805 verbose=verbose, 

806 steps=steps, 

807 callbacks=callbacks, 

808 ) 

809