Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_arrays_v1.py: 14%

255 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Part of the Keras training engine related to plain array data.""" 

16# pylint: disable=protected-access 

17 

18import functools 

19 

20import numpy as np 

21 

22from tensorflow.python.data.ops import iterator_ops 

23from tensorflow.python.eager import context 

24from tensorflow.python.framework import errors 

25from tensorflow.python.keras import backend 

26from tensorflow.python.keras import callbacks as cbks 

27from tensorflow.python.keras.distribute import distributed_training_utils_v1 

28from tensorflow.python.keras.engine import training_utils_v1 

29from tensorflow.python.keras.utils.generic_utils import make_batches 

30from tensorflow.python.keras.utils.generic_utils import slice_arrays 

31from tensorflow.python.keras.utils.mode_keys import ModeKeys 

32from tensorflow.python.platform import tf_logging as logging 

33from tensorflow.python.types import data as data_types 

34from tensorflow.python.util import nest 

35 

36try: 

37 from scipy.sparse import issparse # pylint: disable=g-import-not-at-top 

38except ImportError: 

39 issparse = None 

40 

41 

42def model_iteration(model, 

43 inputs, 

44 targets=None, 

45 sample_weights=None, 

46 batch_size=None, 

47 epochs=1, 

48 verbose=1, 

49 callbacks=None, 

50 val_inputs=None, 

51 val_targets=None, 

52 val_sample_weights=None, 

53 shuffle=True, 

54 initial_epoch=0, 

55 steps_per_epoch=None, 

56 validation_steps=None, 

57 validation_freq=1, 

58 mode=ModeKeys.TRAIN, 

59 validation_in_fit=False, 

60 prepared_feed_values_from_dataset=False, 

61 steps_name='steps', 

62 **kwargs): 

63 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT. 

64 

65 Args: 

66 model: Keras Model instance. 

67 inputs: Either a list or dictionary of arrays, or a dataset instance. 

68 targets: List/dictionary of input arrays. 

69 sample_weights: Optional list of sample weight arrays. 

70 batch_size: Integer batch size or None if unknown. 

71 epochs: Number of times to iterate over the data 

72 verbose: 0, 1, or 2. Verbosity mode. 

73 0 = silent, 1 = progress bar, 2 = one line per epoch. 

74 Note that the progress bar is not particularly useful when 

75 logged to a file, so verbose=2 is recommended when not running 

76 interactively (eg, in a production environment). 

77 callbacks: List of callbacks to be called during training 

78 val_inputs: Either a list or dictionary of arrays, or a dataset instance. 

79 val_targets: List/dictionary of target arrays. 

80 val_sample_weights: Optional list of sample weight arrays. 

81 shuffle: Whether to shuffle the data at the beginning of each epoch 

82 concatenation of list the display names of the outputs of `f` and the 

83 list of display names of the outputs of `f_val`. 

84 initial_epoch: Epoch at which to start training (useful for resuming a 

85 previous training run) 

86 steps_per_epoch: Total number of steps (batches of samples) before 

87 declaring one epoch finished and starting the next epoch. Ignored with 

88 the default value of `None`. 

89 validation_steps: Number of steps to run validation for (only if doing 

90 validation from data tensors). Ignored with the default value of 

91 `None`. 

92 validation_freq: Only relevant if validation data is provided. Integer or 

93 `collections.abc.Container` instance (e.g. list, tuple, etc.). If an 

94 integer, specifies how many training epochs to run before a new 

95 validation run is performed, e.g. `validation_freq=2` runs 

96 validation every 2 epochs. If a Container, specifies the epochs on 

97 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 

98 validation at the end of the 1st, 2nd, and 10th epochs. 

99 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 

100 validation_in_fit: if true, then this method is invoked from within 

101 training iteration (for validation). In the case where `val_inputs` is 

102 a dataset, this flag indicates that its iterator and feed values are 

103 already created so should properly reuse resources. 

104 prepared_feed_values_from_dataset: if True, `inputs` is a list of feed 

105 tensors returned from `_prepare_feed_values` call on the validation 

106 dataset, so do not call it again on `inputs`. Should only be used for 

107 inline validation (i.e., only if `validation_in_fit` is also True). 

108 steps_name: The string name of the steps argument, either `steps`, 

109 `validation_steps`, or `steps_per_epoch`. Only used for error message 

110 formatting. 

111 **kwargs: Additional arguments for backwards compatibility. 

112 

113 Returns: 

114 - In TRAIN mode: `History` object. 

115 - In TEST mode: Evaluation metrics. 

116 - In PREDICT mode: Outputs of the Model called on inputs. 

117 

118 Raises: 

119 ValueError: in case of invalid arguments. 

120 """ 

