Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_v1.py: 20%

1081 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""V1 Training-related part of the Keras engine.""" 

16 

17import collections 

18import warnings 

19 

20import numpy as np 

21 

22from tensorflow.python import tf2 

23from tensorflow.python.data.ops import dataset_ops 

24from tensorflow.python.data.ops import iterator_ops 

25from tensorflow.python.distribute import distribute_lib 

26from tensorflow.python.distribute import parameter_server_strategy 

27from tensorflow.python.distribute import parameter_server_strategy_v2 

28from tensorflow.python.eager import context 

29from tensorflow.python.eager import def_function 

30from tensorflow.python.framework import constant_op 

31from tensorflow.python.framework import ops 

32from tensorflow.python.framework import sparse_tensor 

33from tensorflow.python.framework import tensor_shape 

34from tensorflow.python.framework import tensor_spec 

35from tensorflow.python.framework import tensor_util 

36from tensorflow.python.framework import type_spec 

37from tensorflow.python.keras import backend 

38from tensorflow.python.keras import losses 

39from tensorflow.python.keras import metrics as metrics_module 

40from tensorflow.python.keras import optimizer_v1 

41from tensorflow.python.keras import optimizers 

42from tensorflow.python.keras.distribute import distributed_training_utils 

43from tensorflow.python.keras.distribute import distributed_training_utils_v1 

44from tensorflow.python.keras.engine import base_layer 

45from tensorflow.python.keras.engine import training as training_lib 

46from tensorflow.python.keras.engine import training_arrays_v1 

47from tensorflow.python.keras.engine import training_distributed_v1 

48from tensorflow.python.keras.engine import training_eager_v1 

49from tensorflow.python.keras.engine import training_generator_v1 

50from tensorflow.python.keras.engine import training_utils 

51from tensorflow.python.keras.engine import training_utils_v1 

52from tensorflow.python.keras.mixed_precision import loss_scale_optimizer 

53from tensorflow.python.keras.mixed_precision import policy 

54from tensorflow.python.keras.optimizer_v2 import optimizer_v2 

55from tensorflow.python.keras.saving import saving_utils 

56from tensorflow.python.keras.saving.saved_model import model_serialization 

57from tensorflow.python.keras.utils import data_utils 

58from tensorflow.python.keras.utils import layer_utils 

59from tensorflow.python.keras.utils import losses_utils 

60from tensorflow.python.keras.utils import tf_inspect 

61from tensorflow.python.keras.utils import tf_utils 

62from tensorflow.python.keras.utils.mode_keys import ModeKeys 

63from tensorflow.python.ops import array_ops 

64from tensorflow.python.ops import math_ops 

65from tensorflow.python.platform import tf_logging as logging 

66from tensorflow.python.trackable import base as trackable 

67from tensorflow.python.types import data as data_types 

68from tensorflow.python.util import nest 

69 

70try: 

71 from scipy.sparse import issparse # pylint: disable=g-import-not-at-top 

72except ImportError: 

73 issparse = None 

74 

75 

76class Model(training_lib.Model): 

77 """`Model` groups layers into an object with training and inference features. 

78 

79 There are two ways to instantiate a `Model`: 

80 

81 1 - With the "functional API", where you start from `Input`, 

82 you chain layer calls to specify the model's forward pass, 

83 and finally you create your model from inputs and outputs: 

84 

85 ```python 

86 import tensorflow as tf 

87 

88 inputs = tf.keras.Input(shape=(3,)) 

89 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) 

90 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) 

91 model = tf.keras.Model(inputs=inputs, outputs=outputs) 

92 ``` 

93 

94 2 - By subclassing the `Model` class: in that case, you should define your 

95 layers in `__init__` and you should implement the model's forward pass 

96 in `call`. 

97 

98 ```python 

99 import tensorflow as tf 

100 

101 class MyModel(tf.keras.Model): 

102 

103 def __init__(self): 

104 super(MyModel, self).__init__() 

105 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 

106 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 

107 

108 def call(self, inputs): 

109 x = self.dense1(inputs) 

110 return self.dense2(x) 

111 

112 model = MyModel() 

113 ``` 

114 

115 If you subclass `Model`, you can optionally have 

116 a `training` argument (boolean) in `call`, which you can use to specify 

117 a different behavior in training and inference: 

118 

119 ```python 

120 import tensorflow as tf 

121 

122 class MyModel(tf.keras.Model): 

123 

124 def __init__(self): 

125 super(MyModel, self).__init__() 

126 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 

127 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 

128 self.dropout = tf.keras.layers.Dropout(0.5) 

129 

130 def call(self, inputs, training=False): 

131 x = self.dense1(inputs) 

132 if training: 

133 x = self.dropout(x, training=training) 

134 return self.dense2(x) 

135 

136 model = MyModel() 

137 ``` 

138 """ 

139 

140 def __init__(self, *args, **kwargs): 

141 super(Model, self).__init__(*args, **kwargs) 

142 # initializing _distribution_strategy here since it is possible to call 

143 # predict on a model without compiling it. 

144 self._distribution_strategy = None 

145 self._compile_time_distribution_strategy = None 

146 if (ops.executing_eagerly_outside_functions() and 

147 distribute_lib.has_strategy()): 

148 self._set_strategy( 

149 distribute_lib.get_strategy()) 

150 

151 # This flag is used to track if the user is using the deprecated path of 

152 # passing distribution strategy to compile rather than creating the model 

153 # under distribution strategy scope. 

154 self._compile_distribution = False 

155 

156 self._run_eagerly = None 

157 self._experimental_run_tf_function = ( 

158 ops.executing_eagerly_outside_functions()) 

159 

160 self._v1_compile_was_called = False 

161 

162 def _init_batch_counters(self): 

163 pass # Batch counters should not be created in legacy graph mode. 

164 

165 @trackable.no_automatic_dependency_tracking 

166 def _set_strategy(self, strategy): 

167 self._compile_time_distribution_strategy = strategy 

168 

169 def get_weights(self): 

170 """Retrieves the weights of the model. 

171 

172 Returns: 

173 A flat list of Numpy arrays. 

174 """ 

175 strategy = (self._distribution_strategy or 

176 self._compile_time_distribution_strategy) 

177 if strategy: 

178 with strategy.scope(): 

179 return base_layer.Layer.get_weights(self) 

180 return base_layer.Layer.get_weights(self) 

181 

182 def load_weights(self, filepath, by_name=False, skip_mismatch=False): 

183 """Loads all layer weights, either from a TensorFlow or an HDF5 weight file. 

184 

185 If `by_name` is False weights are loaded based on the network's 

186 topology. This means the architecture should be the same as when the weights 

187 were saved. Note that layers that don't have weights are not taken into 

188 account in the topological ordering, so adding or removing layers is fine as 

189 long as they don't have weights. 

190 

191 If `by_name` is True, weights are loaded into layers only if they share the 

192 same name. This is useful for fine-tuning or transfer-learning models where 

193 some of the layers have changed. 

194 

195 Only topological loading (`by_name=False`) is supported when loading weights 

196 from the TensorFlow format. Note that topological loading differs slightly 

197 between TensorFlow and HDF5 formats for user-defined classes inheriting from 

198 `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the 

199 TensorFlow format loads based on the object-local names of attributes to 

200 which layers are assigned in the `Model`'s constructor. 

201 

202 Args: 

203 filepath: String, path to the weights file to load. For weight files in 

204 TensorFlow format, this is the file prefix (the same as was passed 

205 to `save_weights`). 

206 by_name: Boolean, whether to load weights by name or by topological 

207 order. Only topological loading is supported for weight files in 

208 TensorFlow format. 

209 skip_mismatch: Boolean, whether to skip loading of layers where there is 

210 a mismatch in the number of weights, or a mismatch in the shape of 

211 the weight (only valid when `by_name=True`). 

212 

213 Returns: 

214 When loading a weight file in TensorFlow format, returns the same status 

215 object as `tf.train.Checkpoint.restore`. When graph building, restore 

216 ops are run automatically as soon as the network is built (on first call 

217 for user-defined classes inheriting from `Model`, immediately if it is 

218 already built). 

219 

220 When loading weights in HDF5 format, returns `None`. 

221 

222 Raises: 

223 ImportError: If h5py is not available and the weight file is in HDF5 

224 format. 

225 ValueError: If `skip_mismatch` is set to `True` when `by_name` is 

226 `False`. 

227 """ 

228 if backend.is_tpu_strategy(self._distribution_strategy): 

229 if (self._distribution_strategy.extended.steps_per_run > 1 and 

230 (not saving_utils.is_hdf5_filepath(filepath))): # pylint: disable=protected-access 

231 raise ValueError('Load weights is not yet supported with TPUStrategy ' 

232 'with steps_per_run greater than 1.') 

233 return super(Model, self).load_weights(filepath, by_name, skip_mismatch) 

234 

235 @trackable.no_automatic_dependency_tracking 

236 def compile(self, 

237 optimizer='rmsprop', 

238 loss=None, 

239 metrics=None, 

240 loss_weights=None, 

241 sample_weight_mode=None, 

242 weighted_metrics=None, 

243 target_tensors=None, 

244 distribute=None, 

245 **kwargs): 

246 """Configures the model for training. 

247 

248 Args: 

249 optimizer: String (name of optimizer) or optimizer instance. 

250 See `tf.keras.optimizers`. 

251 loss: String (name of objective function), objective function or 

252 `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective 

253 function is any callable with the signature 

254 `scalar_loss = fn(y_true, y_pred)`. If the model has multiple 

255 outputs, you can use a different loss on each output by passing a 

256 dictionary or a list of losses. The loss value that will be 

257 minimized by the model will then be the sum of all individual 

258 losses. 

259 metrics: List of metrics to be evaluated by the model during training 

260 and testing. Typically you will use `metrics=['accuracy']`. 

261 To specify different metrics for different outputs of a 

262 multi-output model, you could also pass a dictionary, such as 

263 `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`. 

264 You can also pass a list (len = len(outputs)) of lists of metrics 

265 such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or 

266 `metrics=['accuracy', ['accuracy', 'mse']]`. 

267 loss_weights: Optional list or dictionary specifying scalar 

268 coefficients (Python floats) to weight the loss contributions 

269 of different model outputs. 

270 The loss value that will be minimized by the model 

271 will then be the *weighted sum* of all individual losses, 

272 weighted by the `loss_weights` coefficients. 

273 If a list, it is expected to have a 1:1 mapping 

274 to the model's outputs. If a tensor, it is expected to map 

275 output names (strings) to scalar coefficients. 

276 sample_weight_mode: If you need to do timestep-wise 

277 sample weighting (2D weights), set this to `"temporal"`. 

278 `None` defaults to sample-wise weights (1D). 

279 If the model has multiple outputs, you can use a different 

280 `sample_weight_mode` on each output by passing a 

281 dictionary or a list of modes. 

282 weighted_metrics: List of metrics to be evaluated and weighted 

283 by sample_weight or class_weight during training and testing. 

284 target_tensors: By default, Keras will create placeholders for the 

285 model's target, which will be fed with the target data during 

286 training. If instead you would like to use your own 

287 target tensors (in turn, Keras will not expect external 

288 Numpy data for these targets at training time), you 

289 can specify them via the `target_tensors` argument. It can be 

290 a single tensor (for a single-output model), a list of tensors, 

291 or a dict mapping output names to target tensors. 

292 distribute: NOT SUPPORTED IN TF 2.0, please create and compile the 

293 model under distribution strategy scope instead of passing it to 

294 compile. 

295 **kwargs: Any additional arguments. 

296 

297 Raises: 

298 ValueError: In case of invalid arguments for 

299 `optimizer`, `loss`, `metrics` or `sample_weight_mode`. 

300 """ 

301 self._assert_built_as_v1() 

302 self._run_eagerly = kwargs.pop('run_eagerly', None) 

303 self._experimental_run_tf_function = kwargs.pop( 

304 'experimental_run_tf_function', True) 

305 self._v1_compile_was_called = True 

306 

307 # Prepare Session arguments (legacy). 

308 kwargs.pop('cloning', None) # Legacy DistStrat argument, never used. 

309 self._from_serialized = kwargs.pop('from_serialized', False) 

310 allowed_kwargs = {'feed_dict', 'fetches', 'options', 'run_metadata'} 

311 unknown_kwargs = set(kwargs.keys()) - allowed_kwargs 

312 if unknown_kwargs: 

313 raise TypeError( 

314 'Invalid keyword argument(s) in `compile`: %s' % (unknown_kwargs,)) 

315 self._function_kwargs = kwargs 

316 if self._function_kwargs: 

317 self._experimental_run_tf_function = False 

318 if self.run_eagerly: 

319 raise ValueError( 

320 'Session keyword arguments are not supported ' 

321 'when `run_eagerly=True`. You passed the following ' 

322 'Session arguments: %s' % (self._function_kwargs,)) 

323 

324 self._set_optimizer(optimizer) 

325 is_any_keras_optimizer_v1 = any( 

326 (isinstance(opt, optimizer_v1.Optimizer) 

327 and not isinstance(opt, optimizer_v1.TFOptimizer) 

328 ) for opt in nest.flatten(self.optimizer)) 

329 

330 if is_any_keras_optimizer_v1 and ops.executing_eagerly_outside_functions(): 

331 raise ValueError('`tf.compat.v1.keras` Optimizer (', optimizer, ') is ' 

332 'not supported when eager execution is enabled. Use a ' 

333 '`tf.keras` Optimizer instead, or disable eager ' 

334 'execution.') 

335 

336 if ((target_tensors is not None) 

337 or not ops.executing_eagerly_outside_functions()): 

338 # Fallback out of things that aren't supported with v2 loops 

339 self._experimental_run_tf_function = False 

340 

341 if distribute is not None: 

342 if tf2.enabled() or self._experimental_run_tf_function: 

343 raise ValueError( 

344 'Distribute argument in compile is not available in TF 2.0 please ' 

345 'create the model under the distribution strategy scope.') 

346 logging.warning('Distribute argument in compile is deprecated please ' 

347 'create the model under the distribution strategy scope.') 

348 self._distribution_strategy = distribute 

349 self._compile_distribution = True 

350 else: 

351 if distribute_lib.has_strategy(): 

352 # When the user builds the model in the DS scope and cross replica 

353 # context we want distribution strategy to be set but when building the 

354 # replica copies of the models internally we should not be compiling 

355 # with distribution strategy and use the default compilation path. 

356 if distribute_lib.in_cross_replica_context(): 

357 self._distribution_strategy = ( 

358 distribute_lib.get_strategy()) 

359 

360 if isinstance(self._distribution_strategy, 

361 parameter_server_strategy.ParameterServerStrategyV1): 

362 raise NotImplementedError( 

363 '`tf.compat.v1.distribute.experimental.ParameterServerStrategy` ' 

364 'currently only works with the tf.Estimator API') 

365 

366 if isinstance(self._distribution_strategy, 

367 parameter_server_strategy_v2.ParameterServerStrategyV2): 

368 raise NotImplementedError( 

369 '`tf.distribute.experimental.ParameterServerStrategy` is only ' 

370 'supported in TF2.') 

371 

372 if not self._experimental_run_tf_function: 

373 self._validate_compile_param_for_distribution_strategy(self.run_eagerly, 

374 sample_weight_mode, 

375 target_tensors, 

376 weighted_metrics) 

377 # We've disabled automatic dependency tracking for this method, but do want 

378 # to add a checkpoint dependency on the optimizer if it's trackable. 

379 if isinstance(self.optimizer, trackable.Trackable): 

380 self._track_trackable( 

381 self.optimizer, name='optimizer', overwrite=True) 

382 self.loss = loss or {} 

383 self.loss_weights = loss_weights 

384 self.sample_weight_mode = sample_weight_mode 

385 self._compile_metrics = metrics or [] 

386 self._compile_weighted_metrics = weighted_metrics 

387 if self.run_eagerly and target_tensors is not None: 

388 raise ValueError( 

389 'target_tensors argument is not supported when ' 

390 'running a model eagerly.') 

391 

392 # _training_endpoints contains a list of _TrainingEndpoint object, which has 

393 # all the model output/target/loss and related metadata. 

394 self._training_endpoints = [] 

395 

396 # Used to freeze the behavior of the Model once `compile` has been called. 

397 self._compiled_trainable_state = self._get_trainable_state() 

398 

399 # Set tf.distribute.Strategy specific parameters. 

400 self._distributed_model_cache = {} 

401 self._distributed_function_cache = {} 

402 

403 # Clear any `_eager_losses` that was added. 

404 self._clear_losses() 

405 

406 if (not context.executing_eagerly() and 

407 self._distribution_strategy is not None): 

408 # Ensures a Session is created and configured correctly for Distribution 

409 # Strategy. 

410 backend.configure_and_create_distributed_session( 

411 self._distribution_strategy) 

412 # Initialize model metric attributes. 

413 self._init_metric_attributes() 

414 if not self.built or not self.inputs or not self.outputs: 

415 # Model is not compilable because it does not know its number of inputs 

416 # and outputs, nor their shapes and names. We will compile after the first 

417 # time the model gets called on training data. 

418 return 

419 self._is_compiled = True 

420 

421 # Prepare list of loss functions, same size of model outputs. 

422 self.loss_functions = training_utils_v1.prepare_loss_functions( 

423 self.loss, self.output_names) 

424 

425 target_tensors = self._process_target_tensor_for_compile(target_tensors) 

426 

427 for o, n, l, t in zip(self.outputs, self.output_names, 

428 self.loss_functions, target_tensors): 

429 endpoint = _TrainingEndpoint(o, n, l) 

430 endpoint.create_training_target(t, run_eagerly=self.run_eagerly) 

431 self._training_endpoints.append(endpoint) 

432 

433 # Prepare list loss weights, same size of model outputs. 

434 training_utils_v1.prepare_loss_weights(self._training_endpoints, 

435 loss_weights) 

436 

437 # Initialization for Eager mode execution. 

438 if self.run_eagerly: 

439 self._compile_eagerly(metrics, weighted_metrics, sample_weight_mode) 

440 return 

441 

442 with backend.get_graph().as_default(): 

443 # Save all metric attributes per output of the model. 

444 self._cache_output_metric_attributes(metrics, weighted_metrics) 

445 

446 # Set metric attributes on model. 

447 self._set_metric_attributes() 

448 

449 # Invoke metric functions (unweighted) for all the outputs. 

450 self._handle_metrics( 

451 self.outputs, 

452 targets=self._targets, 

453 skip_target_masks=self._prepare_skip_target_masks(), 

454 masks=self._prepare_output_masks()) 

455 

456 # Prepare sample weight modes. List with the same length as model outputs. 

457 training_utils_v1.prepare_sample_weight_modes( 

458 self._training_endpoints, sample_weight_mode) 

459 

460 # Creates the model loss and weighted metrics sub-graphs. 

461 self._compile_weights_loss_and_weighted_metrics() 

462 

463 # Functions for train, test and predict will 

464 # be compiled lazily when required. 

465 # This saves time when the user is not using all functions. 

466 self.train_function = None 

467 self.test_function = None 

468 self.predict_function = None 

469 

470 # Collected trainable weights, sorted in topological order. 

471 self._collected_trainable_weights = self.trainable_weights 

472 

473 # Validate all variables were correctly created in distribution scope. 

474 if self._distribution_strategy and not self._compile_distribution: 

475 for v in self.variables: 

476 strategy = self._distribution_strategy 

477 if not strategy.extended.variable_created_in_scope(v): 

478 raise ValueError( 

479 'Variable (%s) was not created in the distribution strategy ' 

480 'scope of (%s). It is most likely due to not all layers or ' 

481 'the model or optimizer being created outside the distribution ' 

482 'strategy scope. Try to make sure your code looks similar ' 

483 'to the following.\n' 

484 'with strategy.scope():\n' 

485 ' model=_create_model()\n' 

486 ' model.compile(...)'% (v, strategy)) 

487 

488 @trackable.no_automatic_dependency_tracking 

489 def _init_distributed_function_cache_if_not_compiled(self): 

490 if not hasattr(self, '_distributed_function_cache'): 

491 self._distributed_function_cache = {} 

492 

493 @property 

494 def metrics(self): 

495 """Returns the model's metrics added using `compile`, `add_metric` APIs.""" 

