Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_generator_v1.py: 17%

238 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Part of the Keras training engine related to Python generators of array data. 

16""" 

17# pylint: disable=protected-access 

18 

19import functools 

20import math 

21 

22import numpy as np 

23 

24from tensorflow.python.data.ops import dataset_ops 

25from tensorflow.python.data.ops import iterator_ops 

26from tensorflow.python.eager import context 

27from tensorflow.python.framework import errors 

28from tensorflow.python.keras import backend 

29from tensorflow.python.keras import callbacks as cbks 

30from tensorflow.python.keras.engine import training_utils 

31from tensorflow.python.keras.engine import training_utils_v1 

32from tensorflow.python.keras.utils import data_utils 

33from tensorflow.python.keras.utils import generic_utils 

34from tensorflow.python.keras.utils.mode_keys import ModeKeys 

35from tensorflow.python.platform import tf_logging as logging 

36from tensorflow.python.types import data as data_types 

37from tensorflow.python.util import nest 

38 

39 

40def model_iteration(model, 

41 data, 

42 steps_per_epoch=None, 

43 epochs=1, 

44 verbose=1, 

45 callbacks=None, 

46 validation_data=None, 

47 validation_steps=None, 

48 validation_freq=1, 

49 class_weight=None, 

50 max_queue_size=10, 

51 workers=1, 

52 use_multiprocessing=False, 

53 shuffle=False, 

54 initial_epoch=0, 

55 mode=ModeKeys.TRAIN, 

56 batch_size=None, 

57 steps_name='steps', 

58 **kwargs): 

59 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT. 

60 

61 Args: 

62 model: Keras Model instance. 

63 data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or 

64 `(x, y, sample_weights)`) or a generator or 

65 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 

66 steps_per_epoch: Total number of steps (batches of samples) before 

67 declaring one epoch finished and starting the next epoch. Ignored with 

68 the default value of `None`. 

69 epochs: Number of times to iterate over the data. 

70 verbose: 0, 1, or 2. Verbosity mode. 

71 0 = silent, 1 = progress bar, 2 = one line per epoch. 

72 Note that the progress bar is not particularly useful when 

73 logged to a file, so verbose=2 is recommended when not running 

74 interactively (eg, in a production environment). 

75 callbacks: List of callbacks to be called during training. 

76 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or 

77 `(x, y)` or `(x, y, sample_weights)`) or a generator or 

78 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 

79 validation_steps: Total number of steps (batches of samples) before 

80 declaring validation finished. 

81 validation_freq: Only relevant if validation data is provided. Integer or 

82 `collections.abc.Container` instance (e.g. list, tuple, etc.). If an 

83 integer, specifies how many training epochs to run before a new 

84 validation run is performed, e.g. `validation_freq=2` runs 

85 validation every 2 epochs. If a Container, specifies the epochs on 

86 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 

87 validation at the end of the 1st, 2nd, and 10th epochs. 

88 class_weight: Dictionary mapping class indices to a weight for the class. 

89 max_queue_size: Integer. Maximum size for the generator queue. If 

90 unspecified, `max_queue_size` will default to 10. 

91 workers: Integer. Maximum number of processes to spin up when using 

92 process-based threading. If unspecified, `workers` will default to 1. If 

93 0, will execute the generator on the main thread. 

94 use_multiprocessing: Boolean. If `True`, use process-based threading. If 

95 unspecified, `use_multiprocessing` will default to `False`. Note that 

96 because this implementation relies on multiprocessing, you should not 

97 pass non-picklable arguments to the generator as they can't be passed 

98 easily to children processes. 

99 shuffle: Boolean. Whether to shuffle the order of the batches at the 

100 beginning of each epoch. Only used with instances of `Sequence` 

101 (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not 

102 `None`. 

103 initial_epoch: Epoch at which to start training (useful for resuming a 

104 previous training run). 

105 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 

106 batch_size: Integer batch size or None if unknown. Will only be used if 

107 `data` is in NumPy/Tensor format. 

108 steps_name: The string name of the steps argument, either `steps`, 

109 `validation_steps`, or `steps_per_epoch`. Only used for error message 

110 formatting. 

111 **kwargs: Additional arguments for backwards compatibility. `steps` is 

112 accepted as an alias for `steps_per_epoch`. 

113 

114 Returns: 

115 - In TRAIN mode: `History` object. 

116 - In TEST mode: Evaluation metrics. 

117 - In PREDICT mode: Outputs of the Model called on inputs. 

118 

119 Raises: 

120 ValueError: in case of invalid arguments. 

121 """ 