121 # Backwards compatibility. 

122 if 'steps' in kwargs: 

123 steps_per_epoch = kwargs.pop('steps') 

124 if kwargs: 

125 raise TypeError('Unknown arguments: %s' % (kwargs,)) 

126 

127 # In case we were passed a dataset, we extract symbolic tensors from it. 

128 reset_dataset_after_each_epoch = False 

129 input_iterator = None 

130 is_dataset = isinstance(inputs, 

131 (data_types.DatasetV1, data_types.DatasetV2)) 

132 # TODO(fchollet): consider moving `steps_per_epoch` inference to 

133 # _standardize_user_data and set reset_dataset_after_each_epoch as an 

134 # attribute on the dataset instance. 

135 if is_dataset: 

136 if steps_per_epoch is None: 

137 reset_dataset_after_each_epoch = True 

138 steps_per_epoch = training_utils_v1.infer_steps_for_dataset( 

139 model, inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name) 

140 input_iterator = _get_iterator(inputs, model._distribution_strategy) 

141 

142 # Enter tf.distribute.Strategy scope. 

143 if model._distribution_strategy: 

144 scope = distributed_training_utils_v1.distributed_scope( 

145 strategy=model._distribution_strategy, 

146 learning_phase=(1 if mode == ModeKeys.TRAIN else 0)) 

147 scope.__enter__() 

148 

149 use_steps = is_dataset or steps_per_epoch is not None 

150 do_validation = val_inputs is not None 

151 

152 # Prepare input data. 

153 inputs = input_iterator or inputs 

154 if validation_in_fit and prepared_feed_values_from_dataset: 

155 # When invoking validation in training loop, avoid creating iterator and 

156 # list of feed values for the same validation dataset multiple times (which 

157 # essentially would call `iterator.get_next()` that slows down execution and 

158 # leads to OOM errors eventually. 

159 ins = inputs 

160 else: 

161 ins = _prepare_feed_values(model, inputs, targets, sample_weights, mode) 

162 # `ins` is a function when a distribute strategy is used in Eager mode. In 

163 # that case `is_dataset` is True. The code branches that have requirements 

164 # about the type of `ins` do not trigger in the distributed case. 

165 

166 if not is_dataset: 

167 num_samples_or_steps = _get_num_samples_or_steps(ins, batch_size, 

168 steps_per_epoch) 

169 else: 

170 num_samples_or_steps = steps_per_epoch 

171 

172 # Update sample_weight_mode of the model if sample_weights is specified by the 

173 # user. We need to call this function after we have a handle on the inputs 

174 # (both numpy arrays and datasets) in order to determine if the user has 

175 # specified sample_weights. 

176 _update_sample_weight_mode(model, mode, ins) 

177 

178 # Get step function and loop type. As part of building the execution 

179 # function we recompile the metrics based on the updated 

180 # sample_weight_mode value. 

181 f = _make_execution_function(model, mode) 

182 

183 # Prepare validation data. Hold references to the iterator and the input list 

184 # to properly reinitialize and reuse in multiple validation passes. 

185 val_iterator = None 

186 if isinstance(val_inputs, (data_types.DatasetV1, data_types.DatasetV2)): 

187 if validation_steps is None: 

188 # Because we pass an iterator feed instead of a Dataset to the eval 

189 # model_iteration() call, it will not trigger the dataset-input path 

190 # that determines the number of steps required. To avoid this issue, 

