Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training_v1.py: 18%

1053 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""V1 Training-related part of the Keras engine.""" 

16import collections 

17import warnings 

18 

19import numpy as np 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src import backend 

23from keras.src import losses 

24from keras.src import metrics as metrics_module 

25from keras.src import optimizers 

26from keras.src.distribute import distributed_training_utils 

27from keras.src.distribute import distributed_training_utils_v1 

28from keras.src.engine import base_layer 

29from keras.src.engine import training as training_lib 

30from keras.src.engine import training_arrays_v1 

31from keras.src.engine import training_distributed_v1 

32from keras.src.engine import training_eager_v1 

33from keras.src.engine import training_generator_v1 

34from keras.src.engine import training_utils 

35from keras.src.engine import training_utils_v1 

36from keras.src.mixed_precision import loss_scale_optimizer 

37from keras.src.optimizers import optimizer_v1 

38from keras.src.optimizers.legacy import optimizer_v2 

39from keras.src.saving.legacy import saving_utils 

40from keras.src.saving.legacy.saved_model import model_serialization 

41from keras.src.utils import data_utils 

42from keras.src.utils import layer_utils 

43from keras.src.utils import losses_utils 

44from keras.src.utils import tf_inspect 

45from keras.src.utils import tf_utils 

46from keras.src.utils.mode_keys import ModeKeys 

47 

48# isort: off 

49from tensorflow.python.platform import tf_logging as logging 

50 

51try: 

52 from scipy.sparse import issparse 

53except ImportError: 

54 issparse = None 

55 

56 

57class Model(training_lib.Model): 

58 """A model groups layers into an object with training & inference features. 

59 

60 There are two ways to instantiate a `Model`: 

61 

62 1 - With the "functional API", where you start from `Input`, 

63 you chain layer calls to specify the model's forward pass, 

64 and finally you create your model from inputs and outputs: 

65 

66 ```python 

67 import tensorflow as tf 

68 

69 inputs = tf.keras.Input(shape=(3,)) 

70 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) 

71 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) 

72 model = tf.keras.Model(inputs=inputs, outputs=outputs) 

73 ``` 

74 

75 2 - By subclassing the `Model` class: in that case, you should define your 

76 layers in `__init__` and you should implement the model's forward pass 

77 in `call`. 

78 

79 ```python 

80 import tensorflow as tf 

81 

82 class MyModel(tf.keras.Model): 

83 

84 def __init__(self): 

85 super().__init__() 

86 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 

87 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 

88 

89 def call(self, inputs): 

90 x = self.dense1(inputs) 

91 return self.dense2(x) 

92 

93 model = MyModel() 

94 ``` 

95 

96 If you subclass `Model`, you can optionally have 

97 a `training` argument (boolean) in `call`, which you can use to specify 

98 a different behavior in training and inference: 

99 

100 ```python 

101 import tensorflow as tf 

102 

103 class MyModel(tf.keras.Model): 

104 

105 def __init__(self): 

106 super().__init__() 

107 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 

108 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 

109 self.dropout = tf.keras.layers.Dropout(0.5) 

110 

111 def call(self, inputs, training=False): 

112 x = self.dense1(inputs) 

113 if training: 

114 x = self.dropout(x, training=training) 

115 return self.dense2(x) 

116 

117 model = MyModel() 

118 ``` 

119 """ 

120 

121 def __init__(self, *args, **kwargs): 

122 super().__init__(*args, **kwargs) 

123 # initializing _distribution_strategy here since it is possible to call 

124 # predict on a model without compiling it. 

125 self._distribution_strategy = None 

126 self._compile_time_distribution_strategy = None 

127 if ( 

128 tf.compat.v1.executing_eagerly_outside_functions() 

129 and tf.distribute.has_strategy() 

130 ): 

131 self._set_strategy(tf.distribute.get_strategy()) 

132 

133 # This flag is used to track if the user is using the deprecated path of 

134 # passing distribution strategy to compile rather than creating the 

135 # model under distribution strategy scope. 

136 self._compile_distribution = False 

137 

138 self._run_eagerly = None 

139 self._experimental_run_tf_function = ( 

140 tf.compat.v1.executing_eagerly_outside_functions() 

141 ) 

142 

143 self._v1_compile_was_called = False 

144 

145 def _init_batch_counters(self): 

146 pass # Batch counters should not be created in legacy graph mode. 

147 

148 @tf.__internal__.tracking.no_automatic_dependency_tracking 

149 def _set_strategy(self, strategy): 

150 self._compile_time_distribution_strategy = strategy 

151 

152 def get_weights(self): 

153 """Retrieves the weights of the model. 

154 

155 Returns: 

156 A flat list of Numpy arrays. 

157 """ 

158 strategy = ( 

159 self._distribution_strategy 

160 or self._compile_time_distribution_strategy 

161 ) 

162 if strategy: 

163 with strategy.scope(): 

164 return base_layer.Layer.get_weights(self) 

165 return base_layer.Layer.get_weights(self) 

166 

167 def load_weights(self, filepath, by_name=False, skip_mismatch=False): 

168 """Loads all layer weights, either from a TensorFlow or an HDF5 file. 

169 

170 If `by_name` is False weights are loaded based on the network's 

171 topology. This means the architecture should be the same as when the 

172 weights were saved. Note that layers that don't have weights are not 

173 taken into account in the topological ordering, so adding or removing 

174 layers is fine as long as they don't have weights. 

175 

176 If `by_name` is True, weights are loaded into layers only if they share 

177 the same name. This is useful for fine-tuning or transfer-learning 

178 models where some of the layers have changed. 

179 

180 Only topological loading (`by_name=False`) is supported when loading 

181 weights from the TensorFlow format. Note that topological loading 

182 differs slightly between TensorFlow and HDF5 formats for user-defined 

183 classes inheriting from `tf.keras.Model`: HDF5 loads based on a 

184 flattened list of weights, while the TensorFlow format loads based on 

185 the object-local names of attributes to which layers are assigned in the 

186 `Model`'s constructor. 

187 

188 Args: 

189 filepath: String, path to the weights file to load. For weight files 

190 in TensorFlow format, this is the file prefix (the same as was 

191 passed to `save_weights`). 

192 by_name: Boolean, whether to load weights by name or by topological 

193 order. Only topological loading is supported for weight files in 

194 TensorFlow format. 

195 skip_mismatch: Boolean, whether to skip loading of layers where 

196 there is a mismatch in the number of weights, or a mismatch in 

197 the shape of the weight (only valid when `by_name=True`). 

198 

199 Returns: 

200 When loading a weight file in TensorFlow format, returns the same 

201 status object as `tf.train.Checkpoint.restore`. When graph building, 

202 restore ops are run automatically as soon as the network is built 

203 (on first call for user-defined classes inheriting from `Model`, 

204 immediately if it is already built). 

205 

206 When loading weights in HDF5 format, returns `None`. 

207 

208 Raises: 

209 ImportError: If h5py is not available and the weight file is in HDF5 

210 format. 

211 ValueError: If `skip_mismatch` is set to `True` when `by_name` is 

212 `False`. 

213 """ 

214 if backend.is_tpu_strategy(self._distribution_strategy): 

215 if self._distribution_strategy.extended.steps_per_run > 1 and ( 

216 not saving_utils.is_hdf5_filepath(filepath) 

217 ): 

218 raise ValueError( 

219 "Load weights is not yet supported with TPUStrategy " 

220 "with steps_per_run greater than 1." 

221 ) 

222 return super().load_weights( 

223 filepath, by_name=by_name, skip_mismatch=skip_mismatch 

224 ) 

225 

226 @tf.__internal__.tracking.no_automatic_dependency_tracking 

227 def compile( 

228 self, 

229 optimizer="rmsprop", 

230 loss=None, 

231 metrics=None, 

232 loss_weights=None, 

233 sample_weight_mode=None, 

234 weighted_metrics=None, 

235 target_tensors=None, 

236 distribute=None, 

237 **kwargs, 

238 ): 

239 """Configures the model for training. 

240 

241 Args: 

242 optimizer: String (name of optimizer) or optimizer instance. 

243 See `tf.keras.optimizers`. 

244 loss: String (name of objective function), objective function or 

