Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training_generator_v1.py: 15%

233 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Part of the Keras training engine related to Python generators of array data. 

16""" 

17 

18import functools 

19import math 

20 

21import numpy as np 

22import tensorflow.compat.v2 as tf 

23 

24from keras.src import backend 

25from keras.src import callbacks as cbks 

26from keras.src.engine import training_utils 

27from keras.src.engine import training_utils_v1 

28from keras.src.utils import data_utils 

29from keras.src.utils import generic_utils 

30from keras.src.utils.mode_keys import ModeKeys 

31 

32# isort: off 

33from tensorflow.python.platform import tf_logging as logging 

34 

35 

36def model_iteration( 

37 model, 

38 data, 

39 steps_per_epoch=None, 

40 epochs=1, 

41 verbose=1, 

42 callbacks=None, 

43 validation_data=None, 

44 validation_steps=None, 

45 validation_freq=1, 

46 class_weight=None, 

47 max_queue_size=10, 

48 workers=1, 

49 use_multiprocessing=False, 

50 shuffle=False, 

51 initial_epoch=0, 

52 mode=ModeKeys.TRAIN, 

53 batch_size=None, 

54 steps_name="steps", 

55 **kwargs, 

56): 

57 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT. 

58 

59 Args: 

60 model: Keras Model instance. 

61 data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or 

62 `(x, y, sample_weights)`) or a generator or 

63 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 

64 steps_per_epoch: Total number of steps (batches of samples) before 

65 declaring one epoch finished and starting the next epoch. Ignored with 

66 the default value of `None`. 

67 epochs: Number of times to iterate over the data. 

68 verbose: 0, 1, or 2. Verbosity mode. 

69 0 = silent, 1 = progress bar, 2 = one line per epoch. 

70 Note that the progress bar is not particularly useful when 

71 logged to a file, so verbose=2 is recommended when not running 

72 interactively (eg, in a production environment). 

73 callbacks: List of callbacks to be called during training. 

74 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or 

75 `(x, y)` or `(x, y, sample_weights)`) or a generator or 

76 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 

77 validation_steps: Total number of steps (batches of samples) before 

78 declaring validation finished. 

79 validation_freq: Only relevant if validation data is provided. Integer 

80 or `collections.abc.Container` instance (e.g. list, tuple, etc.). If 

81 an integer, specifies how many training epochs to run before a new 

82 validation run is performed, e.g. `validation_freq=2` runs validation 

83 every 2 epochs. If a Container, specifies the epochs on which to run 

84 validation, e.g. `validation_freq=[1, 2, 10]` runs validation at the 

85 end of the 1st, 2nd, and 10th epochs. 

86 class_weight: Dictionary mapping class indices to a weight for the 

87 class. 

88 max_queue_size: Integer. Maximum size for the generator queue. If 

89 unspecified, `max_queue_size` will default to 10. 

90 workers: Integer. Maximum number of processes to spin up when using 

91 process-based threading. If unspecified, `workers` will default to 1. 

92 If 0, will execute the generator on the main thread. 

93 use_multiprocessing: Boolean. If `True`, use process-based threading. If 

94 unspecified, `use_multiprocessing` will default to `False`. Note that 

95 because this implementation relies on multiprocessing, you should not 

96 pass non-picklable arguments to the generator as they can't be passed 

97 easily to children processes. 

98 shuffle: Boolean. Whether to shuffle the order of the batches at the 

99 beginning of each epoch. Only used with instances of `Sequence` 

100 (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not 

101 `None`. 

102 initial_epoch: Epoch at which to start training (useful for resuming a 

103 previous training run). 

104 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 

105 batch_size: Integer batch size or None if unknown. Will only be used if 

106 `data` is in NumPy/Tensor format. 

107 steps_name: The string name of the steps argument, either `steps`, 

108 `validation_steps`, or `steps_per_epoch`. Only used for error message 

109 formatting. 

110 **kwargs: Additional arguments for backwards compatibility. `steps` is 

111 accepted as an alias for `steps_per_epoch`. 

112 

113 Returns: 

114 - In TRAIN mode: `History` object. 

115 - In TEST mode: Evaluation metrics. 

116 - In PREDICT mode: Outputs of the Model called on inputs. 

117 

118 Raises: 

119 ValueError: in case of invalid arguments. 

120 """ 