191 # set validation_steps here if validation_steps is None. 

192 validation_steps = training_utils_v1.infer_steps_for_dataset( 

193 model, 

194 val_inputs, 

195 validation_steps, 

196 epochs=epochs, 

197 steps_name='validation_steps') 

198 val_iterator = _get_iterator(val_inputs, model._distribution_strategy) 

199 val_inputs = _prepare_feed_values( 

200 model, val_iterator, val_targets, val_sample_weights, ModeKeys.TEST) 

201 # Get num steps for printing. 

202 val_samples_or_steps = validation_steps 

203 else: 

204 # Get num samples for printing. 

205 val_samples_or_steps = val_inputs and nest.flatten( 

206 val_inputs)[0].shape[0] or None 

207 

208 if mode == ModeKeys.TRAIN and verbose: 

209 _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset) 

210 

211 # Configure callbacks. 

212 count_mode = 'steps' if use_steps else 'samples' 

213 callbacks = cbks.configure_callbacks( 

214 callbacks, 

215 model, 

216 do_validation=do_validation, 

217 batch_size=batch_size, 

218 epochs=epochs, 

219 steps_per_epoch=steps_per_epoch, 

220 samples=num_samples_or_steps, 

221 count_mode=count_mode, 

222 verbose=verbose, 

223 mode=mode) 

224 

225 # Find beforehand arrays that need sparse-to-dense conversion. 

226 if issparse is not None and not use_steps: 

227 indices_for_conversion_to_dense = [] 

228 feed = _get_model_feed(model, mode) 

229 for i, (input_data, feed_tensor) in enumerate(zip(ins, feed)): 

230 if issparse(input_data) and not backend.is_sparse(feed_tensor): 

231 indices_for_conversion_to_dense.append(i) 

232 

233 # Select aggregation method. 

234 if mode == ModeKeys.PREDICT: 

235 aggregator = training_utils_v1.OutputsAggregator( 

236 use_steps, 

237 num_samples=None if steps_per_epoch else num_samples_or_steps, 

238 steps=steps_per_epoch) 

239 else: 

240 aggregator = training_utils_v1.MetricsAggregator( 

241 use_steps, 

242 num_samples=None if steps_per_epoch else num_samples_or_steps, 

243 steps=steps_per_epoch) 

244 

245 if model._compile_distribution: 

246 distributed_training_utils_v1._copy_weights_to_distributed_model( 

247 model, mode) 

248 

249 callbacks.model.stop_training = False 

250 callbacks._call_begin_hook(mode) 

251 

252 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode) 

253 

254 for epoch in range(initial_epoch, epochs): 

255 if callbacks.model.stop_training: 

256 break 

257 

258 # Setup work for each epoch 

259 epoch_logs = {} 

260 if mode != ModeKeys.PREDICT: 

261 # Collecting and resetting metrics has non-zero cost and will needlessly 

262 # slow down model.predict. 

263 model.reset_metrics() 

264 if mode == ModeKeys.TRAIN: 

265 callbacks.on_epoch_begin(epoch, epoch_logs) 

266 

267 if use_steps: 

268 # Step-wise loop. 

269 if steps_per_epoch is None: 

270 # Loop over dataset until `OutOfRangeError` is raised. 

271 target_steps = np.inf 

272 else: 

273 # Loop over dataset for the specified number of steps. 

274 target_steps = steps_per_epoch 

275 

276 step = 0 

277 while step < target_steps: 

278 batch_logs = {'batch': step, 'size': 1} 

279 callbacks._call_batch_hook(mode, 'begin', step, batch_logs) 

280 

281 # Get outputs. 

282 try: 

283 # `ins` can be callable in tf.distribute.Strategy + eager case. 

284 if not callable(ins) or (model._distribution_strategy and 

285 not distributed_training_utils_v1 

286 .is_distributing_by_cloning(model)): 

287 actual_inputs = ins 

288 else: 

289 actual_inputs = ins() 

290 batch_outs = f(actual_inputs) 

291 except errors.OutOfRangeError: 

292 if is_dataset: 