245 `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An 

246 objective function is any callable with the signature 

247 `scalar_loss = fn(y_true, y_pred)`. If the model has multiple 

248 outputs, you can use a different loss on each output by passing 

249 a dictionary or a list of losses. The loss value that will be 

250 minimized by the model will then be the sum of all individual 

251 losses. 

252 metrics: List of metrics to be evaluated by the model during 

253 training and testing. Typically you will use 

254 `metrics=['accuracy']`. To specify different metrics for 

255 different outputs of a multi-output model, you could also pass a 

256 dictionary, such as `metrics={'output_a': 'accuracy', 

257 'output_b': ['accuracy', 'mse']}`. You can also pass a list 

258 (len = len(outputs)) of lists of metrics such as 

259 `metrics=[['accuracy'], ['accuracy', 'mse']]` or 

260 `metrics=['accuracy', ['accuracy', 'mse']]`. 

261 loss_weights: Optional list or dictionary specifying scalar 

262 coefficients (Python floats) to weight the loss contributions 

263 of different model outputs. 

264 The loss value that will be minimized by the model 

265 will then be the *weighted sum* of all individual losses, 

266 weighted by the `loss_weights` coefficients. 

267 If a list, it is expected to have a 1:1 mapping 

268 to the model's outputs. If a tensor, it is expected to map 

269 output names (strings) to scalar coefficients. 

270 sample_weight_mode: If you need to do timestep-wise 

271 sample weighting (2D weights), set this to `"temporal"`. 

272 `None` becomes sample-wise weights (1D). 

273 If the model has multiple outputs, you can use a different 

274 `sample_weight_mode` on each output by passing a 

275 dictionary or a list of modes. Defaults to `None`. 

276 weighted_metrics: List of metrics to be evaluated and weighted 

277 by sample_weight or class_weight during training and testing. 

278 target_tensors: By default, Keras will create placeholders for the 

279 model's target, which will be fed with the target data during 

280 training. If instead you would like to use your own 

281 target tensors (in turn, Keras will not expect external 

282 Numpy data for these targets at training time), you 

283 can specify them via the `target_tensors` argument. It can be 

284 a single tensor (for a single-output model), a list of tensors, 

285 or a dict mapping output names to target tensors. 

286 distribute: NOT SUPPORTED IN TF 2.0, please create and compile the 

287 model under distribution strategy scope instead of passing it to 

288 compile. 

289 **kwargs: Any additional arguments. 

290 

291 Raises: 

292 ValueError: In case of invalid arguments for 

293 `optimizer`, `loss`, `metrics` or `sample_weight_mode`. 

294 """ 

295 self._assert_built_as_v1() 

296 self._run_eagerly = kwargs.pop("run_eagerly", None) 

297 self._experimental_run_tf_function = kwargs.pop( 

298 "experimental_run_tf_function", True 

299 ) 

300 self._v1_compile_was_called = True 

301 

302 # Prepare Session arguments (legacy). 

303 kwargs.pop("cloning", None) # Legacy DistStrat argument, never used. 

304 self._from_serialized = kwargs.pop("from_serialized", False) 

305 allowed_kwargs = {"feed_dict", "fetches", "options", "run_metadata"} 

306 unknown_kwargs = set(kwargs.keys()) - allowed_kwargs 

307 if unknown_kwargs: 

308 raise TypeError( 

309 f"Invalid keyword argument(s) in `compile`: {unknown_kwargs}" 

310 ) 

311 self._function_kwargs = kwargs 

312 if self._function_kwargs: 

313 self._experimental_run_tf_function = False 

314 if self.run_eagerly: 

315 raise ValueError( 

316 "Session keyword arguments are not supported " 

317 "when `run_eagerly=True`. You passed the following " 

318 "Session arguments: %s" % (self._function_kwargs,) 

319 ) 

320 

321 self._set_optimizer(optimizer) 

322 is_any_keras_optimizer_v1 = any( 

323 ( 

324 isinstance(opt, optimizer_v1.Optimizer) 

325 and not isinstance(opt, optimizer_v1.TFOptimizer) 

326 ) 

327 for opt in tf.nest.flatten(self.optimizer) 

328 ) 

329 

330 if ( 

331 is_any_keras_optimizer_v1 

332 and tf.compat.v1.executing_eagerly_outside_functions() 

333 ): 

334 raise ValueError( 

335 "`tf.compat.v1.keras` Optimizer (", 

336 optimizer, 

337 ") is " 

338 "not supported when eager execution is enabled. Use a " 

339 "`tf.keras` Optimizer instead, or disable eager " 

340 "execution.", 

341 ) 

342 

343 if ( 

344 target_tensors is not None 

345 ) or not tf.compat.v1.executing_eagerly_outside_functions(): 

346 # Fallback out of things that aren't supported with v2 loops 

347 self._experimental_run_tf_function = False 

348 

349 if distribute is not None: 

350 if ( 

351 tf.__internal__.tf2.enabled() 

352 or self._experimental_run_tf_function 

353 ): 

354 raise ValueError( 

355 "Distribute argument in compile is not available in TF 2.0 " 

356 "please create the model under the distribution strategy " 

357 "scope." 

358 ) 

359 logging.warning( 

360 "Distribute argument in compile is deprecated please " 

361 "create the model under the distribution strategy scope." 

362 ) 

363 self._distribution_strategy = distribute 

364 self._compile_distribution = True 

365 else: 

366 if tf.distribute.has_strategy(): 

367 # When the user builds the model in the DS scope and cross 

368 # replica context we want distribution strategy to be set but 

369 # when building the replica copies of the models internally we 

370 # should not be compiling with distribution strategy and use the 

371 # default compilation path. 

372 if tf.distribute.in_cross_replica_context(): 

373 self._distribution_strategy = tf.distribute.get_strategy() 

374 

375 if isinstance( 

376 self._distribution_strategy, 

377 tf.compat.v1.distribute.experimental.ParameterServerStrategy, 

378 ): 

379 raise NotImplementedError( 

380 "`tf.compat.v1.distribute.experimental.ParameterServerStrategy`" 

381 " currently only works with the tf.Estimator API" 

382 ) 

383 

384 if isinstance( 

385 self._distribution_strategy, 

386 tf.distribute.experimental.ParameterServerStrategy, 

387 ): 

388 raise NotImplementedError( 

389 "`tf.distribute.experimental.ParameterServerStrategy` is only " 

390 "supported in TF2." 

391 ) 

392 

393 if not self._experimental_run_tf_function: 

394 self._validate_compile_param_for_distribution_strategy( 

395 self.run_eagerly, 

396 sample_weight_mode, 

397 target_tensors, 

398 weighted_metrics, 

399 ) 

400 # We've disabled automatic dependency tracking for this method, but do 

401 # want to add a checkpoint dependency on the optimizer if it's 

402 # trackable. 

403 if isinstance(self.optimizer, tf.__internal__.tracking.Trackable): 

404 self._track_trackable( 

405 self.optimizer, name="optimizer", overwrite=True 

406 ) 

407 self.loss = loss or {} 

408 self.loss_weights = loss_weights 

409 self.sample_weight_mode = sample_weight_mode 

410 self._compile_metrics = metrics or [] 

411 self._compile_weighted_metrics = weighted_metrics 

412 if self.run_eagerly and target_tensors is not None: 

413 raise ValueError( 

414 "target_tensors argument is not supported when " 

415 "running a model eagerly." 

416 ) 

417 

418 # _training_endpoints contains a list of _TrainingEndpoint object, which 

419 # has all the model output/target/loss and related metadata. 

420 self._training_endpoints = [] 

421 

422 # Used to freeze the behavior of the Model once `compile` has been 

423 # called. 

424 self._compiled_trainable_state = self._get_trainable_state() 

425 

426 # Set tf.distribute.Strategy specific parameters. 

427 self._distributed_model_cache = {} 

428 self._distributed_function_cache = {} 

429 

430 # Clear any `_eager_losses` that was added. 

431 self._clear_losses() 

432 

433 if ( 

434 not tf.executing_eagerly() 

435 and self._distribution_strategy is not None 

436 ): 

437 # Ensures a Session is created and configured correctly for 

438 # Distribution Strategy. 

439 backend.configure_and_create_distributed_session( 

440 self._distribution_strategy 

441 ) 

442 # Initialize model metric attributes. 

443 self._init_metric_attributes() 

444 if not self.built or not self.inputs or not self.outputs: 

445 # Model is not compilable because it does not know its number of 

446 # inputs and outputs, nor their shapes and names. We will compile 

447 # after the first time the model gets called on training data. 

448 return 

449 self._is_compiled = True 

450 base_layer.keras_api_gauge.get_cell("compile").set(True) 

451 

452 # Prepare list of loss functions, same size of model outputs. 

453 self.loss_functions = training_utils_v1.prepare_loss_functions( 

454 self.loss, self.output_names 

455 ) 

456 

457 target_tensors = self._process_target_tensor_for_compile(target_tensors) 

458 

459 for o, n, l, t in zip( 

460 self.outputs, self.output_names, self.loss_functions, target_tensors 

461 ): 

462 endpoint = _TrainingEndpoint(o, n, l) 

463 endpoint.create_training_target(t, run_eagerly=self.run_eagerly) 

464 self._training_endpoints.append(endpoint) 

465 

466 # Prepare list loss weights, same size of model outputs. 

467 training_utils_v1.prepare_loss_weights( 

468 self._training_endpoints, loss_weights 

469 ) 

470 

471 # Initialization for Eager mode execution. 

472 if self.run_eagerly: 

473 self._compile_eagerly(metrics, weighted_metrics, sample_weight_mode) 

474 return 

475 

476 with backend.get_graph().as_default(): 

477 # Save all metric attributes per output of the model. 

478 self._cache_output_metric_attributes(metrics, weighted_metrics) 

479 

480 # Set metric attributes on model. 

481 self._set_metric_attributes() 

482 

483 # Invoke metric functions (unweighted) for all the outputs. 

484 self._handle_metrics( 

485 self.outputs, 

486 targets=self._targets, 

487 skip_target_masks=self._prepare_skip_target_masks(), 

488 masks=self._prepare_output_masks(), 

489 ) 

490 

491 # Prepare sample weight modes. List with the same length as model 

492 # outputs. 

493 training_utils_v1.prepare_sample_weight_modes( 

494 self._training_endpoints, sample_weight_mode 

495 ) 

496 

497 # Creates the model loss and weighted metrics sub-graphs. 

498 self._compile_weights_loss_and_weighted_metrics() 

499 

500 # Functions for train, test and predict will 

501 # be compiled lazily when required. 

502 # This saves time when the user is not using all functions. 

503 self.train_function = None 

504 self.test_function = None 

505 self.predict_function = None 

506 

507 # Collected trainable weights, sorted in topological order. 

508 self._collected_trainable_weights = self.trainable_weights 

509 

510 # Validate all variables were correctly created in distribution 

511 # scope. 

512 if self._distribution_strategy and not self._compile_distribution: 

513 for v in self.variables: 

514 strategy = self._distribution_strategy 

515 if not strategy.extended.variable_created_in_scope(v): 

516 raise ValueError( 

517 "Variable (%s) was not created in the distribution " 

518 "strategy scope of (%s). It is most likely due to " 

519 "not all layers or the model or optimizer being " 

520 "created outside the distribution strategy scope. " 

521 "Try to make sure your code looks similar " 

522 "to the following.\n" 

523 "with strategy.scope():\n" 

524 " model=_create_model()\n" 

525 " model.compile(...)" % (v, strategy) 

526 ) 

527 

528 @tf.__internal__.tracking.no_automatic_dependency_tracking 

529 def _init_distributed_function_cache_if_not_compiled(self): 

530 if not hasattr(self, "_distributed_function_cache"): 

531 self._distributed_function_cache = {} 

532 

533 @property 

534 def metrics(self): 

535 """Returns the model's metrics added using `compile`, `add_metric` 

536 APIs.""" 

537 metrics = [] 

538 if self._is_compiled: 

539 if not hasattr(self, "_v1_compile_was_called"): 

540 # See b/155687393 for more details, the model is created as a v2 

541 # instance but converted to v1. Fallback to use base Model to 

542 # retrieve the metrics. 

543 return super().metrics 

544 metrics += self._compile_metric_functions 

545 metrics.extend(self._metrics) 

546 metrics.extend( 

547 _get_metrics_from_layers( 

548 list(self._flatten_layers(include_self=False, recursive=False)) 

549 ) 

550 ) 

551 return metrics 

552 

553 @property 

554 def metrics_names(self): 

555 """Returns the model's display labels for all outputs.""" 

556 

557 # This property includes all output names including `loss` and 

558 # per-output losses for backward compatibility. 

559 metrics_names = ["loss"] 

560 if self._is_compiled: 

561 if not hasattr(self, "_v1_compile_was_called"): 

562 # See b/155687393 for more details, the model is created as a v2 

563 # instance but converted to v1. Fallback to use base Model to 

564 # retrieve the metrics name 

565 return super().metrics_names 

566 

567 # Add output loss metric names to the metric names list. 

568 if len(self._training_endpoints) > 1: 

569 metrics_names.extend( 

570 [ 

571 e.loss_name() 

572 for e in self._training_endpoints 

573 if not e.should_skip_target() 

574 ] 

575 ) 

576 

577 # Add all metric names. 

578 metrics_names += [m.name for m in self.metrics] 

579 return metrics_names 

580 

581 @property 

582 def run_eagerly(self): 

583 """Settable attribute indicating whether the model should run eagerly. 

584 

585 Running eagerly means that your model will be run step by step, 

586 like Python code. Your model might run slower, but it should become 

587 easier for you to debug it by stepping into individual layer calls. 

588 

589 By default, we will attempt to compile your model to a static graph to 

590 deliver the best execution performance. 

591 

592 Returns: 

593 Boolean, whether the model should run eagerly. 

594 """ 

595 if self._run_eagerly is True and not tf.executing_eagerly(): 

596 raise ValueError( 

597 "You can only set `run_eagerly=True` if eager execution " 

598 "is enabled." 

599 ) 

600 if not self.dynamic: 

601 if self._run_eagerly is None: 

602 # Respect `tf.config.run_functions_eagerly` unless 

603 # `run_eagerly` was explicitly passed to `compile`. 

604 return tf.config.functions_run_eagerly() 

605 else: 

606 return self._run_eagerly 

607 else: 

608 if not tf.executing_eagerly(): 

609 raise ValueError( 

610 "Your model contains layers that can only be " 

611 "successfully run in eager execution (layers " 

612 "constructed with `dynamic=True`). " 

613 "You must enable eager execution with " 

614 "`tf.enable_eager_execution()`." 

615 ) 

616 if self._run_eagerly is False: 

617 # TODO(fchollet): consider using py_func to enable this. 

618 raise ValueError( 

619 "Your model contains layers that can only be " 

620 "successfully run in eager execution (layers " 

621 "constructed with `dynamic=True`). " 

622 "You cannot set `run_eagerly=False`." 

623 ) 

624 return tf.executing_eagerly() 

625 

626 @run_eagerly.setter 

627 def run_eagerly(self, value): 

628 self._run_eagerly = value 

629 

630 def _select_training_loop(self, inputs): 

631 """Select training loop for fit/eval/predict based on the inputs.""" 

632 # TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely 

633 # integrated into the data adapters in the v2 loop. We can't do this yet 

634 # because we currently have to fall back for unhandled data types. 

635 if isinstance(inputs, (tf.compat.v1.data.Iterator, tf.data.Iterator)): 

636 raise ValueError( 

637 "For performance reasons Keras `fit`, `evaluate` and" 

638 "`predict` accept tf.data `Datasets` as input but not " 

639 "iterators that have been manually generated from " 

640 "Datasets by users. Please directly pass in the " 

641 "original `Dataset` object instead of passing in " 

642 "`iter(dataset)`." 

643 ) 

644 

645 # Case 1: distribution strategy. 

646 if self._distribution_strategy: 

647 if self._in_multi_worker_mode(): 

648 return training_distributed_v1.DistributionMultiWorkerTrainingLoop( # noqa: E501 

649 training_distributed_v1.DistributionSingleWorkerTrainingLoop() # noqa: E501 

650 ) 

651 else: 

652 return ( 

653 training_distributed_v1.DistributionSingleWorkerTrainingLoop() # noqa: E501 

654 ) 

655 

656 # Case 2: generator-like. Input is Python generator, or Sequence object, 

657 # or a non-distributed Dataset or iterator in eager execution. 

658 if data_utils.is_generator_or_sequence(inputs): 

659 return training_generator_v1.GeneratorOrSequenceTrainingLoop() 

660 if training_utils_v1.is_eager_dataset_or_iterator(inputs): 

661 return training_generator_v1.EagerDatasetOrIteratorTrainingLoop() 

662 

663 # Case 3: Symbolic tensors or Numpy array-like. 

664 # This includes Datasets and iterators in graph mode (since they 

665 # generate symbolic tensors). 

666 if self.run_eagerly: 

667 return training_generator_v1.GeneratorLikeTrainingLoop() 

668 else: 

669 return training_arrays_v1.ArrayLikeTrainingLoop() 

670 

671 def fit( 

672 self, 

673 x=None, 

674 y=None, 

675 batch_size=None, 

676 epochs=1, 

677 verbose=1, 

678 callbacks=None, 

679 validation_split=0.0, 

680 validation_data=None, 

681 shuffle=True, 

682 class_weight=None, 

683 sample_weight=None, 

684 initial_epoch=0, 

685 steps_per_epoch=None, 

686 validation_steps=None, 

687 validation_freq=1, 

688 max_queue_size=10, 

689 workers=1, 

690 use_multiprocessing=False, 

691 **kwargs, 

692 ): 

693 """Trains the model for a fixed number of epochs (dataset iterations). 

694 

695 Args: 

696 x: Input data. It could be: 

697 - A Numpy array (or array-like), or a list of arrays 

698 (in case the model has multiple inputs). 

699 - A TensorFlow tensor, or a list of tensors 

700 (in case the model has multiple inputs). 

701 - A dict mapping input names to the corresponding array/tensors, 

702 if the model has named inputs. 

703 - A `tf.data` dataset. Should return a tuple 

704 of either `(inputs, targets)` or 

705 `(inputs, targets, sample_weights)`. 

706 - A generator or `keras.utils.Sequence` returning `(inputs, 

707 targets)` or `(inputs, targets, sample weights)`. 

708 y: Target data. Like the input data `x`, 

709 it could be either Numpy array(s) or TensorFlow tensor(s). 

710 It should be consistent with `x` (you cannot have Numpy inputs and 

711 tensor targets, or inversely). If `x` is a dataset, generator, 

712 or `keras.utils.Sequence` instance, `y` should 

713 not be specified (since targets will be obtained from `x`). 

714 batch_size: Integer or `None`. 

715 Number of samples per gradient update. 

716 If unspecified, `batch_size` will default to 32. 

717 Do not specify the `batch_size` if your data is in the 

718 form of symbolic tensors, datasets, 

719 generators, or `keras.utils.Sequence` instances (since they 

720 generate batches). 

721 epochs: Integer. Number of epochs to train the model. 

722 An epoch is an iteration over the entire `x` and `y` 

723 data provided. 

724 Note that in conjunction with `initial_epoch`, 

725 `epochs` is to be understood as "final epoch". 

726 The model is not trained for a number of iterations 

727 given by `epochs`, but merely until the epoch 

728 of index `epochs` is reached. 

729 verbose: 0, 1, or 2. Verbosity mode. 

730 0 = silent, 1 = progress bar, 2 = one line per epoch. 

731 Note that the progress bar is not particularly useful when 

732 logged to a file, so verbose=2 is recommended when not running 

733 interactively (eg, in a production environment). 

734 callbacks: List of `keras.callbacks.Callback` instances. 

735 List of callbacks to apply during training. 

736 See `tf.keras.callbacks`. 

737 validation_split: Float between 0 and 1. 

738 Fraction of the training data to be used as validation data. 

739 The model will set apart this fraction of the training data, 

740 will not train on it, and will evaluate 

741 the loss and any model metrics 

742 on this data at the end of each epoch. 

743 The validation data is selected from the last samples 

744 in the `x` and `y` data provided, before shuffling. This 

745 argument is not supported when `x` is a dataset, generator or 

746 `keras.utils.Sequence` instance. 

747 validation_data: Data on which to evaluate 

748 the loss and any model metrics at the end of each epoch. 

749 The model will not be trained on this data. 

750 `validation_data` will override `validation_split`. 

751 `validation_data` could be: 

752 - tuple `(x_val, y_val)` of Numpy arrays or tensors 

753 - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays 

754 - dataset 

755 For the first two cases, `batch_size` must be provided. 

756 For the last case, `validation_steps` could be provided. 

757 shuffle: Boolean (whether to shuffle the training data 

758 before each epoch) or str (for 'batch'). 

759 'batch' is a special option for dealing with the 

760 limitations of HDF5 data; it shuffles in batch-sized chunks. 

761 Has no effect when `steps_per_epoch` is not `None`. 

762 class_weight: Optional dictionary mapping class indices (integers) 

763 to a weight (float) value, used for weighting the loss function 

764 (during training only). 

765 This can be useful to tell the model to 

766 "pay more attention" to samples from 

767 an under-represented class. 

768 sample_weight: Optional Numpy array of weights for 

769 the training samples, used for weighting the loss function 

770 (during training only). You can either pass a flat (1D) 

771 Numpy array with the same length as the input samples 

772 (1:1 mapping between weights and samples), 

773 or in the case of temporal data, 

774 you can pass a 2D array with shape 

775 `(samples, sequence_length)`, 

776 to apply a different weight to every timestep of every sample. 

777 In this case you should make sure to specify 

778 `sample_weight_mode="temporal"` in `compile()`. This argument is 

779 not supported when `x` is a dataset, generator, or 

780 `keras.utils.Sequence` instance, instead provide the 

781 sample_weights as the third element of `x`. 

782 initial_epoch: Integer. 

783 Epoch at which to start training 

784 (useful for resuming a previous training run). 

785 steps_per_epoch: Integer or `None`. 

786 Total number of steps (batches of samples) 

787 before declaring one epoch finished and starting the 

788 next epoch. When training with input tensors such as 

789 TensorFlow data tensors, the default `None` is equal to 

790 the number of samples in your dataset divided by 

791 the batch size, or 1 if that cannot be determined. If x is a 

792 `tf.data` dataset, and 'steps_per_epoch' 

793 is None, the epoch will run until the input dataset is 

794 exhausted. This argument is not supported with array inputs. 

795 validation_steps: Only relevant if `validation_data` is provided and 

796 is a `tf.data` dataset. Total number of steps (batches of 

797 samples) to draw before stopping when performing validation at 

798 the end of every epoch. If 'validation_steps' is None, 

799 validation will run until the `validation_data` dataset is 

800 exhausted. In the case of a infinite dataset, it will run into a 

801 infinite loop. If 'validation_steps' is specified and only part 

802 of the dataset will be consumed, the evaluation will start from 

803 the beginning of the dataset at each epoch. This ensures that 

804 the same validation samples are used every time. 

805 validation_freq: Only relevant if validation data is provided. 

806 Integer or `collections.abc.Container` instance (e.g. list, 

807 tuple, etc.). If an integer, specifies how many training epochs 

808 to run before a new validation run is performed, e.g. 

809 `validation_freq=2` runs validation every 2 epochs. If a 

810 Container, specifies the epochs on which to run validation, e.g. 

811 `validation_freq=[1, 2, 10]` runs validation at the end of the 

812 1st, 2nd, and 10th epochs. 

813 max_queue_size: Integer. Used for generator or 

814 `keras.utils.Sequence` input only. Maximum size for the 

815 generator queue. If unspecified, `max_queue_size` will default 

816 to 10. 

817 workers: Integer. Used for generator or `keras.utils.Sequence` input 

818 only. Maximum number of processes to spin up 

819 when using process-based threading. If unspecified, `workers` 

820 will default to 1. If 0, will execute the generator on the main 

821 thread. 

822 use_multiprocessing: Boolean. Used for generator or 

823 `keras.utils.Sequence` input only. If `True`, use process-based 

824 threading. If unspecified, `use_multiprocessing` will default to 

825 `False`. Note that because this implementation relies on 

826 multiprocessing, you should not pass non-picklable arguments to 

827 the generator as they can't be passed easily to children 

828 processes. 

829 **kwargs: Used for backwards compatibility. 

830 

831 Returns: 

832 A `History` object. Its `History.history` attribute is 

833 a record of training loss values and metrics values 

834 at successive epochs, as well as validation loss values 

835 and validation metrics values (if applicable). 

836 

837 Raises: 

838 RuntimeError: If the model was never compiled. 

839 ValueError: In case of mismatch between the provided input data 

840 and what the model expects. 

841 """ 

842 self._assert_built_as_v1() 

843 base_layer.keras_api_gauge.get_cell("fit").set(True) 

844 # Legacy support 

845 if "nb_epoch" in kwargs: 

846 logging.warning( 

847 "The `nb_epoch` argument in `fit` has been renamed `epochs`." 

848 ) 

849 epochs = kwargs.pop("nb_epoch") 

850 if kwargs: 

851 raise TypeError("Unrecognized keyword arguments: " + str(kwargs)) 

852 self._assert_compile_was_called() 

853 self._check_call_args("fit") 

854 

855 func = self._select_training_loop(x) 

856 return func.fit( 

857 self, 

858 x=x, 

859 y=y, 

860 batch_size=batch_size, 

861 epochs=epochs, 

862 verbose=verbose, 

863 callbacks=callbacks, 

864 validation_split=validation_split, 

865 validation_data=validation_data, 

866 shuffle=shuffle, 

867 class_weight=class_weight, 

868 sample_weight=sample_weight, 

869 initial_epoch=initial_epoch, 

870 steps_per_epoch=steps_per_epoch, 

871 validation_steps=validation_steps, 

872 validation_freq=validation_freq, 

873 max_queue_size=max_queue_size, 

874 workers=workers, 

875 use_multiprocessing=use_multiprocessing, 

876 ) 

877 

878 def evaluate( 

879 self, 

880 x=None, 

881 y=None, 

882 batch_size=None, 

883 verbose=1, 

884 sample_weight=None, 

885 steps=None, 

886 callbacks=None, 

887 max_queue_size=10, 

888 workers=1, 

889 use_multiprocessing=False, 

890 ): 

891 """Returns the loss value & metrics values for the model in test mode. 

892 

893 Computation is done in batches (see the `batch_size` arg.) 

894 

895 Args: 

896 x: Input data. It could be: 

897 - A Numpy array (or array-like), or a list of arrays 

898 (in case the model has multiple inputs). 

899 - A TensorFlow tensor, or a list of tensors 

900 (in case the model has multiple inputs). 

901 - A dict mapping input names to the corresponding array/tensors, 

902 if the model has named inputs. 

903 - A `tf.data` dataset. 

904 - A generator or `keras.utils.Sequence` instance. 

905 y: Target data. Like the input data `x`, 

906 it could be either Numpy array(s) or TensorFlow tensor(s). 

907 It should be consistent with `x` (you cannot have Numpy inputs and 

908 tensor targets, or inversely). 

909 If `x` is a dataset, generator or 

910 `keras.utils.Sequence` instance, `y` should not be specified 

911 (since targets will be obtained from the iterator/dataset). 

912 batch_size: Integer or `None`. 

913 Number of samples per batch of computation. 

914 If unspecified, `batch_size` will default to 32. 

915 Do not specify the `batch_size` if your data is in the 

916 form of symbolic tensors, dataset, 

917 generators, or `keras.utils.Sequence` instances (since they 

918 generate batches). 

919 verbose: 0 or 1. Verbosity mode. 

920 0 = silent, 1 = progress bar. 

921 sample_weight: Optional Numpy array of weights for 

922 the test samples, used for weighting the loss function. 

923 You can either pass a flat (1D) 

924 Numpy array with the same length as the input samples 

925 (1:1 mapping between weights and samples), 

926 or in the case of temporal data, 

927 you can pass a 2D array with shape 

928 `(samples, sequence_length)`, 

929 to apply a different weight to every timestep of every sample. 

930 In this case you should make sure to specify 

931 `sample_weight_mode="temporal"` in `compile()`. This argument is 

932 not supported when `x` is a dataset, instead pass sample weights 

933 as the third element of `x`. 

934 steps: Integer or `None`. 

935 Total number of steps (batches of samples) 

936 before declaring the evaluation round finished. 

937 Ignored with the default value of `None`. 

938 If x is a `tf.data` dataset and `steps` is 

939 None, 'evaluate' will run until the dataset is exhausted. 

940 This argument is not supported with array inputs. 

941 callbacks: List of `keras.callbacks.Callback` instances. 

942 List of callbacks to apply during evaluation. 

943 See [callbacks](/api_docs/python/tf/keras/callbacks). 

944 max_queue_size: Integer. Used for generator or 

945 `keras.utils.Sequence` input only. Maximum size for the 

946 generator queue. If unspecified, `max_queue_size` will default 

947 to 10. 

948 workers: Integer. Used for generator or `keras.utils.Sequence` input 

949 only. Maximum number of processes to spin up when using 

950 process-based threading. If unspecified, `workers` will default 

951 to 1. If 0, will execute the generator on the main thread. 

952 use_multiprocessing: Boolean. Used for generator or 

953 `keras.utils.Sequence` input only. If `True`, use process-based 

954 threading. If unspecified, `use_multiprocessing` will default to 

955 `False`. Note that because this implementation relies on 

956 multiprocessing, you should not pass non-picklable arguments to 

957 the generator as they can't be passed easily to children 

958 processes. 

959 

960 Returns: 

961 Scalar test loss (if the model has a single output and no metrics) 

962 or list of scalars (if the model has multiple outputs 

963 and/or metrics). The attribute `model.metrics_names` will give you 

964 the display labels for the scalar outputs. 

965 

966 Raises: 

967 ValueError: in case of invalid arguments. 

968 """ 

969 self._assert_built_as_v1() 

970 base_layer.keras_api_gauge.get_cell("evaluate").set(True) 

971 self._assert_compile_was_called() 

972 self._check_call_args("evaluate") 

973 

974 func = self._select_training_loop(x) 

975 return func.evaluate( 

976 self, 

977 x=x, 

978 y=y, 

979 batch_size=batch_size, 

980 verbose=verbose, 

981 sample_weight=sample_weight, 

982 steps=steps, 

983 callbacks=callbacks, 

984 max_queue_size=max_queue_size, 

985 workers=workers, 

986 use_multiprocessing=use_multiprocessing, 

987 ) 

988 

989 def predict( 

990 self, 

991 x, 

992 batch_size=None, 

993 verbose=0, 

994 steps=None, 

995 callbacks=None, 

996 max_queue_size=10, 

997 workers=1, 

998 use_multiprocessing=False, 

999 ): 

1000 """Generates output predictions for the input samples. 

1001 

1002 Computation is done in batches (see the `batch_size` arg.) 

1003 

1004 Args: 

1005 x: Input samples. It could be: 

1006 - A Numpy array (or array-like), or a list of arrays 

1007 (in case the model has multiple inputs). 

1008 - A TensorFlow tensor, or a list of tensors 

1009 (in case the model has multiple inputs). 

1010 - A `tf.data` dataset. 

1011 - A generator or `keras.utils.Sequence` instance. 

1012 batch_size: Integer or `None`. 

1013 Number of samples per batch of computation. 

1014 If unspecified, `batch_size` will default to 32. 

1015 Do not specify the `batch_size` if your data is in the 

1016 form of symbolic tensors, dataset, 

1017 generators, or `keras.utils.Sequence` instances (since they 

1018 generate batches). 

1019 verbose: Verbosity mode, 0 or 1. 

1020 steps: Total number of steps (batches of samples) 

1021 before declaring the prediction round finished. 

1022 Ignored with the default value of `None`. If x is a `tf.data` 

1023 dataset and `steps` is None, `predict` will 

1024 run until the input dataset is exhausted. 

1025 callbacks: List of `keras.callbacks.Callback` instances. 

1026 List of callbacks to apply during prediction. 

1027 See [callbacks](/api_docs/python/tf/keras/callbacks). 

1028 max_queue_size: Integer. Used for generator or 

1029 `keras.utils.Sequence` input only. Maximum size for the 

1030 generator queue. If unspecified, `max_queue_size` will default 

1031 to 10. 

1032 workers: Integer. Used for generator or `keras.utils.Sequence` input 

1033 only. Maximum number of processes to spin up when using 

1034 process-based threading. If unspecified, `workers` will default 

1035 to 1. If 0, will execute the generator on the main thread. 

1036 use_multiprocessing: Boolean. Used for generator or 

1037 `keras.utils.Sequence` input only. If `True`, use process-based 

1038 threading. If unspecified, `use_multiprocessing` will default to 

1039 `False`. Note that because this implementation relies on 

1040 multiprocessing, you should not pass non-picklable arguments to 

1041 the generator as they can't be passed easily to children 

1042 processes. 

1043 

1044 

1045 Returns: 

1046 Numpy array(s) of predictions. 

1047 

1048 Raises: 

1049 ValueError: In case of mismatch between the provided 

1050 input data and the model's expectations, 

1051 or in case a stateful model receives a number of samples 

1052 that is not a multiple of the batch size. 

1053 """ 

1054 self._assert_built_as_v1() 

1055 base_layer.keras_api_gauge.get_cell("predict").set(True) 

1056 self._check_call_args("predict") 

1057 

1058 func = self._select_training_loop(x) 

1059 return func.predict( 

1060 self, 

1061 x=x, 

1062 batch_size=batch_size, 

1063 verbose=verbose, 

1064 steps=steps, 

1065 callbacks=callbacks, 

1066 max_queue_size=max_queue_size, 

1067 workers=workers, 

1068 use_multiprocessing=use_multiprocessing, 

1069 ) 

1070 

1071 def reset_metrics(self): 

1072 """Resets the state of metrics.""" 

1073 metrics = self._get_training_eval_metrics() 

1074 for m in metrics: 

1075 m.reset_state() 

1076 

1077 # Reset metrics on all the distributed (cloned) models. 

1078 if self._distribution_strategy: 

1079 distributed_training_utils_v1._reset_metrics(self) 

1080 

1081 def train_on_batch( 

1082 self, 

1083 x, 

1084 y=None, 

1085 sample_weight=None, 

1086 class_weight=None, 

1087 reset_metrics=True, 

1088 ): 

1089 """Runs a single gradient update on a single batch of data. 

1090 

1091 Args: 

1092 x: Input data. It could be: 

1093 - A Numpy array (or array-like), or a list of arrays 

1094 (in case the model has multiple inputs). 

1095 - A TensorFlow tensor, or a list of tensors 

1096 (in case the model has multiple inputs). 

1097 - A dict mapping input names to the corresponding array/tensors, 

1098 if the model has named inputs. 

1099 - A `tf.data` dataset. 

1100 y: Target data. Like the input data `x`, it could be either Numpy 

1101 array(s) or TensorFlow tensor(s). It should be consistent with `x` 

1102 (you cannot have Numpy inputs and tensor targets, or inversely). 

1103 If `x` is a dataset, `y` should not be specified 

1104 (since targets will be obtained from the iterator). 

1105 sample_weight: Optional array of the same length as x, containing 

1106 weights to apply to the model's loss for each sample. In the case 

1107 of temporal data, you can pass a 2D array with shape (samples, 

1108 sequence_length), to apply a different weight to every timestep of 

1109 every sample. In this case you should make sure to specify 

1110 sample_weight_mode="temporal" in compile(). This argument is not 

1111 supported when `x` is a dataset. 

1112 class_weight: Optional dictionary mapping class indices (integers) 

1113 to a weight (float) to apply to the model's loss for the samples 

1114 from this class during training. This can be useful to tell the 

1115 model to "pay more attention" to samples from an under-represented 

1116 class. 

1117 reset_metrics: If `True`, the metrics returned will be only for this 

1118 batch. If `False`, the metrics will be statefully accumulated 

1119 across batches. 

1120 

1121 Returns: 

1122 Scalar training loss 

1123 (if the model has a single output and no metrics) 

1124 or list of scalars (if the model has multiple outputs 

1125 and/or metrics). The attribute `model.metrics_names` will give you 

1126 the display labels for the scalar outputs. 

1127 

1128 Raises: 

1129 ValueError: In case of invalid user-provided arguments. 

1130 """ 

1131 self._assert_compile_was_called() 

1132 self._check_call_args("train_on_batch") 

1133 

1134 # If at this point we are in the replica context, then it is okay to 

1135 # execute the Eager code path. The expected way to get here is to call 

1136 # `fit` that calls `train_on_batch` on each replica. 

1137 if ( 

1138 self._distribution_strategy 

1139 and tf.distribute.in_cross_replica_context() 

1140 ): 

1141 raise NotImplementedError( 

1142 "`train_on_batch` is not supported for models " 

1143 "distributed with tf.distribute.Strategy." 

1144 ) 

1145 # Validate and standardize user data. 

1146 x, y, sample_weights = self._standardize_user_data( 

1147 x, 

1148 y, 

1149 sample_weight=sample_weight, 

1150 class_weight=class_weight, 

1151 extract_tensors_from_dataset=True, 

1152 ) 

1153 

1154 # If `self._distribution_strategy` is True, then we are in a replica 

1155 # context at this point because of the check above. `train_on_batch` is 

1156 # being run for each replica by `self._distribution_strategy` and the 

1157 # same code path as Eager is expected to be taken. 

1158 if self.run_eagerly or self._distribution_strategy: 

1159 output_dict = training_eager_v1.train_on_batch( 

1160 self, 

1161 x, 

1162 y, 

1163 sample_weights=sample_weights, 

1164 output_loss_metrics=self._output_loss_metrics, 

1165 ) 

1166 outputs = ( 

1167 output_dict["total_loss"] 

1168 + output_dict["output_losses"] 

1169 + output_dict["metrics"] 

1170 ) 

1171 outputs = [_non_none_constant_value(v) for v in outputs] 

1172 else: 

1173 x = training_utils_v1.ModelInputs(x).as_list() 

1174 ins = x + list(y or []) + list(sample_weights or []) 

1175 

1176 if not isinstance(backend.symbolic_learning_phase(), int): 

1177 ins += [True] # Add learning phase value. 

1178 

1179 self._update_sample_weight_modes(sample_weights=sample_weights) 

1180 self._make_train_function() 

1181 outputs = self.train_function(ins) 

1182 

1183 if reset_metrics: 

1184 self.reset_metrics() 

1185 

1186 if len(outputs) == 1: 

1187 return outputs[0] 

1188 return outputs 

1189 

1190 def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True): 

1191 """Test the model on a single batch of samples. 

1192 

1193 Args: 

1194 x: Input data. It could be: 

1195 - A Numpy array (or array-like), or a list of arrays 

1196 (in case the model has multiple inputs). 

1197 - A TensorFlow tensor, or a list of tensors 

1198 (in case the model has multiple inputs). 

1199 - A dict mapping input names to the corresponding array/tensors, 

1200 if the model has named inputs. 

1201 - A `tf.data` dataset. 

1202 y: Target data. Like the input data `x`, 

1203 it could be either Numpy array(s) or TensorFlow tensor(s). 

1204 It should be consistent with `x` (you cannot have Numpy inputs and 

1205 tensor targets, or inversely). If `x` is a dataset `y` should 

1206 not be specified (since targets will be obtained from the 

1207 iterator). 

1208 sample_weight: Optional array of the same length as x, containing 

1209 weights to apply to the model's loss for each sample. 

1210 In the case of temporal data, you can pass a 2D array 

1211 with shape (samples, sequence_length), 

1212 to apply a different weight to every timestep of every sample. 

1213 In this case you should make sure to specify 

1214 sample_weight_mode="temporal" in compile(). This argument is not 

1215 supported when `x` is a dataset. 

1216 reset_metrics: If `True`, the metrics returned will be only for this 

1217 batch. If `False`, the metrics will be statefully accumulated 

1218 across batches. 

1219 

1220 Returns: 

1221 Scalar test loss (if the model has a single output and no metrics) 

1222 or list of scalars (if the model has multiple outputs 

1223 and/or metrics). The attribute `model.metrics_names` will give you 

1224 the display labels for the scalar outputs. 

1225 

1226 Raises: 

1227 ValueError: In case of invalid user-provided arguments. 

1228 """ 

1229 self._assert_compile_was_called() 

1230 self._check_call_args("test_on_batch") 

1231 

1232 if ( 

1233 self._distribution_strategy 

1234 and tf.distribute.in_cross_replica_context() 

1235 ): 

1236 raise NotImplementedError( 

1237 "`test_on_batch` is not supported for models " 

1238 "distributed with tf.distribute.Strategy." 

1239 ) 

1240 # Validate and standardize user data. 

1241 x, y, sample_weights = self._standardize_user_data( 

1242 x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True 

1243 ) 

1244 

1245 # If `self._distribution_strategy` is True, then we are in a replica 

1246 # context at this point. 

1247 if self.run_eagerly or self._distribution_strategy: 

1248 output_dict = training_eager_v1.test_on_batch( 

1249 self, 

1250 x, 

1251 y, 

1252 sample_weights=sample_weights, 

1253 output_loss_metrics=self._output_loss_metrics, 

1254 ) 

1255 outputs = ( 

1256 output_dict["total_loss"] 

1257 + output_dict["output_losses"] 

1258 + output_dict["metrics"] 

1259 ) 

1260 outputs = [_non_none_constant_value(v) for v in outputs] 

1261 else: 

1262 x = training_utils_v1.ModelInputs(x).as_list() 

1263 inputs = x + list(y or []) + list(sample_weights or []) 

1264 

1265 self._update_sample_weight_modes(sample_weights=sample_weights) 

1266 self._make_test_function() 

1267 outputs = self.test_function(inputs) 

1268 

1269 if reset_metrics: 

1270 self.reset_metrics() 

1271 

1272 if len(outputs) == 1: 

1273 return outputs[0] 

1274 return outputs 

1275 

1276 def predict_on_batch(self, x): 

1277 """Returns predictions for a single batch of samples. 

1278 

1279 Args: 

1280 x: Input data. It could be: 

1281 - A Numpy array (or array-like), or a list of arrays 

1282 (in case the model has multiple inputs). 

1283 - A TensorFlow tensor, or a list of tensors 

1284 (in case the model has multiple inputs). 

1285 - A `tf.data` dataset. 

1286 

1287 Returns: 

1288 Numpy array(s) of predictions. 

1289 

1290 Raises: 

1291 ValueError: In case of mismatch between given number of inputs and 

1292 expectations of the model. 

1293 """ 

1294 self._check_call_args("predict_on_batch") 

1295 

1296 if ( 

1297 self._distribution_strategy 

1298 and tf.distribute.in_cross_replica_context() 

1299 ): 

1300 raise NotImplementedError( 

1301 "`predict_on_batch` is not supported for models distributed " 

1302 "with tf.distribute.Strategy." 

1303 ) 

1304 # Validate and standardize user data. 

1305 inputs, _, _ = self._standardize_user_data( 

1306 x, extract_tensors_from_dataset=True 

1307 ) 

1308 # If `self._distribution_strategy` is True, then we are in a replica 

1309 # context at this point. 

1310 if self.run_eagerly or self._distribution_strategy: 

1311 inputs = training_utils_v1.cast_if_floating_dtype(inputs) 

1312 if isinstance(inputs, collections.abc.Sequence): 

1313 # Unwrap lists with only one input, as we do when training on 

1314 # batch 

1315 if len(inputs) == 1: 

1316 inputs = inputs[0] 

1317 

1318 return self(inputs) 

1319 

1320 self._make_predict_function() 

1321 outputs = self.predict_function(inputs) 

1322 

1323 if len(outputs) == 1: 

1324 return outputs[0] 

1325 return outputs 

1326 

1327 def fit_generator( 

1328 self, 

1329 generator, 

1330 steps_per_epoch=None, 

1331 epochs=1, 

1332 verbose=1, 

1333 callbacks=None, 

1334 validation_data=None, 

1335 validation_steps=None, 

1336 validation_freq=1, 

1337 class_weight=None, 

1338 max_queue_size=10, 

1339 workers=1, 

1340 use_multiprocessing=False, 

1341 shuffle=True, 

1342 initial_epoch=0, 

1343 ): 

1344 """Fits the model on data yielded batch-by-batch by a Python generator. 

1345 

1346 DEPRECATED: 

1347 `Model.fit` now supports generators, so there is no longer any need to 

1348 use this endpoint. 

1349 """ 

1350 warnings.warn( 

1351 "`model.fit_generator` is deprecated and " 

1352 "will be removed in a future version. " 

1353 "Please use `Model.fit`, which supports generators.", 

1354 stacklevel=2, 

1355 ) 

1356 return self.fit( 

1357 generator, 

1358 steps_per_epoch=steps_per_epoch, 

1359 epochs=epochs, 

1360 verbose=verbose, 

1361 callbacks=callbacks, 

1362 validation_data=validation_data, 

1363 validation_steps=validation_steps, 

1364 validation_freq=validation_freq, 

1365 class_weight=class_weight, 

1366 max_queue_size=max_queue_size, 

1367 workers=workers, 

1368 use_multiprocessing=use_multiprocessing, 

1369 shuffle=shuffle, 

1370 initial_epoch=initial_epoch, 

1371 ) 

1372 

1373 def evaluate_generator( 

1374 self, 

1375 generator, 

1376 steps=None, 

1377 callbacks=None, 

1378 max_queue_size=10, 

1379 workers=1, 

1380 use_multiprocessing=False, 

1381 verbose=0, 

1382 ): 

1383 """Evaluates the model on a data generator. 

1384 

1385 DEPRECATED: 

1386 `Model.evaluate` now supports generators, so there is no longer any 

1387 need to use this endpoint. 

1388 """ 

1389 warnings.warn( 

1390 "`Model.evaluate_generator` is deprecated and " 

1391 "will be removed in a future version. " 

1392 "Please use `Model.evaluate`, which supports generators.", 

1393 stacklevel=2, 

1394 ) 

1395 self._check_call_args("evaluate_generator") 

1396 

1397 return self.evaluate( 

1398 generator, 

1399 steps=steps, 

1400 max_queue_size=max_queue_size, 

1401 workers=workers, 

1402 use_multiprocessing=use_multiprocessing, 

1403 verbose=verbose, 

1404 callbacks=callbacks, 

1405 ) 

1406 

1407 def predict_generator( 

1408 self, 

1409 generator, 

1410 steps=None, 

1411 callbacks=None, 

1412 max_queue_size=10, 

1413 workers=1, 

1414 use_multiprocessing=False, 

1415 verbose=0, 

1416 ): 

1417 """Generates predictions for the input samples from a data generator. 

1418 

1419 DEPRECATED: 

1420 `Model.predict` now supports generators, so there is no longer any 

1421 need to use this endpoint. 

1422 """ 

1423 warnings.warn( 

1424 "`Model.predict_generator` is deprecated and " 

1425 "will be removed in a future version. " 

1426 "Please use `Model.predict`, which supports generators.", 

1427 stacklevel=2, 

1428 ) 

1429 return self.predict( 

1430 generator, 

1431 steps=steps, 

1432 max_queue_size=max_queue_size, 

1433 workers=workers, 

1434 use_multiprocessing=use_multiprocessing, 

1435 verbose=verbose, 

1436 callbacks=callbacks, 

1437 ) 

1438 

1439 def _check_call_args(self, method_name): 

1440 """Check that `call` has only one positional arg.""" 

1441 # Always allow first arg, regardless of arg name. 

1442 fullargspec = self._call_spec.full_argspec 

1443 if fullargspec.defaults: 

1444 positional_args = fullargspec.args[: -len(fullargspec.defaults)] 

1445 else: 

1446 positional_args = fullargspec.args 

1447 if "training" in positional_args: 

1448 positional_args.remove("training") 

1449 

1450 # self and first arg can be positional. 

1451 if len(positional_args) > 2: 

1452 extra_args = positional_args[2:] 

1453 raise ValueError( 

1454 "Models passed to `" 