121 if "steps" in kwargs: 

122 steps_per_epoch = kwargs["steps"] 

123 

124 # Determine the number of steps per epoch and whether we should reset the 

125 # dataset at the end of each epoch. 

126 reset_dataset_after_each_epoch = False 

127 original_dataset = None 

128 is_dataset = isinstance(data, (tf.data.Dataset, tf.compat.v1.data.Dataset)) 

129 if is_dataset: 

130 original_dataset = data 

131 if steps_per_epoch is None: 

132 reset_dataset_after_each_epoch = True 

133 steps_per_epoch = training_utils_v1.infer_steps_for_dataset( 

134 model, 

135 data, 

136 steps_per_epoch, 

137 epochs=epochs, 

138 steps_name=steps_name, 

139 ) 

140 

141 # Convert to a format that supports `next(generator)`. 

142 generator, steps_per_epoch = convert_to_generator_like( 

143 data, 

144 steps_per_epoch=steps_per_epoch, 

145 batch_size=batch_size, 

146 epochs=epochs - initial_epoch, 

147 shuffle=shuffle, 

148 ) 

149 

150 do_validation = validation_data is not None 

151 is_sequence = isinstance(generator, data_utils.Sequence) 

152 _validate_arguments( 

153 is_sequence, 

154 is_dataset, 

155 use_multiprocessing, 

156 workers, 

157 steps_per_epoch, 

158 validation_data, 

159 validation_steps, 

160 mode, 

161 kwargs, 

162 ) 

163 

164 batch_function = _make_execution_function( 

165 model, mode, class_weight=class_weight 

166 ) 

167 

168 # Create the queue for the generator. 

169 enqueuer = None 

170 if not is_dataset: 

171 generator, enqueuer = _make_enqueued_generator( 

172 generator, 

173 workers=workers, 

174 use_multiprocessing=use_multiprocessing, 

175 max_queue_size=max_queue_size, 

176 shuffle=shuffle, 

177 ) 

178 

179 num_samples_or_steps, use_steps = _get_num_samples_or_steps( 

180 data, steps_per_epoch 

181 ) 

182 

183 count_mode = "steps" if use_steps else "samples" 

184 callbacks = cbks.configure_callbacks( 

185 callbacks, 

186 model, 

187 do_validation=do_validation, 

188 epochs=epochs, 

189 steps_per_epoch=steps_per_epoch, 

190 batch_size=batch_size, 

191 samples=num_samples_or_steps, 

192 count_mode=count_mode, 

193 verbose=verbose, 

194 mode=mode, 

195 ) 

196 

197 if mode == ModeKeys.PREDICT: 

198 aggregator = training_utils_v1.OutputsAggregator( 

199 True, steps=steps_per_epoch 

200 ) 

201 else: 

202 aggregator = training_utils_v1.MetricsAggregator( 

203 True, steps=steps_per_epoch 

204 ) 

205 

206 should_set_learning_phase = tf.executing_eagerly() and model.run_eagerly 

207 if should_set_learning_phase: 

208 learning_phase_scope = backend.eager_learning_phase_scope( 

209 1 if mode == ModeKeys.TRAIN else 0 

210 ) 

211 learning_phase_scope.__enter__() 

212 

213 callbacks.model.stop_training = False 

214 callbacks._call_begin_hook(mode) 

215 

216 initial_epoch = model._maybe_load_initial_epoch_from_ckpt( 

217 initial_epoch, mode 

218 ) 

219 

220 for epoch in range(initial_epoch, epochs): 

221 if callbacks.model.stop_training: 

222 break 

223 

224 # Setup work for each epoch. 

225 model.reset_metrics() 

226 epoch_logs = {} 

227 if mode == ModeKeys.TRAIN: 

228 callbacks.on_epoch_begin(epoch, epoch_logs) 

229 

230 if steps_per_epoch is None: 