496 metrics = [] 

497 if self._is_compiled: 

498 if not hasattr(self, '_v1_compile_was_called'): 

499 # See b/155687393 for more details, the model is created as a v2 

500 # instance but converted to v1. Fallback to use base Model to retrieve 

501 # the metrics. 

502 return super(Model, self).metrics 

503 metrics += self._compile_metric_functions 

504 metrics.extend(self._metrics) 

505 metrics.extend( 

506 _get_metrics_from_layers( 

507 list(self._flatten_layers(include_self=False, recursive=False)))) 

508 return metrics 

509 

510 @property 

511 def metrics_names(self): 

512 """Returns the model's display labels for all outputs.""" 

513 

514 # This property includes all output names including `loss` and per-output 

515 # losses for backward compatibility. 

516 metrics_names = ['loss'] 

517 if self._is_compiled: 

518 if not hasattr(self, '_v1_compile_was_called'): 

519 # See b/155687393 for more details, the model is created as a v2 

520 # instance but converted to v1. Fallback to use base Model to retrieve 

521 # the metrics name 

522 return super(Model, self).metrics_names 

523 

524 # Add output loss metric names to the metric names list. 

525 if len(self._training_endpoints) > 1: 

526 metrics_names.extend([ 

527 e.loss_name() 

528 for e in self._training_endpoints 

529 if not e.should_skip_target() 

530 ]) 

531 

532 # Add all metric names. 

533 metrics_names += [m.name for m in self.metrics] 

534 return metrics_names 

535 

536 @property 

537 def run_eagerly(self): 

538 """Settable attribute indicating whether the model should run eagerly. 

539 

540 Running eagerly means that your model will be run step by step, 

541 like Python code. Your model might run slower, but it should become easier 

542 for you to debug it by stepping into individual layer calls. 

543 

544 By default, we will attempt to compile your model to a static graph to 

545 deliver the best execution performance. 

546 

547 Returns: 

548 Boolean, whether the model should run eagerly. 

549 """ 

550 if self._run_eagerly is True and not context.executing_eagerly(): 

551 raise ValueError('You can only set `run_eagerly=True` if eager execution ' 

552 'is enabled.') 

553 if not self.dynamic: 

554 if self._run_eagerly is None: 

555 # Respect `tf.config.run_functions_eagerly` unless 

556 # `run_eagerly` was explicitly passed to `compile`. 

557 return def_function.functions_run_eagerly() 

558 else: 

559 return self._run_eagerly 

560 else: 

561 if not context.executing_eagerly(): 

562 raise ValueError('Your model contains layers that can only be ' 

563 'successfully run in eager execution (layers ' 

564 'constructed with `dynamic=True`). ' 

565 'You must enable eager execution with ' 

566 '`tf.enable_eager_execution()`.') 

567 if self._run_eagerly is False: 

568 # TODO(fchollet): consider using py_func to enable this. 

569 raise ValueError('Your model contains layers that can only be ' 

570 'successfully run in eager execution (layers ' 

571 'constructed with `dynamic=True`). ' 

572 'You cannot set `run_eagerly=False`.') 

573 return context.executing_eagerly() 

574 

575 @run_eagerly.setter 

576 def run_eagerly(self, value): 

577 self._run_eagerly = value 

578 

579 def _select_training_loop(self, inputs): 

580 """Select training loop for fit/eval/predict based on the inputs.""" 

581 # TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely 

582 # integrated into the data adapters in the v2 loop. We can't do this yet 

583 # because we currently have to fall back for unhandled data types. 

584 if isinstance(inputs, (iterator_ops.Iterator, 

585 iterator_ops.IteratorBase)): 

586 raise ValueError('For performance reasons Keras `fit`, `evaluate` and' 

587 '`predict` accept tf.data `Datasets` as input but not ' 

588 'iterators that have been manually generated from ' 

589 'Datasets by users. Please directly pass in the ' 

590 'original `Dataset` object instead of passing in ' 

591 '`iter(dataset)`.') 

592 

593 # Case 1: distribution strategy. 

594 if self._distribution_strategy: 

595 if self._in_multi_worker_mode(): 

596 return training_distributed_v1.DistributionMultiWorkerTrainingLoop( 

597 training_distributed_v1.DistributionSingleWorkerTrainingLoop()) 

598 else: 

599 return training_distributed_v1.DistributionSingleWorkerTrainingLoop() 

600 

601 # Case 2: generator-like. Input is Python generator, or Sequence object, 

602 # or a non-distributed Dataset or iterator in eager execution. 

603 if data_utils.is_generator_or_sequence(inputs): 

604 return training_generator_v1.GeneratorOrSequenceTrainingLoop() 

605 if training_utils_v1.is_eager_dataset_or_iterator(inputs): 

606 return training_generator_v1.EagerDatasetOrIteratorTrainingLoop() 

607 

608 # Case 3: Symbolic tensors or Numpy array-like. 

609 # This includes Datasets and iterators in graph mode (since they 

610 # generate symbolic tensors). 

611 if self.run_eagerly: 

612 return training_generator_v1.GeneratorLikeTrainingLoop() 

613 else: 

614 return training_arrays_v1.ArrayLikeTrainingLoop() 

615 

616 def fit(self, 

617 x=None, 

618 y=None, 

619 batch_size=None, 

620 epochs=1, 

621 verbose=1, 

622 callbacks=None, 

623 validation_split=0., 

624 validation_data=None, 

625 shuffle=True, 

626 class_weight=None, 

627 sample_weight=None, 

628 initial_epoch=0, 

629 steps_per_epoch=None, 

630 validation_steps=None, 

631 validation_freq=1, 

632 max_queue_size=10, 

633 workers=1, 

634 use_multiprocessing=False, 

635 **kwargs): 

636 """Trains the model for a fixed number of epochs (iterations on a dataset). 

637 

638 Args: 

639 x: Input data. It could be: 

640 - A Numpy array (or array-like), or a list of arrays 

641 (in case the model has multiple inputs). 

642 - A TensorFlow tensor, or a list of tensors 

643 (in case the model has multiple inputs). 

644 - A dict mapping input names to the corresponding array/tensors, 

645 if the model has named inputs. 

646 - A `tf.data` dataset. Should return a tuple 

647 of either `(inputs, targets)` or 

648 `(inputs, targets, sample_weights)`. 

649 - A generator or `keras.utils.Sequence` returning `(inputs, targets)` 

650 or `(inputs, targets, sample weights)`. 

651 y: Target data. Like the input data `x`, 

652 it could be either Numpy array(s) or TensorFlow tensor(s). 

653 It should be consistent with `x` (you cannot have Numpy inputs and 

654 tensor targets, or inversely). If `x` is a dataset, generator, 

655 or `keras.utils.Sequence` instance, `y` should 

656 not be specified (since targets will be obtained from `x`). 

657 batch_size: Integer or `None`. 

658 Number of samples per gradient update. 

659 If unspecified, `batch_size` will default to 32. 

660 Do not specify the `batch_size` if your data is in the 

661 form of symbolic tensors, datasets, 

662 generators, or `keras.utils.Sequence` instances (since they generate 

663 batches). 

664 epochs: Integer. Number of epochs to train the model. 

665 An epoch is an iteration over the entire `x` and `y` 

666 data provided. 

667 Note that in conjunction with `initial_epoch`, 

668 `epochs` is to be understood as "final epoch". 

669 The model is not trained for a number of iterations 

670 given by `epochs`, but merely until the epoch 

671 of index `epochs` is reached. 

672 verbose: 0, 1, or 2. Verbosity mode. 

673 0 = silent, 1 = progress bar, 2 = one line per epoch. 

674 Note that the progress bar is not particularly useful when 

675 logged to a file, so verbose=2 is recommended when not running 

676 interactively (eg, in a production environment). 

677 callbacks: List of `keras.callbacks.Callback` instances. 

678 List of callbacks to apply during training. 

679 See `tf.keras.callbacks`. 

680 validation_split: Float between 0 and 1. 

681 Fraction of the training data to be used as validation data. 

682 The model will set apart this fraction of the training data, 

683 will not train on it, and will evaluate 

684 the loss and any model metrics 

685 on this data at the end of each epoch. 

686 The validation data is selected from the last samples 

687 in the `x` and `y` data provided, before shuffling. This argument is 

688 not supported when `x` is a dataset, generator or 

689 `keras.utils.Sequence` instance. 

690 validation_data: Data on which to evaluate 

691 the loss and any model metrics at the end of each epoch. 

692 The model will not be trained on this data. 

693 `validation_data` will override `validation_split`. 

694 `validation_data` could be: 

695 - tuple `(x_val, y_val)` of Numpy arrays or tensors 

696 - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays 

697 - dataset 

698 For the first two cases, `batch_size` must be provided. 

699 For the last case, `validation_steps` could be provided. 

700 shuffle: Boolean (whether to shuffle the training data 

701 before each epoch) or str (for 'batch'). 

702 'batch' is a special option for dealing with the 

703 limitations of HDF5 data; it shuffles in batch-sized chunks. 

704 Has no effect when `steps_per_epoch` is not `None`. 

705 class_weight: Optional dictionary mapping class indices (integers) 

706 to a weight (float) value, used for weighting the loss function 

707 (during training only). 

708 This can be useful to tell the model to 

709 "pay more attention" to samples from 

710 an under-represented class. 

711 sample_weight: Optional Numpy array of weights for 

712 the training samples, used for weighting the loss function 

713 (during training only). You can either pass a flat (1D) 

714 Numpy array with the same length as the input samples 

715 (1:1 mapping between weights and samples), 

716 or in the case of temporal data, 

717 you can pass a 2D array with shape 

718 `(samples, sequence_length)`, 

719 to apply a different weight to every timestep of every sample. 

720 In this case you should make sure to specify 

721 `sample_weight_mode="temporal"` in `compile()`. This argument is not 

722 supported when `x` is a dataset, generator, or 

723 `keras.utils.Sequence` instance, instead provide the sample_weights 

724 as the third element of `x`. 

725 initial_epoch: Integer. 

726 Epoch at which to start training 

727 (useful for resuming a previous training run). 

728 steps_per_epoch: Integer or `None`. 

729 Total number of steps (batches of samples) 

730 before declaring one epoch finished and starting the 

731 next epoch. When training with input tensors such as 

732 TensorFlow data tensors, the default `None` is equal to 

733 the number of samples in your dataset divided by 

734 the batch size, or 1 if that cannot be determined. If x is a 

735 `tf.data` dataset, and 'steps_per_epoch' 

736 is None, the epoch will run until the input dataset is exhausted. 

737 This argument is not supported with array inputs. 

738 validation_steps: Only relevant if `validation_data` is provided and 

739 is a `tf.data` dataset. Total number of steps (batches of 

740 samples) to draw before stopping when performing validation 

741 at the end of every epoch. If 'validation_steps' is None, validation 

742 will run until the `validation_data` dataset is exhausted. In the 

743 case of a infinite dataset, it will run into a infinite loop. 

744 If 'validation_steps' is specified and only part of the dataset 

745 will be consumed, the evaluation will start from the beginning of 

746 the dataset at each epoch. This ensures that the same validation 

747 samples are used every time. 

748 validation_freq: Only relevant if validation data is provided. Integer 

749 or `collections.abc.Container` instance (e.g. list, tuple, etc.). 

750 If an integer, specifies how many training epochs to run before a 

751 new validation run is performed, e.g. `validation_freq=2` runs 

752 validation every 2 epochs. If a Container, specifies the epochs on 

753 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 

754 validation at the end of the 1st, 2nd, and 10th epochs. 

755 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 

756 input only. Maximum size for the generator queue. 

757 If unspecified, `max_queue_size` will default to 10. 

758 workers: Integer. Used for generator or `keras.utils.Sequence` input 

759 only. Maximum number of processes to spin up 

760 when using process-based threading. If unspecified, `workers` 

761 will default to 1. If 0, will execute the generator on the main 

762 thread. 

763 use_multiprocessing: Boolean. Used for generator or 

764 `keras.utils.Sequence` input only. If `True`, use process-based 

765 threading. If unspecified, `use_multiprocessing` will default to 

766 `False`. Note that because this implementation relies on 

767 multiprocessing, you should not pass non-picklable arguments to 

768 the generator as they can't be passed easily to children processes. 

769 **kwargs: Used for backwards compatibility. 

770 

771 Returns: 

772 A `History` object. Its `History.history` attribute is 

773 a record of training loss values and metrics values 

774 at successive epochs, as well as validation loss values 

775 and validation metrics values (if applicable). 

776 

777 Raises: 

778 RuntimeError: If the model was never compiled. 

779 ValueError: In case of mismatch between the provided input data 

780 and what the model expects. 

781 """ 

782 self._assert_built_as_v1() 

783 # Legacy support 

784 if 'nb_epoch' in kwargs: 

785 logging.warning( 

786 'The `nb_epoch` argument in `fit` has been renamed `epochs`.') 

787 epochs = kwargs.pop('nb_epoch') 

788 if kwargs: 

789 raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) 

790 self._assert_compile_was_called() 

791 self._check_call_args('fit') 

792 

793 func = self._select_training_loop(x) 

794 return func.fit( 

795 self, 

796 x=x, 

797 y=y, 

798 batch_size=batch_size, 

799 epochs=epochs, 

800 verbose=verbose, 

801 callbacks=callbacks, 

802 validation_split=validation_split, 

803 validation_data=validation_data, 

804 shuffle=shuffle, 

805 class_weight=class_weight, 

806 sample_weight=sample_weight, 

807 initial_epoch=initial_epoch, 

808 steps_per_epoch=steps_per_epoch, 

809 validation_steps=validation_steps, 

810 validation_freq=validation_freq, 

811 max_queue_size=max_queue_size, 

812 workers=workers, 

813 use_multiprocessing=use_multiprocessing) 