293 # The dataset passed by the user ran out of batches. 

294 # Now we know the cardinality of the dataset. 

295 # If steps_per_epoch was specified, then running out of data is 

296 # unexpected, so we stop training and inform the user. 

297 if steps_per_epoch: 

298 callbacks.model.stop_training = True 

299 logging.warning( 

300 'Your dataset ran out of data; interrupting training. ' 

301 'Make sure that your dataset can generate at least ' 

302 '`%s * epochs` batches (in this case, %d batches). ' 

303 'You may need to use the repeat() function when ' 

304 'building your dataset.' 

305 % (steps_name, steps_per_epoch * epochs)) 

306 elif step > 0: 

307 steps_per_epoch = step 

308 aggregator.steps = steps_per_epoch 

309 else: 

310 # We ran out of batches while the user passed an iterator (legacy). 

311 callbacks.model.stop_training = True 

312 logging.warning( 

313 'Your dataset iterator ran out of data; ' 

314 'interrupting training. Make sure that your iterator ' 

315 'can generate at least `%s * epochs` ' 

316 'batches (in this case, %d batches). You may need to' 

317 'use the repeat() function when building your ' 

318 'dataset.' % (steps_name, steps_per_epoch * epochs)) 

319 break 

320 

321 if not isinstance(batch_outs, list): 

322 batch_outs = [batch_outs] 

323 

324 if model._distribution_strategy: 

325 batch_outs = ( 

326 distributed_training_utils_v1._per_replica_aggregate_batch( 

327 model._distribution_strategy, batch_outs, model, mode)) 

328 

329 # Aggregate results. 

330 if step == 0: 

331 aggregator.create(batch_outs) 

332 aggregator.aggregate(batch_outs) 

333 

334 # Callbacks batch end. 

335 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 

336 callbacks._call_batch_hook(mode, 'end', step, batch_logs) 

337 step += 1 

338 

339 if callbacks.model.stop_training: 

340 break 

341 else: 

342 # Sample-wise loop. 

343 index_array = np.arange(num_samples_or_steps) 

344 if shuffle == 'batch': 

345 index_array = training_utils_v1.batch_shuffle(index_array, batch_size) 

346 elif shuffle: 

347 np.random.shuffle(index_array) 

348 batches = make_batches(num_samples_or_steps, batch_size) 

349 for batch_index, (batch_start, batch_end) in enumerate(batches): 

350 batch_ids = index_array[batch_start:batch_end] 

351 # Slice into a batch. 

352 if len(batches) == 1: 

353 # If we only have one batch, do not slice. This takes care of 

354 # composite tensors in non-Dataset modes; we currently don't support 

355 # slicing them. 

356 # TODO(b/133517906): Add slicing support. 

357 ins_batch = ins 

358 else: 

359 try: 

360 if ins and isinstance(ins[-1], int): 

361 # Do not slice the training phase flag. 

362 ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] 

363 else: 

364 ins_batch = slice_arrays(ins, batch_ids) 

365 except TypeError: 

366 raise TypeError('TypeError while preparing batch. ' 

367 'If using HDF5 input data, ' 

368 'pass shuffle="batch".') 

369 

370 # Sparse to dense conversion. 

371 if issparse is not None: 

372 for i in indices_for_conversion_to_dense: 

373 ins_batch[i] = ins_batch[i].toarray() 

374 

375 # Callbacks batch_begin. 

376 batch_logs = {'batch': batch_index, 'size': len(batch_ids)} 

377 callbacks._call_batch_hook(mode, 'begin', batch_index, batch_logs) 

378 

379 # Get outputs. 

380 batch_outs = f(ins_batch) 

381 if not isinstance(batch_outs, list): 

382 batch_outs = [batch_outs] 

383 

384 # Aggregate results. 

385 if batch_index == 0: 

386 aggregator.create(batch_outs) 

387 aggregator.aggregate(batch_outs, batch_start, batch_end) 

388 

389 # Callbacks batch end. 

390 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 

391 callbacks._call_batch_hook(mode, 'end', batch_index, batch_logs) 