1455 + method_name 

1456 + "` can only have `training` " 

1457 "and the first argument in `call` as positional arguments, " 

1458 "found: " + str(extra_args) + "." 

1459 ) 

1460 

1461 def _set_optimizer(self, optimizer): 

1462 """Sets self.optimizer. 

1463 

1464 Sets self.optimizer to `optimizer`, potentially wrapping it with a 

1465 LossScaleOptimizer. 

1466 

1467 Args: 

1468 optimizer: The optimizer(s) to assign to self.optimizer. 

1469 """ 

1470 if isinstance(optimizer, (list, tuple)): 

1471 self.optimizer = [optimizers.get(opt) for opt in optimizer] 

1472 else: 

1473 self.optimizer = optimizers.get(optimizer) 

1474 

1475 if self._dtype_policy.name == "mixed_float16" and not isinstance( 

1476 self.optimizer, loss_scale_optimizer.LossScaleOptimizer 

1477 ): 

1478 if isinstance(self.optimizer, list): 

1479 raise ValueError( 

1480 'When the "mixed_float16" dtype policy is used, you ' 

1481 "can only pass a single optimizer. Using policy %s " 

1482 "and got optimizers: %s" % self._dtype_policy, 

1483 self.optimizer, 

1484 ) 

1485 if not isinstance(self.optimizer, optimizer_v2.OptimizerV2): 