814 

815 def evaluate(self, 

816 x=None, 

817 y=None, 

818 batch_size=None, 

819 verbose=1, 

820 sample_weight=None, 

821 steps=None, 

822 callbacks=None, 

823 max_queue_size=10, 

824 workers=1, 

825 use_multiprocessing=False): 

826 """Returns the loss value & metrics values for the model in test mode. 

827 

828 Computation is done in batches (see the `batch_size` arg.) 

829 

830 Args: 

831 x: Input data. It could be: 

832 - A Numpy array (or array-like), or a list of arrays 

833 (in case the model has multiple inputs). 

834 - A TensorFlow tensor, or a list of tensors 

835 (in case the model has multiple inputs). 

836 - A dict mapping input names to the corresponding array/tensors, 

837 if the model has named inputs. 

838 - A `tf.data` dataset. 

839 - A generator or `keras.utils.Sequence` instance. 

840 y: Target data. Like the input data `x`, 

841 it could be either Numpy array(s) or TensorFlow tensor(s). 

842 It should be consistent with `x` (you cannot have Numpy inputs and 

843 tensor targets, or inversely). 

844 If `x` is a dataset, generator or 

845 `keras.utils.Sequence` instance, `y` should not be specified (since 

846 targets will be obtained from the iterator/dataset). 

847 batch_size: Integer or `None`. 

848 Number of samples per batch of computation. 

849 If unspecified, `batch_size` will default to 32. 

850 Do not specify the `batch_size` if your data is in the 

851 form of symbolic tensors, dataset, 

852 generators, or `keras.utils.Sequence` instances (since they generate 

853 batches). 

854 verbose: 0 or 1. Verbosity mode. 

855 0 = silent, 1 = progress bar. 

856 sample_weight: Optional Numpy array of weights for 

857 the test samples, used for weighting the loss function. 

858 You can either pass a flat (1D) 

859 Numpy array with the same length as the input samples 

860 (1:1 mapping between weights and samples), 

861 or in the case of temporal data, 

862 you can pass a 2D array with shape 

863 `(samples, sequence_length)`, 

864 to apply a different weight to every timestep of every sample. 

865 In this case you should make sure to specify 

866 `sample_weight_mode="temporal"` in `compile()`. This argument is not 

867 supported when `x` is a dataset, instead pass 

868 sample weights as the third element of `x`. 

869 steps: Integer or `None`. 

870 Total number of steps (batches of samples) 

871 before declaring the evaluation round finished. 

872 Ignored with the default value of `None`. 

873 If x is a `tf.data` dataset and `steps` is 

874 None, 'evaluate' will run until the dataset is exhausted. 

875 This argument is not supported with array inputs. 

876 callbacks: List of `keras.callbacks.Callback` instances. 

877 List of callbacks to apply during evaluation. 

878 See [callbacks](/api_docs/python/tf/keras/callbacks). 

879 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 

880 input only. Maximum size for the generator queue. 

881 If unspecified, `max_queue_size` will default to 10. 

882 workers: Integer. Used for generator or `keras.utils.Sequence` input 

883 only. Maximum number of processes to spin up when using 

884 process-based threading. If unspecified, `workers` will default 

885 to 1. If 0, will execute the generator on the main thread. 

886 use_multiprocessing: Boolean. Used for generator or 

887 `keras.utils.Sequence` input only. If `True`, use process-based 

888 threading. If unspecified, `use_multiprocessing` will default to 

889 `False`. Note that because this implementation relies on 

890 multiprocessing, you should not pass non-picklable arguments to 

891 the generator as they can't be passed easily to children processes. 

892 

893 Returns: 

894 Scalar test loss (if the model has a single output and no metrics) 

895 or list of scalars (if the model has multiple outputs 

896 and/or metrics). The attribute `model.metrics_names` will give you 

897 the display labels for the scalar outputs. 

898 

899 Raises: 

900 ValueError: in case of invalid arguments. 

901 """ 

902 self._assert_built_as_v1() 

903 self._assert_compile_was_called() 

904 self._check_call_args('evaluate') 

905 

906 func = self._select_training_loop(x) 

907 return func.evaluate( 

908 self, 

909 x=x, 

910 y=y, 

911 batch_size=batch_size, 

912 verbose=verbose, 

913 sample_weight=sample_weight, 

914 steps=steps, 

915 callbacks=callbacks, 

916 max_queue_size=max_queue_size, 

917 workers=workers, 

918 use_multiprocessing=use_multiprocessing) 

919 

920 def predict(self, 

921 x, 

922 batch_size=None, 

923 verbose=0, 

924 steps=None, 

925 callbacks=None, 

926 max_queue_size=10, 

927 workers=1, 

928 use_multiprocessing=False): 

929 """Generates output predictions for the input samples. 

930 

931 Computation is done in batches (see the `batch_size` arg.) 

932 

933 Args: 

934 x: Input samples. It could be: 

935 - A Numpy array (or array-like), or a list of arrays 

936 (in case the model has multiple inputs). 

937 - A TensorFlow tensor, or a list of tensors 

938 (in case the model has multiple inputs). 

939 - A `tf.data` dataset. 

940 - A generator or `keras.utils.Sequence` instance. 

941 batch_size: Integer or `None`. 

942 Number of samples per batch of computation. 

943 If unspecified, `batch_size` will default to 32. 

944 Do not specify the `batch_size` if your data is in the 

945 form of symbolic tensors, dataset, 

946 generators, or `keras.utils.Sequence` instances (since they generate 

947 batches). 

948 verbose: Verbosity mode, 0 or 1. 

949 steps: Total number of steps (batches of samples) 

950 before declaring the prediction round finished. 

951 Ignored with the default value of `None`. If x is a `tf.data` 

952 dataset and `steps` is None, `predict` will 

953 run until the input dataset is exhausted. 

954 callbacks: List of `keras.callbacks.Callback` instances. 

955 List of callbacks to apply during prediction. 

956 See [callbacks](/api_docs/python/tf/keras/callbacks). 

957 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 

958 input only. Maximum size for the generator queue. 

959 If unspecified, `max_queue_size` will default to 10. 

960 workers: Integer. Used for generator or `keras.utils.Sequence` input 

961 only. Maximum number of processes to spin up when using 

962 process-based threading. If unspecified, `workers` will default 

963 to 1. If 0, will execute the generator on the main thread. 

964 use_multiprocessing: Boolean. Used for generator or 

965 `keras.utils.Sequence` input only. If `True`, use process-based 

966 threading. If unspecified, `use_multiprocessing` will default to 

967 `False`. Note that because this implementation relies on 

968 multiprocessing, you should not pass non-picklable arguments to 

969 the generator as they can't be passed easily to children processes. 

970 

971 

972 Returns: 

973 Numpy array(s) of predictions. 

974 

975 Raises: 

976 ValueError: In case of mismatch between the provided 

977 input data and the model's expectations, 

978 or in case a stateful model receives a number of samples 

979 that is not a multiple of the batch size. 

980 """ 

981 self._assert_built_as_v1() 

982 self._check_call_args('predict') 

983 

984 func = self._select_training_loop(x) 

985 return func.predict( 

986 self, 

987 x=x, 

988 batch_size=batch_size, 

989 verbose=verbose, 

990 steps=steps, 

991 callbacks=callbacks, 

992 max_queue_size=max_queue_size, 

993 workers=workers, 

994 use_multiprocessing=use_multiprocessing) 

995 

996 def reset_metrics(self): 

997 """Resets the state of metrics.""" 

998 metrics = self._get_training_eval_metrics() 

999 for m in metrics: 

1000 m.reset_state() 

1001 

1002 # Reset metrics on all the distributed (cloned) models. 

1003 if self._distribution_strategy: 

1004 distributed_training_utils_v1._reset_metrics(self) # pylint: disable=protected-access 

1005 

1006 def train_on_batch(self, 

1007 x, 

1008 y=None, 

1009 sample_weight=None, 

1010 class_weight=None, 

1011 reset_metrics=True): 

1012 """Runs a single gradient update on a single batch of data. 

1013 

1014 Args: 

1015 x: Input data. It could be: 

1016 - A Numpy array (or array-like), or a list of arrays 

1017 (in case the model has multiple inputs). 

1018 - A TensorFlow tensor, or a list of tensors 

1019 (in case the model has multiple inputs). 

1020 - A dict mapping input names to the corresponding array/tensors, 

1021 if the model has named inputs. 

1022 - A `tf.data` dataset. 

1023 y: Target data. Like the input data `x`, it could be either Numpy 

1024 array(s) or TensorFlow tensor(s). It should be consistent with `x` 

1025 (you cannot have Numpy inputs and tensor targets, or inversely). If 

1026 `x` is a dataset, `y` should not be specified 

1027 (since targets will be obtained from the iterator). 

1028 sample_weight: Optional array of the same length as x, containing 

1029 weights to apply to the model's loss for each sample. In the case of 

1030 temporal data, you can pass a 2D array with shape (samples, 

1031 sequence_length), to apply a different weight to every timestep of 

1032 every sample. In this case you should make sure to specify 

1033 sample_weight_mode="temporal" in compile(). This argument is not 

1034 supported when `x` is a dataset. 

1035 class_weight: Optional dictionary mapping class indices (integers) to a 

1036 weight (float) to apply to the model's loss for the samples from this 

1037 class during training. This can be useful to tell the model to "pay 

1038 more attention" to samples from an under-represented class. 

1039 reset_metrics: If `True`, the metrics returned will be only for this 

1040 batch. If `False`, the metrics will be statefully accumulated across 

1041 batches. 

1042 

1043 Returns: 

1044 Scalar training loss 

1045 (if the model has a single output and no metrics) 

1046 or list of scalars (if the model has multiple outputs 

1047 and/or metrics). The attribute `model.metrics_names` will give you 

1048 the display labels for the scalar outputs. 

1049 

1050 Raises: 

1051 ValueError: In case of invalid user-provided arguments. 

1052 """ 

1053 self._assert_compile_was_called() 

1054 self._check_call_args('train_on_batch') 

1055 

1056 # If at this point we are in the replica context, then it is okay to execute 

1057 # the Eager code path. The expected way to get here is to call `fit` that 

1058 # calls `train_on_batch` on each replica. 

1059 if (self._distribution_strategy and 

1060 distribute_lib.in_cross_replica_context()): 

1061 raise NotImplementedError('`train_on_batch` is not supported for models ' 

1062 'distributed with tf.distribute.Strategy.') 

1063 # Validate and standardize user data. 

1064 x, y, sample_weights = self._standardize_user_data( 

1065 x, y, sample_weight=sample_weight, class_weight=class_weight, 

1066 extract_tensors_from_dataset=True) 

1067 

1068 # If `self._distribution_strategy` is True, then we are in a replica context 

1069 # at this point because of the check above. `train_on_batch` is being run 

1070 # for each replica by `self._distribution_strategy` and the same code path 

1071 # as Eager is expected to be taken. 

1072 if self.run_eagerly or self._distribution_strategy: 

1073 output_dict = training_eager_v1.train_on_batch( 

1074 self, 

1075 x, 

1076 y, 

1077 sample_weights=sample_weights, 

1078 output_loss_metrics=self._output_loss_metrics) 

1079 outputs = (output_dict['total_loss'] + output_dict['output_losses'] 

1080 + output_dict['metrics']) 

1081 outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access 

1082 else: 

1083 x = training_utils_v1.ModelInputs(x).as_list() 

1084 ins = x + list(y or []) + list(sample_weights or []) 

1085 

1086 if not isinstance(backend.symbolic_learning_phase(), int): 

1087 ins += [True] # Add learning phase value. 

1088 

1089 self._update_sample_weight_modes(sample_weights=sample_weights) 

1090 self._make_train_function() 

1091 outputs = self.train_function(ins) # pylint: disable=not-callable 

1092 

1093 if reset_metrics: 

1094 self.reset_metrics() 

1095 

1096 if len(outputs) == 1: 

1097 return outputs[0] 

1098 return outputs 

1099 

1100 def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True): 

1101 """Test the model on a single batch of samples. 

1102 

1103 Args: 

1104 x: Input data. It could be: 

1105 - A Numpy array (or array-like), or a list of arrays 

1106 (in case the model has multiple inputs). 

1107 - A TensorFlow tensor, or a list of tensors 

1108 (in case the model has multiple inputs). 

1109 - A dict mapping input names to the corresponding array/tensors, 

1110 if the model has named inputs. 

1111 - A `tf.data` dataset. 

1112 y: Target data. Like the input data `x`, 

1113 it could be either Numpy array(s) or TensorFlow tensor(s). 

1114 It should be consistent with `x` (you cannot have Numpy inputs and 

1115 tensor targets, or inversely). If `x` is a dataset `y` should 

1116 not be specified (since targets will be obtained from the iterator). 

1117 sample_weight: Optional array of the same length as x, containing 

1118 weights to apply to the model's loss for each sample. 

1119 In the case of temporal data, you can pass a 2D array 

1120 with shape (samples, sequence_length), 

1121 to apply a different weight to every timestep of every sample. 

1122 In this case you should make sure to specify 

1123 sample_weight_mode="temporal" in compile(). This argument is not 

1124 supported when `x` is a dataset. 

1125 reset_metrics: If `True`, the metrics returned will be only for this 

1126 batch. If `False`, the metrics will be statefully accumulated across 

1127 batches. 

1128 

1129 Returns: 

1130 Scalar test loss (if the model has a single output and no metrics) 

1131 or list of scalars (if the model has multiple outputs 

1132 and/or metrics). The attribute `model.metrics_names` will give you 

1133 the display labels for the scalar outputs. 

1134 

1135 Raises: 

1136 ValueError: In case of invalid user-provided arguments. 

1137 """ 

1138 self._assert_compile_was_called() 

1139 self._check_call_args('test_on_batch') 

1140 

1141 if (self._distribution_strategy and 

1142 distribute_lib.in_cross_replica_context()): 

1143 raise NotImplementedError('`test_on_batch` is not supported for models ' 

1144 'distributed with tf.distribute.Strategy.') 

1145 # Validate and standardize user data. 

1146 x, y, sample_weights = self._standardize_user_data( 

1147 x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True) 

1148 

1149 # If `self._distribution_strategy` is True, then we are in a replica context 

1150 # at this point. 

1151 if self.run_eagerly or self._distribution_strategy: 

1152 output_dict = training_eager_v1.test_on_batch( 

1153 self, 

1154 x, 

1155 y, 

1156 sample_weights=sample_weights, 

1157 output_loss_metrics=self._output_loss_metrics) 

1158 outputs = (output_dict['total_loss'] + output_dict['output_losses'] 

1159 + output_dict['metrics']) 

1160 outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access 

1161 else: 

1162 x = training_utils_v1.ModelInputs(x).as_list() 

1163 inputs = x + list(y or []) + list(sample_weights or []) 

1164 

1165 self._update_sample_weight_modes(sample_weights=sample_weights) 

1166 self._make_test_function() 

1167 outputs = self.test_function(inputs) # pylint: disable=not-callable 

1168 

1169 if reset_metrics: 

1170 self.reset_metrics() 

1171 

1172 if len(outputs) == 1: 

1173 return outputs[0] 

1174 return outputs 

1175 

1176 def predict_on_batch(self, x): 

1177 """Returns predictions for a single batch of samples. 

1178 

1179 Args: 

1180 x: Input data. It could be: 

1181 - A Numpy array (or array-like), or a list of arrays 

1182 (in case the model has multiple inputs). 

1183 - A TensorFlow tensor, or a list of tensors 

1184 (in case the model has multiple inputs). 

1185 - A `tf.data` dataset. 

1186 

1187 Returns: 

1188 Numpy array(s) of predictions. 

1189 

1190 Raises: 

1191 ValueError: In case of mismatch between given number of inputs and 

1192 expectations of the model. 

1193 """ 

1194 self._check_call_args('predict_on_batch') 

1195 

1196 if (self._distribution_strategy and 

1197 distribute_lib.in_cross_replica_context()): 

1198 raise NotImplementedError( 

1199 '`predict_on_batch` is not supported for models distributed with' 

1200 ' tf.distribute.Strategy.') 

1201 # Validate and standardize user data. 

1202 inputs, _, _ = self._standardize_user_data( 

1203 x, extract_tensors_from_dataset=True) 

1204 # If `self._distribution_strategy` is True, then we are in a replica context 

1205 # at this point. 

1206 if self.run_eagerly or self._distribution_strategy: 

1207 inputs = training_utils_v1.cast_if_floating_dtype(inputs) 

1208 if isinstance(inputs, collections.abc.Sequence): 

1209 # Unwrap lists with only one input, as we do when training on batch 

1210 if len(inputs) == 1: 

1211 inputs = inputs[0] 

1212 

1213 return self(inputs) # pylint: disable=not-callable 

1214 

1215 self._make_predict_function() 

1216 outputs = self.predict_function(inputs) 

1217 

1218 if len(outputs) == 1: 

1219 return outputs[0] 