392 

393 if callbacks.model.stop_training: 

394 break 

395 

396 aggregator.finalize() 

397 results = aggregator.results 

398 epoch_logs = cbks.make_logs(model, epoch_logs, results, mode) 

399 if len(results) == 1: 

400 results = results[0] 

401 

402 # Run the test loop every `validation_freq` epochs during training. 

403 if (do_validation and 

404 training_utils_v1.should_run_validation(validation_freq, epoch) and 

405 not callbacks.model.stop_training): 

406 

407 if model._compile_distribution: 

408 # Since we create a new clone from the original model we need to copy 

409 # the weights back to the original model before we can run validation. 

410 distributed_training_utils_v1._copy_weights_to_original_model( 

411 model, ModeKeys.TRAIN) 

412 

413 val_results = model_iteration( 

414 model, 

415 val_inputs, 

416 targets=val_targets, 

417 sample_weights=val_sample_weights, 

418 batch_size=batch_size, 

419 steps_per_epoch=validation_steps, 

420 callbacks=callbacks, 

421 verbose=0, 

422 mode=ModeKeys.TEST, 

423 validation_in_fit=True, 

424 prepared_feed_values_from_dataset=(val_iterator is not None), 

425 steps_name='validation_steps') 

426 if not isinstance(val_results, list): 

427 val_results = [val_results] 

428 epoch_logs = cbks.make_logs( 

429 model, epoch_logs, val_results, mode, prefix='val_') 

430 if val_iterator and epoch < epochs - 1: 

431 _reinitialize_iterator(val_iterator, model._distribution_strategy) 

432 

433 if mode == ModeKeys.TRAIN: 

434 # Epochs only apply to `fit`. 

435 callbacks.on_epoch_end(epoch, epoch_logs) 

436 

437 # Reinitialize dataset iterator for the next epoch. 

438 if reset_dataset_after_each_epoch and epoch < epochs - 1: 

439 _reinitialize_iterator(input_iterator, model._distribution_strategy) 

440 

441 model._successful_loop_finish = True 

442 callbacks._call_end_hook(mode) 

443 

444 if model._distribution_strategy: 

445 if model._compile_distribution: 

446 # TODO(priyag, psv): Copy back metrics to the original model as well? 

447 distributed_training_utils_v1._copy_weights_to_original_model(model, mode) 

448 scope.__exit__(None, None, None) 

449 

450 if mode == ModeKeys.TRAIN: 

451 return model.history 

452 return results 

453 

454 

455def _get_model_feed(model, mode): 

456 if mode == ModeKeys.PREDICT: 

457 feed = model._feed_inputs 

458 else: 

459 feed = ( 

460 model._feed_inputs + model._feed_targets + model._feed_sample_weights) 

461 return feed 

462 

463 

464def _print_train_info(num_samples_or_steps, val_samples_or_steps, is_dataset): 

465 increment = 'steps' if is_dataset else 'samples' 

466 msg = 'Train on {0} {increment}'.format( 

467 num_samples_or_steps, increment=increment) 

468 if val_samples_or_steps: 

469 msg += ', validate on {0} {increment}'.format( 

470 val_samples_or_steps, increment=increment) 

471 print(msg) 

472 

473 

474def _get_num_samples_or_steps(ins, batch_size, steps_per_epoch): 

475 """Returns total number of samples (when training in batch mode) or steps.""" 

476 if steps_per_epoch: 

477 return steps_per_epoch 

478 return training_utils_v1.check_num_samples(ins, batch_size, steps_per_epoch, 

479 'steps_per_epoch') 

480 

481 

482def _prepare_feed_values(model, inputs, targets, sample_weights, mode): 

483 """Prepare feed values to the model execution function. 

484 

485 Args: 

486 model: Model to prepare feed values for. 

487 inputs: List or dict of model inputs. 

488 targets: Optional list of model targets. 

489 sample_weights: Optional list of sample weight arrays. 

490 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 

491 

492 Returns: 

493 Feed values for the model in the given mode. 

494 """ 

495 if model._distribution_strategy: 