1486 raise ValueError( 

1487 '"optimizer" must be an instance of ' 

1488 "tf.keras.optimizers.legacy.Optimizer when a dype policy " 

1489 "with a loss scale is used, but got: %s. Using policy: " 

1490 "%s" % (self.optimizer, self._dtype_policy) 

1491 ) 

1492 self.optimizer = loss_scale_optimizer.LossScaleOptimizer( 

1493 self.optimizer 

1494 ) 

1495 

1496 def _prepare_validation_data( 

1497 self, validation_data, batch_size, validation_steps 

1498 ): 

1499 """Unpack and check the validation data.""" 

1500 ( 

1501 val_x, 

1502 val_y, 

1503 val_sample_weights, 

1504 ) = training_utils_v1.unpack_validation_data(validation_data) 

1505 return self._standardize_user_data( 

1506 val_x, 

1507 val_y, 

1508 sample_weight=val_sample_weights, 

1509 batch_size=batch_size, 

1510 steps=validation_steps, 

1511 steps_name="validation_steps", 

1512 ) 

1513 

1514 def _validate_compile_param_for_distribution_strategy( 

1515 self, run_eagerly, sample_weight_mode, target_tensors, weighted_metrics 

1516 ): 

1517 # Validate that arguments passed by the user to `compile` are supported 

1518 # by tf.distribute.Strategy. 

1519 if self._distribution_strategy: 

1520 if sample_weight_mode: 

1521 raise NotImplementedError( 

1522 "sample_weight_mode is not supported with " 

1523 "tf.distribute.Strategy." 

1524 ) 

1525 if weighted_metrics: 

1526 raise NotImplementedError( 

1527 "weighted_metrics is not supported with " 

1528 "tf.distribute.Strategy." 

1529 ) 

1530 if target_tensors: 

1531 raise ValueError( 

1532 "target_tensors is not supported with " 

1533 "tf.distribute.Strategy." 

1534 ) 

1535 

1536 if run_eagerly: 

1537 raise ValueError( 

1538 "We currently do not support enabling `run_eagerly` with " 

1539 "distribution strategy." 

1540 ) 

1541 

1542 if distributed_training_utils_v1.is_distributing_by_cloning( 

1543 self 

1544 ) and (not self.built or not self.inputs or not self.outputs): 

1545 raise ValueError( 

1546 "We currently do not support distribution strategy with a " 

1547 "`Sequential` model that is created without `input_shape`/" 

1548 "`input_dim` set in its first layer or a subclassed model." 

1549 ) 

1550 

1551 def _process_target_tensor_for_compile(self, target_tensors): 

1552 if self.run_eagerly: 

1553 # target tensor is not supported with run_eagerly. Create a list 

1554 # with None as placeholder for each output. 

1555 return [None for _ in self.output_names] 

1556 

1557 if target_tensors is not None and not ( 

1558 isinstance(target_tensors, list) and target_tensors == [] 

1559 ): 

1560 if isinstance(target_tensors, list): 

1561 if len(target_tensors) != len(self.outputs): 

1562 raise ValueError( 

1563 "When passing a list as `target_tensors`, " 

1564 "it should have one entry per model output. " 

1565 "The model has %s outputs, " 

1566 "but you passed target_tensors=%s" 

1567 % (len(self.outputs), target_tensors) 

1568 ) 

1569 elif isinstance(target_tensors, dict): 

1570 unexpected_target_tensor_names = set( 

1571 target_tensors.keys() 

1572 ).difference(self.output_names) 

1573 if unexpected_target_tensor_names: 

1574 raise ValueError( 

1575 "Unknown entry in `target_tensors` dictionary: " 

1576 '"{name}". ' 

1577 "Only expected the following keys: {keys}".format( 

1578 name=unexpected_target_tensor_names, 

1579 keys=str(self.output_names), 

1580 ) 

1581 ) 

1582 tmp_target_tensors = [] 

1583 for name in self.output_names: 

1584 tmp_target_tensors.append(target_tensors.get(name, None)) 

1585 target_tensors = tmp_target_tensors 

1586 elif tf.is_tensor(target_tensors): 

1587 target_tensors = [target_tensors] 

1588 else: 

1589 raise TypeError( 

1590 "Expected `target_tensors` to be a list or tuple or " 

1591 "dict or a single tensor, but got:", 

1592 target_tensors, 

1593 ) 

1594 else: 

1595 # In case target tensor is empty or None, create a list with Nones 

1596 # that has same length as self.output_names. With that, the None 

1597 # check of target tensor can be skipped downstream. 

1598 target_tensors = [None for _ in self.output_names] 

1599 return target_tensors 

1600 

1601 def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode): 

1602 # Prepare sample weight modes. List with the same length as model 

1603 # outputs. 

1604 training_utils_v1.prepare_sample_weight_modes( 

1605 self._training_endpoints, sample_weight_mode 

1606 ) 

1607 # Prepare sample weights. 

1608 self._prepare_sample_weights() 

1609 # Save all metric attributes per output of the model. 

1610 self._cache_output_metric_attributes(metrics, weighted_metrics) 

1611 self.total_loss = None 

1612 # Set metric attributes on model. 

1613 self._set_metric_attributes() 

1614 

1615 self._collected_trainable_weights = self.trainable_weights 

1616 

1617 def _update_sample_weight_modes(self, sample_weights=None): 

1618 """Updates sample weight modes based on training/eval inputs. 

1619 

1620 Sample weight placeholders will be created for all or no outputs 

1621 based on whether sample_weight is provided for any output. 

1622 

1623 If model contains `_sample_weight_modes` we check if the input 

1624 `sample_weights` corresponds to the sample weight modes. 

1625 1. Set sample weight mode to be 'temporal' for output i, if `compile` 

1626 sample_weight_mode was set to `temporal` and sample weight inputs 

1627 are given for one or more outputs. 

1628 2. Set sample weight mode to be 'samplewise' for output i, if 

1629 `compile` sample_weight_mode was not set and sample weight inputs 

1630 are given for one or more outputs. 

1631 3. Reset sample weight mode to None for output i if sample weight mode 

1632 was set but there is no sample weight input. 

1633 

1634 Args: 

1635 sample_weights: List of sample weights of the same length as model 

1636 outputs or None. 

1637 """ 

1638 if not self._is_compiled: 

1639 return 

1640 if sample_weights and any(s is not None for s in sample_weights): 

1641 for endpoint in self._training_endpoints: 

1642 endpoint.sample_weight_mode = ( 

1643 endpoint.sample_weight_mode or "samplewise" 

1644 ) 

1645 else: 

1646 for endpoint in self._training_endpoints: 

1647 endpoint.sample_weight_mode = None 

1648 

1649 def _recompile_weights_loss_and_weighted_metrics(self): 

1650 if not self._is_compiled: 

1651 return False 

1652 recompile = any( 

1653 e.sample_weights_mismatch() for e in self._training_endpoints 

1654 ) 

1655 

1656 if recompile: 

1657 self._compile_weights_loss_and_weighted_metrics() 

1658 return recompile 

1659 

1660 @tf.__internal__.tracking.no_automatic_dependency_tracking 

1661 def _compile_weights_loss_and_weighted_metrics(self, sample_weights=None): 

1662 """Compiles the model loss and weighted metric sub-graphs. 

1663 

1664 This may be used to set graph tensors as sample weights (instead of 

1665 creating placeholders). This functionality is necessary for 

1666 `tf.keras.estimator.model_to_estimator`, which calls Keras models in a 

1667 v1 graph, and creates iterator tensors for inputs, targets, and sample 

1668 weights. 

1669 

1670 Args: 

1671 sample_weights: List of tensors to use as the sample weights. Must be 

1672 the same length as the number of outputs. If left as `None`, 

1673 placeholders are used instead. 

1674 """ 

1675 with backend.get_graph().as_default(): 

1676 if sample_weights is not None: 

1677 self._update_sample_weight_modes(sample_weights) 

1678 self._prepare_sample_weights(sample_weights) 

1679 

1680 masks = self._prepare_output_masks() 

1681 

1682 # Compute weighted metrics. 

1683 self._handle_metrics( 

1684 self.outputs, 

1685 targets=self._targets, 

1686 skip_target_masks=self._prepare_skip_target_masks(), 

1687 sample_weights=self.sample_weights, 

1688 masks=masks, 

1689 return_weighted_metrics=True, 

1690 ) 

1691 

1692 # Compute total loss. 

1693 # Used to keep track of the total loss value (stateless). 

1694 # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) + 

1695 # loss_weight_2 * output_2_loss_fn(...) + 

1696 # layer losses. 

1697 self.total_loss = self._prepare_total_loss(masks) 

1698 

1699 def _prepare_skip_target_masks(self): 

1700 """Boolean mask for whether target in output list should be skipped. 

1701 

1702 If the loss function corresponding to a model output is None, then this 

1703 output will be skipped during total loss calculation and feed targets 

1704 preparation. 

1705 

1706 Returns: 

1707 A boolean list for whether the corresponding target in the output list 

1708 should be skipped during loss calculation. 

1709 """ 

1710 return [l is None for l in self.loss_functions] 

1711 

1712 def _prepare_output_masks(self): 

1713 """Returns masks corresponding to model outputs.""" 

1714 return [getattr(x, "_keras_mask", None) for x in self.outputs] 

1715 

1716 def _prepare_total_loss(self, masks): 

1717 """Computes total loss from loss functions. 

1718 

1719 Args: 

1720 masks: List of mask values corresponding to each model output. 

1721 

1722 Returns: 

1723 A list of loss weights of python floats. 

1724 

1725 Raises: 

1726 TypeError: If model run_eagerly is True. 

1727 """ 

1728 if self.run_eagerly: 

1729 raise TypeError( 

1730 "total loss can not be computed when compiled with " 

1731 "run_eagerly = True." 

1732 ) 

1733 loss_list = [] 

1734 with backend.name_scope("loss"): 

1735 for endpoint, mask in zip(self._training_endpoints, masks): 

1736 if endpoint.should_skip_target(): 

1737 continue 

1738 y_true = endpoint.training_target.target 

1739 y_pred = endpoint.output 

1740 loss_fn = endpoint.loss_fn 

1741 loss_weight = endpoint.loss_weight 

1742 loss_name = endpoint.loss_name() 

1743 sample_weight = endpoint.sample_weight 

1744 

1745 with backend.name_scope(loss_name): 

1746 if mask is not None: 

1747 mask = tf.cast(mask, y_pred.dtype) 

1748 # Update weights with mask. 

1749 if sample_weight is None: 

1750 sample_weight = mask 

1751 else: 

1752 # Update dimensions of weights to match with mask if 

1753 # possible. 

1754 ( 

1755 mask, 

1756 _, 

1757 sample_weight, 

1758 ) = losses_utils.squeeze_or_expand_dimensions( 

1759 mask, sample_weight=sample_weight 

1760 ) 

1761 

1762 if hasattr(loss_fn, "reduction"): 

1763 per_sample_losses = loss_fn.call(y_true, y_pred) 

1764 sample_weight = losses_utils.apply_valid_mask( 

1765 per_sample_losses, 

1766 sample_weight, 

1767 mask, 

1768 loss_fn.reduction, 

1769 ) 

1770 weighted_losses = losses_utils.compute_weighted_loss( 

1771 per_sample_losses, 

1772 sample_weight=sample_weight, 

1773 reduction=losses_utils.ReductionV2.NONE, 

1774 ) 

1775 loss_reduction = loss_fn.reduction 

1776 

1777 # `AUTO` loss reduction defaults to 

1778 # `SUM_OVER_BATCH_SIZE` for all compile use cases. 

1779 if loss_reduction == losses_utils.ReductionV2.AUTO: 

1780 loss_reduction = ( 

1781 losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 

1782 ) 

1783 

1784 # Compute the stateless loss value. 

1785 output_loss = losses_utils.reduce_weighted_loss( 

1786 weighted_losses, reduction=loss_reduction 

1787 ) 

1788 else: 

1789 # Compute the stateless loss value for a custom loss 

1790 # class. Here we assume that the class takes care of 

1791 # loss reduction because if this class returns a vector 

1792 # value we cannot differentiate between use case where a 

1793 # custom optimizer expects a vector loss value vs 

1794 # unreduced per-sample loss value. 

1795 output_loss = loss_fn( 

1796 y_true, y_pred, sample_weight=sample_weight 

1797 ) 

1798 loss_reduction = ( 

1799 losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 

1800 ) 

1801 

1802 if len(self.outputs) > 1: 

1803 # Keep track of stateful result tensor for the loss. 

1804 endpoint.output_loss_metric(output_loss) 

1805 

1806 # Scale output loss for distribution. For custom losses we 

1807 # assume reduction was mean. 

1808 if ( 

1809 loss_reduction 

1810 == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 

1811 ): 

1812 output_loss = losses_utils.scale_loss_for_distribution( 

1813 output_loss 

1814 ) 

1815 

1816 loss_list.append(loss_weight * output_loss) 

1817 if not loss_list and not self.losses: 

1818 raise ValueError( 

1819 "The model cannot be compiled " 

1820 "because it has no loss to optimize." 

1821 ) 

1822 

1823 # Add regularization penalties and other layer-specific losses. 