122 if 'steps' in kwargs: 

123 steps_per_epoch = kwargs['steps'] 

124 

125 # Determine the number of steps per epoch and whether we should reset the 

126 # dataset at the end of each epoch. 

127 reset_dataset_after_each_epoch = False 

128 original_dataset = None 

129 is_dataset = isinstance(data, (data_types.DatasetV2, data_types.DatasetV1)) 

130 if is_dataset: 

131 original_dataset = data 

132 if steps_per_epoch is None: 

133 reset_dataset_after_each_epoch = True 

134 steps_per_epoch = training_utils_v1.infer_steps_for_dataset( 

135 model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name) 

136 

137 # Convert to a format that supports `next(generator)`. 

138 generator, steps_per_epoch = convert_to_generator_like( 

139 data, 

140 steps_per_epoch=steps_per_epoch, 

141 batch_size=batch_size, 

142 epochs=epochs - initial_epoch, 

143 shuffle=shuffle) 

144 

145 do_validation = validation_data is not None 

146 is_sequence = isinstance(generator, data_utils.Sequence) 

147 _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers, 

148 steps_per_epoch, validation_data, validation_steps, mode, 

149 kwargs) 

150 

151 batch_function = _make_execution_function( 

152 model, mode, class_weight=class_weight) 

153 

154 # Create the queue for the generator. 

155 enqueuer = None 

156 if not is_dataset: 

157 generator, enqueuer = _make_enqueued_generator( 

158 generator, 

159 workers=workers, 

160 use_multiprocessing=use_multiprocessing, 

161 max_queue_size=max_queue_size, 

162 shuffle=shuffle) 

163 

164 num_samples_or_steps, use_steps = _get_num_samples_or_steps( 

165 data, steps_per_epoch) 

166 

167 count_mode = 'steps' if use_steps else 'samples' 

168 callbacks = cbks.configure_callbacks( 

169 callbacks, 

170 model, 

171 do_validation=do_validation, 

172 epochs=epochs, 

173 steps_per_epoch=steps_per_epoch, 

174 batch_size=batch_size, 

175 samples=num_samples_or_steps, 

176 count_mode=count_mode, 

177 verbose=verbose, 

178 mode=mode) 

179 

180 if mode == ModeKeys.PREDICT: 

181 aggregator = training_utils_v1.OutputsAggregator( 

182 True, steps=steps_per_epoch) 

183 else: 

184 aggregator = training_utils_v1.MetricsAggregator( 

185 True, steps=steps_per_epoch) 

186 

187 should_set_learning_phase = context.executing_eagerly() and model.run_eagerly 

188 if should_set_learning_phase: 

189 learning_phase_scope = backend.eager_learning_phase_scope( 

190 1 if mode == ModeKeys.TRAIN else 0) 

191 learning_phase_scope.__enter__() 

192 

193 callbacks.model.stop_training = False 

194 callbacks._call_begin_hook(mode) 

195 

196 initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode) 

197 

198 for epoch in range(initial_epoch, epochs): 

199 if callbacks.model.stop_training: 

200 break 

201 

202 # Setup work for each epoch. 

203 model.reset_metrics() 

204 epoch_logs = {} 

205 if mode == ModeKeys.TRAIN: 

206 callbacks.on_epoch_begin(epoch, epoch_logs) 

207 

208 if steps_per_epoch is None: 

209 # Loop over dataset until `OutOfRangeError` is raised. 

210 target_steps = np.inf 

211 else: 

212 # Loop over dataset for the specified number of steps. 

213 target_steps = steps_per_epoch 

214 

215 step = 0 

216 while step < target_steps: 

217 batch_data = _get_next_batch(generator) 

218 if batch_data is None: 

219 if is_dataset: 

220 # The dataset passed by the user ran out of batches. 

221 # Now we know the cardinality of the dataset. 

222 # If steps_per_epoch was specified, then running out of data is 

223 # unexpected, so we stop training and inform the user. 

224 if steps_per_epoch: 

225 callbacks.model.stop_training = True 

226 logging.warning( 

227 'Your dataset ran out of data; interrupting training. ' 

228 'Make sure that your dataset can generate at least ' 

229 '`%s * epochs` batches (in this case, %d batches). ' 

230 'You may need to use the repeat() function when ' 

231 'building your dataset.' 

232 % (steps_name, steps_per_epoch * epochs)) 