1220 return outputs 

1221 

1222 def fit_generator(self, 

1223 generator, 

1224 steps_per_epoch=None, 

1225 epochs=1, 

1226 verbose=1, 

1227 callbacks=None, 

1228 validation_data=None, 

1229 validation_steps=None, 

1230 validation_freq=1, 

1231 class_weight=None, 

1232 max_queue_size=10, 

1233 workers=1, 

1234 use_multiprocessing=False, 

1235 shuffle=True, 

1236 initial_epoch=0): 

1237 """Fits the model on data yielded batch-by-batch by a Python generator. 

1238 

1239 DEPRECATED: 

1240 `Model.fit` now supports generators, so there is no longer any need to use 

1241 this endpoint. 

1242 """ 

1243 warnings.warn('`model.fit_generator` is deprecated and ' 

1244 'will be removed in a future version. ' 

1245 'Please use `Model.fit`, which supports generators.') 

1246 return self.fit( 

1247 generator, 

1248 steps_per_epoch=steps_per_epoch, 

1249 epochs=epochs, 

1250 verbose=verbose, 

1251 callbacks=callbacks, 

1252 validation_data=validation_data, 

1253 validation_steps=validation_steps, 

1254 validation_freq=validation_freq, 

1255 class_weight=class_weight, 

1256 max_queue_size=max_queue_size, 

1257 workers=workers, 

1258 use_multiprocessing=use_multiprocessing, 

1259 shuffle=shuffle, 

1260 initial_epoch=initial_epoch) 

1261 

1262 def evaluate_generator(self, 

1263 generator, 

1264 steps=None, 

1265 callbacks=None, 

1266 max_queue_size=10, 

1267 workers=1, 

1268 use_multiprocessing=False, 

1269 verbose=0): 

1270 """Evaluates the model on a data generator. 

1271 

1272 DEPRECATED: 

1273 `Model.evaluate` now supports generators, so there is no longer any need 

1274 to use this endpoint. 

1275 """ 

1276 warnings.warn('`Model.evaluate_generator` is deprecated and ' 

1277 'will be removed in a future version. ' 

1278 'Please use `Model.evaluate`, which supports generators.') 

1279 self._check_call_args('evaluate_generator') 

1280 

1281 return self.evaluate( 

1282 generator, 

1283 steps=steps, 

1284 max_queue_size=max_queue_size, 

1285 workers=workers, 

1286 use_multiprocessing=use_multiprocessing, 

1287 verbose=verbose, 

1288 callbacks=callbacks) 

1289 

1290 def predict_generator(self, 

1291 generator, 

1292 steps=None, 

1293 callbacks=None, 

1294 max_queue_size=10, 

1295 workers=1, 

1296 use_multiprocessing=False, 

1297 verbose=0): 

1298 """Generates predictions for the input samples from a data generator. 

1299 

1300 DEPRECATED: 

1301 `Model.predict` now supports generators, so there is no longer any need 

1302 to use this endpoint. 

1303 """ 

1304 warnings.warn('`Model.predict_generator` is deprecated and ' 

1305 'will be removed in a future version. ' 

1306 'Please use `Model.predict`, which supports generators.') 

1307 return self.predict( 

1308 generator, 

1309 steps=steps, 

1310 max_queue_size=max_queue_size, 

1311 workers=workers, 

1312 use_multiprocessing=use_multiprocessing, 

1313 verbose=verbose, 

1314 callbacks=callbacks) 

1315 

1316 def _check_call_args(self, method_name): 

1317 """Check that `call` has only one positional arg.""" 

1318 # Always allow first arg, regardless of arg name. 

1319 fullargspec = self._call_full_argspec 

1320 if fullargspec.defaults: 

1321 positional_args = fullargspec.args[:-len(fullargspec.defaults)] 

1322 else: 

1323 positional_args = fullargspec.args 

1324 if 'training' in positional_args: 

1325 positional_args.remove('training') 

1326 

1327 # self and first arg can be positional. 

1328 if len(positional_args) > 2: 

1329 extra_args = positional_args[2:] 

1330 raise ValueError( 

1331 'Models passed to `' + method_name + '` can only have `training` ' 

1332 'and the first argument in `call` as positional arguments, ' 

1333 'found: ' + str(extra_args) + '.') 

1334 

1335 def _set_optimizer(self, optimizer): 

1336 """Sets self.optimizer. 

1337 

1338 Sets self.optimizer to `optimizer`, potentially wrapping it with a 

1339 LossScaleOptimizer. 

1340 

1341 Args: 

1342 optimizer: The optimizer(s) to assign to self.optimizer. 

1343 """ 

1344 if isinstance(optimizer, (list, tuple)): 

1345 self.optimizer = [optimizers.get(opt) for opt in optimizer] 

1346 else: 

1347 self.optimizer = optimizers.get(optimizer) 

1348 

1349 if isinstance(self._dtype_policy, policy.PolicyV1): 

1350 loss_scale = self._dtype_policy.loss_scale 

1351 elif self._dtype_policy.name == 'mixed_float16': 

1352 loss_scale = 'dynamic' 

1353 else: 

1354 loss_scale = None 

1355 

1356 if (loss_scale is not None and 

1357 not isinstance(self.optimizer, 

1358 loss_scale_optimizer.LossScaleOptimizer)): 

1359 if isinstance(self.optimizer, list): 

1360 raise ValueError('When a dtype policy with a loss scale is used, you ' 

1361 'can only pass a single optimizer. Using policy %s ' 

1362 'and got optimizers: %s' % 

1363 self._dtype_policy, self.optimizer) 

1364 if not isinstance(self.optimizer, optimizer_v2.OptimizerV2): 

1365 raise ValueError('"optimizer" must be an instance of ' 

1366 'tf.keras.optimizers.Optimizer when a dype policy ' 

1367 'with a loss scale used, but got: %s. Using policy: ' 

1368 '%s' % 

1369 (self.optimizer, self._dtype_policy)) 

1370 if loss_scale == 'dynamic': 

1371 self.optimizer = loss_scale_optimizer.LossScaleOptimizer(self.optimizer) 

1372 else: 

1373 self.optimizer = loss_scale_optimizer.LossScaleOptimizerV1( 

1374 self.optimizer, loss_scale) 

1375 

1376 def _prepare_validation_data(self, validation_data, batch_size, 

1377 validation_steps): 

1378 """Unpack and check the validation data.""" 

1379 val_x, val_y, val_sample_weights = training_utils_v1.unpack_validation_data( 

1380 validation_data) 

1381 return self._standardize_user_data( 

1382 val_x, 

1383 val_y, 

1384 sample_weight=val_sample_weights, 

1385 batch_size=batch_size, 

1386 steps=validation_steps, 

1387 steps_name='validation_steps') 

1388 

1389 def _validate_compile_param_for_distribution_strategy( 

1390 self, run_eagerly, sample_weight_mode, target_tensors, weighted_metrics): 

1391 # Validate that arguments passed by the user to `compile` are supported by 

1392 # tf.distribute.Strategy. 

1393 if self._distribution_strategy: 

1394 if sample_weight_mode: 

1395 raise NotImplementedError('sample_weight_mode is not supported with ' 

1396 'tf.distribute.Strategy.') 

1397 if weighted_metrics: 

1398 raise NotImplementedError('weighted_metrics is not supported with ' 

1399 'tf.distribute.Strategy.') 

1400 if target_tensors: 

1401 raise ValueError('target_tensors is not supported with ' 

1402 'tf.distribute.Strategy.') 

1403 

1404 if run_eagerly: 

1405 raise ValueError( 

1406 'We currently do not support enabling `run_eagerly` with ' 

1407 'distribution strategy.') 

1408 

1409 if (distributed_training_utils_v1.is_distributing_by_cloning(self) and 

1410 (not self.built or not self.inputs or not self.outputs)): 

1411 raise ValueError( 

1412 'We currently do not support distribution strategy with a ' 

1413 '`Sequential` model that is created without `input_shape`/' 

1414 '`input_dim` set in its first layer or a subclassed model.') 

1415 

1416 def _process_target_tensor_for_compile(self, target_tensors): 

1417 if self.run_eagerly: 

1418 # target tensor is not supported with run_eagerly. Create a list with None 

1419 # as placeholder for each output. 

1420 return [None for _ in self.output_names] 

1421 

1422 if target_tensors is not None and not (isinstance(target_tensors, list) and 

1423 target_tensors == []): # pylint: disable=g-explicit-bool-comparison 

1424 if isinstance(target_tensors, list): 

1425 if len(target_tensors) != len(self.outputs): 

1426 raise ValueError( 

1427 'When passing a list as `target_tensors`, ' 

1428 'it should have one entry per model output. ' 

1429 'The model has %s outputs, but you passed target_tensors=%s' % 

1430 (len(self.outputs), target_tensors)) 

1431 elif isinstance(target_tensors, dict): 

1432 unexpected_target_tensor_names = set(target_tensors.keys()).difference( 

1433 self.output_names) 

1434 if unexpected_target_tensor_names: 

1435 raise ValueError( 

1436 'Unknown entry in `target_tensors` dictionary: "{name}". ' 

1437 'Only expected the following keys: {keys}'.format( 

1438 name=unexpected_target_tensor_names, 

1439 keys=str(self.output_names))) 

1440 tmp_target_tensors = [] 

1441 for name in self.output_names: 

1442 tmp_target_tensors.append(target_tensors.get(name, None)) 

1443 target_tensors = tmp_target_tensors 

1444 elif tensor_util.is_tf_type(target_tensors): 

1445 target_tensors = [target_tensors] 

1446 else: 

1447 raise TypeError('Expected `target_tensors` to be a list or tuple or ' 

1448 'dict or a single tensor, but got:', target_tensors) 

1449 else: 

1450 # In case target tensor is empty or None, create a list with Nones 

1451 # that has same length as self.output_names. With that, the None check of 

1452 # target tensor can be skipped downstream. 

1453 target_tensors = [None for _ in self.output_names] 

1454 return target_tensors 

1455 

1456 def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode): 

1457 # Prepare sample weight modes. List with the same length as model outputs. 

1458 training_utils_v1.prepare_sample_weight_modes( 

1459 self._training_endpoints, sample_weight_mode) 

1460 # Prepare sample weights. 

1461 self._prepare_sample_weights() 

1462 # Save all metric attributes per output of the model. 

1463 self._cache_output_metric_attributes(metrics, weighted_metrics) 

1464 self.total_loss = None 

1465 # Set metric attributes on model. 

1466 self._set_metric_attributes() 

1467 

1468 self._collected_trainable_weights = self.trainable_weights 

1469 

1470 def _update_sample_weight_modes(self, sample_weights=None): 

1471 """Updates sample weight modes based on training/eval inputs. 

1472 

1473 Sample weight placeholders will be created for all or no outputs 

1474 based on whether sample_weight is provided for any output. 

1475 

1476 If model contains `_sample_weight_modes` we check if the input 

1477 `sample_weights` corresponds to the sample weight modes. 

1478 1. Set sample weight mode to be 'temporal' for output i, if `compile` 

1479 sample_weight_mode was set to `temporal` and sample weight inputs 

1480 are given for one or more outputs. 

1481 2. Set sample weight mode to be 'samplewise' for output i, if `compile` 

1482 sample_weight_mode was not set and sample weight inputs are given for 

1483 one or more outputs. 

1484 3. Reset sample weight mode to None for output i if sample weight mode 

1485 was set but there is no sample weight input. 

1486 

1487 Args: 

1488 sample_weights: List of sample weights of the same length as model outputs 

1489 or None. 

1490 """ 

1491 if not self._is_compiled: 

1492 return 

1493 if sample_weights and any(s is not None for s in sample_weights): 

1494 for endpoint in self._training_endpoints: 

1495 endpoint.sample_weight_mode = ( 

1496 endpoint.sample_weight_mode or 'samplewise') 

1497 else: 

1498 for endpoint in self._training_endpoints: 

1499 endpoint.sample_weight_mode = None 

1500 

1501 def _recompile_weights_loss_and_weighted_metrics(self): 

1502 if not self._is_compiled: 

1503 return False 

1504 recompile = any( 

1505 e.sample_weights_mismatch() for e in self._training_endpoints) 

1506 

1507 if recompile: 

1508 self._compile_weights_loss_and_weighted_metrics() 

1509 return recompile 

1510 

1511 @trackable.no_automatic_dependency_tracking 

1512 def _compile_weights_loss_and_weighted_metrics(self, sample_weights=None): 

1513 """Compiles the model loss and weighted metric sub-graphs. 

1514 

1515 This may be used to set graph tensors as sample weights (instead of creating 

1516 placeholders). This functionality is necessary for 

1517 `tf.keras.estimator.model_to_estimator`, which calls Keras models in a v1 

1518 graph, and creates iterator tensors for inputs, targets, and sample weights. 

1519 

1520 Args: 

1521 sample_weights: List of tensors to use as the sample weights. Must be the 

1522 same length as the number of outputs. If left as `None`, placeholders 

1523 are used instead. 

1524 """ 

1525 with backend.get_graph().as_default(): 

1526 if sample_weights is not None: 

1527 self._update_sample_weight_modes(sample_weights) 

1528 self._prepare_sample_weights(sample_weights) 

1529 

1530 masks = self._prepare_output_masks() 

1531 

1532 # Compute weighted metrics. 

1533 self._handle_metrics( 

1534 self.outputs, 

1535 targets=self._targets, 

1536 skip_target_masks=self._prepare_skip_target_masks(), 

1537 sample_weights=self.sample_weights, 

1538 masks=masks, 

1539 return_weighted_metrics=True) 

1540 

1541 # Compute total loss. 

1542 # Used to keep track of the total loss value (stateless). 

1543 # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) + 

1544 # loss_weight_2 * output_2_loss_fn(...) + 

1545 # layer losses. 

1546 self.total_loss = self._prepare_total_loss(masks) 

1547 

1548 def _prepare_skip_target_masks(self): 

1549 """Boolean mask for whether the target in the output list should be skipped. 

1550 

1551 If the loss function corresponding to a model output is None, then this 

1552 output will be skipped during total loss calculation and feed targets 

1553 preparation. 

1554 

1555 Returns: 

1556 A boolean list for whether the corresponding target in the output list 

1557 should be skipped during loss calculation. 

1558 """ 

1559 return [l is None for l in self.loss_functions] 

1560 

1561 def _prepare_output_masks(self): 

1562 """Returns masks corresponding to model outputs.""" 

1563 return [getattr(x, '_keras_mask', None) for x in self.outputs] 

1564 

1565 def _prepare_total_loss(self, masks): 

1566 """Computes total loss from loss functions. 

1567 

1568 Args: 

1569 masks: List of mask values corresponding to each model output. 

1570 

1571 Returns: 

1572 A list of loss weights of python floats. 

1573 

1574 Raises: 

1575 TypeError: If model run_eagerly is True. 

1576 """ 

1577 if self.run_eagerly: 

1578 raise TypeError('total loss can not be computed when compiled with ' 

1579 'run_eagerly = True.') 

1580 loss_list = [] 

1581 with backend.name_scope('loss'): 

1582 for endpoint, mask in zip(self._training_endpoints, masks): 

1583 if endpoint.should_skip_target(): 

1584 continue 

1585 y_true = endpoint.training_target.target 

1586 y_pred = endpoint.output 

1587 loss_fn = endpoint.loss_fn 

1588 loss_weight = endpoint.loss_weight 

1589 loss_name = endpoint.loss_name() 

1590 sample_weight = endpoint.sample_weight 

1591 

1592 with backend.name_scope(loss_name): 

1593 if mask is not None: 

1594 mask = math_ops.cast(mask, y_pred.dtype) 

1595 # Update weights with mask. 

1596 if sample_weight is None: 

1597 sample_weight = mask 

1598 else: 

1599 # Update dimensions of weights to match with mask if possible. 

1600 mask, _, sample_weight = ( 

1601 losses_utils.squeeze_or_expand_dimensions( 

1602 mask, sample_weight=sample_weight)) 

1603 sample_weight *= mask 

1604 

1605 if hasattr(loss_fn, 'reduction'): 

1606 per_sample_losses = loss_fn.call(y_true, y_pred) 

1607 weighted_losses = losses_utils.compute_weighted_loss( 

1608 per_sample_losses, 

1609 sample_weight=sample_weight, 

1610 reduction=losses_utils.ReductionV2.NONE) 

1611 loss_reduction = loss_fn.reduction 

1612 

1613 # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all 

1614 # compile use cases. 

1615 if loss_reduction == losses_utils.ReductionV2.AUTO: 

1616 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 

1617 

1618 # Compute the stateless loss value. 

1619 output_loss = losses_utils.reduce_weighted_loss( 

1620 weighted_losses, reduction=loss_reduction) 

1621 else: 

1622 # Compute the stateless loss value for a custom loss class. 

1623 # Here we assume that the class takes care of loss reduction 

1624 # because if this class returns a vector value we cannot 

1625 # differentiate between use case where a custom optimizer 