231 # Loop over dataset until `OutOfRangeError` is raised. 

232 target_steps = np.inf 

233 else: 

234 # Loop over dataset for the specified number of steps. 

235 target_steps = steps_per_epoch 

236 

237 step = 0 

238 while step < target_steps: 

239 batch_data = _get_next_batch(generator) 

240 if batch_data is None: 

241 if is_dataset: 

242 # The dataset passed by the user ran out of batches. Now we 

243 # know the cardinality of the dataset. If steps_per_epoch 

244 # was specified, then running out of data is unexpected, so 

245 # we stop training and inform the user. 

246 if steps_per_epoch: 

247 callbacks.model.stop_training = True 

248 logging.warning( 

249 "Your dataset ran out of data; interrupting " 

250 "training. Make sure that your dataset can " 

251 "generate at least `%s * epochs` batches (in " 

252 "this case, %d batches). You may need to use " 

253 "the repeat() function when building your dataset." 

254 % (steps_name, steps_per_epoch * epochs) 

255 ) 

256 elif step > 0: 

257 steps_per_epoch = step 

258 aggregator.steps = steps_per_epoch 

259 else: 

260 # We ran out of batches while the user passed an iterator 

261 # (legacy). 

262 callbacks.model.stop_training = True 

263 logging.warning( 

264 "Your dataset iterator ran out of data; " 

265 "interrupting training. Make sure that your iterator " 

266 "can generate at least `%s * epochs` " 

267 "batches (in this case, %d batches). You may need to" 

268 "use the repeat() function when building your " 

269 "dataset." % (steps_name, steps_per_epoch * epochs) 

270 ) 

271 break 

272 

273 # `batch_size` used for validation data if validation 

274 # data is NumPy/EagerTensors. 

275 batch_size = int(tf.nest.flatten(batch_data)[0].shape[0]) 

276 

277 # Callbacks batch begin. 

278 batch_logs = {"batch": step, "size": batch_size} 

279 callbacks._call_batch_hook(mode, "begin", step, batch_logs) 

280 

281 is_deferred = not model._is_compiled 

282 batch_outs = batch_function(*batch_data) 

283 if not isinstance(batch_outs, list): 

284 batch_outs = [batch_outs] 

285 

286 if step == 0: 

287 aggregator.create(batch_outs) 

288 

289 if is_deferred: 

290 # Set callbacks params. We do this here when model is 

291 # compiled only in the first iteration of this loop 

292 # (deferred build scenario). 

293 cbks.set_callback_parameters( 

294 callbacks, 

295 model, 

296 do_validation=do_validation, 

297 batch_size=batch_size, 

298 epochs=epochs, 

299 steps_per_epoch=steps_per_epoch, 

300 samples=num_samples_or_steps, 

301 verbose=verbose, 

302 mode=mode, 

303 ) 

304 

305 # Aggregate results. 

306 aggregator.aggregate(batch_outs) 

307 

308 # Callbacks batch end. 

309 batch_logs = callbacks.make_logs( 

310 model, batch_logs, batch_outs, mode 

311 ) 

312 callbacks._call_batch_hook(mode, "end", step, batch_logs) 

313 step += 1 

314 

315 if callbacks.model.stop_training: 

316 break 

317 

318 aggregator.finalize() 

319 results = aggregator.results 

320 epoch_logs = callbacks.make_logs(model, epoch_logs, results, mode) 

321 if len(results) == 1: 

322 results = results[0] 

323 

324 # Run the test loop every epoch during training. 

325 if ( 

326 do_validation 

327 and training_utils_v1.should_run_validation(validation_freq, epoch) 

328 and not callbacks.model.stop_training 

329 ): 

330 val_results = model_iteration( 

331 model, 

332 validation_data, 

333 steps_per_epoch=validation_steps, 

334 batch_size=batch_size, 

335 class_weight=class_weight, 

336 workers=workers, 

337 use_multiprocessing=use_multiprocessing, 

338 max_queue_size=max_queue_size, 

339 callbacks=callbacks, 

340 verbose=verbose, 

341 mode=ModeKeys.TEST, 

342 steps_name="validation_steps", 

343 ) 