233 elif step > 0: 

234 steps_per_epoch = step 

235 aggregator.steps = steps_per_epoch 

236 else: 

237 # We ran out of batches while the user passed an iterator (legacy). 

238 callbacks.model.stop_training = True 

239 logging.warning( 

240 'Your dataset iterator ran out of data; ' 

241 'interrupting training. Make sure that your iterator ' 

242 'can generate at least `%s * epochs` ' 

243 'batches (in this case, %d batches). You may need to' 

244 'use the repeat() function when building your ' 

245 'dataset.' % (steps_name, steps_per_epoch * epochs)) 

246 break 

247 

248 # `batch_size` used for validation data if validation 

249 # data is NumPy/EagerTensors. 

250 batch_size = int(nest.flatten(batch_data)[0].shape[0]) 

251 

252 # Callbacks batch begin. 

253 batch_logs = {'batch': step, 'size': batch_size} 

254 callbacks._call_batch_hook(mode, 'begin', step, batch_logs) 

255 

256 is_deferred = not model._is_compiled 

257 batch_outs = batch_function(*batch_data) 

258 if not isinstance(batch_outs, list): 

259 batch_outs = [batch_outs] 

260 

261 if step == 0: 

262 aggregator.create(batch_outs) 

263 

264 if is_deferred: 

265 # Set callbacks params. We do this here when model is compiled only 

266 # in the first iteration of this loop (deferred build scenario). 

267 cbks.set_callback_parameters( 

268 callbacks, 

269 model, 

270 do_validation=do_validation, 

271 batch_size=batch_size, 

272 epochs=epochs, 

273 steps_per_epoch=steps_per_epoch, 

274 samples=num_samples_or_steps, 

275 verbose=verbose, 

276 mode=mode) 

277 

278 # Aggregate results. 

279 aggregator.aggregate(batch_outs) 

280 

281 # Callbacks batch end. 

282 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 

283 callbacks._call_batch_hook(mode, 'end', step, batch_logs) 

284 step += 1 

285 

286 if callbacks.model.stop_training: 

287 break 

288 

289 aggregator.finalize() 

290 results = aggregator.results 

291 epoch_logs = cbks.make_logs(model, epoch_logs, results, mode) 

292 if len(results) == 1: 

293 results = results[0] 

294 

295 # Run the test loop every epoch during training. 

296 if (do_validation and 

297 training_utils_v1.should_run_validation(validation_freq, epoch) and 

298 not callbacks.model.stop_training): 

299 val_results = model_iteration( 

300 model, 

301 validation_data, 

302 steps_per_epoch=validation_steps, 

303 batch_size=batch_size, 

304 class_weight=class_weight, 

305 workers=workers, 

306 use_multiprocessing=use_multiprocessing, 

307 max_queue_size=max_queue_size, 

308 callbacks=callbacks, 

309 verbose=verbose, 

310 mode=ModeKeys.TEST, 

311 steps_name='validation_steps') 

312 

313 if not isinstance(val_results, list): 

314 val_results = [val_results] 

315 epoch_logs = cbks.make_logs( 

316 model, epoch_logs, val_results, mode, prefix='val_') 

317 

318 if mode == ModeKeys.TRAIN: 

319 # Epochs only apply to `fit`. 

320 callbacks.on_epoch_end(epoch, epoch_logs) 

321 

322 # Recreate dataset iterator for the next epoch. 

323 if reset_dataset_after_each_epoch and epoch < epochs - 1: 

324 generator = dataset_ops.make_one_shot_iterator(original_dataset) 

325 

326 model._successful_loop_finish = True 

327 callbacks._call_end_hook(mode) 

328 

329 if enqueuer is not None: 

330 enqueuer.stop() 

331 

332 if should_set_learning_phase: 

333 learning_phase_scope.__exit__(None, None, None) 

334 

335 if mode == ModeKeys.TRAIN: 

336 return model.history 

337 return results 

338 

339 

340# Maintain compatibility with the existing names. 

341fit_generator = functools.partial(model_iteration, mode=ModeKeys.TRAIN) 

342evaluate_generator = functools.partial( 

343 model_iteration, mode=ModeKeys.TEST, shuffle=False) 

344predict_generator = functools.partial( 

345 model_iteration, mode=ModeKeys.PREDICT, shuffle=False) 