1824 custom_losses = self.get_losses_for(None) + self.get_losses_for( 

1825 self.inputs 

1826 ) 

1827 if custom_losses: 

1828 total_custom_loss = tf.add_n( 

1829 losses_utils.cast_losses_to_common_dtype(custom_losses) 

1830 ) 

1831 loss_list.append( 

1832 losses_utils.scale_loss_for_distribution(total_custom_loss) 

1833 ) 

1834 

1835 loss_list = losses_utils.cast_losses_to_common_dtype(loss_list) 

1836 if loss_list: 

1837 total_loss = tf.add_n(loss_list) 

1838 else: 

1839 total_loss = 0.0 

1840 return total_loss 

1841 

1842 def _get_callback_model(self): 

1843 """Returns the Callback Model for this Model.""" 

1844 

1845 if hasattr(self, "_replicated_model") and self._replicated_model: 

1846 # When using training_distributed, we set the callback model 

1847 # to an instance of the `DistributedModel` that we create in 

1848 # the `compile` call. The `DistributedModel` is initialized 

1849 # with the first replicated model. We need to set the callback 

1850 # model to a DistributedModel to allow us to override saving 

1851 # and loading weights when we checkpoint the model during training. 

1852 return self._replicated_model 

1853 if hasattr(self, "callback_model") and self.callback_model: 

1854 return self.callback_model 

1855 return self 

1856 

1857 @tf.__internal__.tracking.no_automatic_dependency_tracking 

1858 def _make_callback_model(self, grouped_model): 

1859 first_replicated_model = self._distribution_strategy.unwrap( 

1860 grouped_model 

1861 )[0] 

1862 # We initialize the callback model with the first replicated model. 

1863 self._replicated_model = DistributedCallbackModel( 

1864 first_replicated_model 

1865 ) 

1866 self._replicated_model.set_original_model(self) 

1867 

1868 def _validate_or_infer_batch_size(self, batch_size, steps, x): 

1869 """Validates that `batch_size` provided is consistent with InputLayer. 

1870 

1871 It's possible that the user specified a static batch size in their 

1872 InputLayer. If so, this method checks the provided `batch_size` and `x` 

1873 arguments are consistent with this static batch size. Also, if 

1874 `batch_size` is `None`, this method will attempt to infer the batch size 

1875 from the static batch size of the InputLayer. Lastly, ValueError will be 

1876 raised if `x` is a tf.data.Dataset and `batch_size` is specified as we 

1877 expect users to provide batched datasets. 

1878 

1879 Args: 

1880 batch_size: The batch_size provided as an argument to 

1881 fit/evaluate/predict. 

1882 steps: The steps provided as an argument to fit/evaluate/predict. 

1883 x: The data passed as `x` to fit/evaluate/predict. 

1884 

1885 Returns: 

1886 The validated batch_size, auto-inferred from the first layer if not 

1887 provided. 

1888 """ 

1889 if isinstance( 

1890 x, (tf.compat.v1.data.Dataset, tf.data.Dataset, data_utils.Sequence) 

1891 ) or tf_inspect.isgenerator(x): 

1892 if batch_size is not None: 

1893 raise ValueError( 

1894 "The `batch_size` argument must not be specified for the " 

1895 "given input type. Received input: " 

1896 "{}, batch_size: {}".format(x, batch_size) 

1897 ) 

1898 return 

1899 

1900 # Avoids the override in Sequential.layers which filters Input layers. 

1901 # (Which are often the very layers that we're after.) 

1902 layers = self._flatten_layers(include_self=False, recursive=False) 

1903 first_layer = next(layers, None) 

1904 if first_layer: 

1905 # The per-replica static batch size. 

1906 static_batch_size = training_utils.get_static_batch_size( 

1907 first_layer 

1908 ) 

1909 if static_batch_size is not None: 

1910 

1911 # Determine number of times the user-supplied batch size will be 

1912 # split. 

1913 if ( 

1914 self._distribution_strategy 

1915 and distributed_training_utils.global_batch_size_supported( 

1916 self._distribution_strategy 

1917 ) 

1918 ): 

1919 num_splits_for_ds = ( 

1920 self._distribution_strategy.num_replicas_in_sync 

1921 ) 

1922 else: 

1923 num_splits_for_ds = 1 

1924 

1925 # Check `batch_size` argument is consistent with InputLayer. 

1926 if batch_size is not None: 

1927 if batch_size % num_splits_for_ds != 0: 

1928 raise ValueError( 

1929 "The `batch_size` argument ({}) must be divisible " 

1930 "the by number of replicas ({})".format( 

1931 batch_size, num_splits_for_ds 

1932 ) 

1933 ) 

1934 per_replica_batch_size = batch_size // num_splits_for_ds 

1935 

1936 if per_replica_batch_size != static_batch_size: 

1937 raise ValueError( 

1938 "The `batch_size` argument value {} is " 

1939 "incompatible with the specified batch size of " 

1940 "your Input Layer: {}".format( 

1941 per_replica_batch_size, static_batch_size 

1942 ) 

1943 ) 

1944 

1945 # Check Dataset/Iterator batch size is consistent with 

1946 # InputLayer. 

1947 if isinstance( 

1948 x, 

1949 ( 

1950 tf.data.Dataset, 

1951 tf.compat.v1.data.Iterator, 

1952 tf.data.Iterator, 

1953 ), 

1954 ): 

1955 ds_batch_size = tf.compat.v1.Dimension( 

1956 tf.nest.flatten(tf.compat.v1.data.get_output_shapes(x))[ 

1957 0 

1958 ][0] 

1959 ).value 

1960 if ds_batch_size is not None: 

1961 if ds_batch_size % num_splits_for_ds != 0: 

1962 raise ValueError( 

1963 "The batch output shape of your `Dataset` {} " 

1964 "cannot be divisible by number of " 

1965 "replicas {}".format( 

1966 ds_batch_size, num_splits_for_ds 

1967 ) 

1968 ) 

1969 

1970 ds_per_replica_batch_size = ( 

1971 ds_batch_size // num_splits_for_ds 

1972 ) 

1973 if ds_per_replica_batch_size != static_batch_size: 

1974 raise ValueError( 

1975 "The batch output shape of your `Dataset` is " 

1976 "{}, which is incompatible with the specified " 

1977 "batch size of your Input Layer: {}".format( 

1978 ds_per_replica_batch_size, static_batch_size 

1979 ) 

1980 ) 

1981 

1982 # Set inferred batch size from the InputLayer. 

1983 if steps is None: 

1984 batch_size = static_batch_size * num_splits_for_ds 

1985 

1986 if batch_size is None and steps is None: 

1987 # Backwards compatibility 

1988 batch_size = 32 

1989 return batch_size 

1990 

1991 def _prepare_sample_weights(self, sample_weights=None): 

1992 """Sets sample weight attribute on the model.""" 

1993 # List with the same length as model outputs. 

1994 if sample_weights is not None: 

1995 if len(sample_weights) != len(self._training_endpoints): 

1996 raise ValueError( 

1997 "Provided sample weights must have same length as the " 

1998 "number of outputs. Expected: {}, got: {}.".format( 

1999 len(self._training_endpoints), len(sample_weights) 

2000 ) 

2001 ) 

2002 else: 

2003 sample_weights = [None] * len(self._training_endpoints) 

2004 for endpoint, weight in zip(self._training_endpoints, sample_weights): 

2005 endpoint.populate_sample_weight(weight, endpoint.sample_weight_mode) 

2006 

2007 def _cache_output_metric_attributes(self, metrics, weighted_metrics): 

2008 """Caches metric name and function attributes for every model output.""" 

2009 output_shapes = [] 

2010 for output in self.outputs: 

2011 if output is None or output.shape.rank is None: 

2012 output_shapes.append(None) 

2013 else: 

2014 output_shapes.append(output.shape.as_list()) 

2015 self._per_output_metrics = ( 

2016 training_utils_v1.collect_per_output_metric_info( 

2017 metrics, 

2018 self.output_names, 

2019 output_shapes, 

2020 self.loss_functions, 

2021 from_serialized=self._from_serialized, 

2022 ) 

2023 ) 

2024 self._per_output_weighted_metrics = ( 

2025 training_utils_v1.collect_per_output_metric_info( 

2026 weighted_metrics, 

2027 self.output_names, 

2028 output_shapes, 

2029 self.loss_functions, 

2030 from_serialized=self._from_serialized, 

2031 is_weighted=True, 

2032 ) 

2033 ) 

2034 

2035 def _add_unique_metric_name(self, metric_name, metric_fn, output_index): 

2036 """Makes the metric name unique. 

2037 

2038 If there are multiple outputs for which the metrics are calculated, 

2039 the metric names have to be made unique by appending an integer. 

2040 

2041 Args: 

2042 metric_name: Metric name that corresponds to the metric specified by 

2043 the user. For example: 'acc'. 

2044 metric_fn: The Metric object. 

2045 output_index: The index of the model output for which the metric name 

2046 is being added. 

2047 

2048 Returns: 

2049 string, name of the model's unique metric name 

2050 """ 

2051 # For multi-output models, prepend the output names to the metric name. 

2052 if len(self.output_names) > 1: 

2053 # If we're loading from an already-serialized model, we've already 

2054 # prepended the output name, and we don't want to do it again. 

2055 # 

2056 # Alternatively, we may be receiving a stateless metric (e.g. the 

2057 # string "accuracy") rather than a `Metric` object, in which case we 

2058 # want to prepend the output name even if we are loading a 

2059 # serialized model. 

2060 if not getattr(metric_fn, "_from_serialized", False): 

2061 metric_name = f"{self.output_names[output_index]}_{metric_name}" 

2062 

2063 j = 1 

2064 base_metric_name = metric_name 

2065 while metric_name in self.metrics_names: 

2066 metric_name = "%s_%d" % (base_metric_name, j) 

2067 j += 1 

2068 

2069 return metric_name 

2070 

2071 def _init_metric_attributes(self): 

2072 """Initialized model metric attributes.""" 

2073 # List of stateful metric functions. Used for resetting metric state 

2074 # during training/eval. 

2075 self._compile_metric_functions = [] 

2076 

2077 def _set_per_output_metric_attributes(self, metrics_dict, output_index): 

2078 """Sets the metric attributes on the model for the given output. 

2079 

2080 Args: 

2081 metrics_dict: A dict with metric names as keys and metric fns as 

2082 values. 

2083 output_index: The index of the model output for which the metric 

2084 attributes are added. 

2085 

2086 Returns: 

2087 Metrics dict updated with unique metric names as keys. 

2088 """ 

2089 updated_metrics_dict = collections.OrderedDict() 

2090 for metric_name, metric_fn in metrics_dict.items(): 

2091 metric_name = self._add_unique_metric_name( 

2092 metric_name, metric_fn, output_index 

2093 ) 

2094 

2095 # Update the name on the metric class to be the unique generated 

2096 # name. 

2097 metric_fn._name = metric_name 

2098 updated_metrics_dict[metric_name] = metric_fn 

2099 # Keep track of metric name and function. 

2100 self._compile_metric_functions.append(metric_fn) 

2101 return updated_metrics_dict 

2102 

2103 def _set_metric_attributes(self): 

2104 """Sets the metric attributes on the model for all the model outputs.""" 

2105 updated_per_output_metrics = [] 

2106 updated_per_output_weighted_metrics = [] 

2107 for i, endpoint in enumerate(self._training_endpoints): 

2108 if endpoint.should_skip_target(): 

2109 updated_per_output_metrics.append(self._per_output_metrics[i]) 

2110 updated_per_output_weighted_metrics.append( 

2111 self._per_output_weighted_metrics[i] 

2112 ) 

2113 continue 

2114 updated_per_output_metrics.append( 

2115 self._set_per_output_metric_attributes( 

2116 self._per_output_metrics[i], i 

2117 ) 

2118 ) 

2119 updated_per_output_weighted_metrics.append( 

2120 self._set_per_output_metric_attributes( 

2121 self._per_output_weighted_metrics[i], i 

2122 ) 

2123 ) 

2124 

2125 # Create a metric wrapper for each output loss. This computes mean of an 

2126 # output loss across mini-batches (irrespective of how we reduce within 

2127 # a batch). 

2128 if len(self._training_endpoints) > 1: 

2129 for endpoint in self._training_endpoints: 

2130 if not endpoint.should_skip_target(): 

2131 endpoint.output_loss_metric = metrics_module.Mean( 

2132 name=endpoint.loss_name() 

2133 ) 

2134 

2135 self._per_output_metrics = updated_per_output_metrics 

2136 self._per_output_weighted_metrics = updated_per_output_weighted_metrics 

2137 

2138 def _handle_per_output_metrics( 

2139 self, metrics_dict, y_true, y_pred, mask, weights=None 

2140 ): 

2141 """Calls metric functions for a single output. 

2142 

2143 Args: 

2144 metrics_dict: A dict with metric names as keys and metric fns as 

2145 values. 

2146 y_true: Target output. 

2147 y_pred: Predicted output. 

2148 mask: Computed mask value for the current output. 

2149 weights: Weights to be applied on the current output. 

2150 

2151 Returns: 

2152 A list of metric result tensors. 

2153 """ 

2154 metric_results = [] 

2155 for metric_name, metric_fn in metrics_dict.items(): 

2156 with backend.name_scope(metric_name): 

2157 metric_result = training_utils_v1.call_metric_function( 

2158 metric_fn, y_true, y_pred, weights=weights, mask=mask 

2159 ) 

2160 metric_results.append(metric_result) 

2161 return metric_results 

2162 

2163 def _handle_metrics( 

2164 self, 

2165 outputs, 

2166 targets=None, 

2167 skip_target_masks=None, 

2168 sample_weights=None, 

2169 masks=None, 

2170 return_weighted_metrics=False, 

2171 return_weighted_and_unweighted_metrics=False, 

2172 ): 

2173 """Handles calling metric functions. 

2174 

2175 Args: 

2176 outputs: List of outputs (predictions). 

2177 targets: List of targets. 

2178 skip_target_masks: Optional. List of boolean for whether the 

2179 corresponding target should be ignored or not. 

2180 sample_weights: Optional list of sample weight arrays. 

2181 masks: List of computed output mask values. 

2182 return_weighted_metrics: Flag that indicates whether weighted metrics 

2183 should be computed instead of unweighted metrics. This flag is 

2184 ignored when `return_weighted_and_unweighted_metrics` is enabled. 

2185 return_weighted_and_unweighted_metrics: Flag that is used to indicate 

2186 whether both weighted and unweighted metrics should be computed. 

2187 When this is not enabled, we use `return_weighted_metrics` param to 

2188 indicate whether weighted or unweighted metrics should be returned. 

2189 

2190 Returns: 

2191 A list of metric result tensors. 

2192 """ 

2193 # TODO(scottzhu): Update this to use the new training_endpoints. 

2194 # Currently the eager and graph logic is bit different. 

2195 skip_target_masks = skip_target_masks or [False] * len(outputs) 

2196 metric_results = [] 

2197 with backend.name_scope("metrics"): 

2198 # Invoke all metrics added using `compile`. 

2199 for i in range(len(outputs)): 

2200 if skip_target_masks[i]: 

2201 continue 

2202 output = outputs[i] if outputs else None 

2203 target = targets[i] if targets else None 

2204 output_mask = masks[i] if masks else None 

2205 

2206 if ( 

2207 return_weighted_and_unweighted_metrics 

2208 or not return_weighted_metrics 

2209 ): 

2210 metric_results.extend( 

2211 self._handle_per_output_metrics( 

2212 self._per_output_metrics[i], 

2213 target, 

2214 output, 

2215 output_mask, 

2216 ) 

2217 ) 

2218 if ( 

2219 return_weighted_and_unweighted_metrics 

2220 or return_weighted_metrics 

2221 ): 

2222 metric_results.extend( 

2223 self._handle_per_output_metrics( 

2224 self._per_output_weighted_metrics[i], 

2225 target, 

2226 output, 

2227 output_mask, 

2228 weights=sample_weights[i] 

2229 if sample_weights 

2230 else None, 

2231 ) 

2232 ) 

2233 return metric_results 

2234 

2235 def _check_trainable_weights_consistency(self): 

2236 """Check trainable weights count consistency. 

2237 

2238 This will raise a warning if `trainable_weights` and 

2239 `_collected_trainable_weights` are inconsistent (i.e. have different 

2240 number of parameters). 

2241 Inconsistency will typically arise when one modifies `model.trainable` 

2242 without calling `model.compile` again. 

2243 """ 

2244 if not hasattr(self, "_collected_trainable_weights"): 

2245 return 

2246 

2247 if len(self.trainable_weights) != len( 

2248 self._collected_trainable_weights 

2249 ): 

2250 logging.log_first_n( 

2251 logging.WARN, 

2252 "Discrepancy between trainable weights and collected" 

2253 " trainable weights, did you set `model.trainable`" 

2254 " without calling `model.compile` after ?", 

2255 1, 

2256 ) 

2257 

2258 def _make_train_function(self): 

2259 has_recompiled = self._recompile_weights_loss_and_weighted_metrics() 

2260 self._check_trainable_weights_consistency() 

2261 if isinstance(self.optimizer, list): 

2262 raise ValueError( 

2263 "The `optimizer` in `compile` should be a single optimizer." 

2264 ) 

2265 # If we have re-compiled the loss/weighted metric sub-graphs then create 

2266 # train function even if one exists already. This is because 

2267 # `_feed_sample_weights` list has been updated on re-compile. 

2268 if getattr(self, "train_function", None) is None or has_recompiled: 

2269 # Restore the compiled trainable state. 

2270 current_trainable_state = self._get_trainable_state() 

2271 self._set_trainable_state(self._compiled_trainable_state) 

2272 