344 

345 if not isinstance(val_results, list): 

346 val_results = [val_results] 

347 epoch_logs = callbacks.make_logs( 

348 model, epoch_logs, val_results, mode, prefix="val_" 

349 ) 

350 

351 if mode == ModeKeys.TRAIN: 

352 # Epochs only apply to `fit`. 

353 callbacks.on_epoch_end(epoch, epoch_logs) 

354 

355 # Recreate dataset iterator for the next epoch. 

356 if reset_dataset_after_each_epoch and epoch < epochs - 1: 

357 generator = tf.compat.v1.data.make_one_shot_iterator( 

358 original_dataset 

359 ) 

360 

361 model._successful_loop_finish = True 

362 callbacks._call_end_hook(mode) 

363 

364 if enqueuer is not None: 

365 enqueuer.stop() 

366 

367 if should_set_learning_phase: 

368 learning_phase_scope.__exit__(None, None, None) 

369 

370 if mode == ModeKeys.TRAIN: 

371 return model.history 

372 return results 

373 

374 

375# Maintain compatibility with the existing names. 

376fit_generator = functools.partial(model_iteration, mode=ModeKeys.TRAIN) 

377evaluate_generator = functools.partial( 

378 model_iteration, mode=ModeKeys.TEST, shuffle=False 

379) 

380predict_generator = functools.partial( 

381 model_iteration, mode=ModeKeys.PREDICT, shuffle=False 

382) 

383 

384 

385def _get_next_batch(generator): 

386 """Retrieves the next batch of input data.""" 

387 try: 

388 generator_output = next(generator) 

389 except (StopIteration, tf.errors.OutOfRangeError): 

390 return None 

391 

392 if not isinstance(generator_output, tuple): 

393 # Always wrap in a tuple. 

394 generator_output = (generator_output,) 

395 if len(generator_output) not in [1, 2, 3]: 

396 raise ValueError( 

397 "Output of generator should be a tuple of 1 or 2 or 3 " 

398 "elements: (input,) or (input, target) or " 

399 "(input, target, sample_weights). Received {}".format( 

400 generator_output 

401 ) 

402 ) 

403 return generator_output 

404 

405 

406def _validate_arguments( 

407 is_sequence, 

408 is_dataset, 

409 use_multiprocessing, 

410 workers, 

411 steps_per_epoch, 

412 validation_data, 

413 validation_steps, 

414 mode, 

415 kwargs, 

416): 

417 """Raises errors if arguments are invalid. 

418 

419 Args: 

420 is_sequence: Boolean, whether data is a `keras.utils.data_utils.Sequence` 

421 instance. 

422 is_dataset: Boolean, whether data is a dataset instance. 

423 use_multiprocessing: Boolean. If `True`, use process-based threading. If 

424 unspecified, `use_multiprocessing` will default to `False`. Note that 

425 because this implementation relies on multiprocessing, you should not 

426 pass non-picklable arguments to the generator as they can't be passed 

427 easily to children processes. 

428 workers: Integer. Maximum number of processes to spin up when using 

429 process-based threading. If unspecified, `workers` will default to 1. If 

430 0, will execute the generator on the main thread. 

431 steps_per_epoch: Total number of steps (batches of samples) before 

432 declaring one epoch finished and starting the next epoch. Ignored with 

433 the default value of `None`. 

434 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or 

435 `(x, y)` or `(x, y, sample_weights)`) or a generator or 

436 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 

437 validation_steps: Total number of steps (batches of samples) before 

438 declaring validation finished. 

439 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 

440 kwargs: Additional arguments for backwards compatibility. 

441 

442 Raises: 

443 ValueError: If `steps_per_epoch` or `validation_steps` are not passed 

444 for data types that require them, or if unrecognized keyword 

445 arguments are passed. 

446 """ 

447 if not is_sequence and use_multiprocessing and workers > 1: 

448 logging.warning( 

449 UserWarning( 

450 "Using a generator with `use_multiprocessing=True`" 

451 " and multiple workers may duplicate your data." 

452 " Please consider using the `keras.utils.Sequence`" 

453 " class." 

454 ) 

455 ) 