346 

347 

348def _get_next_batch(generator): 

349 """Retrieves the next batch of input data.""" 

350 try: 

351 generator_output = next(generator) 

352 except (StopIteration, errors.OutOfRangeError): 

353 return None 

354 

355 if not isinstance(generator_output, tuple): 

356 # Always wrap in a tuple. 

357 generator_output = (generator_output,) 

358 if len(generator_output) not in [1, 2, 3]: 

359 raise ValueError( 

360 'Output of generator should be a tuple of 1 or 2 or 3 ' 

361 'elements: (input,) or (input, target) or ' 

362 '(input, target, sample_weights). Received {}'.format(generator_output)) 

363 return generator_output 

364 

365 

366def _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers, 

367 steps_per_epoch, validation_data, validation_steps, 

368 mode, kwargs): 

369 """Raises errors if arguments are invalid. 

370 

371 Args: 

372 is_sequence: Boolean, whether data is a `keras.utils.data_utils.Sequence` 

373 instance. 

374 is_dataset: Boolean, whether data is a dataset instance. 

375 use_multiprocessing: Boolean. If `True`, use process-based threading. If 

376 unspecified, `use_multiprocessing` will default to `False`. Note that 

377 because this implementation relies on multiprocessing, you should not pass 

378 non-picklable arguments to the generator as they can't be passed easily to 

379 children processes. 

380 workers: Integer. Maximum number of processes to spin up when using 

381 process-based threading. If unspecified, `workers` will default to 1. If 

382 0, will execute the generator on the main thread. 

383 steps_per_epoch: Total number of steps (batches of samples) before declaring 

384 one epoch finished and starting the next epoch. Ignored with the default 

385 value of `None`. 

386 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, 

387 y)` or `(x, y, sample_weights)`) or a generator or 

388 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 

389 validation_steps: Total number of steps (batches of samples) before 

390 declaring validation finished. 

391 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 

392 kwargs: Additional arguments for backwards compatibility. 

393 

394 Raises: 

395 ValueError: If `steps_per_epoch` or `validation_steps` are not passed 

396 for data types that require them, or if unrecognized keyword 

397 arguments are passed. 

398 """ 

399 if not is_sequence and use_multiprocessing and workers > 1: 

400 logging.warning( 

401 UserWarning('Using a generator with `use_multiprocessing=True`' 

402 ' and multiple workers may duplicate your data.' 

403 ' Please consider using the `keras.utils.Sequence`' 

404 ' class.')) 

405 

406 if steps_per_epoch is None and not is_dataset: 

407 arg_name = 'steps_per_epoch' if mode == ModeKeys.TRAIN else 'steps' 

408 raise ValueError('Please specify the number of steps via the ' 

409 '`{}` argument.'.format(arg_name)) 

410 

411 val_gen = ( 

412 data_utils.is_generator_or_sequence(validation_data) or 

413 isinstance(validation_data, iterator_ops.IteratorBase)) 

414 if (val_gen and not isinstance(validation_data, data_utils.Sequence) and 

415 not validation_steps): 

416 raise ValueError('Please specify the `validation_steps` argument.') 

417 

418 if any(k != 'steps' for k in kwargs): 

419 raise ValueError('Invalid arguments passed: {}'.format( 

420 [k for k in kwargs if k != 'steps'])) 

421 

422 

423def convert_to_generator_like(data, 

424 batch_size=None, 

425 steps_per_epoch=None, 

426 epochs=1, 

427 shuffle=False): 

428 """Make a generator out of NumPy or EagerTensor inputs. 

429 

430 Args: 

431 data: Either a generator or `keras.utils.data_utils.Sequence` object or 

432 `Dataset`, `Iterator`, or a {1,2,3}-tuple of NumPy arrays or EagerTensors. 

433 If a tuple, the elements represent `(x, y, sample_weights)` and may be 

434 `None` or `[None]`. 

435 batch_size: Used when creating a generator out of tuples of NumPy arrays or 

436 EagerTensors. 

437 steps_per_epoch: Steps of the generator to run each epoch. If `None` the 

438 number of steps will be read from the data (for 

439 `keras.utils.data_utils.Sequence` types). 

440 epochs: Total number of epochs to run. 

441 shuffle: Whether the data should be shuffled. 

442 

443 Returns: 

444 - Generator, `keras.utils.data_utils.Sequence`, or `Iterator`. 

445 

446 Raises: 

447 - ValueError: If `batch_size` is not provided for NumPy or EagerTensor 

448 inputs. 

449 """ 