2273 inputs = ( 

2274 self._feed_inputs 

2275 + self._feed_targets 

2276 + self._feed_sample_weights 

2277 ) 

2278 if not isinstance(backend.symbolic_learning_phase(), int): 

2279 inputs += [backend.symbolic_learning_phase()] 

2280 

2281 with backend.get_graph().as_default(): 

2282 with backend.name_scope("training"): 

2283 # Training updates 

2284 updates = self.optimizer.get_updates( 

2285 params=self._collected_trainable_weights, 

2286 loss=self.total_loss, 

2287 ) 

2288 # Unconditional updates 

2289 updates += self.get_updates_for(None) 

2290 # Conditional updates relevant to this model 

2291 updates += self.get_updates_for(self.inputs) 

2292 

2293 metrics = self._get_training_eval_metrics() 

2294 metrics_tensors = [ 

2295 m._call_result 

2296 for m in metrics 

2297 if hasattr(m, "_call_result") 

2298 ] 

2299 

2300 with backend.name_scope("training"): 

2301 # Gets loss and metrics. Updates weights at each call. 

2302 fn = backend.function( 

2303 inputs, 

2304 [self.total_loss] + metrics_tensors, 

2305 updates=updates, 

2306 name="train_function", 

2307 **self._function_kwargs, 

2308 ) 

2309 setattr(self, "train_function", fn) 

2310 

2311 # Restore the current trainable state 

2312 self._set_trainable_state(current_trainable_state) 

2313 

2314 def _make_test_function(self): 

2315 has_recompiled = self._recompile_weights_loss_and_weighted_metrics() 

2316 # If we have re-compiled the loss/weighted metric sub-graphs then create 

2317 # test function even if one exists already. This is because 

2318 # `_feed_sample_weights` list has been updated on re-compile. 

2319 if getattr(self, "test_function", None) is None or has_recompiled: 

2320 inputs = ( 

2321 self._feed_inputs 

2322 + self._feed_targets 

2323 + self._feed_sample_weights 

2324 ) 

2325 

2326 with backend.get_graph().as_default(): 

2327 metrics = self._get_training_eval_metrics() 

2328 metrics_tensors = [ 

2329 m._call_result 

2330 for m in metrics 

2331 if hasattr(m, "_call_result") 

2332 ] 

2333 

2334 with backend.name_scope("evaluation"): 

2335 updates = self.state_updates 

2336 # Return loss and metrics, no gradient updates. 

2337 # Does update the network states. 

2338 fn = backend.function( 

2339 inputs, 

2340 [self.total_loss] + metrics_tensors, 

2341 updates=updates, 

2342 name="test_function", 

2343 **self._function_kwargs, 

2344 ) 

2345 setattr(self, "test_function", fn) 

2346 

2347 def _make_predict_function(self): 

2348 if not hasattr(self, "predict_function"): 

2349 self.predict_function = None 

2350 if self.predict_function is None: 

2351 inputs = self._feed_inputs 

2352 # Gets network outputs. Does not update weights. 

2353 # Does update the network states. 

2354 kwargs = getattr(self, "_function_kwargs", {}) 

2355 with backend.name_scope(ModeKeys.PREDICT): 

2356 self.predict_function = backend.function( 

2357 inputs, 

2358 self.outputs, 

2359 updates=self.state_updates, 

2360 name="predict_function", 

2361 **kwargs, 

2362 ) 

2363 

2364 def _make_execution_function(self, mode): 

2365 if mode == ModeKeys.TRAIN: 

2366 self._make_train_function() 

2367 return self.train_function 

2368 if mode == ModeKeys.TEST: 

2369 self._make_test_function() 

2370 return self.test_function 

2371 if mode == ModeKeys.PREDICT: 

2372 self._make_predict_function() 

2373 return self.predict_function 

2374 

2375 def _distribution_standardize_user_data( 

2376 self, 

2377 x, 

2378 y=None, 

2379 sample_weight=None, 

2380 class_weight=None, 

2381 batch_size=None, 

2382 validation_split=0.0, 

2383 shuffle=False, 

2384 epochs=1, 

2385 allow_partial_batch=False, 

2386 ): 

2387 """Runs validation checks on input and target data passed by the user. 

2388 

2389 This is called when using tf.distribute.Strategy to train, evaluate or 

2390 serve the model. 

2391 

2392 Args: 

2393 x: Input data. A numpy array or `tf.data` dataset. 

2394 y: Target data. A numpy array or None if x is a `tf.data` dataset. 

2395 sample_weight: An optional sample-weight array passed by the user to 

2396 weight the importance of each sample in `x`. 

2397 class_weight: An optional class-weight array by the user to 

2398 weight the importance of samples in `x` based on the class they 

2399 belong to, as conveyed by `y`. 

2400 batch_size: Integer batch size. If provided, it is used to run 

2401 additional validation checks on stateful models. 

2402 validation_split: Float between 0 and 1. 

2403 Fraction of the training data to be used as validation data. 

2404 shuffle: Boolean whether to shuffle the training data before each 

2405 epoch. 

2406 epochs: Integer epochs. If > 1, repeat the numpy training data epochs 

2407 times when converting to training dataset. 

2408 allow_partial_batch: Boolean whether to enforce that all batches have 

2409 the same size. 

2410 

2411 Returns: 

2412 Dataset instance. 

2413 

2414 Raises: 

2415 ValueError: In case of invalid user-provided data. 

2416 RuntimeError: If the model was never compiled. 

2417 """ 

2418 if class_weight: 

2419 raise NotImplementedError( 

2420 "`class_weight` is currently not supported " 

2421 "when using tf.distribute.Strategy." 

2422 ) 

2423 

2424 if ( 

2425 sample_weight is not None 

2426 and sample_weight.all() 

2427 and backend.is_tpu_strategy(self._distribution_strategy) 

2428 ): 

2429 raise NotImplementedError( 

2430 "`sample_weight` is currently not supported " 

2431 "when using TPUStrategy." 

2432 ) 

2433 

2434 # Validates `steps` and `shuffle` arguments right at the beginning 

2435 # since we use it to construct the dataset object. 

2436 # TODO(anjalisridhar): Remove this check once we refactor the 

2437 # _standardize_user_data code path. This check is already present 

2438 # elsewhere in the codebase. 

2439 if isinstance(x, tf.data.Dataset): 

2440 if shuffle: 

2441 training_utils_v1.verify_dataset_shuffled(x) 

2442 

2443 strategy = self._distribution_strategy 

2444 with strategy.scope(): 

2445 # We should be sure to call get_session() inside the 

2446 # strategy.scope() so the strategy can affect the session options. 

2447 if tf.compat.v1.executing_eagerly_outside_functions(): 

2448 session = None 

2449 else: 

2450 session = backend.get_session() 

2451 

2452 first_x_value = tf.nest.flatten(x)[0] 

2453 if isinstance(first_x_value, np.ndarray): 

2454 x = training_utils.list_to_tuple(x) 

2455 if y is not None: 

2456 y = training_utils.list_to_tuple(y) 

2457 if sample_weight is not None: 

2458 sample_weight = training_utils.list_to_tuple( 

2459 sample_weight 

2460 ) 

2461 in_tuple = (x, y, sample_weight) 

2462 else: 

2463 in_tuple = (x, y) 

2464 else: 

2465 in_tuple = x 

2466 

2467 ds = strategy.extended.experimental_make_numpy_dataset( 

2468 in_tuple, session=session 

2469 ) 

2470 if shuffle: 

2471 # We want a buffer size that is larger than the batch size 

2472 # provided by the user and provides sufficient randomness. 

2473 # Note that larger numbers introduce more memory usage based 

2474 # on the size of each sample. 

2475 ds = ds.shuffle(max(1024, batch_size * 8)) 

2476 if epochs > 1: 

2477 ds = ds.repeat(epochs) 

2478 

2479 # We need to use the drop_remainder argument to get a known 

2480 # static input shape which is required for TPUs. 

2481 drop_remainder = ( 

2482 not allow_partial_batch 

2483 and strategy.extended.experimental_require_static_shapes 

2484 ) 

2485 

2486 # TODO(b/131720208): We still drop remainder here if number of 

2487 # examples is divisible by batch size, as sometimes dynamic 

2488 # padder will time out with keras.metrics.CategoricalAccuracy() 

2489 # metric. 

2490 if backend.is_tpu_strategy(strategy) and not drop_remainder: 

2491 dataset_size = first_x_value.shape[0] 

2492 if dataset_size % batch_size == 0: 

2493 drop_remainder = True 

2494 

2495 x = ds.batch(batch_size, drop_remainder=drop_remainder) 

2496 else: 

2497 assert isinstance(x, tf.data.Dataset) 

2498 training_utils_v1.validate_dataset_input( 

2499 x, y, sample_weight, validation_split 

2500 ) 

2501 return x 

2502 

2503 def _standardize_user_data( 

2504 self, 

2505 x, 

2506 y=None, 

2507 sample_weight=None, 

2508 class_weight=None, 

2509 batch_size=None, 

2510 check_steps=False, 

2511 steps_name="steps", 

2512 steps=None, 

2513 validation_split=0.0, 

2514 shuffle=False, 

2515 extract_tensors_from_dataset=False, 

2516 ): 

2517 """Runs validation checks on input and target data passed by the user. 

2518 

2519 Also standardizes the data to lists of arrays, in order. 

2520 

2521 Also builds and compiles the model on the fly if it is a subclassed 

2522 model that has never been called before (and thus has no 

2523 inputs/outputs). 

2524 

2525 This is a purely internal method, subject to refactoring at any time. 

2526 

2527 Args: 

2528 x: Input data. It could be: 

2529 - A Numpy array (or array-like), or a list of arrays 

2530 (in case the model has multiple inputs). 

2531 - A TensorFlow tensor, or a list of tensors 

2532 (in case the model has multiple inputs). 

2533 - A dict mapping input names to the corresponding array/tensors, 

2534 if the model has named inputs. 

2535 - A `tf.data` dataset. 

2536 y: Target data. Like the input data `x`, 

2537 it could be either Numpy array(s) or TensorFlow tensor(s). 

2538 It should be consistent with `x` (you cannot have Numpy inputs and 

2539 tensor targets, or inversely). If `x` is a dataset, `y` should not 

2540 be specified (since targets will be obtained from the iterator). 

2541 sample_weight: An optional sample-weight array passed by the user to 

2542 weight the importance of each sample in `x`. 

2543 class_weight: An optional class-weight array by the user to 

2544 weight the importance of samples in `x` based on the class they 

2545 belong to, as conveyed by `y`. If both `sample_weight` and 

2546 `class_weight` are provided, the weights are multiplied. 

2547 batch_size: Integer batch size. If provided, it is used to run 

2548 additional validation checks on stateful models. 

2549 check_steps: boolean, True if we want to check for validity of `steps` 

2550 and False, otherwise. For example, when we are standardizing one 

2551 batch of data for train_on_batch/predict_on_batch/test_on_batch 

2552 APIs, `steps` value is not required and we should not check for its 

2553 validity in these cases. 

2554 steps_name: The public API's parameter name for `steps`. 

2555 steps: Integer or `None`. Total number of steps (batches of samples) 

2556 to execute. 

2557 validation_split: Float between 0 and 1. 

2558 Fraction of the training data to be used as validation data. 

2559 shuffle: Boolean whether to shuffle the training data before each 

2560 epoch. 

2561 extract_tensors_from_dataset: Boolean. When `x` is a dataset instance, 

2562 this indicates whether to extract actual tensors from the dataset or 

2563 instead output the dataset instance itself. 

2564 Set to True when calling from `train_on_batch`/etc. 

2565 

2566 Returns: 

2567 A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a 

2568 dict or not), target arrays, sample-weight arrays. If the model's 

2569 input and targets are symbolic, these lists are empty (since the model 

2570 takes no user-provided data, instead the data comes from the symbolic 

2571 inputs/targets). 

2572 

2573 Raises: 

2574 ValueError: In case of invalid user-provided data. 

2575 RuntimeError: If the model was never compiled. 

2576 """ 

2577 if isinstance(x, (tf.compat.v1.data.Dataset, tf.data.Dataset)): 

2578 # Graph mode dataset. We'll pass the dataset as-is (unless 

2579 # `extract_tensors_from_dataset` is True, in which case we extract 

2580 # the tensors from the dataset and we output them. 

2581 training_utils_v1.validate_dataset_input( 

2582 x, y, sample_weight, validation_split 

2583 ) 

2584 if shuffle: 

2585 training_utils_v1.verify_dataset_shuffled(x) 

2586 

2587 is_dataset = True 

2588 if extract_tensors_from_dataset: 

2589 # We do this for `train_on_batch`/etc. 

2590 ( 

2591 x, 

2592 y, 

2593 sample_weight, 

2594 ) = training_utils_v1.extract_tensors_from_dataset(x) 

2595 elif isinstance(x, tf.compat.v1.data.Iterator): 

2596 # Graph mode iterator. We extract the symbolic tensors. 

2597 training_utils_v1.validate_dataset_input( 

2598 x, y, sample_weight, validation_split 

2599 ) 

2600 iterator = x 

2601 x, y, sample_weight = training_utils_v1.unpack_iterator_input( 

2602 iterator 

2603 ) 

2604 is_dataset = True 

2605 else: 

2606 is_dataset = False 

2607 

2608 # Validates `steps` argument based on x's type. 

2609 if check_steps: 

2610 training_utils_v1.check_steps_argument(x, steps, steps_name) 

2611 

2612 # First, we build the model on the fly if necessary. 

2613 if not self.inputs: 

2614 all_inputs, y_input, dict_inputs = self._build_model_with_inputs( 

2615 x, y 

2616 ) 

2617 is_build_called = True 

2618 else: 

2619 all_inputs = [] 

2620 # Whether this is a subclassed model that expects dictionary inputs 

2621 # rather than list inputs (e.g. FeatureColumn-based models). 

2622 dict_inputs = isinstance(self.inputs, dict) 

2623 is_build_called = False 

2624 y_input = y 

2625 

2626 # Second, we compile the model on the fly if necessary, mostly for 

2627 # subclass models. 

2628 is_compile_called = False 

2629 if not self._is_compiled and self.optimizer: 

2630 self._compile_from_inputs(all_inputs, y_input, x, y) 

2631 is_compile_called = True 

2632 

2633 # In graph mode, if we had just set inputs and targets as symbolic 

2634 # tensors by invoking build and compile on the model respectively, we do 

2635 # not have to feed anything to the model. Model already has input and 

2636 # target data as part of the graph. Note: in this case, `any` and `all` 

2637 # are equivalent since we disallow mixed symbolic/value inputs. 

2638 

2639 # self.run_eagerly is not free to compute, so we want to reuse the 

2640 # value. 

2641 run_eagerly = self.run_eagerly 

2642 

2643 if ( 

2644 not run_eagerly 

2645 and is_build_called 

2646 and is_compile_called 

2647 and not is_dataset 

2648 and any(_is_symbolic_tensor(v) for v in all_inputs) 

2649 ): 

2650 return [], [], None 

2651 

2652 return self._standardize_tensors( 

2653 x, 

2654 y, 

2655 sample_weight, 

2656 run_eagerly=run_eagerly, 

2657 dict_inputs=dict_inputs, 

2658 is_dataset=is_dataset, 

2659 class_weight=class_weight, 

2660 batch_size=batch_size, 

2661 ) 

2662 

2663 def _standardize_tensors( 

2664 self, 

2665 x, 

2666 y, 

2667 sample_weight, 

2668 run_eagerly, 

2669 dict_inputs, 

2670 is_dataset, 

2671 class_weight=None, 

2672 batch_size=None, 

2673 ): 

2674 if run_eagerly: 

2675 # In eager mode, do not do shape validation 

2676 # since the network has no input nodes (placeholders) to be fed. 

2677 feed_input_names = self.input_names 

2678 feed_input_shapes = None 

2679 elif not self._is_graph_network: 

2680 # Case: symbolic-mode subclassed network. Do not do shape 

2681 # validation. 

2682 feed_input_names = self._feed_input_names 

2683 feed_input_shapes = None 

2684 else: 

2685 # Case: symbolic-mode graph network. 

2686 # In this case, we run extensive shape validation checks. 

2687 feed_input_names = self._feed_input_names 

2688 feed_input_shapes = self._feed_input_shapes 

2689 

2690 # Standardize the inputs. 

2691 if not isinstance(x, (tf.compat.v1.data.Dataset, tf.data.Dataset)): 

2692 # TODO(fchollet): run static checks with dataset output shape(s). 

2693 x = training_utils_v1.standardize_input_data( 

2694 x, 

2695 feed_input_names, 

2696 feed_input_shapes, 

2697 check_batch_axis=False, # Don't enforce the batch size. 

2698 exception_prefix="input", 

2699 ) 

2700 

2701 # Get typespecs for the input data and sanitize it if necessary. 

2702 # TODO(momernick): This should be capable of doing full input validation 

2703 # at all times - validate that this is so and refactor the 

2704 # standardization code. 

2705 if isinstance(x, tf.data.Dataset): 

2706 x_shapes = tf.data.experimental.get_structure(x) 

2707 if isinstance(x_shapes, tuple): 

2708 # If the output of a Dataset is a tuple, we assume it's either 

2709 # of the form (x_data, y_data) or (x_data, y_data, 

2710 # sample_weights). In either case, we only care about x_data 

2711 # here. 

2712 x_shapes = x_shapes[0] 

2713 else: 

2714 flat_inputs = tf.nest.flatten(x) 

2715 flat_expected_inputs = tf.nest.flatten(self.inputs) 

2716 converted_x = [] 

2717 for a, b in zip(flat_inputs, flat_expected_inputs): 

2718 converted_x.append(_convert_scipy_sparse_tensor(a, b)) 