456 

457 if steps_per_epoch is None and not is_dataset: 

458 arg_name = "steps_per_epoch" if mode == ModeKeys.TRAIN else "steps" 

459 raise ValueError( 

460 f"Please specify the number of steps via the `{arg_name}` argument." 

461 ) 

462 

463 val_gen = data_utils.is_generator_or_sequence( 

464 validation_data 

465 ) or isinstance(validation_data, tf.data.Iterator) 

466 if ( 

467 val_gen 

468 and not isinstance(validation_data, data_utils.Sequence) 

469 and not validation_steps 

470 ): 

471 raise ValueError("Please specify the `validation_steps` argument.") 

472 

473 if any(k != "steps" for k in kwargs): 

474 raise ValueError( 

475 f"Invalid arguments passed: {[k for k in kwargs if k != 'steps']}" 

476 ) 

477 

478 

479def convert_to_generator_like( 

480 data, batch_size=None, steps_per_epoch=None, epochs=1, shuffle=False 

481): 

482 """Make a generator out of NumPy or EagerTensor inputs. 

483 

484 Args: 

485 data: Either a generator or `keras.utils.data_utils.Sequence` object or 

486 `Dataset`, `Iterator`, or a {1,2,3}-tuple of NumPy arrays or 

487 EagerTensors. If a tuple, the elements represent `(x, y, 

488 sample_weights)` and may be `None` or `[None]`. 

489 batch_size: Used when creating a generator out of tuples of NumPy arrays 

490 or EagerTensors. 

491 steps_per_epoch: Steps of the generator to run each epoch. If `None` the 

492 number of steps will be read from the data (for 

493 `keras.utils.data_utils.Sequence` types). 

494 epochs: Total number of epochs to run. 

495 shuffle: Whether the data should be shuffled. 

496 

497 Returns: 

498 - Generator, `keras.utils.data_utils.Sequence`, or `Iterator`. 

499 

500 Raises: 

501 - ValueError: If `batch_size` is not provided for NumPy or EagerTensor 

502 inputs. 

503 """ 

504 if isinstance(data, tuple): 

505 # Scrub `Nones` that might have been passed for `targets`, 

506 # `sample_weights`. 

507 data = tuple( 

508 ele 

509 for ele in data 

510 if not all(e is None for e in tf.nest.flatten(ele)) 

511 ) 

512 

513 if data_utils.is_generator_or_sequence(data) or isinstance( 

514 data, tf.data.Iterator 

515 ): 

516 if isinstance(data, data_utils.Sequence): 

517 if steps_per_epoch is None: 

518 steps_per_epoch = len(data) 

519 return data, steps_per_epoch 

520 if isinstance(data, tf.data.Dataset): 

521 return tf.compat.v1.data.make_one_shot_iterator(data), steps_per_epoch 

522 

523 # Create generator from NumPy or EagerTensor Input. 

524 num_samples = int(tf.nest.flatten(data)[0].shape[0]) 

525 if batch_size is None: 

526 raise ValueError( 

527 "When passing input data as arrays, do not specify " 

528 "`steps_per_epoch`/`steps` argument. " 

529 "Please use `batch_size` instead." 

530 ) 

531 steps_per_epoch = int(math.ceil(num_samples / batch_size)) 

532 

533 def _gen(data): 

534 """Makes a generator out of a structure of NumPy/EagerTensors.""" 

535 index_array = np.arange(num_samples) 

536 for _ in range(epochs): 

537 if shuffle: 

538 np.random.shuffle(index_array) 

539 batches = generic_utils.make_batches(num_samples, batch_size) 

540 for batch_start, batch_end in batches: 

541 batch_ids = index_array[batch_start:batch_end] 

542 flat_batch_data = training_utils.slice_arrays( 

543 tf.nest.flatten(data), batch_ids, contiguous=(not shuffle) 

544 ) 

545 yield tf.nest.pack_sequence_as(data, flat_batch_data) 

546 

547 return _gen(data), steps_per_epoch 