450 if isinstance(data, tuple): 

451 # Scrub `Nones` that might have been passed for `targets`, `sample_weights`. 

452 data = tuple( 

453 ele for ele in data if not all(e is None for e in nest.flatten(ele))) 

454 

455 if data_utils.is_generator_or_sequence(data) or isinstance( 

456 data, iterator_ops.IteratorBase): 

457 if isinstance(data, data_utils.Sequence): 

458 if steps_per_epoch is None: 

459 steps_per_epoch = len(data) 

460 return data, steps_per_epoch 

461 if isinstance(data, data_types.DatasetV2): 

462 return dataset_ops.make_one_shot_iterator(data), steps_per_epoch 

463 

464 # Create generator from NumPy or EagerTensor Input. 

465 num_samples = int(nest.flatten(data)[0].shape[0]) 

466 if batch_size is None: 

467 raise ValueError( 

468 'When passing input data as arrays, do not specify ' 

469 '`steps_per_epoch`/`steps` argument. Please use `batch_size` instead.') 

470 steps_per_epoch = int(math.ceil(num_samples / batch_size)) 

471 

472 def _gen(data): 

473 """Makes a generator out of a structure of NumPy/EagerTensors.""" 

474 index_array = np.arange(num_samples) 

475 for _ in range(epochs): 

476 if shuffle: 

477 np.random.shuffle(index_array) 

478 batches = generic_utils.make_batches(num_samples, batch_size) 

479 for (batch_start, batch_end) in batches: 

480 batch_ids = index_array[batch_start:batch_end] 

481 flat_batch_data = training_utils.slice_arrays( 

482 nest.flatten(data), batch_ids, contiguous=(not shuffle)) 

483 yield nest.pack_sequence_as(data, flat_batch_data) 

484 

485 return _gen(data), steps_per_epoch 

486 

487 

488def _make_enqueued_generator(generator, 

489 workers=1, 

490 use_multiprocessing=False, 

491 max_queue_size=10, 

492 shuffle=False): 

493 """Create a buffered queue of next elements of the generator.""" 

494 is_sequence = isinstance(generator, data_utils.Sequence) 

495 enqueuer = None 

496 if workers > 0: 

497 if is_sequence: 

498 enqueuer = data_utils.OrderedEnqueuer( 

499 generator, use_multiprocessing=use_multiprocessing, shuffle=shuffle) 

500 else: 

501 enqueuer = data_utils.GeneratorEnqueuer( 

502 generator, use_multiprocessing=use_multiprocessing) 

503 enqueuer.start(workers=workers, max_queue_size=max_queue_size) 

504 output_generator = enqueuer.get() 

505 else: 

506 if is_sequence: 

507 output_generator = data_utils.iter_sequence_infinite(generator) 

508 else: 

509 output_generator = generator 

510 return output_generator, enqueuer 

511 

512 

513def _make_execution_function(model, mode, class_weight=None): 

514 """Makes function to run one step of model execution.""" 

515 if mode == ModeKeys.TRAIN: 

516 f = functools.partial(model.train_on_batch, class_weight=class_weight) 

517 elif mode == ModeKeys.TEST: 

518 f = model.test_on_batch 

519 else: 

520 # Match signature of other modes to allow 

521 # 1, 2, or 3-tuples from generator 

522 def predict_on_batch(x, y=None, sample_weights=None): # pylint: disable=unused-argument 

523 return model.predict_on_batch(x) 

524 

525 f = predict_on_batch 

526 

527 # Maintain stateful metrics across batch-level calls. 

528 if mode != ModeKeys.PREDICT: 

529 f = functools.partial(f, reset_metrics=False) 

530 

531 return f 

532 

533 

534def _get_num_samples_or_steps(data, steps_per_epoch): 

535 """Returns number of samples or steps, and whether to use steps count mode.""" 

536 flat_inputs = nest.flatten(data) 

537 if hasattr(flat_inputs[0], 'shape'): 

538 return int(flat_inputs[0].shape[0]), False 

539 return steps_per_epoch, True 

540 

541 

542class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop): 

543 """Generator-like. 

544 

545 Input is Python generator, or Sequence object. 

546 

547 The difference between this class and `GeneratorLikeTrainingFunction` is that 

548 this class only handles inputs that with x, y and sample_weight fused into one 

549 param. 

550 """ 