1626 # expects a vector loss value vs unreduced per-sample loss value. 

1627 output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) 

1628 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 

1629 

1630 if len(self.outputs) > 1: 

1631 # Keep track of stateful result tensor for the loss. 

1632 endpoint.output_loss_metric(output_loss) 

1633 

1634 # Scale output loss for distribution. For custom losses we assume 

1635 # reduction was mean. 

1636 if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE: 

1637 output_loss = losses_utils.scale_loss_for_distribution(output_loss) 

1638 

1639 loss_list.append(loss_weight * output_loss) 

1640 if not loss_list and not self.losses: 

1641 raise ValueError('The model cannot be compiled ' 

1642 'because it has no loss to optimize.') 

1643 

1644 # Add regularization penalties and other layer-specific losses. 

1645 custom_losses = self.get_losses_for(None) + self.get_losses_for( 

1646 self.inputs) 

1647 if custom_losses: 

1648 total_custom_loss = math_ops.add_n( 

1649 losses_utils.cast_losses_to_common_dtype(custom_losses)) 

1650 loss_list.append( 

1651 losses_utils.scale_loss_for_distribution(total_custom_loss)) 

1652 

1653 loss_list = losses_utils.cast_losses_to_common_dtype(loss_list) 

1654 if loss_list: 

1655 total_loss = math_ops.add_n(loss_list) 

1656 else: 

1657 total_loss = 0. 

1658 return total_loss 

1659 

1660 def _get_callback_model(self): 

1661 """Returns the Callback Model for this Model.""" 

1662 

1663 if hasattr(self, '_replicated_model') and self._replicated_model: 

1664 # When using training_distributed, we set the callback model 

1665 # to an instance of the `DistributedModel` that we create in 

1666 # the `compile` call. The `DistributedModel` is initialized 

1667 # with the first replicated model. We need to set the callback 

1668 # model to a DistributedModel to allow us to override saving 

1669 # and loading weights when we checkpoint the model during training. 

1670 return self._replicated_model 

1671 if hasattr(self, 'callback_model') and self.callback_model: 

1672 return self.callback_model 

1673 return self 

1674 

1675 @trackable.no_automatic_dependency_tracking 

1676 def _make_callback_model(self, grouped_model): 

1677 first_replicated_model = self._distribution_strategy.unwrap( 

1678 grouped_model)[0] 

1679 # We initialize the callback model with the first replicated model. 

1680 self._replicated_model = DistributedCallbackModel(first_replicated_model) 

1681 self._replicated_model.set_original_model(self) 

1682 

1683 def _validate_or_infer_batch_size(self, batch_size, steps, x): 

1684 """Validates that the `batch_size` provided is consistent with InputLayer. 

1685 

1686 It's possible that the user specified a static batch size in their 

1687 InputLayer. If so, this method checks the provided `batch_size` and `x` 

1688 arguments are consistent with this static batch size. Also, if 

1689 `batch_size` is `None`, this method will attempt to infer the batch size 

1690 from the static batch size of the InputLayer. Lastly, ValueError will be 

1691 raised if `x` is a tf.data.Dataset and `batch_size` is specified as we 

1692 expect users to provide batched datasets. 

1693 

1694 Args: 

1695 batch_size: The batch_size provided as an argument to 

1696 fit/evaluate/predict. 

1697 steps: The steps provided as an argument to fit/evaluate/predict. 

1698 x: The data passed as `x` to fit/evaluate/predict. 

1699 

1700 Returns: 

1701 The validated batch_size, auto-inferred from the first layer if not 

1702 provided. 

1703 """ 

1704 if (isinstance(x, (data_types.DatasetV1, 

1705 data_types.DatasetV2, 

1706 data_utils.Sequence)) or 

1707 tf_inspect.isgenerator(x)): 

1708 if batch_size is not None: 

1709 raise ValueError( 

1710 'The `batch_size` argument must not be specified for the given ' 

1711 'input type. Received input: {}, batch_size: {}'.format( 

1712 x, batch_size)) 

1713 return 

1714 

1715 # Avoids the override in Sequential.layers which filters Input layers. 

1716 # (Which are often the very layers that we're after.) 

1717 layers = self._flatten_layers(include_self=False, recursive=False) 

1718 first_layer = next(layers, None) 

1719 if first_layer: 

1720 # The per-replica static batch size. 

1721 static_batch_size = training_utils.get_static_batch_size(first_layer) 

1722 if static_batch_size is not None: 

1723 

1724 # Determine number of times the user-supplied batch size will be split. 

1725 if (self._distribution_strategy and 

1726 distributed_training_utils.global_batch_size_supported( 

1727 self._distribution_strategy)): 

1728 num_splits_for_ds = self._distribution_strategy.num_replicas_in_sync 

1729 else: 

1730 num_splits_for_ds = 1 

1731 

1732 # Check `batch_size` argument is consistent with InputLayer. 

1733 if batch_size is not None: 

1734 if batch_size % num_splits_for_ds != 0: 

1735 raise ValueError('The `batch_size` argument ({}) must be divisible ' 

1736 'the by number of replicas ({})'.format( 

1737 batch_size, num_splits_for_ds)) 

1738 per_replica_batch_size = batch_size // num_splits_for_ds 

1739 

1740 if per_replica_batch_size != static_batch_size: 

1741 raise ValueError('The `batch_size` argument value {} is ' 

1742 'incompatible with the specified batch size of ' 

1743 'your Input Layer: {}'.format( 

1744 per_replica_batch_size, static_batch_size)) 

1745 

1746 # Check Dataset/Iterator batch size is consistent with InputLayer. 

1747 if isinstance(x, (data_types.DatasetV2, iterator_ops.Iterator, 

1748 iterator_ops.IteratorBase)): 