496 if isinstance(inputs, (data_types.DatasetV1, data_types.DatasetV2)): 

497 inputs = distributed_training_utils_v1.get_iterator( 

498 inputs, model._distribution_strategy) 

499 

500 def get_distributed_inputs(): 

501 return distributed_training_utils_v1._prepare_feed_values( 

502 model, inputs, targets, sample_weights, mode) 

503 

504 # In the eager case, we want to call the input method per step, so return 

505 # a lambda from here that can be called. Note that this is applicable only 

506 # in Distribution Strategy case as it follows the same code path for both 

507 # eager and graph modes. 

508 # TODO(priyag,omalleyt): Either we should move the training DS with 

509 # IteratorBase to use training_generator code path, or figure out how to 

510 # set a symbolic Iterator out of a Dataset when in eager mode. 

511 if context.executing_eagerly(): 

512 return get_distributed_inputs 

513 else: 

514 return get_distributed_inputs() 

515 

516 if isinstance(inputs, (data_types.DatasetV1, data_types.DatasetV2, 

517 iterator_ops.Iterator)): 

518 inputs, targets, sample_weights = model._standardize_user_data( 

519 inputs, 

520 extract_tensors_from_dataset=True) 

521 

522 inputs = training_utils_v1.ModelInputs(inputs).as_list() 

523 targets = list(targets or []) 

524 sample_weights = list(sample_weights or []) 

525 ins = inputs + targets + sample_weights 

526 if mode == ModeKeys.TRAIN and not isinstance( 

527 backend.symbolic_learning_phase(), int): 

528 ins += [True] # Add learning phase value. 

529 return ins 

530 

531 

532def _get_iterator(inputs, distribution_strategy=None): 

533 if distribution_strategy: 

534 return distributed_training_utils_v1.get_iterator( 

535 inputs, distribution_strategy) 

536 return training_utils_v1.get_iterator(inputs) 

537 

538 

539def _reinitialize_iterator(iterator, distribution_strategy=None): 

540 if distribution_strategy: 

541 distributed_training_utils_v1.initialize_iterator( 

542 iterator, distribution_strategy) 

543 else: 

544 training_utils_v1.initialize_iterator(iterator) 

545 

546 

547def _make_execution_function(model, mode): 

548 """Makes function to run one step of model execution.""" 

549 if model._distribution_strategy: 

550 return distributed_training_utils_v1._make_execution_function(model, mode) 

551 return model._make_execution_function(mode) 

552 

553 

554def _update_sample_weight_mode(model, mode, inputs): 

555 """Updates the sample_weight_mode of a given model.""" 

556 # Add a quick return to prevent us from calling model._feed_targets that 

557 # accesses certain model properties that may not be set in the `PREDICT` mode. 

558 if mode == ModeKeys.PREDICT: 

559 return 

560 

561 sample_weights = None 

562 # `inputs` is the model's inputs + targets + sample_weights + 

563 # learning phase placeholder if specified. To update the sample_weight_mode 

564 # we need to determine if the user has passed sample weights as part of the 

565 # input. 

566 if not callable(inputs): 

567 sample_weights = inputs[len(model._feed_inputs) + len(model._feed_targets):] 

568 has_learning_phase_pl = (mode == ModeKeys.TRAIN and 

569 not isinstance(backend.symbolic_learning_phase(), 

570 int)) 

571 if has_learning_phase_pl: 

572 sample_weights = sample_weights[:-1] 

573 model._update_sample_weight_modes(sample_weights=sample_weights) 

574 

575 # Call the DistributionStrategy specific function to update the 

576 # sample_weight_mode on the model. 

577 if model._distribution_strategy: 

578 distributed_training_utils_v1._update_sample_weight_modes(model, mode, 

579 sample_weights) 

580 

581# For backwards compatibility for internal users of these loops. 

582fit_loop = functools.partial(model_iteration, mode=ModeKeys.TRAIN) 

583test_loop = functools.partial( 

584 model_iteration, mode=ModeKeys.TEST, shuffle=False) 