2719 x = tf.nest.pack_sequence_as(x, converted_x) 

2720 

2721 # Convert ResourceVariables to tensors so nest.assert_same_structure 

2722 # below won't fail with Variable and Tensor. 

2723 x_tensors = tf_utils.convert_variables_to_tensors(x) 

2724 x_shapes = tf.nest.map_structure( 

2725 tf_utils.type_spec_from_value, x_tensors 

2726 ) 

2727 

2728 flat_inputs = tf.nest.flatten(x_shapes) 

2729 # Convert ResourceVariables to tensors so nest.assert_same_structure 

2730 # below won't fail with Variable and Tensor. 

2731 flat_expected_inputs = tf.nest.flatten( 

2732 tf_utils.convert_variables_to_tensors(self.inputs) 

2733 ) 

2734 for a, b in zip(flat_inputs, flat_expected_inputs): 

2735 tf.nest.assert_same_structure(a, b, expand_composites=True) 

2736 

2737 if y is not None: 

2738 # Prepare self._sample_weight_modes. List with the same length as 

2739 # model outputs. 

2740 training_utils_v1.prepare_sample_weight_modes( 

2741 self._training_endpoints, self.sample_weight_mode 

2742 ) 

2743 feed_output_names = self._feed_output_names 

2744 feed_sample_weight_modes = self._sample_weight_modes 

2745 if not self._is_graph_network: 

2746 feed_output_shapes = None 

2747 else: 

2748 feed_output_shapes = self._feed_output_shapes 

2749 

2750 # Standardize the outputs. 

2751 y = training_utils_v1.standardize_input_data( 

2752 y, 

2753 feed_output_names, 

2754 # Don't enforce target shapes to match output shapes. 

2755 # Precise checks will be run in 

2756 # `check_loss_and_target_compatibility`. 

2757 shapes=None, 

2758 check_batch_axis=False, # Don't enforce the batch size. 

2759 exception_prefix="target", 

2760 ) 

2761 

2762 # Generate sample-wise weight values given the `sample_weight` and 

2763 # `class_weight` arguments. 

2764 sample_weights = training_utils_v1.standardize_sample_weights( 

2765 sample_weight, feed_output_names 

2766 ) 

2767 class_weights = training_utils_v1.standardize_class_weights( 

2768 class_weight, feed_output_names 

2769 ) 

2770 

2771 sample_weights = [ 

2772 training_utils_v1.standardize_weights(ref, sw, cw, mode) 

2773 for (ref, sw, cw, mode) in zip( 

2774 y, sample_weights, class_weights, feed_sample_weight_modes 

2775 ) 

2776 ] 

2777 # Check that all arrays have the same length. 

2778 if not self._distribution_strategy: 

2779 training_utils_v1.check_array_lengths(x, y, sample_weights) 

2780 if self._is_graph_network and not run_eagerly: 

2781 # Additional checks to avoid users mistakenly using improper 

2782 # loss fns. 

2783 training_utils_v1.check_loss_and_target_compatibility( 

2784 y, self._feed_loss_fns, feed_output_shapes 

2785 ) 

2786 

2787 sample_weights, _, _ = training_utils.handle_partial_sample_weights( 

2788 y, sample_weights, feed_sample_weight_modes, check_all_flat=True 

2789 ) 

2790 else: 

2791 y = [] 

2792 sample_weights = None 

2793 

2794 if self.stateful and batch_size and not is_dataset: 

2795 # Check that for stateful networks, number of samples is a multiple 

2796 # of the static batch size. 

2797 if x[0].shape[0] % batch_size != 0: 

2798 raise ValueError( 

2799 "In a stateful network, " 

2800 "you should only pass inputs with " 

2801 "a number of samples that can be " 

2802 "divided by the batch size. Found: " 

2803 + str(x[0].shape[0]) 

2804 + " samples" 

2805 ) 

2806 

2807 # If dictionary inputs were provided, we return a dictionary as well. 

2808 if dict_inputs and not isinstance( 

2809 x, (tf.compat.v1.data.Dataset, tf.data.Dataset) 

2810 ): 

2811 x = dict(zip(feed_input_names, x)) 

2812 return x, y, sample_weights 

2813 

2814 def _build_model_with_inputs(self, inputs, targets): 

2815 """Build the model (set model inputs/outputs), mainly for subclass 

2816 model.""" 

2817 processed_inputs = [] 

2818 is_dict_inputs = False 

2819 orig_inputs = inputs 

2820 # We need to use `inputs` to set the model inputs. 

2821 # If input data is a dataset iterator in graph mode or if it is an eager 

2822 # iterator and only one batch of samples is required, we fetch the data 

2823 # tensors from the iterator and then standardize them. 

2824 if isinstance(inputs, (tf.compat.v1.data.Dataset, tf.data.Dataset)): 

2825 inputs, targets, _ = training_utils_v1.extract_tensors_from_dataset( 

2826 inputs 

2827 ) 

2828 # We type-check that `inputs` and `targets` are either single arrays 

2829 # or lists of arrays, and extract a flat list of inputs from the passed 

2830 # structure. 

2831 training_utils_v1.validate_input_types(inputs, orig_inputs) 

2832 

2833 if isinstance(inputs, (list, tuple)): 

2834 processed_inputs += list(inputs) 

2835 elif isinstance(inputs, dict): 

2836 is_dict_inputs = True 

2837 keys = sorted(inputs.keys()) 

2838 processed_inputs = [inputs[k] for k in keys] 

2839 else: 

2840 processed_inputs.append(inputs) 

2841 # Now that we have a flat set of inputs, we make sure that none of them 

2842 # are CompositeTensors or CompositeTensorValues of any type (or scipy 

2843 # sparse arrays, which we treat as SparseTensor values). We cannot 

2844 # safely infer input data from an arbitrary composite tensor, so we 

2845 # don't try - users should explicitly add composite tensor inputs to 

2846 # their subclassed models. 

2847 for input_tensor in processed_inputs: 

2848 if training_utils_v1.is_composite_or_composite_value( 

2849 input_tensor 

2850 ) and not isinstance(input_tensor, tf.Variable): 

2851 # TODO(b/132691975): Document subclass-model CT input handling. 

2852 raise ValueError( 

2853 "All SparseTensor and RaggedTensor inputs must be " 

2854 "explicitly declared using a keras.Input() with " 

2855 "sparse=True or ragged=True. We found an undeclared " 

2856 "input %s. For Sequential models, please add a " 

2857 "keras.Input() as your first Layer. For subclassed models, " 

2858 "please call self._set_inputs() on your input set, which " 

2859 "you can create using keras.Input() for each input to your " 

2860 "model." % (input_tensor,) 

2861 ) 

2862 # Build the model using the retrieved inputs (value or symbolic). 

2863 # If values are generated from a dataset, then in symbolic-mode 

2864 # placeholders will be created to match the value shapes. 

2865 if isinstance( 

2866 orig_inputs, 

2867 ( 

2868 tf.compat.v1.data.Dataset, 

2869 tf.data.Dataset, 

2870 tf.compat.v1.data.Iterator, 

2871 ), 

2872 ): 

2873 if not self.inputs: 

2874 # For subclassed models, a robust input spec is not available so 

2875 # we must cast to the model dtype. 

2876 inputs = training_utils_v1.cast_if_floating_dtype( 

2877 inputs, self.dtype 

2878 ) 

2879 

2880 def create_tensor_spec(t): 

2881 return tf.TensorSpec(t.shape, t.dtype) 

2882 

2883 cast_inputs = tf.nest.map_structure(create_tensor_spec, inputs) 

2884 elif training_utils_v1.has_tensors(inputs): 

2885 cast_inputs = training_utils_v1.cast_if_floating_dtype(inputs) 

2886 else: 

2887 cast_inputs = inputs 

2888 self._set_inputs(cast_inputs) 

2889 return processed_inputs, targets, is_dict_inputs 

2890 

2891 def _compile_from_inputs( 

2892 self, all_inputs, target, orig_inputs, orig_target 

2893 ): 

2894 if target is not None: 

2895 # We need to use `y` to set the model targets. 

2896 if training_utils_v1.has_tensors(target): 

2897 target = training_utils_v1.cast_if_floating_dtype_and_mismatch( 

2898 target, self.outputs 

2899 ) 

2900 training_utils_v1.validate_input_types( 

2901 target, orig_target, allow_dict=False, field_name="target" 

2902 ) 

2903 if isinstance(target, (list, tuple)): 

2904 all_inputs += list(target) 

2905 else: 

2906 all_inputs.append(target) 

2907 # Type check that all inputs are *either* value *or* symbolic. 

2908 # TODO(fchollet): this check could be removed in Eager mode? 

2909 if any(tf.is_tensor(v) for v in all_inputs): 

2910 if not all(tf.is_tensor(v) for v in all_inputs): 

2911 raise ValueError( 

2912 "Do not pass inputs that mix Numpy arrays and " 

2913 "TensorFlow tensors. " 

2914 "You passed: x=" 

2915 + str(orig_inputs) 

2916 + "; y=" 

2917 + str(orig_target) 

2918 ) 

2919 is_dataset = isinstance( 

2920 orig_inputs, 

2921 ( 

2922 tf.compat.v1.data.Dataset, 

2923 tf.data.Dataset, 

2924 tf.compat.v1.data.Iterator, 

2925 ), 

2926 ) 

2927 if is_dataset or tf.executing_eagerly(): 

2928 target_tensors = None 

2929 else: 

2930 # Handle target tensors if any passed. 

2931 if target is not None: 

2932 if not isinstance(target, (list, tuple)): 

2933 target = [target] 

2934 target_tensors = [v for v in target if _is_symbolic_tensor(v)] 

2935 else: 

2936 target_tensors = None 

2937 

2938 self.compile( 

2939 optimizer=self.optimizer, 

2940 loss=self.loss, 

2941 metrics=self._compile_metrics, 

2942 weighted_metrics=self._compile_weighted_metrics, 

2943 loss_weights=self.loss_weights, 

2944 target_tensors=target_tensors, 

2945 sample_weight_mode=self.sample_weight_mode, 

2946 run_eagerly=self.run_eagerly, 

2947 experimental_run_tf_function=self._experimental_run_tf_function, 

2948 ) 

2949 

2950 # TODO(omalleyt): Consider changing to a more descriptive function name. 

2951 def _set_inputs(self, inputs, outputs=None, training=None): 

2952 """Set model's input and output specs based on the input data received. 

2953 

2954 This is to be used for Model subclasses, which do not know at 

2955 instantiation time what their inputs look like. 

2956 

2957 Args: 

2958 inputs: Single array, or list of arrays. The arrays could be 

2959 placeholders, Numpy arrays, data tensors, or TensorSpecs. 

2960 - if placeholders: the model is built on top of these placeholders, 

2961 and we expect Numpy data to be fed for them when calling 

2962 `fit`/etc. 

2963 - if Numpy data or TensorShapes: we create placeholders matching the 

2964 TensorShapes or shapes of the Numpy arrays. We expect Numpy data 

2965 to be fed for these placeholders when calling `fit`/etc. 

2966 - if data tensors: the model is built on top of these tensors. 

2967 We do not expect any Numpy data to be provided when calling 

2968 `fit`/etc. 

2969 outputs: None, a data tensor, or a list of tensors. If None, the 

2970 outputs will be determined by invoking `self.call()`, otherwise the 

2971 provided value will be used. 

2972 training: Boolean or None. Only relevant in symbolic mode. Specifies 

2973 whether to build the model's graph in inference mode (False), 

2974 training mode (True), or using the Keras learning phase (None). 

2975 Raises: 

2976 ValueError: If dict inputs are passed to a Sequential Model where the 

2977 first layer isn't FeatureLayer. 

2978 """ 

2979 self._set_save_spec(inputs) 

2980 inputs = self._set_input_attrs(inputs) 

2981 

2982 if outputs is None: 

2983 kwargs = {} 

2984 if self._expects_training_arg: 

2985 # In V2 mode, feeding `training=None` is not allowed because any 

2986 # value explicitly passed by the user is respected, even 

2987 # `None`.` 

2988 if ( 

2989 training is None 

2990 and not tf.compat.v1.executing_eagerly_outside_functions() 

2991 ): 

2992 training = backend.learning_phase() 

2993 if training is not None: 

2994 kwargs["training"] = training 

2995 try: 

2996 outputs = self(inputs, **kwargs) 

2997 except NotImplementedError: 

2998 # This Model or a submodel is dynamic and hasn't overridden 

2999 # `compute_output_shape`. 

3000 outputs = None 

3001 

3002 self._set_output_attrs(outputs) 

3003 

3004 @tf.__internal__.tracking.no_automatic_dependency_tracking 

3005 def _set_input_attrs(self, inputs): 

3006 """Sets attributes related to the inputs of the Model.""" 

3007 if self.inputs: 

3008 raise ValueError("Model inputs are already set.") 

3009 

3010 if self.__class__.__name__ == "Sequential" and not self.built: 

3011 if tf.is_tensor(inputs): 

3012 input_shape = (None,) + tuple(inputs.shape.as_list()[1:]) 

3013 elif isinstance(inputs, tf.TensorShape): 

3014 input_shape = (None,) + tuple(inputs.as_list()[1:]) 

3015 elif isinstance(inputs, dict): 

3016 # We assert that the first layer is a FeatureLayer. 

3017 if not training_utils_v1.is_feature_layer(self.layers[0]): 

3018 raise ValueError( 

3019 "Passing a dictionary input to a Sequential Model " 

3020 "which doesn't have FeatureLayer as the first layer" 

3021 " is an error." 

3022 ) 

3023 input_shape = (None,) 

3024 else: 

3025 input_shape = (None,) + tuple(inputs.shape[1:]) 

3026 self._build_input_shape = input_shape 

3027 

3028 # Cast inputs to the compute dtype. This is primarily used 

3029 # when saving to determine the correct dtype in the input signature. 

3030 inputs = self._maybe_cast_inputs(inputs) 

3031 

3032 # On-the-fly setting of symbolic model inputs (either by using the 

3033 # tensor provided, or by creating a placeholder if Numpy data was 

3034 # provided). 

3035 model_inputs = training_utils_v1.ModelInputs(inputs) 

3036 inputs = model_inputs.get_symbolic_inputs() 

3037 self.inputs = model_inputs.get_symbolic_inputs( 

3038 return_single_as_list=True 

3039 ) 

3040 self.input_names = model_inputs.get_input_names() 

3041 

3042 self._feed_inputs = [] 

3043 self._feed_input_names = [] 

3044 self._feed_input_shapes = [] 

3045 

3046 for k, v in model_inputs.as_dict(): 

3047 if backend.is_placeholder(v): 

3048 self._feed_input_names.append(k) 

3049 self._feed_inputs.append(v) 

3050 self._feed_input_shapes.append(backend.int_shape(v)) 

3051 

3052 return inputs 

3053 

3054 @tf.__internal__.tracking.no_automatic_dependency_tracking 

3055 def _set_output_attrs(self, outputs): 

3056 """Sets attributes related to the outputs of the Model.""" 

3057 # NOTE(taylorrobie): This convention cannot be changed without updating 

3058 # the data adapter since it assumes nest.flatten ordering. 

3059 outputs = tf.nest.flatten(outputs) 

3060 self.outputs = outputs 

3061 self.output_names = training_utils_v1.generic_output_names(outputs) 

3062 # TODO(scottzhu): Should we cleanup the self._training_endpoints here? 

3063 self.built = True 

3064 

3065 @property 

3066 def _targets(self): 

3067 """The output target tensors for the model.""" 

3068 return [ 

3069 e.training_target.target 

3070 for e in self._training_endpoints 

3071 if e.has_training_target() 

3072 ] 

3073 

3074 @property 

3075 def _feed_targets(self): 

3076 return [ 

3077 e.training_target.target 

3078 for e in self._training_endpoints 

3079 if e.has_feedable_training_target() 

3080 ] 

3081 

3082 @property 

3083 def _feed_output_names(self): 

3084 return [ 

3085 e.output_name 

3086 for e in self._training_endpoints 

3087 if e.has_feedable_training_target() 

3088 ] 

3089 

3090 @property 

3091 def _feed_output_shapes(self): 

3092 return [ 

3093 e.feed_output_shape 

3094 for e in self._training_endpoints 

3095 if e.has_feedable_training_target() 

3096 ] 

3097 

3098 @property 

3099 def _feed_loss_fns(self): 

3100 return [ 

3101 e.loss_fn 

3102 for e in self._training_endpoints 

3103 if e.has_feedable_training_target() 

3104 ] 

3105 

3106 @property 

3107 def _loss_weights_list(self): 

3108 return [e.loss_weight for e in self._training_endpoints] 

3109 

3110 @property 

3111 def _output_loss_metrics(self): 

3112 if hasattr(self, "_training_endpoints"): 

3113 return [ 

3114 e.output_loss_metric 

3115 for e in self._training_endpoints 

3116 if e.output_loss_metric is not None 

3117 ] 

3118 return None 

3119 

3120 @property 

3121 def sample_weights(self): 

3122 return [e.sample_weight for e in self._training_endpoints] 

3123 

3124 @property 

3125 def _sample_weight_modes(self): 

3126 return [e.sample_weight_mode for e in self._training_endpoints] 

3127 

3128 @property 

3129 def _feed_sample_weights(self): 

3130 return [ 

3131 e.sample_weight 

3132 for e in self._training_endpoints 

3133 if e.sample_weight is not None 

3134 ] 

3135 

3136 def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode): 