548 

549 

550def _make_enqueued_generator( 

551 generator, 

552 workers=1, 

553 use_multiprocessing=False, 

554 max_queue_size=10, 

555 shuffle=False, 

556): 

557 """Create a buffered queue of next elements of the generator.""" 

558 is_sequence = isinstance(generator, data_utils.Sequence) 

559 enqueuer = None 

560 if workers > 0: 

561 if is_sequence: 

562 enqueuer = data_utils.OrderedEnqueuer( 

563 generator, 

564 use_multiprocessing=use_multiprocessing, 

565 shuffle=shuffle, 

566 ) 

567 else: 

568 enqueuer = data_utils.GeneratorEnqueuer( 

569 generator, use_multiprocessing=use_multiprocessing 

570 ) 

571 enqueuer.start(workers=workers, max_queue_size=max_queue_size) 

572 output_generator = enqueuer.get() 

573 else: 

574 if is_sequence: 

575 output_generator = data_utils.iter_sequence_infinite(generator) 

576 else: 

577 output_generator = generator 

578 return output_generator, enqueuer 

579 

580 

581def _make_execution_function(model, mode, class_weight=None): 

582 """Makes function to run one step of model execution.""" 

583 if mode == ModeKeys.TRAIN: 

584 f = functools.partial(model.train_on_batch, class_weight=class_weight) 

585 elif mode == ModeKeys.TEST: 

586 f = model.test_on_batch 

587 else: 

588 # Match signature of other modes to allow 

589 # 1, 2, or 3-tuples from generator 

590 def predict_on_batch(x, y=None, sample_weights=None): 

591 return model.predict_on_batch(x) 

592 

593 f = predict_on_batch 

594 

595 # Maintain stateful metrics across batch-level calls. 

596 if mode != ModeKeys.PREDICT: 

597 f = functools.partial(f, reset_metrics=False) 

598 

599 return f 

600 

601 

602def _get_num_samples_or_steps(data, steps_per_epoch): 

603 """Returns number of samples or steps, and whether to use steps count 

604 mode.""" 

605 flat_inputs = tf.nest.flatten(data) 

606 if hasattr(flat_inputs[0], "shape"): 

607 return int(flat_inputs[0].shape[0]), False 

608 return steps_per_epoch, True 

609 

610 

611class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop): 

612 """Generator-like. 

613 

614 Input is Python generator, or Sequence object. 

615 

616 The difference between this class and `GeneratorLikeTrainingFunction` is 

617 that this class only handles inputs that with x, y and sample_weight fused 

618 into one param. 

619 """ 

620 

621 def fit( 

622 self, 

623 model, 

624 x=None, 

625 y=None, 

626 batch_size=None, 

627 epochs=1, 

628 verbose=1, 

629 callbacks=None, 

630 validation_split=0.0, 

631 validation_data=None, 

632 shuffle=True, 

633 class_weight=None, 

634 sample_weight=None, 

635 initial_epoch=0, 

636 steps_per_epoch=None, 

637 validation_steps=None, 

638 validation_freq=1, 

639 max_queue_size=10, 

640 workers=1, 

641 use_multiprocessing=False, 

642 ): 

643 model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x) 

644 training_utils_v1.check_generator_arguments( 

645 y, sample_weight, validation_split=validation_split 

646 ) 

647 return fit_generator( 

648 model, 

649 x, 

650 steps_per_epoch=steps_per_epoch, 

651 epochs=epochs, 

652 verbose=verbose, 

653 callbacks=callbacks, 

654 validation_data=validation_data, 

655 validation_steps=validation_steps, 

656 validation_freq=validation_freq, 

657 class_weight=class_weight, 

658 max_queue_size=max_queue_size, 

659 workers=workers, 

660 use_multiprocessing=use_multiprocessing, 

661 shuffle=shuffle, 

662 initial_epoch=initial_epoch, 

663 steps_name="steps_per_epoch", 

664 ) 

665 