1749 ds_batch_size = tensor_shape.Dimension( 

1750 nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value 

1751 if ds_batch_size is not None: 

1752 if ds_batch_size % num_splits_for_ds != 0: 

1753 raise ValueError( 

1754 'The batch output shape of your `Dataset` {} ' 

1755 'cannot be divisible by number of replicas {}'.format( 

1756 ds_batch_size, num_splits_for_ds)) 

1757 

1758 ds_per_replica_batch_size = ds_batch_size // num_splits_for_ds 

1759 if ds_per_replica_batch_size != static_batch_size: 

1760 raise ValueError('The batch output shape of your `Dataset` is ' 

1761 '{}, which is incompatible with the specified ' 

1762 'batch size of your Input Layer: {}'.format( 

1763 ds_per_replica_batch_size, 

1764 static_batch_size)) 

1765 

1766 # Set inferred batch size from the InputLayer. 

1767 if steps is None: 

1768 batch_size = static_batch_size * num_splits_for_ds 

1769 

1770 if batch_size is None and steps is None: 

1771 # Backwards compatibility 

1772 batch_size = 32 

1773 return batch_size 

1774 

1775 def _prepare_sample_weights(self, sample_weights=None): 

1776 """Sets sample weight attribute on the model.""" 

1777 # List with the same length as model outputs. 

1778 if sample_weights is not None: 

1779 if len(sample_weights) != len(self._training_endpoints): 

1780 raise ValueError('Provided sample weights must have same length as the ' 

1781 'number of outputs. Expected: {}, got: {}.'.format( 

1782 len(self._training_endpoints), 

1783 len(sample_weights))) 

1784 else: 

1785 sample_weights = [None] * len(self._training_endpoints) 

1786 for endpoint, weight in zip(self._training_endpoints, sample_weights): 

1787 endpoint.populate_sample_weight(weight, endpoint.sample_weight_mode) 

1788 

1789 def _cache_output_metric_attributes(self, metrics, weighted_metrics): 

1790 """Caches metric name and function attributes for every model output.""" 

1791 output_shapes = [] 

1792 for output in self.outputs: 

1793 if output is None or output.shape.rank is None: 

1794 output_shapes.append(None) 

1795 else: 

1796 output_shapes.append(output.shape.as_list()) 

1797 self._per_output_metrics = training_utils_v1.collect_per_output_metric_info( 

1798 metrics, self.output_names, output_shapes, self.loss_functions, 

1799 from_serialized=self._from_serialized) 

1800 self._per_output_weighted_metrics = ( 

1801 training_utils_v1.collect_per_output_metric_info( 

1802 weighted_metrics, 

1803 self.output_names, 

1804 output_shapes, 

1805 self.loss_functions, 

1806 from_serialized=self._from_serialized, 

1807 is_weighted=True)) 

1808 

1809 def _add_unique_metric_name(self, metric_name, metric_fn, output_index): 

1810 """Makes the metric name unique. 

1811 

1812 If there are multiple outputs for which the metrics are calculated, the 

1813 metric names have to be made unique by appending an integer. 

1814 

1815 Args: 

1816 metric_name: Metric name that corresponds to the metric specified by the 

1817 user. For example: 'acc'. 

1818 metric_fn: The Metric object. 

1819 output_index: The index of the model output for which the metric name is 

1820 being added. 

1821 

1822 Returns: 

1823 string, name of the model's unique metric name 

1824 """ 

1825 # For multi-output models, prepend the output names to the metric name. 

1826 if len(self.output_names) > 1: 

1827 # If we're loading from an already-serialized model, we've already 

1828 # prepended the output name, and we don't want to do it again. 

1829 # 

1830 # Alternatively, we may be receiving a stateless metric (e.g. the string 

1831 # "accuracy") rather than a `Metric` object, in which case we want to 

1832 # prepend the output name even if we are loading a serialized model. 

1833 if not getattr(metric_fn, '_from_serialized', False): 

1834 metric_name = '%s_%s' % (self.output_names[output_index], metric_name) 

1835 

1836 j = 1 

1837 base_metric_name = metric_name 

1838 while metric_name in self.metrics_names: 

1839 metric_name = '%s_%d' % (base_metric_name, j) 

1840 j += 1 

1841 

1842 return metric_name 

1843 

1844 def _init_metric_attributes(self): 

1845 """Initialized model metric attributes.""" 

1846 # List of stateful metric functions. Used for resetting metric state during 

1847 # training/eval. 

1848 self._compile_metric_functions = [] 

1849 

1850 def _set_per_output_metric_attributes(self, metrics_dict, output_index): 

1851 """Sets the metric attributes on the model for the given output. 

1852 

1853 Args: 

1854 metrics_dict: A dict with metric names as keys and metric fns as values. 

1855 output_index: The index of the model output for which the metric 

1856 attributes are added. 

1857 

1858 Returns: 

1859 Metrics dict updated with unique metric names as keys. 

1860 """ 

1861 updated_metrics_dict = collections.OrderedDict() 

1862 for metric_name, metric_fn in metrics_dict.items(): 

1863 metric_name = self._add_unique_metric_name( 

1864 metric_name, metric_fn, output_index) 

1865 

1866 # Update the name on the metric class to be the unique generated name. 

1867 metric_fn._name = metric_name # pylint: disable=protected-access 

1868 updated_metrics_dict[metric_name] = metric_fn 

1869 # Keep track of metric name and function. 

1870 self._compile_metric_functions.append(metric_fn) 

1871 return updated_metrics_dict 

1872 

1873 def _set_metric_attributes(self): 

1874 """Sets the metric attributes on the model for all the model outputs.""" 

1875 updated_per_output_metrics = [] 

1876 updated_per_output_weighted_metrics = [] 

1877 for i, endpoint in enumerate(self._training_endpoints): 

1878 if endpoint.should_skip_target(): 

1879 updated_per_output_metrics.append(self._per_output_metrics[i]) 

1880 updated_per_output_weighted_metrics.append( 

1881 self._per_output_weighted_metrics[i]) 

1882 continue 

1883 updated_per_output_metrics.append( 

1884 self._set_per_output_metric_attributes(self._per_output_metrics[i], 

1885 i)) 

1886 updated_per_output_weighted_metrics.append( 

1887 self._set_per_output_metric_attributes( 

1888 self._per_output_weighted_metrics[i], i)) 

1889 

1890 # Create a metric wrapper for each output loss. This computes mean of an 

1891 # output loss across mini-batches (irrespective of how we reduce within a 

1892 # batch). 

1893 if len(self._training_endpoints) > 1: 

1894 for endpoint in self._training_endpoints: 

1895 if not endpoint.should_skip_target(): 

1896 endpoint.output_loss_metric = metrics_module.Mean( 

1897 name=endpoint.loss_name()) 

1898 

1899 self._per_output_metrics = updated_per_output_metrics 

1900 self._per_output_weighted_metrics = updated_per_output_weighted_metrics 

1901 

1902 def _handle_per_output_metrics(self, 

1903 metrics_dict, 

1904 y_true, 

1905 y_pred, 

1906 mask, 

1907 weights=None): 

1908 """Calls metric functions for a single output. 

1909 

1910 Args: 

1911 metrics_dict: A dict with metric names as keys and metric fns as values. 

1912 y_true: Target output. 

1913 y_pred: Predicted output. 

1914 mask: Computed mask value for the current output. 

1915 weights: Weights to be applied on the current output. 

1916 

1917 Returns: 

1918 A list of metric result tensors. 

1919 """ 

1920 metric_results = [] 

1921 for metric_name, metric_fn in metrics_dict.items(): 

1922 with backend.name_scope(metric_name): 

1923 metric_result = training_utils_v1.call_metric_function( 

1924 metric_fn, y_true, y_pred, weights=weights, mask=mask) 

1925 metric_results.append(metric_result) 

1926 return metric_results 

1927 

1928 def _handle_metrics(self, 

1929 outputs, 

1930 targets=None, 

1931 skip_target_masks=None, 

1932 sample_weights=None, 

1933 masks=None, 

1934 return_weighted_metrics=False, 

1935 return_weighted_and_unweighted_metrics=False): 

1936 """Handles calling metric functions. 

1937 

1938 Args: 

1939 outputs: List of outputs (predictions). 

1940 targets: List of targets. 

1941 skip_target_masks: Optional. List of boolean for whether the corresponding 

1942 target should be ignored or not. 

1943 sample_weights: Optional list of sample weight arrays. 

1944 masks: List of computed output mask values. 

1945 return_weighted_metrics: Flag that indicates whether weighted metrics 

1946 should be computed instead of unweighted metrics. This flag is ignored 

1947 when `return_weighted_and_unweighted_metrics` is enabled. 

1948 return_weighted_and_unweighted_metrics: Flag that is used to indicate 

1949 whether both weighted and unweighted metrics should be computed. When 

1950 this is not enabled, we use `return_weighted_metrics` param to indicate 

1951 whether weighted or unweighted metrics should be returned. 

1952 

1953 Returns: 

1954 A list of metric result tensors. 

1955 """ 

1956 # TODO(scottzhu): Update this to use the new training_endpoints. Currently 

1957 # the eager and graph logic is bit different. 

1958 skip_target_masks = skip_target_masks or [False] * len(outputs) 

1959 metric_results = [] 

1960 with backend.name_scope('metrics'): 

1961 # Invoke all metrics added using `compile`. 

1962 for i in range(len(outputs)): 

1963 if skip_target_masks[i]: 

1964 continue 

1965 output = outputs[i] if outputs else None 

1966 target = targets[i] if targets else None 

1967 output_mask = masks[i] if masks else None 

1968 

1969 if (return_weighted_and_unweighted_metrics or 

1970 not return_weighted_metrics): 

1971 metric_results.extend( 

1972 self._handle_per_output_metrics(self._per_output_metrics[i], 

1973 target, output, output_mask)) 

1974 if return_weighted_and_unweighted_metrics or return_weighted_metrics: 

1975 metric_results.extend( 

1976 self._handle_per_output_metrics( 

1977 self._per_output_weighted_metrics[i], 

1978 target, 

1979 output, 

1980 output_mask, 

1981 weights=sample_weights[i] if sample_weights else None)) 

1982 return metric_results 

1983 

1984 def _check_trainable_weights_consistency(self): 

1985 """Check trainable weights count consistency. 

1986 

1987 This will raise a warning if `trainable_weights` and 

1988 `_collected_trainable_weights` are inconsistent (i.e. have different 

1989 number of parameters). 

1990 Inconsistency will typically arise when one modifies `model.trainable` 

1991 without calling `model.compile` again. 

1992 """ 

1993 if not hasattr(self, '_collected_trainable_weights'): 

1994 return 

1995 

1996 if len(self.trainable_weights) != len(self._collected_trainable_weights): 

1997 logging.log_first_n( 

1998 logging.WARN, 'Discrepancy between trainable weights and collected' 

1999 ' trainable weights, did you set `model.trainable`' 

2000 ' without calling `model.compile` after ?', 1) 

2001 

2002 def _make_train_function(self): 

2003 has_recompiled = self._recompile_weights_loss_and_weighted_metrics() 

2004 self._check_trainable_weights_consistency() 

2005 if isinstance(self.optimizer, list): 

2006 raise ValueError('The `optimizer` in `compile` should be a single ' 

2007 'optimizer.') 

2008 # If we have re-compiled the loss/weighted metric sub-graphs then create 

2009 # train function even if one exists already. This is because 

2010 # `_feed_sample_weights` list has been updated on re-compile. 

2011 if getattr(self, 'train_function', None) is None or has_recompiled: 

2012 # Restore the compiled trainable state. 

2013 current_trainable_state = self._get_trainable_state() 

2014 self._set_trainable_state(self._compiled_trainable_state) 

2015 

2016 inputs = (self._feed_inputs + 

2017 self._feed_targets + 

2018 self._feed_sample_weights) 

2019 if not isinstance(backend.symbolic_learning_phase(), int): 

2020 inputs += [backend.symbolic_learning_phase()] 

2021 

2022 with backend.get_graph().as_default(): 

2023 with backend.name_scope('training'): 

2024 # Training updates 

2025 updates = self.optimizer.get_updates( 

2026 params=self._collected_trainable_weights, loss=self.total_loss) 

2027 # Unconditional updates 

2028 updates += self.get_updates_for(None) 

2029 # Conditional updates relevant to this model 

2030 updates += self.get_updates_for(self.inputs) 

2031 

2032 metrics = self._get_training_eval_metrics() 

2033 metrics_tensors = [ 

2034 m._call_result for m in metrics if hasattr(m, '_call_result') # pylint: disable=protected-access 

2035 ] 

2036 

2037 with backend.name_scope('training'): 

2038 # Gets loss and metrics. Updates weights at each call. 

2039 fn = backend.function( 

2040 inputs, [self.total_loss] + metrics_tensors, 

2041 updates=updates, 

2042 name='train_function', 

2043 **self._function_kwargs) 

2044 setattr(self, 'train_function', fn) 

2045 

2046 # Restore the current trainable state 

2047 self._set_trainable_state(current_trainable_state) 

2048 

2049 def _make_test_function(self): 

2050 has_recompiled = self._recompile_weights_loss_and_weighted_metrics() 

2051 # If we have re-compiled the loss/weighted metric sub-graphs then create 

2052 # test function even if one exists already. This is because 

2053 # `_feed_sample_weights` list has been updated on re-compile. 

2054 if getattr(self, 'test_function', None) is None or has_recompiled: 

2055 inputs = (self._feed_inputs + 

2056 self._feed_targets + 

2057 self._feed_sample_weights) 

2058 

2059 with backend.get_graph().as_default(): 

2060 metrics = self._get_training_eval_metrics() 

2061 metrics_tensors = [ 

2062 m._call_result for m in metrics if hasattr(m, '_call_result') # pylint: disable=protected-access 

2063 ] 

2064 

2065 with backend.name_scope('evaluation'): 

2066 updates = self.state_updates 

2067 # Return loss and metrics, no gradient updates. 

2068 # Does update the network states. 

2069 fn = backend.function( 

2070 inputs, [self.total_loss] + metrics_tensors, 

2071 updates=updates, 

2072 name='test_function', 

2073 **self._function_kwargs) 

2074 setattr(self, 'test_function', fn) 

2075 

2076 def _make_predict_function(self): 

2077 if not hasattr(self, 'predict_function'): 

2078 self.predict_function = None 

2079 if self.predict_function is None: 

2080 inputs = self._feed_inputs 

2081 # Gets network outputs. Does not update weights. 

2082 # Does update the network states. 

2083 kwargs = getattr(self, '_function_kwargs', {}) 

2084 with backend.name_scope(ModeKeys.PREDICT): 

2085 self.predict_function = backend.function( 

2086 inputs, 

2087 self.outputs, 

2088 updates=self.state_updates, 

2089 name='predict_function', 

2090 **kwargs) 

2091 

2092 def _make_execution_function(self, mode): 

2093 if mode == ModeKeys.TRAIN: 

2094 self._make_train_function() 

2095 return self.train_function 

2096 if mode == ModeKeys.TEST: 

2097 self._make_test_function() 

2098 return self.test_function 

2099 if mode == ModeKeys.PREDICT: 

2100 self._make_predict_function() 

2101 return self.predict_function 

2102 

2103 def _distribution_standardize_user_data(self, 

2104 x, 

2105 y=None, 

2106 sample_weight=None, 

2107 class_weight=None, 

2108 batch_size=None, 

2109 validation_split=0, 

2110 shuffle=False, 

2111 epochs=1, 

2112 allow_partial_batch=False): 

2113 """Runs validation checks on input and target data passed by the user. 

2114 

2115 This is called when using tf.distribute.Strategy to train, evaluate or serve 

2116 the model. 

2117 

2118 Args: 

2119 x: Input data. A numpy array or `tf.data` dataset. 

2120 y: Target data. A numpy array or None if x is a `tf.data` dataset. 

2121 sample_weight: An optional sample-weight array passed by the user to 

2122 weight the importance of each sample in `x`. 

2123 class_weight: An optional class-weight array by the user to 

2124 weight the importance of samples in `x` based on the class they belong 

2125 to, as conveyed by `y`. 

2126 batch_size: Integer batch size. If provided, it is used to run additional 

2127 validation checks on stateful models. 

2128 validation_split: Float between 0 and 1. 

2129 Fraction of the training data to be used as validation data. 

2130 shuffle: Boolean whether to shuffle the training data before each epoch. 

2131 epochs: Integer epochs. If > 1, repeat the numpy training data epochs 

2132 times when converting to training dataset. 

2133 allow_partial_batch: Boolean whether to enforce that all batches have the 

2134 same size. 

2135 

2136 Returns: 

2137 Dataset instance. 

2138 

2139 Raises: 

2140 ValueError: In case of invalid user-provided data. 

2141 RuntimeError: If the model was never compiled. 

2142 """ 

2143 if class_weight: 

2144 raise NotImplementedError('`class_weight` is currently not supported ' 

2145 'when using tf.distribute.Strategy.') 

2146 

2147 if (sample_weight is not None and sample_weight.all() and 

2148 backend.is_tpu_strategy(self._distribution_strategy)): 

2149 raise NotImplementedError('`sample_weight` is currently not supported ' 

2150 'when using TPUStrategy.') 

2151 

2152 # Validates `steps` and `shuffle` arguments right at the beginning 

2153 # since we use it to construct the dataset object. 

2154 # TODO(anjalisridhar): Remove this check once we refactor the 

2155 # _standardize_user_data code path. This check is already present elsewhere 

2156 # in the codebase. 

2157 if isinstance(x, data_types.DatasetV2): 

2158 if shuffle: 

2159 training_utils_v1.verify_dataset_shuffled(x) 

2160 

2161 strategy = self._distribution_strategy 

2162 with strategy.scope(): 

2163 # We should be sure to call get_session() inside the strategy.scope() 

2164 # so the strategy can affect the session options. 

2165 if ops.executing_eagerly_outside_functions(): 

2166 session = None 

2167 else: 

2168 session = backend.get_session() 

2169 

2170 first_x_value = nest.flatten(x)[0] 

2171 if isinstance(first_x_value, np.ndarray): 

2172 x = training_utils.list_to_tuple(x) 

2173 if y is not None: 

2174 y = training_utils.list_to_tuple(y) 

2175 if sample_weight is not None: 

2176 sample_weight = training_utils.list_to_tuple(sample_weight) 

2177 in_tuple = (x, y, sample_weight) 

2178 else: 

2179 in_tuple = (x, y) 

2180 else: 

2181 in_tuple = x 

2182 

2183 ds = strategy.extended.experimental_make_numpy_dataset(in_tuple, 

2184 session=session) 

2185 if shuffle: 

2186 # We want a buffer size that is larger than the batch size provided by 

2187 # the user and provides sufficient randomness. Note that larger 

2188 # numbers introduce more memory usage based on the size of each 

2189 # sample. 

2190 ds = ds.shuffle(max(1024, batch_size * 8)) 

2191 if epochs > 1: 

2192 ds = ds.repeat(epochs) 

2193 

2194 # We need to use the drop_remainder argument to get a known static 

2195 # input shape which is required for TPUs. 

2196 drop_remainder = (not allow_partial_batch and 

2197 strategy.extended.experimental_require_static_shapes) 

2198 

2199 # TODO(b/131720208): We still drop remainder here if number of examples 

2200 # is divisible by batch size, as sometimes dynamic padder will time out 

2201 # with keras.metrics.CategoricalAccuracy() metric. 

2202 if backend.is_tpu_strategy(strategy) and not drop_remainder: 

2203 dataset_size = first_x_value.shape[0] 

2204 if dataset_size % batch_size == 0: 

2205 drop_remainder = True 

2206 

2207 x = ds.batch(batch_size, drop_remainder=drop_remainder) 

2208 else: 

2209 assert isinstance(x, data_types.DatasetV2) 

2210 training_utils_v1.validate_dataset_input(x, y, sample_weight, 

2211 validation_split) 

2212 return x 

2213 

2214 def _standardize_user_data(self, 

2215 x, 

2216 y=None, 

2217 sample_weight=None, 

2218 class_weight=None, 

2219 batch_size=None, 

2220 check_steps=False, 

2221 steps_name='steps', 

2222 steps=None, 

2223 validation_split=0, 

2224 shuffle=False, 

2225 extract_tensors_from_dataset=False): 

2226 """Runs validation checks on input and target data passed by the user. 

2227 

2228 Also standardizes the data to lists of arrays, in order. 

2229 

2230 Also builds and compiles the model on the fly if it is a subclassed model 

2231 that has never been called before (and thus has no inputs/outputs). 

2232 

2233 This is a purely internal method, subject to refactoring at any time. 

2234 

2235 Args: 

2236 x: Input data. It could be: 

2237 - A Numpy array (or array-like), or a list of arrays 

2238 (in case the model has multiple inputs). 

2239 - A TensorFlow tensor, or a list of tensors 

2240 (in case the model has multiple inputs). 

2241 - A dict mapping input names to the corresponding array/tensors, 

2242 if the model has named inputs. 

2243 - A `tf.data` dataset. 

2244 y: Target data. Like the input data `x`, 

2245 it could be either Numpy array(s) or TensorFlow tensor(s). 

2246 It should be consistent with `x` (you cannot have Numpy inputs and 

2247 tensor targets, or inversely). If `x` is a dataset, `y` should not be 

2248 specified (since targets will be obtained from the iterator). 

2249 sample_weight: An optional sample-weight array passed by the user to 

2250 weight the importance of each sample in `x`. 

2251 class_weight: An optional class-weight array by the user to 

2252 weight the importance of samples in `x` based on the class they belong 

2253 to, as conveyed by `y`. If both `sample_weight` and `class_weight` are 

2254 provided, the weights are multiplied. 

2255 batch_size: Integer batch size. If provided, it is used to run additional 

2256 validation checks on stateful models. 

2257 check_steps: boolean, True if we want to check for validity of `steps` and 

2258 False, otherwise. For example, when we are standardizing one batch of 

2259 data for train_on_batch/predict_on_batch/test_on_batch APIs, `steps` 

2260 value is not required and we should not check for its validity in these 

2261 cases. 

2262 steps_name: The public API's parameter name for `steps`. 

2263 steps: Integer or `None`. Total number of steps (batches of samples) to 

2264 execute. 

2265 validation_split: Float between 0 and 1. 

2266 Fraction of the training data to be used as validation data. 

2267 shuffle: Boolean whether to shuffle the training data before each epoch. 

2268 extract_tensors_from_dataset: Boolean. When `x` is a dataset instance, 

2269 this indicates whether to extract actual tensors from the dataset or 

2270 instead output the dataset instance itself. 

2271 Set to True when calling from `train_on_batch`/etc. 

2272 

2273 Returns: 

2274 A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict 

2275 or not), target arrays, sample-weight arrays. 

2276 If the model's input and targets are symbolic, these lists are empty 

2277 (since the model takes no user-provided data, instead the data comes 

2278 from the symbolic inputs/targets). 

2279 

2280 Raises: 

2281 ValueError: In case of invalid user-provided data. 

2282 RuntimeError: If the model was never compiled. 

2283 """ 

2284 if isinstance(x, (data_types.DatasetV1, data_types.DatasetV2)): 

2285 # Graph mode dataset. We'll pass the dataset as-is (unless 

2286 # `extract_tensors_from_dataset` is True, in which case we extract 

2287 # the tensors from the dataset and we output them. 

2288 training_utils_v1.validate_dataset_input(x, y, sample_weight, 

2289 validation_split) 

2290 if shuffle: 

2291 training_utils_v1.verify_dataset_shuffled(x) 

2292 

2293 is_dataset = True 

2294 if extract_tensors_from_dataset: 

2295 # We do this for `train_on_batch`/etc. 

2296 x, y, sample_weight = training_utils_v1.extract_tensors_from_dataset(x) 

2297 elif isinstance(x, iterator_ops.Iterator): 

2298 # Graph mode iterator. We extract the symbolic tensors. 

2299 training_utils_v1.validate_dataset_input(x, y, sample_weight, 

2300 validation_split) 

2301 iterator = x 

2302 x, y, sample_weight = training_utils_v1.unpack_iterator_input(iterator) 

2303 is_dataset = True 

2304 else: 

2305 is_dataset = False 

2306 

2307 # Validates `steps` argument based on x's type. 

2308 if check_steps: 

2309 training_utils_v1.check_steps_argument(x, steps, steps_name) 

2310 

2311 # First, we build the model on the fly if necessary. 

2312 if not self.inputs: 

2313 all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y) 

2314 is_build_called = True 

2315 else: 

2316 all_inputs = [] 

2317 # Whether this is a subclassed model that expects dictionary inputs 

2318 # rather than list inputs (e.g. FeatureColumn-based models). 

2319 dict_inputs = isinstance(self.inputs, dict) 

2320 is_build_called = False 

2321 y_input = y 

2322 

2323 # Second, we compile the model on the fly if necessary, mostly for subclass 

2324 # models. 

2325 is_compile_called = False 

2326 if not self._is_compiled and self.optimizer: 

2327 self._compile_from_inputs(all_inputs, y_input, x, y) 

2328 is_compile_called = True 

2329 

2330 # In graph mode, if we had just set inputs and targets as symbolic tensors 

2331 # by invoking build and compile on the model respectively, we do not have to 

2332 # feed anything to the model. Model already has input and target data as 

2333 # part of the graph. 

2334 # Note: in this case, `any` and `all` are equivalent since we disallow 

2335 # mixed symbolic/value inputs. 

2336 

2337 # self.run_eagerly is not free to compute, so we want to reuse the value. 

2338 run_eagerly = self.run_eagerly 

2339 

2340 if (not run_eagerly and is_build_called and is_compile_called and 

2341 not is_dataset and any(_is_symbolic_tensor(v) for v in all_inputs)): 

2342 return [], [], None 

2343 

2344 return self._standardize_tensors( 

2345 x, y, sample_weight, 

2346 run_eagerly=run_eagerly, 

2347 dict_inputs=dict_inputs, 

2348 is_dataset=is_dataset, 

2349 class_weight=class_weight, 

2350 batch_size=batch_size) 

2351 

2352 def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, 

2353 is_dataset, class_weight=None, batch_size=None): 

2354 if run_eagerly: 

2355 # In eager mode, do not do shape validation 

2356 # since the network has no input nodes (placeholders) to be fed. 

2357 feed_input_names = self.input_names 

2358 feed_input_shapes = None 

2359 elif not self._is_graph_network: 

2360 # Case: symbolic-mode subclassed network. Do not do shape validation. 

2361 feed_input_names = self._feed_input_names 

2362 feed_input_shapes = None 

2363 else: 

2364 # Case: symbolic-mode graph network. 

2365 # In this case, we run extensive shape validation checks. 

2366 feed_input_names = self._feed_input_names 

2367 feed_input_shapes = self._feed_input_shapes 

2368 

2369 # Standardize the inputs. 

2370 if not isinstance(x, (data_types.DatasetV1, data_types.DatasetV2)): 

2371 # TODO(fchollet): run static checks with dataset output shape(s). 

2372 x = training_utils_v1.standardize_input_data( 

2373 x, 

2374 feed_input_names, 

2375 feed_input_shapes, 

2376 check_batch_axis=False, # Don't enforce the batch size. 

2377 exception_prefix='input') 

2378 

2379 # Get typespecs for the input data and sanitize it if necessary. 

2380 # TODO(momernick): This should be capable of doing full input validation 

2381 # at all times - validate that this is so and refactor the standardization 

2382 # code. 

2383 if isinstance(x, data_types.DatasetV2): 

2384 x_shapes = dataset_ops.get_structure(x) 

2385 if isinstance(x_shapes, tuple): 

2386 # If the output of a Dataset is a tuple, we assume it's either of the 

2387 # form (x_data, y_data) or (x_data, y_data, sample_weights). In either 

2388 # case, we only care about x_data here. 

2389 x_shapes = x_shapes[0] 

2390 else: 

2391 flat_inputs = nest.flatten(x, expand_composites=False) 

2392 flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False) 

2393 converted_x = [] 

2394 for (a, b) in zip(flat_inputs, flat_expected_inputs): 

2395 converted_x.append(_convert_scipy_sparse_tensor(a, b)) 