585predict_loop = functools.partial( 

586 model_iteration, mode=ModeKeys.PREDICT, shuffle=False) 

587 

588 

589class ArrayLikeTrainingLoop(training_utils_v1.TrainingLoop): 

590 """TrainingLoop that handle inputs like array. 

591 

592 This is the default handler for most of the input data types, includes 

593 symbolic tensors or Numpy array-like, Datasets and iterators in graph mode 

594 (since they generate symbolic tensors). This Function is used to handle model 

595 with `run_eagerly` = False. 

596 """ 

597 

598 def fit(self, 

599 model, 

600 x=None, 

601 y=None, 

602 batch_size=None, 

603 epochs=1, 

604 verbose=1, 

605 callbacks=None, 

606 validation_split=0., 

607 validation_data=None, 

608 shuffle=True, 

609 class_weight=None, 

610 sample_weight=None, 

611 initial_epoch=0, 

612 steps_per_epoch=None, 

613 validation_steps=None, 

614 validation_freq=1, 

615 **kwargs): 

616 batch_size = model._validate_or_infer_batch_size(batch_size, 

617 steps_per_epoch, x) 

618 

619 x, y, sample_weights = model._standardize_user_data( 

620 x, 

621 y, 

622 sample_weight=sample_weight, 

623 class_weight=class_weight, 

624 batch_size=batch_size, 

625 check_steps=True, 

626 steps_name='steps_per_epoch', 

627 steps=steps_per_epoch, 

628 validation_split=validation_split, 

629 shuffle=shuffle) 

630 

631 if validation_data: 

632 val_x, val_y, val_sample_weights = model._prepare_validation_data( 

633 validation_data, batch_size, validation_steps) 

634 elif validation_split and 0. < validation_split < 1.: 

635 (x, y, sample_weights, val_x, val_y, val_sample_weights 

636 ) = training_utils_v1.split_training_and_validation_data( 

637 x, y, sample_weights, validation_split) 

638 else: 

639 if validation_steps: 

640 raise ValueError('`validation_steps` should not be specified if ' 

641 '`validation_data` is None.') 

642 val_x, val_y, val_sample_weights = None, None, None 

643 

644 return fit_loop( 

645 model, 

646 inputs=x, 

647 targets=y, 

648 sample_weights=sample_weights, 

649 batch_size=batch_size, 

650 epochs=epochs, 

651 verbose=verbose, 

652 callbacks=callbacks, 

653 val_inputs=val_x, 

654 val_targets=val_y, 

655 val_sample_weights=val_sample_weights, 

656 shuffle=shuffle, 

657 initial_epoch=initial_epoch, 

658 steps_per_epoch=steps_per_epoch, 

659 validation_steps=validation_steps, 

660 validation_freq=validation_freq, 

661 steps_name='steps_per_epoch') 

662 

663 def evaluate(self, 

664 model, 

665 x=None, 

666 y=None, 

667 batch_size=None, 

668 verbose=1, 

669 sample_weight=None, 

670 steps=None, 

671 callbacks=None, 

672 **kwargs): 

673 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 

674 x, y, sample_weights = model._standardize_user_data( 

675 x, 

676 y, 

677 sample_weight=sample_weight, 

678 batch_size=batch_size, 

679 check_steps=True, 

680 steps_name='steps', 

681 steps=steps) 

682 return test_loop( 

683 model, 

684 inputs=x, 

685 targets=y, 

686 sample_weights=sample_weights, 

687 batch_size=batch_size, 

688 verbose=verbose, 

689 steps=steps, 

690 callbacks=callbacks) 

691 

692 def predict(self, 

693 model, 

694 x, 

695 batch_size=None, 

696 verbose=0, 

697 steps=None, 

698 callbacks=None, 

699 **kwargs): 

700 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 

701 x, _, _ = model._standardize_user_data( 

702 x, check_steps=True, steps_name='steps', steps=steps) 

703 return predict_loop( 

704 model, 

705 x, 

706 batch_size=batch_size, 

707 verbose=verbose, 

708 steps=steps, 

709 callbacks=callbacks)