551 

552 def fit(self, 

553 model, 

554 x=None, 

555 y=None, 

556 batch_size=None, 

557 epochs=1, 

558 verbose=1, 

559 callbacks=None, 

560 validation_split=0., 

561 validation_data=None, 

562 shuffle=True, 

563 class_weight=None, 

564 sample_weight=None, 

565 initial_epoch=0, 

566 steps_per_epoch=None, 

567 validation_steps=None, 

568 validation_freq=1, 

569 max_queue_size=10, 

570 workers=1, 

571 use_multiprocessing=False): 

572 model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x) 

573 training_utils_v1.check_generator_arguments( 

574 y, sample_weight, validation_split=validation_split) 

575 return fit_generator( 

576 model, 

577 x, 

578 steps_per_epoch=steps_per_epoch, 

579 epochs=epochs, 

580 verbose=verbose, 

581 callbacks=callbacks, 

582 validation_data=validation_data, 

583 validation_steps=validation_steps, 

584 validation_freq=validation_freq, 

585 class_weight=class_weight, 

586 max_queue_size=max_queue_size, 

587 workers=workers, 

588 use_multiprocessing=use_multiprocessing, 

589 shuffle=shuffle, 

590 initial_epoch=initial_epoch, 

591 steps_name='steps_per_epoch') 

592 

593 def evaluate(self, 

594 model, 

595 x=None, 

596 y=None, 

597 batch_size=None, 

598 verbose=1, 

599 sample_weight=None, 

600 steps=None, 

601 callbacks=None, 

602 max_queue_size=10, 

603 workers=1, 

604 use_multiprocessing=False): 

605 model._validate_or_infer_batch_size(batch_size, steps, x) 

606 training_utils_v1.check_generator_arguments(y, sample_weight) 

607 return evaluate_generator( 

608 model, 

609 x, 

610 steps=steps, 

611 verbose=verbose, 

612 callbacks=callbacks, 

613 max_queue_size=max_queue_size, 

614 workers=workers, 

615 use_multiprocessing=use_multiprocessing) 

616 

617 def predict(self, 

618 model, 

619 x, 

620 batch_size=None, 

621 verbose=0, 

622 steps=None, 

623 callbacks=None, 

624 max_queue_size=10, 

625 workers=1, 

626 use_multiprocessing=False): 

627 model._validate_or_infer_batch_size(batch_size, steps, x) 

628 return predict_generator( 

629 model, 

630 x, 

631 steps=steps, 

632 verbose=verbose, 

633 callbacks=callbacks, 

634 max_queue_size=max_queue_size, 

635 workers=workers, 

636 use_multiprocessing=use_multiprocessing) 

637 

638 

639class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop): 

640 """A non-distributed Dataset or iterator in eager execution.""" 

641 

642 def fit(self, 

643 model, 

644 x=None, 

645 y=None, 

646 batch_size=None, 

647 epochs=1, 

648 verbose=1, 

649 callbacks=None, 

650 validation_split=0., 

651 validation_data=None, 

652 shuffle=True, 

653 class_weight=None, 

654 sample_weight=None, 

655 initial_epoch=0, 

656 steps_per_epoch=None, 

657 validation_steps=None, 

658 validation_freq=1, 

659 **kwargs): 

660 model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x) 

661 # Make sure that y, sample_weights, validation_split are not passed. 

662 training_utils_v1.validate_dataset_input(x, y, sample_weight, 

663 validation_split) 

664 if (isinstance(x, (data_types.DatasetV1, data_types.DatasetV2)) and 

665 shuffle): 

666 training_utils_v1.verify_dataset_shuffled(x) 

667 

668 return fit_generator( 

669 model, 

670 x, 

671 steps_per_epoch=steps_per_epoch, 

672 epochs=epochs, 

673 verbose=verbose, 

674 callbacks=callbacks, 

675 validation_data=validation_data, 

676 validation_steps=validation_steps, 

677 validation_freq=validation_freq, 

678 class_weight=class_weight, 

679 workers=0, 

680 shuffle=shuffle, 

681 initial_epoch=initial_epoch, 

682 steps_name='steps_per_epoch') 

683 

684 def evaluate(self, 

685 model, 

686 x=None, 

687 y=None, 

688 batch_size=None, 

689 verbose=1, 

690 sample_weight=None, 

691 steps=None, 

692 callbacks=None, 

693 **kwargs): 