666 def evaluate( 

667 self, 

668 model, 

669 x=None, 

670 y=None, 

671 batch_size=None, 

672 verbose=1, 

673 sample_weight=None, 

674 steps=None, 

675 callbacks=None, 

676 max_queue_size=10, 

677 workers=1, 

678 use_multiprocessing=False, 

679 ): 

680 model._validate_or_infer_batch_size(batch_size, steps, x) 

681 training_utils_v1.check_generator_arguments(y, sample_weight) 

682 return evaluate_generator( 

683 model, 

684 x, 

685 steps=steps, 

686 verbose=verbose, 

687 callbacks=callbacks, 

688 max_queue_size=max_queue_size, 

689 workers=workers, 

690 use_multiprocessing=use_multiprocessing, 

691 ) 

692 

693 def predict( 

694 self, 

695 model, 

696 x, 

697 batch_size=None, 

698 verbose=0, 

699 steps=None, 

700 callbacks=None, 

701 max_queue_size=10, 

702 workers=1, 

703 use_multiprocessing=False, 

704 ): 

705 model._validate_or_infer_batch_size(batch_size, steps, x) 

706 return predict_generator( 

707 model, 

708 x, 

709 steps=steps, 

710 verbose=verbose, 

711 callbacks=callbacks, 

712 max_queue_size=max_queue_size, 

713 workers=workers, 

714 use_multiprocessing=use_multiprocessing, 

715 ) 

716 

717 

718class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop): 

719 """A non-distributed Dataset or iterator in eager execution.""" 

720 

721 def fit( 

722 self, 

723 model, 

724 x=None, 

725 y=None, 

726 batch_size=None, 

727 epochs=1, 

728 verbose=1, 

729 callbacks=None, 

730 validation_split=0.0, 

731 validation_data=None, 

732 shuffle=True, 

733 class_weight=None, 

734 sample_weight=None, 

735 initial_epoch=0, 

736 steps_per_epoch=None, 

737 validation_steps=None, 

738 validation_freq=1, 

739 **kwargs, 

740 ): 

741 model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x) 

742 # Make sure that y, sample_weights, validation_split are not passed. 

743 training_utils_v1.validate_dataset_input( 

744 x, y, sample_weight, validation_split 

745 ) 

746 if ( 

747 isinstance(x, (tf.compat.v1.data.Dataset, tf.data.Dataset)) 

748 and shuffle 

749 ): 

750 training_utils_v1.verify_dataset_shuffled(x) 

751 

752 return fit_generator( 

753 model, 

754 x, 

755 steps_per_epoch=steps_per_epoch, 

756 epochs=epochs, 

757 verbose=verbose, 

758 callbacks=callbacks, 

759 validation_data=validation_data, 

760 validation_steps=validation_steps, 

761 validation_freq=validation_freq, 

762 class_weight=class_weight, 

763 workers=0, 

764 shuffle=shuffle, 

765 initial_epoch=initial_epoch, 

766 steps_name="steps_per_epoch", 

767 ) 

768 

769 def evaluate( 

770 self, 

771 model, 

772 x=None, 

773 y=None, 

774 batch_size=None, 

775 verbose=1, 

776 sample_weight=None, 

777 steps=None, 

778 callbacks=None, 

779 **kwargs, 

780 ): 

781 model._validate_or_infer_batch_size(batch_size, steps, x) 

782 # Make sure that y, sample_weights, validation_split are not passed. 

783 training_utils_v1.validate_dataset_input(x, y, sample_weight) 

784 return evaluate_generator( 

785 model, 

786 x, 

787 steps=steps, 

788 verbose=verbose, 

789 workers=0, 

790 callbacks=callbacks, 

791 ) 

792 

793 def predict( 

794 self, 

795 model, 

796 x, 

797 batch_size=None, 

798 verbose=0, 

799 steps=None, 

800 callbacks=None, 

801 **kwargs, 

802 ): 

803 model._validate_or_infer_batch_size(batch_size, steps, x) 

804 return predict_generator( 

805 model, 

806 x, 

807 steps=steps, 

808 verbose=verbose, 

809 workers=0, 

810 callbacks=callbacks, 

811 ) 

812 

813 

814class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop): 