2396 x = nest.pack_sequence_as(x, converted_x, expand_composites=False) 

2397 

2398 def _type_spec_from_value(value): 

2399 """Grab type_spec without converting array-likes to tensors.""" 

2400 if tf_utils.is_extension_type(value): 

2401 return value._type_spec # pylint: disable=protected-access 

2402 # Get a TensorSpec for array-like data without 

2403 # converting the data to a Tensor 

2404 if hasattr(value, 'shape') and hasattr(value, 'dtype'): 

2405 return tensor_spec.TensorSpec(value.shape, value.dtype) 

2406 else: 

2407 return type_spec.type_spec_from_value(value) 

2408 

2409 x_shapes = nest.map_structure(_type_spec_from_value, x) 

2410 

2411 flat_inputs = nest.flatten(x_shapes, expand_composites=False) 

2412 flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False) 

2413 for (a, b) in zip(flat_inputs, flat_expected_inputs): 

2414 nest.assert_same_structure(a, b, expand_composites=True) 

2415 

2416 if y is not None: 

2417 # Prepare self._sample_weight_modes. List with the same length as 

2418 # model outputs. 

2419 training_utils_v1.prepare_sample_weight_modes(self._training_endpoints, 

2420 self.sample_weight_mode) 

2421 feed_output_names = self._feed_output_names 

2422 feed_sample_weight_modes = self._sample_weight_modes 

2423 if not self._is_graph_network: 

2424 feed_output_shapes = None 

2425 else: 

2426 feed_output_shapes = self._feed_output_shapes 

2427 

2428 # Standardize the outputs. 

2429 y = training_utils_v1.standardize_input_data( 

2430 y, 

2431 feed_output_names, 

2432 # Don't enforce target shapes to match output shapes. 

2433 # Precise checks will be run in `check_loss_and_target_compatibility`. 

2434 shapes=None, 

2435 check_batch_axis=False, # Don't enforce the batch size. 

2436 exception_prefix='target') 

2437 

2438 # Generate sample-wise weight values given the `sample_weight` and 

2439 # `class_weight` arguments. 

2440 sample_weights = training_utils_v1.standardize_sample_weights( 

2441 sample_weight, feed_output_names) 

2442 class_weights = training_utils_v1.standardize_class_weights( 

2443 class_weight, feed_output_names) 

2444 

2445 sample_weights = [ 

2446 training_utils_v1.standardize_weights(ref, sw, cw, mode) 

2447 for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, 

2448 feed_sample_weight_modes) 

2449 ] 

2450 # Check that all arrays have the same length. 

2451 if not self._distribution_strategy: 

2452 training_utils_v1.check_array_lengths(x, y, sample_weights) 

2453 if self._is_graph_network and not run_eagerly: 

2454 # Additional checks to avoid users mistakenly using improper loss fns. 

2455 training_utils_v1.check_loss_and_target_compatibility( 

2456 y, self._feed_loss_fns, feed_output_shapes) 

2457 

2458 sample_weights, _, _ = training_utils.handle_partial_sample_weights( 

2459 y, sample_weights, feed_sample_weight_modes, check_all_flat=True) 

2460 else: 

2461 y = [] 

2462 sample_weights = None 

2463 

2464 if self.stateful and batch_size and not is_dataset: 

2465 # Check that for stateful networks, number of samples is a multiple 

2466 # of the static batch size. 

2467 if x[0].shape[0] % batch_size != 0: 

2468 raise ValueError('In a stateful network, ' 

2469 'you should only pass inputs with ' 

2470 'a number of samples that can be ' 

2471 'divided by the batch size. Found: ' + 

2472 str(x[0].shape[0]) + ' samples') 

2473 

2474 # If dictionary inputs were provided, we return a dictionary as well. 

2475 if dict_inputs and not isinstance(x, (data_types.DatasetV1, 

2476 data_types.DatasetV2)): 

2477 x = dict(zip(feed_input_names, x)) 

2478 return x, y, sample_weights 

2479 

2480 def _build_model_with_inputs(self, inputs, targets): 

2481 """Build the model (set model inputs/outputs), mainly for subclass model.""" 

2482 processed_inputs = [] 

2483 is_dict_inputs = False 

2484 orig_inputs = inputs 

2485 # We need to use `inputs` to set the model inputs. 

2486 # If input data is a dataset iterator in graph mode or if it is an eager 

2487 # iterator and only one batch of samples is required, we fetch the data 

2488 # tensors from the iterator and then standardize them. 

2489 if isinstance(inputs, (data_types.DatasetV1, data_types.DatasetV2)): 

2490 inputs, targets, _ = training_utils_v1.extract_tensors_from_dataset( 

2491 inputs) 

2492 # We type-check that `inputs` and `targets` are either single arrays 

2493 # or lists of arrays, and extract a flat list of inputs from the passed 

2494 # structure. 

2495 training_utils_v1.validate_input_types(inputs, orig_inputs) 

2496 

2497 if isinstance(inputs, (list, tuple)): 

2498 processed_inputs += list(inputs) 

2499 elif isinstance(inputs, dict): 

2500 is_dict_inputs = True 

2501 keys = sorted(inputs.keys()) 

2502 processed_inputs = [inputs[k] for k in keys] 

2503 else: 

2504 processed_inputs.append(inputs) 

2505 # Now that we have a flat set of inputs, we make sure that none of them 

2506 # are CompositeTensors or CompositeTensorValues of any type (or scipy 

2507 # sparse arrays, which we treat as SparseTensor values). We cannot safely 

2508 # infer input data from an arbitrary composite tensor, so we don't try - 

2509 # users should explicitly add composite tensor inputs to their subclassed 

2510 # models. 

2511 for input_tensor in processed_inputs: 

2512 if training_utils_v1.is_composite_or_composite_value(input_tensor): 

2513 # TODO(b/132691975): Document subclass-model CT input handling. 

2514 raise ValueError( 

2515 'All SparseTensor and RaggedTensor inputs must be explicitly ' 

2516 'declared using a keras.Input() with sparse=True or ragged=True. ' 

2517 'We found an undeclared input %s. For Sequential models, please ' 

2518 'add a keras.Input() as your first Layer. For subclassed models, ' 

2519 'please call self._set_inputs() on your input set, which you can ' 

2520 'create using keras.Input() for each input to your model.' % 

2521 (input_tensor,)) 

2522 # Build the model using the retrieved inputs (value or symbolic). 

2523 # If values are generated from a dataset, then in symbolic-mode 

2524 # placeholders will be created to match the value shapes. 

2525 if isinstance(orig_inputs, (data_types.DatasetV1, data_types.DatasetV2, 

2526 iterator_ops.Iterator)): 

2527 if not self.inputs: 

2528 # For subclassed models, a robust input spec is not available so we 

2529 # must cast to the model dtype. 

2530 inputs = training_utils_v1.cast_if_floating_dtype(inputs, self.dtype) 

2531 

2532 def create_tensor_spec(t): 

2533 return tensor_spec.TensorSpec(t.shape, t.dtype) 

2534 

2535 cast_inputs = nest.map_structure(create_tensor_spec, inputs) 

2536 elif training_utils_v1.has_tensors(inputs): 

2537 cast_inputs = training_utils_v1.cast_if_floating_dtype(inputs) 

2538 else: 

2539 cast_inputs = inputs 

2540 self._set_inputs(cast_inputs) 

2541 return processed_inputs, targets, is_dict_inputs 

2542 

2543 def _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target): 

2544 if target is not None: 

2545 # We need to use `y` to set the model targets. 

2546 if training_utils_v1.has_tensors(target): 

2547 target = training_utils_v1.cast_if_floating_dtype_and_mismatch( 

2548 target, self.outputs) 

2549 training_utils_v1.validate_input_types( 

2550 target, orig_target, allow_dict=False, field_name='target') 

2551 if isinstance(target, (list, tuple)): 

2552 all_inputs += list(target) 

2553 else: 

2554 all_inputs.append(target) 

2555 # Type check that all inputs are *either* value *or* symbolic. 

2556 # TODO(fchollet): this check could be removed in Eager mode? 

2557 if any(tensor_util.is_tf_type(v) for v in all_inputs): 

2558 if not all(tensor_util.is_tf_type(v) for v in all_inputs): 

2559 raise ValueError('Do not pass inputs that mix Numpy arrays and ' 

2560 'TensorFlow tensors. ' 

2561 'You passed: x=' + str(orig_inputs) + 

2562 '; y=' + str(orig_target)) 

2563 is_dataset = isinstance(orig_inputs, (data_types.DatasetV1, 

2564 data_types.DatasetV2, 

2565 iterator_ops.Iterator)) 

2566 if is_dataset or context.executing_eagerly(): 

2567 target_tensors = None 

2568 else: 

2569 # Handle target tensors if any passed. 

2570 if target is not None: 

2571 if not isinstance(target, (list, tuple)): 

2572 target = [target] 

2573 target_tensors = [v for v in target if _is_symbolic_tensor(v)] 

2574 else: 

2575 target_tensors = None 

2576 

2577 self.compile( 

2578 optimizer=self.optimizer, 

2579 loss=self.loss, 

2580 metrics=self._compile_metrics, 

2581 weighted_metrics=self._compile_weighted_metrics, 

2582 loss_weights=self.loss_weights, 

2583 target_tensors=target_tensors, 

2584 sample_weight_mode=self.sample_weight_mode, 

2585 run_eagerly=self.run_eagerly, 

2586 experimental_run_tf_function=self._experimental_run_tf_function) 

2587 

2588 # TODO(omalleyt): Consider changing to a more descriptive function name. 

2589 def _set_inputs(self, inputs, outputs=None, training=None): 

2590 """Set model's input and output specs based on the input data received. 

2591 

2592 This is to be used for Model subclasses, which do not know at instantiation 

2593 time what their inputs look like. 

2594 

2595 Args: 

2596 inputs: Single array, or list of arrays. The arrays could be placeholders, 

2597 Numpy arrays, data tensors, or TensorSpecs. 

2598 - if placeholders: the model is built on top of these placeholders, 

2599 and we expect Numpy data to be fed for them when calling `fit`/etc. 

2600 - if Numpy data or TensorShapes: we create placeholders matching the 

2601 TensorShapes or shapes of the Numpy arrays. We expect Numpy data to be 

2602 fed for these placeholders when calling `fit`/etc. 

2603 - if data tensors: the model is built on top of these tensors. 

2604 We do not expect any Numpy data to be provided when calling `fit`/etc. 

2605 outputs: None, a data tensor, or a list of tensors. If None, the 

2606 outputs will be determined by invoking `self.call()`, otherwise the 

2607 provided value will be used. 

2608 training: Boolean or None. Only relevant in symbolic mode. Specifies 

2609 whether to build the model's graph in inference mode (False), training 

2610 mode (True), or using the Keras learning phase (None). 

2611 Raises: 

2612 ValueError: If dict inputs are passed to a Sequential Model where the 

2613 first layer isn't FeatureLayer. 

2614 """ 

2615 self._set_save_spec(inputs) 

2616 inputs = self._set_input_attrs(inputs) 

2617 

2618 if outputs is None: 

2619 kwargs = {} 

2620 if self._expects_training_arg: 

2621 # In V2 mode, feeding `training=None` is not allowed because any value 

2622 # explicitly passed by the user is respected, even `None`.` 

2623 if training is None and not ops.executing_eagerly_outside_functions(): 

2624 training = backend.learning_phase() 

2625 if training is not None: 

2626 kwargs['training'] = training 

2627 try: 

2628 outputs = self(inputs, **kwargs) 

2629 except NotImplementedError: 

2630 # This Model or a submodel is dynamic and hasn't overridden 

2631 # `compute_output_shape`. 

2632 outputs = None 

2633 

2634 self._set_output_attrs(outputs) 

2635 

2636 @trackable.no_automatic_dependency_tracking 

2637 def _set_input_attrs(self, inputs): 

2638 """Sets attributes related to the inputs of the Model.""" 

2639 if self.inputs: 

2640 raise ValueError('Model inputs are already set.') 

2641 

2642 if self.__class__.__name__ == 'Sequential' and not self.built: 

2643 if tensor_util.is_tf_type(inputs): 

2644 input_shape = (None,) + tuple(inputs.shape.as_list()[1:]) 

2645 elif isinstance(inputs, tensor_shape.TensorShape): 

2646 input_shape = (None,) + tuple(inputs.as_list()[1:]) 

2647 elif isinstance(inputs, dict): 

2648 # We assert that the first layer is a FeatureLayer. 

2649 if not training_utils_v1.is_feature_layer(self.layers[0]): 

2650 raise ValueError('Passing a dictionary input to a Sequential Model ' 

2651 'which doesn\'t have FeatureLayer as the first layer' 

2652 ' is an error.') 

2653 input_shape = (None,) 

2654 else: 

2655 input_shape = (None,) + tuple(inputs.shape[1:]) 

2656 self._build_input_shape = input_shape 

2657 

2658 # Cast inputs to the compute dtype. This is primarily used 

2659 # when saving to determine the correct dtype in the input signature. 

2660 inputs = self._maybe_cast_inputs(inputs) 

2661 

2662 # On-the-fly setting of symbolic model inputs (either by using the tensor 

2663 # provided, or by creating a placeholder if Numpy data was provided). 

2664 model_inputs = training_utils_v1.ModelInputs(inputs) 

2665 inputs = model_inputs.get_symbolic_inputs() 

2666 self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True) 

2667 self.input_names = model_inputs.get_input_names() 

2668 

2669 self._feed_inputs = [] 

2670 self._feed_input_names = [] 

2671 self._feed_input_shapes = [] 

2672 

2673 for k, v in model_inputs.as_dict(): 

2674 if backend.is_placeholder(v): 

2675 self._feed_input_names.append(k) 

2676 self._feed_inputs.append(v) 

2677 self._feed_input_shapes.append(backend.int_shape(v)) 

2678 

2679 return inputs 

2680 

2681 @trackable.no_automatic_dependency_tracking 

2682 def _set_output_attrs(self, outputs): 

2683 """Sets attributes related to the outputs of the Model.""" 

2684 # NOTE(taylorrobie): This convention cannot be changed without updating the 

2685 # data adapter since it assumes nest.flatten ordering. 

2686 outputs = nest.flatten(outputs) 

2687 self.outputs = outputs 

2688 self.output_names = training_utils_v1.generic_output_names(outputs) 

2689 # TODO(scottzhu): Should we cleanup the self._training_endpoints here? 

2690 self.built = True 

2691 

2692 @property 

2693 def _targets(self): 

2694 """The output target tensors for the model.""" 

2695 return [ 

2696 e.training_target.target 

2697 for e in self._training_endpoints 

2698 if e.has_training_target() 

2699 ] 

2700 

2701 @property 

2702 def _feed_targets(self): 

2703 return [ 

2704 e.training_target.target 

2705 for e in self._training_endpoints 

2706 if e.has_feedable_training_target() 

2707 ] 

2708 

2709 @property 

2710 def _feed_output_names(self): 

2711 return [ 

2712 e.output_name 

2713 for e in self._training_endpoints 

2714 if e.has_feedable_training_target() 

2715 ] 

2716 

2717 @property 

2718 def _feed_output_shapes(self): 

2719 return [ 

2720 e.feed_output_shape 

2721 for e in self._training_endpoints 

2722 if e.has_feedable_training_target() 

2723 ] 

2724 

2725 @property 

2726 def _feed_loss_fns(self): 

2727 return [ 

2728 e.loss_fn 

2729 for e in self._training_endpoints 

2730 if e.has_feedable_training_target() 

2731 ] 

2732 

2733 @property 

2734 def _loss_weights_list(self): 

2735 return [e.loss_weight for e in self._training_endpoints] 

2736 

2737 @property 

2738 def _output_loss_metrics(self): 

2739 if hasattr(self, '_training_endpoints'): 

2740 return [ 

2741 e.output_loss_metric 

2742 for e in self._training_endpoints 

2743 if e.output_loss_metric is not None 

2744 ] 

2745 return None 

2746 

2747 @property 

2748 def sample_weights(self): 

2749 return [e.sample_weight for e in self._training_endpoints] 

2750 

2751 @property 

2752 def _sample_weight_modes(self): 

2753 return [e.sample_weight_mode for e in self._training_endpoints] 

2754 

2755 @property 

2756 def _feed_sample_weights(self): 

2757 return [e.sample_weight for e in self._training_endpoints 

2758 if e.sample_weight is not None] 

2759 

2760 def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode): 

2761 """Maybe load initial epoch from ckpt considering possible worker recovery. 

2762 

2763 Refer to tensorflow/python/keras/distribute/worker_training_state.py 

2764 for more information. 

2765 

2766 Args: 

2767 initial_epoch: The original initial_epoch user passes in in `fit()`. 

2768 mode: The mode for running `model.fit()`. 

2769 

2770 Returns: 

2771 If the training is recovering from previous failure under multi-worker 

2772 training setting, return the epoch the training is supposed to continue 

2773 at. Otherwise, return the `initial_epoch` the user passes in. 

2774 """ 