694 model._validate_or_infer_batch_size(batch_size, steps, x) 

695 # Make sure that y, sample_weights, validation_split are not passed. 

696 training_utils_v1.validate_dataset_input(x, y, sample_weight) 

697 return evaluate_generator( 

698 model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks) 

699 

700 def predict(self, 

701 model, 

702 x, 

703 batch_size=None, 

704 verbose=0, 

705 steps=None, 

706 callbacks=None, 

707 **kwargs): 

708 model._validate_or_infer_batch_size(batch_size, steps, x) 

709 return predict_generator( 

710 model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks) 

711 

712 

713class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop): 

714 """TrainingLoop that handle inputs like python generator. 

715 

716 This is the default handler for most of the input data types, includes 

717 symbolic tensors or Numpy array-like, Datasets and iterators in graph mode 

718 (since they generate symbolic tensors). This Function is used to handle model 

719 with `run_eagerly` = True. 

720 """ 

721 

722 def fit(self, 

723 model, 

724 x=None, 

725 y=None, 

726 batch_size=None, 

727 epochs=1, 

728 verbose=1, 

729 callbacks=None, 

730 validation_split=0., 

731 validation_data=None, 

732 shuffle=True, 

733 class_weight=None, 

734 sample_weight=None, 

735 initial_epoch=0, 

736 steps_per_epoch=None, 

737 validation_steps=None, 

738 validation_freq=1, 

739 **kwargs): 

740 batch_size = model._validate_or_infer_batch_size(batch_size, 

741 steps_per_epoch, x) 

742 x, y, sample_weights = model._standardize_user_data( 

743 x, 

744 y, 

745 sample_weight=sample_weight, 

746 class_weight=class_weight, 

747 batch_size=batch_size, 

748 check_steps=True, 

749 steps_name='steps_per_epoch', 

750 steps=steps_per_epoch, 

751 validation_split=validation_split, 

752 shuffle=shuffle) 

753 

754 if validation_data: 

755 validation_data = model._prepare_validation_data(validation_data, 

756 batch_size, 

757 validation_steps) 

758 elif validation_split and 0. < validation_split < 1.: 

759 (x, y, sample_weights, val_x, val_y, 

760 val_sample_weights) = ( 

761 training_utils_v1.split_training_and_validation_data( 

762 x, y, sample_weights, validation_split)) 

763 validation_data = (val_x, val_y, val_sample_weights) 

764 else: 

765 if validation_steps: 

766 raise ValueError('`validation_steps` should not be specified if ' 

767 '`validation_data` is None.') 

768 

769 return fit_generator( 

770 model, (x, y, sample_weights), 

771 steps_per_epoch=steps_per_epoch, 

772 batch_size=batch_size, 

773 epochs=epochs, 

774 verbose=verbose, 

775 callbacks=callbacks, 

776 validation_data=validation_data, 

777 validation_steps=validation_steps, 

778 validation_freq=validation_freq, 

779 workers=0, 

780 shuffle=shuffle, 

781 initial_epoch=initial_epoch, 

782 steps_name='steps_per_epoch') 

783 

784 def evaluate(self, 

785 model, 

786 x=None, 

787 y=None, 

788 batch_size=None, 

789 verbose=1, 

790 sample_weight=None, 

791 steps=None, 

792 callbacks=None, 

793 **kwargs): 

794 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 

795 x, y, sample_weights = model._standardize_user_data( 

796 x, 

797 y, 

798 sample_weight=sample_weight, 

799 batch_size=batch_size, 

800 check_steps=True, 

801 steps_name='steps', 

802 steps=steps) 

803 return evaluate_generator( 

804 model, (x, y, sample_weights), 

805 steps=steps, 

806 batch_size=batch_size, 

807 verbose=verbose, 

808 workers=0, 

809 callbacks=callbacks) 

810 

811 def predict(self, 

812 model, 

813 x, 

814 batch_size=None, 

815 verbose=0, 

816 steps=None, 

817 callbacks=None, 

818 **kwargs): 

819 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 

820 x, _, _ = model._standardize_user_data( 

821 x, check_steps=True, steps_name='steps', steps=steps) 

822 return predict_generator( 

823 model, 

824 x, 

825 steps=steps, 

826 batch_size=batch_size, 

827 verbose=verbose, 

828 workers=0, 

829 callbacks=callbacks)