815 """TrainingLoop that handle inputs like python generator. 

816 

817 This is the default handler for most of the input data types, includes 

818 symbolic tensors or Numpy array-like, Datasets and iterators in graph mode 

819 (since they generate symbolic tensors). This Function is used to handle 

820 model with `run_eagerly` = True. 

821 """ 

822 

823 def fit( 

824 self, 

825 model, 

826 x=None, 

827 y=None, 

828 batch_size=None, 

829 epochs=1, 

830 verbose=1, 

831 callbacks=None, 

832 validation_split=0.0, 

833 validation_data=None, 

834 shuffle=True, 

835 class_weight=None, 

836 sample_weight=None, 

837 initial_epoch=0, 

838 steps_per_epoch=None, 

839 validation_steps=None, 

840 validation_freq=1, 

841 **kwargs, 

842 ): 

843 batch_size = model._validate_or_infer_batch_size( 

844 batch_size, steps_per_epoch, x 

845 ) 

846 x, y, sample_weights = model._standardize_user_data( 

847 x, 

848 y, 

849 sample_weight=sample_weight, 

850 class_weight=class_weight, 

851 batch_size=batch_size, 

852 check_steps=True, 

853 steps_name="steps_per_epoch", 

854 steps=steps_per_epoch, 

855 validation_split=validation_split, 

856 shuffle=shuffle, 

857 ) 

858 

859 if validation_data: 

860 validation_data = model._prepare_validation_data( 

861 validation_data, batch_size, validation_steps 

862 ) 

863 elif validation_split and 0.0 < validation_split < 1.0: 

864 ( 

865 x, 

866 y, 

867 sample_weights, 

868 val_x, 

869 val_y, 

870 val_sample_weights, 

871 ) = training_utils_v1.split_training_and_validation_data( 

872 x, y, sample_weights, validation_split 

873 ) 

874 validation_data = (val_x, val_y, val_sample_weights) 

875 else: 

876 if validation_steps: 

877 raise ValueError( 

878 "`validation_steps` should not be specified if " 

879 "`validation_data` is None." 

880 ) 

881 

882 return fit_generator( 

883 model, 

884 (x, y, sample_weights), 

885 steps_per_epoch=steps_per_epoch, 

886 batch_size=batch_size, 

887 epochs=epochs, 

888 verbose=verbose, 

889 callbacks=callbacks, 

890 validation_data=validation_data, 

891 validation_steps=validation_steps, 

892 validation_freq=validation_freq, 

893 workers=0, 

894 shuffle=shuffle, 

895 initial_epoch=initial_epoch, 

896 steps_name="steps_per_epoch", 

897 ) 

898 

899 def evaluate( 

900 self, 

901 model, 

902 x=None, 

903 y=None, 

904 batch_size=None, 

905 verbose=1, 

906 sample_weight=None, 

907 steps=None, 

908 callbacks=None, 

909 **kwargs, 

910 ): 

911 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 

912 x, y, sample_weights = model._standardize_user_data( 

913 x, 

914 y, 

915 sample_weight=sample_weight, 

916 batch_size=batch_size, 

917 check_steps=True, 

918 steps_name="steps", 

919 steps=steps, 

920 ) 

921 return evaluate_generator( 

922 model, 

923 (x, y, sample_weights), 

924 steps=steps, 

925 batch_size=batch_size, 

926 verbose=verbose, 

927 workers=0, 

928 callbacks=callbacks, 

929 ) 

930 

931 def predict( 

932 self, 

933 model, 

934 x, 

935 batch_size=None, 

936 verbose=0, 

937 steps=None, 

938 callbacks=None, 

939 **kwargs, 

940 ): 

941 batch_size = model._validate_or_infer_batch_size(batch_size, steps, x) 

942 x, _, _ = model._standardize_user_data( 

943 x, check_steps=True, steps_name="steps", steps=steps 

944 ) 

945 return predict_generator( 

946 model, 

947 x, 

948 steps=steps, 

949 batch_size=batch_size, 

950 verbose=verbose, 

951 workers=0, 

952 callbacks=callbacks, 

953 ) 

954