3137 """Maybe load 1st epoch from checkpoint, considering worker recovery. 

3138 

3139 Refer to tensorflow/python/keras/distribute/worker_training_state.py 

3140 for more information. 

3141 

3142 Args: 

3143 initial_epoch: The original initial_epoch user passes in in `fit()`. 

3144 mode: The mode for running `model.fit()`. 

3145 

3146 Returns: 

3147 If the training is recovering from previous failure under multi-worker 

3148 training setting, return the epoch the training is supposed to 

3149 continue at. Otherwise, return the `initial_epoch` the user passes in. 

3150 """ 

3151 if self._training_state is not None: 

3152 return self._training_state.maybe_load_initial_epoch_from_ckpt( 

3153 initial_epoch, mode 

3154 ) 

3155 return initial_epoch 

3156 

3157 def _get_training_eval_metrics(self): 

3158 """Returns all the metrics that are to be reported. 

3159 

3160 This includes the output loss metrics, compile metrics/weighted metrics, 

3161 add_metric metrics. 

3162 """ 

3163 metrics = [] 

3164 metrics.extend(getattr(self, "_output_loss_metrics", None) or []) 

3165 metrics.extend(getattr(self, "metrics", None) or []) 

3166 return metrics 

3167 

3168 def _assert_compile_was_called(self): 

3169 # Checks whether `compile` has been called. If it has been called, 

3170 # then the optimizer is set. This is different from whether the 

3171 # model is compiled 

3172 # (i.e. whether the model is built and its inputs/outputs are set). 

3173 if not self._compile_was_called: 

3174 raise RuntimeError( 

3175 "You must compile your model before " 

3176 "training/testing. " 

3177 "Use `model.compile(optimizer, loss)`." 

3178 ) 

3179 

3180 def _in_multi_worker_mode(self): 

3181 """Method to infer if this `Model` is working in multi-worker settings. 

3182 

3183 Multi-worker training refers to the setup where the training is 

3184 distributed across multiple workers, as opposed to the case where 

3185 only a local process performs the training. This function is 

3186 used to infer for example whether or not a distribute coordinator 

3187 should be run, and thus TensorFlow servers should be started for 

3188 communication with other servers in the cluster, or whether or not 

3189 saving/restoring checkpoints is relevant for preemption fault tolerance. 

3190 

3191 Experimental. Signature and implementation are subject to change. 

3192 

3193 Returns: 

3194 Whether this model indicates it's working in multi-worker settings. 

3195 """ 

3196 strategy = self._distribution_strategy 

3197 

3198 # Otherwise, use the strategy whose scope this is in. 

3199 if not strategy and tf.distribute.has_strategy(): 

3200 strategy = tf.distribute.get_strategy() 

3201 return strategy and strategy.extended._in_multi_worker_mode() 

3202 

3203 @property 

3204 def _trackable_saved_model_saver(self): 

3205 return model_serialization.ModelSavedModelSaver(self) 

3206 

3207 def _get_compile_args(self, user_metrics=True): 

3208 del user_metrics 

3209 self._assert_compile_was_called() 

3210 kwargs = { 

3211 "loss": self.loss, 

3212 "metrics": self._compile_metrics, 

3213 "loss_weights": self.loss_weights, 

3214 "sample_weight_mode": self.sample_weight_mode, 

3215 "weighted_metrics": self._compile_weighted_metrics, 

3216 } 

3217 return kwargs 

3218 

3219 @property 

3220 def _compile_was_called(self): 

3221 return self._v1_compile_was_called 

3222 

3223 

3224class DistributedCallbackModel(Model): 

3225 """Model that is used for callbacks with tf.distribute.Strategy.""" 

3226 

3227 def __init__(self, model): 

3228 super().__init__() 

3229 self.optimizer = model.optimizer 

3230 

3231 def set_original_model(self, orig_model): 

3232 self._original_model = orig_model 

3233 

3234 def save_weights(self, filepath, overwrite=True, save_format=None): 

3235 self._replicated_model.save_weights( 

3236 filepath, overwrite=overwrite, save_format=save_format 

3237 ) 

3238 

3239 def save(self, filepath, overwrite=True, include_optimizer=True): 

3240 # save weights from the distributed model to the original model 

3241 distributed_model_weights = self.get_weights() 

3242 self._original_model.set_weights(distributed_model_weights) 

3243 # TODO(anjalisridhar): Do we need to save the original model here? 

3244 # Saving the first replicated model works as well. 

3245 self._original_model.save( 

3246 filepath, overwrite=True, include_optimizer=False 

3247 ) 

3248 

3249 def load_weights(self, filepath, by_name=False): 

3250 self._original_model.load_weights(filepath, by_name=False) 

3251 # Copy the weights from the original model to each of the replicated 

3252 # models. 

3253 orig_model_weights = self._original_model.get_weights() 

3254 distributed_training_utils_v1.set_weights( 

3255 self._original_model._distribution_strategy, 

3256 self, 

3257 orig_model_weights, 

3258 ) 

3259 

3260 def __getattr__(self, item): 

3261 # Allowed attributes of the model that can be accessed by the user 

3262 # during a callback. 

3263 if item not in ("_setattr_tracking", "_layers"): 

3264 logging.warning( 

3265 "You are accessing attribute " + item + " of the " 

3266 "DistributedCallbackModel that may not have been set " 

3267 "correctly." 

3268 ) 

3269 return super().__getattr__(item) 

3270 

3271 

3272class _TrainingEndpoint: 

3273 """A container for the training output/target and related entities. 

3274 

3275 In the case of model with multiple outputs, there is a one-to-one mapping 

3276 between model output (y_pred), model target (y_true), loss, metrics etc. 

3277 By unifying these entities into one class, different entity can access 

3278 information between each other, rather than currently access different list 

3279 of attributes of the model. 

3280 """ 

3281 

3282 def __init__( 

3283 self, 

3284 output, 

3285 output_name, 

3286 loss_fn, 

3287 loss_weight=None, 

3288 training_target=None, 

3289 output_loss_metric=None, 

3290 sample_weight=None, 

3291 sample_weight_mode=None, 

3292 ): 

3293 """Initialize the _TrainingEndpoint. 

3294 

3295 Note that the output and output_name should be stable as long as the 

3296 model structure doesn't change. The training_target suppose to be 

3297 mutable since the information is provided via `compile()` 

3298 

3299 Args: 

3300 output: the output tensor of the model. 

3301 output_name: the unique name of the output tensor. 

3302 loss_fn: the loss function for the output tensor. 

3303 loss_weight: float, the weights for the loss. 

3304 training_target: the _TrainingTarget for the model. 

3305 output_loss_metric: the metric object for the loss function. 

3306 sample_weight: the weights for how a sample is weighted during metric 

3307 and loss calculation. Could be None. 

3308 sample_weight_mode: string, 'temporal', 'samplewise' or None. The mode 

3309 for how the sample_weight is populated. 

3310 """ 

3311 self._output = output 

3312 self._output_name = output_name 

3313 self._loss_fn = loss_fn 

3314 self._loss_weight = loss_weight 

3315 self._training_target = training_target 

3316 self._output_loss_metric = output_loss_metric 

3317 self._sample_weight = sample_weight 

3318 self._sample_weight_mode = sample_weight_mode 

3319 

3320 @property 

3321 def output(self): 

3322 return self._output 

3323 

3324 @property 

3325 def output_name(self): 

3326 return self._output_name 

3327 

3328 @property 

3329 def shape(self): 

3330 return backend.int_shape(self.output) 

3331 

3332 @property 

3333 def loss_fn(self): 

3334 return self._loss_fn 

3335 

3336 @property 

3337 def loss_weight(self): 

3338 return self._loss_weight 

3339 

3340 @loss_weight.setter 

3341 def loss_weight(self, value): 

3342 self._loss_weight = value 

3343 

3344 @property 

3345 def training_target(self): 

3346 return self._training_target 

3347 

3348 @training_target.setter 

3349 def training_target(self, value): 

3350 self._training_target = value 

3351 

3352 def create_training_target(self, target, run_eagerly=False): 

3353 """Create training_target instance and update the self.training_target. 

3354 

3355 Note that the input target should just be a tensor or None, and 

3356 corresponding training target will be created based on the output and 

3357 loss_fn. 

3358 

3359 Args: 

3360 target: the target tensor for the current output. Could be None. 

3361 run_eagerly: boolean, whether the model is in run_eagerly mode. 

3362 

3363 Raises: 

3364 ValueError if the training_target field for the current instance has 

3365 already been populated. 

3366 """ 

3367 if self.has_training_target(): 

3368 raise ValueError( 

3369 "The training_target field for the _TrainingEndpoint " 

3370 "instance has already been populated" 

3371 ) 

3372 if run_eagerly: 

3373 # When run_eagerly, the target tensor is ignored, and the None 

3374 # placeholder is created instead. 

3375 self.training_target = _TrainingTarget( 

3376 None, feedable=True, skip_target_weights=False 

3377 ) 

3378 return 

3379 

3380 if self.should_skip_target(): 

3381 self.training_target = _TrainingTarget(None) 

3382 else: 

3383 if target is not None and not backend.is_placeholder(target): 

3384 feedable = False 

3385 skip_target_weights = True 

3386 else: 

3387 feedable = True 

3388 skip_target_weights = False 

3389 

3390 if target is None: 

3391 target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get( 

3392 self.loss_fn, backend.dtype(self.output) 

3393 ) 

3394 

3395 target = backend.placeholder( 

3396 ndim=len(self.shape), 

3397 name=self.output_name + "_target", 

3398 sparse=backend.is_sparse(self.output), 

3399 dtype=target_dtype, 

3400 ) 

3401 

3402 self.training_target = _TrainingTarget( 

3403 target, 

3404 feedable=feedable, 

3405 skip_target_weights=skip_target_weights, 

3406 ) 

3407 

3408 @property 

3409 def output_loss_metric(self): 

3410 return self._output_loss_metric 

3411 

3412 @output_loss_metric.setter 

3413 def output_loss_metric(self, value): 

3414 self._output_loss_metric = value 

3415 

3416 @property 

3417 def sample_weight(self): 

3418 return self._sample_weight 

3419 

3420 @sample_weight.setter 

3421 def sample_weight(self, value): 

3422 self._sample_weight = value 

3423 

3424 @property 

3425 def sample_weight_mode(self): 

3426 return self._sample_weight_mode 

3427 

3428 @sample_weight_mode.setter 

3429 def sample_weight_mode(self, value): 

3430 self._sample_weight_mode = value 

3431 

3432 def should_skip_target(self): 

3433 return self._loss_fn is None 

3434 

3435 def should_skip_target_weights(self): 

3436 return ( 

3437 self.should_skip_target() 

3438 or self.training_target is None 

3439 or self.training_target.skip_target_weights 

3440 ) 

3441 

3442 def has_training_target(self): 

3443 return self.training_target is not None 

3444 

3445 def has_feedable_training_target(self): 

3446 return ( 

3447 not self.should_skip_target() 

3448 and self.training_target is not None 

3449 and self.training_target.feedable 

3450 ) 

3451 

3452 def loss_name(self): 

3453 if self._loss_fn is not None: 

3454 return self._output_name + "_loss" 

3455 return None 

3456 

3457 @property 

3458 def feed_output_shape(self): 

3459 """The output shape for the feedable target.""" 

3460 if not self.has_feedable_training_target(): 

3461 return None 

3462 

3463 if ( 

3464 ( 

3465 isinstance(self.loss_fn, losses.LossFunctionWrapper) 

3466 and self.loss_fn.fn == losses.sparse_categorical_crossentropy 

3467 ) 

3468 ) or (isinstance(self.loss_fn, losses.SparseCategoricalCrossentropy)): 

3469 if backend.image_data_format() == "channels_first": 

3470 return (self.shape[0], 1) + self.shape[2:] 

3471 else: 

3472 return self.shape[:-1] + (1,) 

3473 elif not isinstance(self.loss_fn, losses.Loss) or ( 

3474 isinstance(self.loss_fn, losses.LossFunctionWrapper) 

3475 and (getattr(losses, self.loss_fn.fn.__name__, None) is None) 

3476 ): 

3477 # If the given loss is not an instance of the `Loss` class (custom 

3478 # class) or if the loss function that is wrapped is not in the 

3479 # `losses` module, then it is a user-defined loss and we make no 

3480 # assumptions about it. 

3481 return None 

3482 else: 

3483 return self.shape 

3484 

3485 def sample_weights_mismatch(self): 

3486 """Check if the sample weight and the mode match or not.""" 

3487 # If there is a mismatch between sample weight mode and the placeholders 

3488 # created, then recompile the sub-graphs that depend on sample weights. 

3489 return ( 

3490 self.sample_weight_mode is not None and self.sample_weight is None 

3491 ) or ( 

3492 self.sample_weight_mode is None and self.sample_weight is not None 

3493 ) 

3494 

3495 def populate_sample_weight(self, sample_weight, sample_weight_mode): 

3496 """Populate the sample weight and based on the sample weight mode.""" 

3497 if sample_weight is None and ( 

3498 self.should_skip_target_weights() 

3499 or sample_weight_mode is None 

3500 or tf.executing_eagerly() 

3501 ): 

3502 self._sample_weight = None 

3503 return 

3504 

3505 assert sample_weight_mode in ["temporal", "samplewise"] 

3506 if sample_weight_mode == "temporal": 

3507 default_value = [[1.0]] 

3508 shape = [None, None] 

3509 else: 

3510 # sample_weight_mode == 'samplewise' 

3511 default_value = [1.0] 

3512 shape = [None] 

3513 

3514 if sample_weight is not None: 

3515 if not sample_weight.shape.is_compatible_with(shape): 

3516 raise ValueError( 

3517 "Received sample weight with shape {}. Expected shape " 

3518 "{}.".format(sample_weight.shape, shape) 

3519 ) 

3520 self._sample_weight = sample_weight 

3521 else: 

3522 self._sample_weight = tf.compat.v1.placeholder_with_default( 

3523 tf.constant(default_value, dtype=backend.floatx()), 

3524 shape=shape, 

3525 name=self.output_name + "_sample_weights", 

3526 ) 

3527 

3528 

3529class _TrainingTarget: 

3530 """Container for a target tensor (y_true) and its metadata (shape, loss...). 

3531 

3532 Args: 

3533 target: A target tensor for the model. It may be `None` if the 

3534 output is excluded from loss computation. It is still kept as None 

3535 since each output of the model should have a corresponding target. If 

3536 the target is None, the rest of the attributes will be None as well. 

3537 feedable: Boolean, whether the target is feedable (requires data to be 

3538 passed in `fit` or `train_on_batch`), or not (model compiled with 

3539 `target_tensors` argument). 

3540 skip_target_weights: Boolean, whether the target should be skipped during 

3541 weights calculation. 

3542 """ 

3543 

3544 def __init__(self, target, feedable=False, skip_target_weights=True): 

3545 self._target = target 

3546 self._feedable = feedable 

3547 self._skip_target_weights = skip_target_weights 

3548 

3549 @property 

3550 def target(self): 

3551 return self._target 

3552 

3553 @property 

3554 def feedable(self): 

3555 return self._feedable 

3556 

3557 @property 

3558 def skip_target_weights(self): 

3559 return self._skip_target_weights 

3560 

3561 

3562def _is_symbolic_tensor(x): 

3563 return tf.is_tensor(x) 

3564 

3565 

3566def _convert_scipy_sparse_tensor(value, expected_input): 

3567 """Handle scipy sparse tensor conversions. 

3568 

3569 This method takes a value 'value' and returns the proper conversion. If 

3570 value is a scipy sparse tensor and the expected input is a dense tensor, 

3571 we densify 'value'. If value is a scipy sparse tensor and the expected input 

3572 is a TF SparseTensor, we convert 'value' to a SparseTensor. If 'value' is 

3573 not a scipy sparse tensor, or scipy is not imported, we pass it through 

3574 unchanged. 

3575 

3576 Args: 

3577 value: An object that may be a scipy sparse tensor 

3578 expected_input: The expected input placeholder. 

3579 

3580 Returns: 

3581 The possibly-converted 'value'. 

3582 """ 

3583 if issparse is not None and issparse(value): 

3584 if backend.is_sparse(expected_input): 

3585 sparse_coo = value.tocoo() 

3586 row, col = sparse_coo.row, sparse_coo.col 

3587 data, shape = sparse_coo.data, sparse_coo.shape 

3588 indices = np.concatenate( 

3589 (np.expand_dims(row, 1), np.expand_dims(col, 1)), 1 

3590 ) 

3591 return tf.SparseTensor(indices, data, shape) 

3592 else: 

3593 if tf.compat.v1.executing_eagerly_outside_functions(): 

3594 # In TF2 we do not silently densify sparse matrices. 

3595 raise ValueError( 

3596 "A SciPy sparse matrix was passed to a model " 

3597 "that expects dense inputs. Please densify your " 

3598 "inputs first, such as by calling `x.toarray()." 

3599 ) 

3600 return value.toarray() 

3601 else: 

3602 return value 

3603 

3604 

3605def _get_metrics_from_layers(layers): 

3606 """Returns list of metrics from the given layers. 

3607 

3608 This will not include the `compile` metrics of a model layer. 

3609 

3610 Args: 

3611 layers: List of layers. 

3612 

3613 Returns: 

3614 List of metrics. 

3615 """ 

3616 metrics = [] 

3617 layers = layer_utils.filter_empty_layer_containers(layers) 

3618 for layer in layers: 

3619 if isinstance(layer, Model): 

3620 # We cannot call 'metrics' on the model because we do not want to 

3621 # include the metrics that were added in compile API of a nested 

3622 # model. 

3623 metrics.extend(layer._metrics) 

3624 metrics.extend(_get_metrics_from_layers(layer.layers)) 

3625 else: 

3626 metrics.extend(layer.metrics) 

3627 return metrics 

3628 

3629 

3630def _non_none_constant_value(v): 

3631 constant_value = tf.get_static_value(v) 

3632 return constant_value if constant_value is not None else v 

3633