2775 if self._training_state is not None: 

2776 return self._training_state.maybe_load_initial_epoch_from_ckpt( 

2777 initial_epoch, mode) 

2778 return initial_epoch 

2779 

2780 def _get_training_eval_metrics(self): 

2781 """Returns all the metrics that are to be reported. 

2782 

2783 This includes the output loss metrics, compile metrics/weighted metrics, 

2784 add_metric metrics. 

2785 """ 

2786 metrics = [] 

2787 metrics.extend(getattr(self, '_output_loss_metrics', None) or []) 

2788 metrics.extend(getattr(self, 'metrics', None) or []) 

2789 return metrics 

2790 

2791 def _assert_compile_was_called(self): 

2792 # Checks whether `compile` has been called. If it has been called, 

2793 # then the optimizer is set. This is different from whether the 

2794 # model is compiled 

2795 # (i.e. whether the model is built and its inputs/outputs are set). 

2796 if not self._compile_was_called: 

2797 raise RuntimeError('You must compile your model before ' 

2798 'training/testing. ' 

2799 'Use `model.compile(optimizer, loss)`.') 

2800 

2801 def _in_multi_worker_mode(self): 

2802 """Method to infer if this `Model` is working in multi-worker settings. 

2803 

2804 Multi-worker training refers to the setup where the training is 

2805 distributed across multiple workers, as opposed to the case where 

2806 only a local process performs the training. This function is 

2807 used to infer for example whether or not a distribute coordinator 

2808 should be run, and thus TensorFlow servers should be started for 

2809 communication with other servers in the cluster, or whether or not 

2810 saving/restoring checkpoints is relevant for preemption fault tolerance. 

2811 

2812 Experimental. Signature and implementation are subject to change. 

2813 

2814 Returns: 

2815 Whether this model indicates it's working in multi-worker settings. 

2816 """ 

2817 strategy = self._distribution_strategy 

2818 

2819 # Otherwise, use the strategy whose scope this is in. 

2820 if not strategy and distribute_lib.has_strategy(): 

2821 strategy = distribute_lib.get_strategy() 

2822 return strategy and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 

2823 

2824 @property 

2825 def _trackable_saved_model_saver(self): 

2826 return model_serialization.ModelSavedModelSaver(self) 

2827 

2828 def _get_compile_args(self, user_metrics=True): 

2829 del user_metrics 

2830 self._assert_compile_was_called() 

2831 kwargs = { 

2832 'loss': self.loss, 

2833 'metrics': self._compile_metrics, 

2834 'loss_weights': self.loss_weights, 

2835 'sample_weight_mode': self.sample_weight_mode, 

2836 'weighted_metrics': self._compile_weighted_metrics, 

2837 } 

2838 return kwargs 

2839 

2840 @property 

2841 def _compile_was_called(self): 

2842 return self._v1_compile_was_called 

2843 

2844 

2845class DistributedCallbackModel(Model): 

2846 """Model that is used for callbacks with tf.distribute.Strategy.""" 

2847 

2848 def __init__(self, model): 

2849 super(DistributedCallbackModel, self).__init__() 

2850 self.optimizer = model.optimizer 

2851 

2852 def set_original_model(self, orig_model): 

2853 self._original_model = orig_model 

2854 

2855 def save_weights(self, filepath, overwrite=True, save_format=None): 

2856 self._replicated_model.save_weights(filepath, overwrite=overwrite, 

2857 save_format=save_format) 

2858 

2859 def save(self, filepath, overwrite=True, include_optimizer=True): 

2860 # save weights from the distributed model to the original model 

2861 distributed_model_weights = self.get_weights() 

2862 self._original_model.set_weights(distributed_model_weights) 

2863 # TODO(anjalisridhar): Do we need to save the original model here? 

2864 # Saving the first replicated model works as well. 

2865 self._original_model.save(filepath, overwrite=True, include_optimizer=False) 

2866 

2867 def load_weights(self, filepath, by_name=False): 

2868 self._original_model.load_weights(filepath, by_name=False) 

2869 # Copy the weights from the original model to each of the replicated models. 

2870 orig_model_weights = self._original_model.get_weights() 

2871 distributed_training_utils_v1.set_weights( 

2872 self._original_model._distribution_strategy, self, # pylint: disable=protected-access 

2873 orig_model_weights) 

2874 

2875 def __getattr__(self, item): 

2876 # Allowed attributes of the model that can be accessed by the user 

2877 # during a callback. 

2878 if item not in ('_setattr_tracking', '_layers'): 

2879 logging.warning('You are accessing attribute ' + item + ' of the ' 

2880 'DistributedCallbackModel that may not have been set ' 

2881 'correctly.') 

2882 return super(DistributedCallbackModel, self).__getattr__(item) 

2883 

2884 

2885class _TrainingEndpoint(object): 

2886 """A container for the training output/target and related entities. 

2887 

2888 In the case of model with multiple outputs, there is a one-to-one mapping 

2889 between model output (y_pred), model target (y_true), loss, metrics etc. 

2890 By unifying these entities into one class, different entity can access 

2891 information between each other, rather than currently access different list of 

2892 attributes of the model. 

2893 """ 

2894 

2895 def __init__(self, 

2896 output, 

2897 output_name, 

2898 loss_fn, 

2899 loss_weight=None, 

2900 training_target=None, 

2901 output_loss_metric=None, 

2902 sample_weight=None, 

2903 sample_weight_mode=None): 

2904 """Initialize the _TrainingEndpoint. 

2905 

2906 Note that the output and output_name should be stable as long as the model 

2907 structure doesn't change. The training_target suppose to be mutable since 

2908 the information is provided via `compile()` 

2909 

2910 Args: 

2911 output: the output tensor of the model. 

2912 output_name: the unique name of the output tensor. 

2913 loss_fn: the loss function for the output tensor. 

2914 loss_weight: float, the weights for the loss. 

2915 training_target: the _TrainingTarget for the model. 

2916 output_loss_metric: the metric object for the loss function. 

2917 sample_weight: the weights for how a sample is weighted during metric and 

2918 loss calculation. Could be None. 

2919 sample_weight_mode: string, 'temporal', 'samplewise' or None. The mode for 

2920 how the sample_weight is populated. 

2921 """ 

2922 self._output = output 

2923 self._output_name = output_name 

2924 self._loss_fn = loss_fn 

2925 self._loss_weight = loss_weight 

2926 self._training_target = training_target 

2927 self._output_loss_metric = output_loss_metric 

2928 self._sample_weight = sample_weight 

2929 self._sample_weight_mode = sample_weight_mode 

2930 

2931 @property 

2932 def output(self): 

2933 return self._output 

2934 

2935 @property 

2936 def output_name(self): 

2937 return self._output_name 

2938 

2939 @property 

2940 def shape(self): 

2941 return backend.int_shape(self.output) 

2942 

2943 @property 

2944 def loss_fn(self): 

2945 return self._loss_fn 

2946 

2947 @property 

2948 def loss_weight(self): 

2949 return self._loss_weight 

2950 

2951 @loss_weight.setter 

2952 def loss_weight(self, value): 

2953 self._loss_weight = value 

2954 

2955 @property 

2956 def training_target(self): 

2957 return self._training_target 

2958 

2959 @training_target.setter 

2960 def training_target(self, value): 

2961 self._training_target = value 

2962 

2963 def create_training_target(self, target, run_eagerly=False): 

2964 """Create training_target instance and update the self.training_target. 

2965 

2966 Note that the input target should just be a tensor or None, and 

2967 corresponding training target will be created based on the output and 

2968 loss_fn. 

2969 

2970 Args: 

2971 target: the target tensor for the current output. Could be None. 

2972 run_eagerly: boolean, whether the model is in run_eagerly mode. 

2973 

2974 Raises: 

2975 ValueError if the training_target field for the current instance has 

2976 already been populated. 

2977 """ 

2978 if self.has_training_target(): 

2979 raise ValueError('The training_target field for the _TrainingEndpoint ' 

2980 'instance has already been populated') 

2981 if run_eagerly: 

2982 # When run_eagerly, the target tensor is ignored, and the None placeholder 

2983 # is created instead. 

2984 self.training_target = _TrainingTarget( 

2985 None, feedable=True, skip_target_weights=False) 

2986 return 

2987 

2988 if self.should_skip_target(): 

2989 self.training_target = _TrainingTarget(None) 

2990 else: 

2991 if target is not None and not backend.is_placeholder(target): 

2992 feedable = False 

2993 skip_target_weights = True 

2994 else: 

2995 feedable = True 

2996 skip_target_weights = False 

2997 

2998 if target is None: 

2999 target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get( 

3000 self.loss_fn, backend.dtype(self.output)) 

3001 

3002 target = backend.placeholder( 

3003 ndim=len(self.shape), 

3004 name=self.output_name + '_target', 

3005 sparse=backend.is_sparse(self.output), 

3006 dtype=target_dtype) 

3007 

3008 self.training_target = _TrainingTarget( 

3009 target, 

3010 feedable=feedable, 

3011 skip_target_weights=skip_target_weights) 

3012 

3013 @property 

3014 def output_loss_metric(self): 

3015 return self._output_loss_metric 

3016 

3017 @output_loss_metric.setter 

3018 def output_loss_metric(self, value): 

3019 self._output_loss_metric = value 

3020 

3021 @property 

3022 def sample_weight(self): 

3023 return self._sample_weight 

3024 

3025 @sample_weight.setter 

3026 def sample_weight(self, value): 

3027 self._sample_weight = value 

3028 

3029 @property 

3030 def sample_weight_mode(self): 

3031 return self._sample_weight_mode 

3032 

3033 @sample_weight_mode.setter 

3034 def sample_weight_mode(self, value): 

3035 self._sample_weight_mode = value 

3036 

3037 def should_skip_target(self): 

3038 return self._loss_fn is None 

3039 

3040 def should_skip_target_weights(self): 

3041 return (self.should_skip_target() or self.training_target is None or 

3042 self.training_target.skip_target_weights) 

3043 

3044 def has_training_target(self): 

3045 return self.training_target is not None 

3046 

3047 def has_feedable_training_target(self): 

3048 return (not self.should_skip_target() and 

3049 self.training_target is not None and self.training_target.feedable) 

3050 

3051 def loss_name(self): 

3052 if self._loss_fn is not None: 

3053 return self._output_name + '_loss' 

3054 return None 

3055 

3056 @property 

3057 def feed_output_shape(self): 

3058 """The output shape for the feedable target.""" 

3059 if not self.has_feedable_training_target(): 

3060 return None 

3061 

3062 if ((isinstance(self.loss_fn, losses.LossFunctionWrapper) and 

3063 self.loss_fn.fn == losses.sparse_categorical_crossentropy)) or ( 

3064 isinstance(self.loss_fn, losses.SparseCategoricalCrossentropy)): 

3065 if backend.image_data_format() == 'channels_first': 

3066 return (self.shape[0], 1) + self.shape[2:] 

3067 else: 

3068 return self.shape[:-1] + (1,) 

3069 elif (not isinstance(self.loss_fn, losses.Loss) or 

3070 (isinstance(self.loss_fn, losses.LossFunctionWrapper) and 

3071 (getattr(losses, self.loss_fn.fn.__name__, None) is None))): 

3072 # If the given loss is not an instance of the `Loss` class (custom 

3073 # class) or if the loss function that is wrapped is not in the 

3074 # `losses` module, then it is a user-defined loss and we make no 

3075 # assumptions about it. 

3076 return None 

3077 else: 

3078 return self.shape 

3079 

3080 def sample_weights_mismatch(self): 

3081 """Check if the sample weight and the mode match or not.""" 

3082 # If there is a mismatch between sample weight mode and the placeholders 

3083 # created, then recompile the sub-graphs that depend on sample weights. 

3084 return ( 

3085 (self.sample_weight_mode is not None and self.sample_weight is None) or 

3086 (self.sample_weight_mode is None and self.sample_weight is not None)) 

3087 

3088 def populate_sample_weight(self, sample_weight, sample_weight_mode): 

3089 """Populate the sample weight and based on the sample weight mode.""" 

3090 if (sample_weight is None and 

3091 (self.should_skip_target_weights() or sample_weight_mode is None or 

3092 context.executing_eagerly())): 

3093 self._sample_weight = None 

3094 return 

3095 

3096 assert sample_weight_mode in ['temporal', 'samplewise'] 

3097 if sample_weight_mode == 'temporal': 

3098 default_value = [[1.]] 

3099 shape = [None, None] 

3100 else: 

3101 # sample_weight_mode == 'samplewise' 

3102 default_value = [1.] 

3103 shape = [None] 

3104 

3105 if sample_weight is not None: 

3106 if not sample_weight.shape.is_compatible_with(shape): 

3107 raise ValueError('Received sample weight with shape {}. Expected shape ' 

3108 '{}.'.format(sample_weight.shape, shape)) 

3109 self._sample_weight = sample_weight 

3110 else: 

3111 self._sample_weight = array_ops.placeholder_with_default( 

3112 constant_op.constant(default_value, dtype=backend.floatx()), 

3113 shape=shape, 

3114 name=self.output_name + '_sample_weights') 

3115 

3116 

3117class _TrainingTarget(object): 

3118 """Container for a target tensor (y_true) and its metadata (shape, loss...). 

3119 

3120 Args: 

3121 target: A target tensor for the model. It may be `None` if the 

3122 output is excluded from loss computation. It is still kept as None 

3123 since each output of the model should have a corresponding target. If 

3124 the target is None, the rest of the attributes will be None as well. 

3125 feedable: Boolean, whether the target is feedable (requires data to be 

3126 passed in `fit` or `train_on_batch`), or not (model compiled with 

3127 `target_tensors` argument). 

3128 skip_target_weights: Boolean, whether the target should be skipped during 

3129 weights calculation. 

3130 """ 

3131 

3132 def __init__(self, target, feedable=False, skip_target_weights=True): 

3133 self._target = target 

3134 self._feedable = feedable 

3135 self._skip_target_weights = skip_target_weights 

3136 

3137 @property 

3138 def target(self): 

3139 return self._target 

3140 

3141 @property 

3142 def feedable(self): 

3143 return self._feedable 

3144 

3145 @property 

3146 def skip_target_weights(self): 

3147 return self._skip_target_weights 

3148 

3149 

3150def _is_symbolic_tensor(x): 

3151 return tensor_util.is_tf_type(x) 

3152 

3153 

3154def _convert_scipy_sparse_tensor(value, expected_input): 

3155 """Handle scipy sparse tensor conversions. 

3156 

3157 This method takes a value 'value' and returns the proper conversion. If 

3158 value is a scipy sparse tensor and the expected input is a dense tensor, 

3159 we densify 'value'. If value is a scipy sparse tensor and the expected input 

3160 is a TF SparseTensor, we convert 'value' to a SparseTensor. If 'value' is 

3161 not a scipy sparse tensor, or scipy is not imported, we pass it through 

3162 unchanged. 

3163 

3164 Args: 

3165 value: An object that may be a scipy sparse tensor 

3166 expected_input: The expected input placeholder. 

3167 

3168 Returns: 

3169 The possibly-converted 'value'. 

3170 """ 

3171 if issparse is not None and issparse(value): 

3172 if backend.is_sparse(expected_input): 

3173 sparse_coo = value.tocoo() 

3174 row, col = sparse_coo.row, sparse_coo.col 

3175 data, shape = sparse_coo.data, sparse_coo.shape 

3176 indices = np.concatenate((np.expand_dims(row, 1), np.expand_dims(col, 1)), 

3177 1) 

3178 return sparse_tensor.SparseTensor(indices, data, shape) 

3179 else: 

3180 if ops.executing_eagerly_outside_functions(): 

3181 # In TF2 we do not silently densify sparse matrices. 

3182 raise ValueError('A SciPy sparse matrix was passed to a model ' 

3183 'that expects dense inputs. Please densify your ' 

3184 'inputs first, such as by calling `x.toarray().') 

3185 return value.toarray() 

3186 else: 

3187 return value 

3188 

3189 

3190def _get_metrics_from_layers(layers): 

3191 """Returns list of metrics from the given layers. 

3192 

3193 This will not include the `compile` metrics of a model layer. 

3194 

3195 Args: 

3196 layers: List of layers. 

3197 

3198 Returns: 

3199 List of metrics. 

3200 """ 

3201 metrics = [] 

3202 layers = layer_utils.filter_empty_layer_containers(layers) 

3203 for layer in layers: 

3204 if isinstance(layer, Model): 

3205 # We cannot call 'metrics' on the model because we do not want to 

3206 # include the metrics that were added in compile API of a nested model. 

3207 metrics.extend(layer._metrics) # pylint: disable=protected-access 

3208 metrics.extend(_get_metrics_from_layers(layer.layers)) 

3209 else: 

3210 metrics.extend(layer.metrics) 

3211 return metrics 

3212 

3213 

3214def _non_none_constant_value(v): 

3215 constant_value = tensor_util.constant_value(v) 

3216 return constant_value if constant_value is not None else v