Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py: 19%

939 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Training-related part of the Keras engine.""" 

16 

17import copy 

18import itertools 

19import json 

20import os 

21import warnings 

22import weakref 

23 

24from tensorflow.python.autograph.lang import directives 

25from tensorflow.python.checkpoint import checkpoint as trackable_utils 

26from tensorflow.python.checkpoint import checkpoint_management 

27from tensorflow.python.data.ops import options as options_lib 

28from tensorflow.python.distribute import collective_all_reduce_strategy 

29from tensorflow.python.distribute import distribute_lib 

30from tensorflow.python.distribute import values as ds_values 

31from tensorflow.python.distribute.coordinator import cluster_coordinator 

32from tensorflow.python.eager import backprop 

33from tensorflow.python.eager import context 

34from tensorflow.python.eager import def_function 

35from tensorflow.python.framework import composite_tensor 

36from tensorflow.python.framework import errors 

37from tensorflow.python.framework import errors_impl 

38from tensorflow.python.framework import func_graph 

39from tensorflow.python.framework import ops 

40from tensorflow.python.framework import sparse_tensor 

41from tensorflow.python.framework import tensor_shape 

42from tensorflow.python.keras import backend 

43from tensorflow.python.keras import callbacks as callbacks_module 

44from tensorflow.python.keras import optimizer_v1 

45from tensorflow.python.keras import optimizers 

46from tensorflow.python.keras.engine import base_layer 

47from tensorflow.python.keras.engine import base_layer_utils 

48from tensorflow.python.keras.engine import compile_utils 

49from tensorflow.python.keras.engine import data_adapter 

50from tensorflow.python.keras.engine import training_utils 

51from tensorflow.python.keras.mixed_precision import loss_scale_optimizer as lso 

52from tensorflow.python.keras.mixed_precision import policy 

53from tensorflow.python.keras.saving import hdf5_format 

54from tensorflow.python.keras.saving import save 

55from tensorflow.python.keras.saving import saving_utils 

56from tensorflow.python.keras.saving.saved_model import json_utils 

57from tensorflow.python.keras.saving.saved_model import model_serialization 

58from tensorflow.python.keras.utils import generic_utils 

59from tensorflow.python.keras.utils import layer_utils 

60from tensorflow.python.keras.utils import tf_utils 

61from tensorflow.python.keras.utils import version_utils 

62from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 

63from tensorflow.python.keras.utils.io_utils import path_to_string 

64from tensorflow.python.keras.utils.mode_keys import ModeKeys 

65from tensorflow.python.ops import array_ops 

66from tensorflow.python.ops import array_ops_stack 

67from tensorflow.python.ops import math_ops 

68from tensorflow.python.ops import sparse_ops 

69from tensorflow.python.ops import summary_ops_v2 

70from tensorflow.python.ops import variables 

71from tensorflow.python.platform import tf_logging as logging 

72from tensorflow.python.profiler import trace 

73from tensorflow.python.saved_model import constants as sm_constants 

74from tensorflow.python.saved_model import loader_impl as sm_loader 

75from tensorflow.python.trackable import base as trackable 

76from tensorflow.python.training import py_checkpoint_reader 

77from tensorflow.python.types import data as data_types 

78from tensorflow.python.util import nest 

79from tensorflow.python.util import tf_decorator 

80from tensorflow.python.util.tf_export import keras_export 

81from tensorflow.tools.docs import doc_controls 

82 

83 

84# pylint: disable=g-import-not-at-top 

85try: 

86 import h5py 

87except ImportError: 

88 h5py = None 

89# pylint: enable=g-import-not-at-top 

90 

91 

92def disable_multi_worker(method): 

93 """Decorator that disallows multi-worker use of `method`.""" 

94 

95 def _method_wrapper(self, *args, **kwargs): 

96 if self._in_multi_worker_mode(): # pylint: disable=protected-access 

97 raise ValueError('{} is not supported in multi-worker mode.'.format( 

98 method.__name__)) 

99 return method(self, *args, **kwargs) 

100 

101 return tf_decorator.make_decorator( 

102 target=method, decorator_func=_method_wrapper) 

103 

104 

105def inject_functional_model_class(cls): 

106 """Inject `Functional` into the hierarchy of this class if needed.""" 

107 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 

108 from tensorflow.python.keras.engine import training_v1 # pylint: disable=g-import-not-at-top 

109 if cls == Model or cls == training_v1.Model: 

110 return functional.Functional 

111 # In case there is any multiple inheritance, we stop injecting the 

112 # class if keras model is not in its class hierarchy. 

113 if cls == object: 

114 return object 

115 

116 cls.__bases__ = tuple(inject_functional_model_class(base) 

117 for base in cls.__bases__) 

118 # Trigger any `__new__` class swapping that needed to happen on `Functional` 

119 # but did not because functional was not in the class hierarchy. 

120 cls.__new__(cls) 

121 

122 return cls 

123 

124 

125def is_functional_model_init_params(args, kwargs): 

126 return (len(args) == 2 or 

127 len(args) == 1 and 'outputs' in kwargs or 

128 'inputs' in kwargs and 'outputs' in kwargs) 

129 

130 

131@keras_export('keras.Model', 'keras.models.Model') 

132class Model(base_layer.Layer, version_utils.ModelVersionSelector): 

133 """`Model` groups layers into an object with training and inference features. 

134 

135 Args: 

136 inputs: The input(s) of the model: a `keras.Input` object or list of 

137 `keras.Input` objects. 

138 outputs: The output(s) of the model. See Functional API example below. 

139 name: String, the name of the model. 

140 

141 There are two ways to instantiate a `Model`: 

142 

143 1 - With the "Functional API", where you start from `Input`, 

144 you chain layer calls to specify the model's forward pass, 

145 and finally you create your model from inputs and outputs: 

146 

147 ```python 

148 import tensorflow as tf 

149 

150 inputs = tf.keras.Input(shape=(3,)) 

151 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) 

152 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) 

153 model = tf.keras.Model(inputs=inputs, outputs=outputs) 

154 ``` 

155 

156 Note: Only dicts, lists, and tuples of input tensors are supported. Nested 

157 inputs are not supported (e.g. lists of list or dicts of dict). 

158 

159 2 - By subclassing the `Model` class: in that case, you should define your 

160 layers in `__init__` and you should implement the model's forward pass 

161 in `call`. 

162 

163 ```python 

164 import tensorflow as tf 

165 

166 class MyModel(tf.keras.Model): 

167 

168 def __init__(self): 

169 super(MyModel, self).__init__() 

170 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 

171 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 

172 

173 def call(self, inputs): 

174 x = self.dense1(inputs) 

175 return self.dense2(x) 

176 

177 model = MyModel() 

178 ``` 

179 

180 If you subclass `Model`, you can optionally have 

181 a `training` argument (boolean) in `call`, which you can use to specify 

182 a different behavior in training and inference: 

183 

184 ```python 

185 import tensorflow as tf 

186 

187 class MyModel(tf.keras.Model): 

188 

189 def __init__(self): 

190 super(MyModel, self).__init__() 

191 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 

192 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 

193 self.dropout = tf.keras.layers.Dropout(0.5) 

194 

195 def call(self, inputs, training=False): 

196 x = self.dense1(inputs) 

197 if training: 

198 x = self.dropout(x, training=training) 

199 return self.dense2(x) 

200 

201 model = MyModel() 

202 ``` 

203 

204 Once the model is created, you can config the model with losses and metrics 

205 with `model.compile()`, train the model with `model.fit()`, or use the model 

206 to do prediction with `model.predict()`. 

207 """ 

208 _TF_MODULE_IGNORED_PROPERTIES = frozenset( 

209 itertools.chain(('_train_counter', '_test_counter', '_predict_counter', 

210 '_steps_per_execution'), 

211 base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access 

212 

213 def __new__(cls, *args, **kwargs): 

214 # Signature detection 

215 if is_functional_model_init_params(args, kwargs) and cls == Model: 

216 # Functional model 

217 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 

218 return functional.Functional(skip_init=True, *args, **kwargs) 

219 else: 

220 return super(Model, cls).__new__(cls, *args, **kwargs) 

221 

222 @trackable.no_automatic_dependency_tracking 

223 def __init__(self, *args, **kwargs): 

224 self._is_model_for_instrumentation = True 

225 

226 # Special case for Subclassed Functional Model, which we couldn't detect 

227 # when __new__ is called. We only realize it is a functional model when it 

228 # calls super.__init__ with input and output tensor. 

229 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 

230 if (is_functional_model_init_params(args, kwargs) and 

231 not isinstance(self, functional.Functional)): 

232 # Filter the kwargs for multiple inheritance. 

233 supported_kwargs = ['inputs', 'outputs', 'name', 'trainable', 'skip_init'] 

234 model_kwargs = {k: kwargs[k] for k in kwargs if k in supported_kwargs} 

235 other_kwargs = {k: kwargs[k] for k in kwargs if k not in supported_kwargs} 

236 inject_functional_model_class(self.__class__) 

237 functional.Functional.__init__(self, *args, **model_kwargs) 

238 

239 # In case there is any multiple inheritance here, we need to call the 

240 # __init__ for any class that appears after the Functional class. 

241 clz_to_init = [] 

242 found_functional_class = False 

243 for clz in self.__class__.__bases__: 

244 if issubclass(clz, functional.Functional): 

245 found_functional_class = True 

246 continue 

247 if found_functional_class: 

248 clz_to_init.append(clz) 

249 

250 if clz_to_init: 

251 for clz in clz_to_init: 

252 clz.__init__(self, *args, **other_kwargs) 

253 elif other_kwargs: 

254 # In case there are unused kwargs, we should raise an error to user, in 

255 # case they have a typo in the param name. 

256 raise TypeError( 

257 'The following keyword arguments aren\'t supported: {}'.format( 

258 other_kwargs)) 

259 return 

260 

261 # The following are implemented as property functions: 

262 # self.trainable_weights 

263 # self.non_trainable_weights 

264 # `inputs` / `outputs` will only appear in kwargs if either are misspelled. 

265 generic_utils.validate_kwargs(kwargs, { 

266 'trainable', 'dtype', 'dynamic', 'name', 'autocast', 'inputs', 'outputs' 

267 }) 

268 super(Model, self).__init__(**kwargs) 

269 # By default, Model is a subclass model, which is not in graph network. 

270 self._is_graph_network = False 

271 

272 self.inputs = None 

273 self.outputs = None 

274 self.input_names = None 

275 self.output_names = None 

276 # stop_training is used by callback to stop training when error happens 

277 self.stop_training = False 

278 self.history = None 

279 # These objects are used in the default `Model.compile`. They are not 

280 # guaranteed to be set after `Model.compile` is called, as users can 

281 # override compile with custom logic. 

282 self.compiled_loss = None 

283 self.compiled_metrics = None 

284 

285 # This is True for Sequential networks and Functional networks. 

286 self._compute_output_and_mask_jointly = False 

287 

288 # Don't reset compilation if already done. This may occur if calling 

289 # `__init__` (or `_init_graph_network`) on an already-compiled model 

290 # such as a Sequential model. Sequential models may need to rebuild 

291 # themselves after compilation. 

292 self._maybe_create_attribute('_is_compiled', False) 

293 self._maybe_create_attribute('optimizer', None) 

294 

295 # Model must be created under scope of DistStrat it will be trained with. 

296 if distribute_lib.has_strategy(): 

297 self._distribution_strategy = distribute_lib.get_strategy() 

298 else: 

299 self._distribution_strategy = None 

300 

301 self._cluster_coordinator = None 

302 

303 # Defaults to value of `tf.config.experimental_functions_run_eagerly`. 

304 self._run_eagerly = None 

305 # Initialize cache attrs. 

306 self._reset_compile_cache() 

307 

308 # Fault-tolerance handler. Set in `ModelCheckpoint`. 

309 self._training_state = None 

310 self._saved_model_inputs_spec = None 

311 self._checkpoint = trackable_utils.Checkpoint(root=weakref.ref(self)) 

312 

313 self._steps_per_execution = None 

314 

315 self._init_batch_counters() 

316 self._base_model_initialized = True 

317 

318 @trackable.no_automatic_dependency_tracking 

319 def _init_batch_counters(self): 

320 # Untracked Variables, used to keep track of mini-batches seen in `fit`, 

321 # `evaluate`, and `predict`. 

322 agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA 

323 self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg) 

324 self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg) 

325 self._predict_counter = variables.Variable( 

326 0, dtype='int64', aggregation=agg) 

327 

328 def __setattr__(self, name, value): 

329 if not getattr(self, '_self_setattr_tracking', True): 

330 super(Model, self).__setattr__(name, value) 

331 return 

332 

333 if all( 

334 isinstance(v, (base_layer.Layer, variables.Variable)) or 

335 base_layer_utils.has_weights(v) for v in nest.flatten(value)): 

336 try: 

337 self._base_model_initialized 

338 except AttributeError: 

339 raise RuntimeError( 

340 'It looks like you are subclassing `Model` and you ' 

341 'forgot to call `super().__init__()`.' 

342 ' Always start with this line.') 

343 

344 super(Model, self).__setattr__(name, value) 

345 

346 @generic_utils.default 

347 def build(self, input_shape): 

348 """Builds the model based on input shapes received. 

349 

350 This is to be used for subclassed models, which do not know at instantiation 

351 time what their inputs look like. 

352 

353 This method only exists for users who want to call `model.build()` in a 

354 standalone way (as a substitute for calling the model on real data to 

355 build it). It will never be called by the framework (and thus it will 

356 never throw unexpected errors in an unrelated workflow). 

357 

358 Args: 

359 input_shape: Single tuple, TensorShape, or list/dict of shapes, where 

360 shapes are tuples, integers, or TensorShapes. 

361 

362 Raises: 

363 ValueError: 

364 1. In case of invalid user-provided data (not of type tuple, 

365 list, TensorShape, or dict). 

366 2. If the model requires call arguments that are agnostic 

367 to the input shapes (positional or kwarg in call signature). 

368 3. If not all layers were properly built. 

369 4. If float type inputs are not supported within the layers. 

370 

371 In each of these cases, the user should build their model by calling it 

372 on real tensor data. 

373 """ 

374 if self._is_graph_network: 

375 super(Model, self).build(input_shape) 

376 return 

377 

378 if input_shape is None: 

379 raise ValueError('Input shape must be defined when calling build on a ' 

380 'model subclass network.') 

381 valid_types = (tuple, list, tensor_shape.TensorShape, dict) 

382 if not isinstance(input_shape, valid_types): 

383 raise ValueError('Specified input shape is not one of the valid types. ' 

384 'Please specify a batch input shape of type tuple or ' 

385 'list of input shapes. User provided ' 

386 'input type: {}'.format(type(input_shape))) 

387 

388 if input_shape and not self.inputs: 

389 # We create placeholders for the `None`s in the shape and build the model 

390 # in a Graph. Since tf.Variable is compatible with both eager execution 

391 # and graph building, the variables created after building the model in 

392 # a Graph are still valid when executing eagerly. 

393 if context.executing_eagerly(): 

394 graph = func_graph.FuncGraph('build_graph') 

395 else: 

396 graph = backend.get_graph() 

397 with graph.as_default(): 

398 if (isinstance(input_shape, list) and 

399 all(d is None or isinstance(d, int) for d in input_shape)): 

400 input_shape = tuple(input_shape) 

401 if isinstance(input_shape, list): 

402 x = [base_layer_utils.generate_placeholders_from_shape(shape) 

403 for shape in input_shape] 

404 elif isinstance(input_shape, dict): 

405 x = { 

406 k: base_layer_utils.generate_placeholders_from_shape(shape) 

407 for k, shape in input_shape.items() 

408 } 

409 else: 

410 x = base_layer_utils.generate_placeholders_from_shape(input_shape) 

411 

412 kwargs = {} 

413 call_signature = self._call_full_argspec 

414 call_args = call_signature.args 

415 # Exclude `self`, `inputs`, and any argument with a default value. 

416 if len(call_args) > 2: 

417 if call_signature.defaults: 

418 call_args = call_args[2:-len(call_signature.defaults)] 

419 else: 

420 call_args = call_args[2:] 

421 for arg in call_args: 

422 if arg == 'training': 

423 # Case where `training` is a positional arg with no default. 

424 kwargs['training'] = False 

425 else: 

426 # Has invalid call signature with unknown positional arguments. 

427 raise ValueError( 

428 'Currently, you cannot build your model if it has ' 

429 'positional or keyword arguments that are not ' 

430 'inputs to the model, but are required for its ' 

431 '`call` method. Instead, in order to instantiate ' 

432 'and build your model, `call` your model on real ' 

433 'tensor data with all expected call arguments.') 

434 elif len(call_args) < 2: 

435 # Signature without `inputs`. 

436 raise ValueError('You can only call `build` on a model if its `call` ' 

437 'method accepts an `inputs` argument.') 

438 try: 

439 self.call(x, **kwargs) 

440 except (errors.InvalidArgumentError, TypeError): 

441 raise ValueError('You cannot build your model by calling `build` ' 

442 'if your layers do not support float type inputs. ' 

443 'Instead, in order to instantiate and build your ' 

444 'model, `call` your model on real tensor data (of ' 

445 'the correct dtype).') 

446 super(Model, self).build(input_shape) 

447 

448 @doc_controls.doc_in_current_and_subclasses 

449 def call(self, inputs, training=None, mask=None): 

450 """Calls the model on new inputs. 

451 

452 In this case `call` just reapplies 

453 all ops in the graph to the new inputs 

454 (e.g. build a new computational graph from the provided inputs). 

455 

456 Note: This method should not be called directly. It is only meant to be 

457 overridden when subclassing `tf.keras.Model`. 

458 To call a model on an input, always use the `__call__` method, 

459 i.e. `model(inputs)`, which relies on the underlying `call` method. 

460 

461 Args: 

462 inputs: Input tensor, or dict/list/tuple of input tensors. 

463 training: Boolean or boolean scalar tensor, indicating whether to run 

464 the `Network` in training mode or inference mode. 

465 mask: A mask or list of masks. A mask can be 

466 either a tensor or None (no mask). 

467 

468 Returns: 

469 A tensor if there is a single output, or 

470 a list of tensors if there are more than one outputs. 

471 """ 

472 raise NotImplementedError('When subclassing the `Model` class, you should ' 

473 'implement a `call` method.') 

474 

475 def compile(self, 

476 optimizer='rmsprop', 

477 loss=None, 

478 metrics=None, 

479 loss_weights=None, 

480 weighted_metrics=None, 

481 run_eagerly=None, 

482 steps_per_execution=None, 

483 **kwargs): 

484 """Configures the model for training. 

485 

486 Args: 

487 optimizer: String (name of optimizer) or optimizer instance. See 

488 `tf.keras.optimizers`. 

489 loss: String (name of objective function), objective function or 

490 `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective 

491 function is any callable with the signature `loss = fn(y_true, 

492 y_pred)`, where y_true = ground truth values with shape = 

493 `[batch_size, d0, .. dN]`, except sparse loss functions such as sparse 

494 categorical crossentropy where shape = `[batch_size, d0, .. dN-1]`. 

495 y_pred = predicted values with shape = `[batch_size, d0, .. dN]`. It 

496 returns a weighted loss float tensor. If a custom `Loss` instance is 

497 used and reduction is set to `None`, return value has the shape 

498 `[batch_size, d0, .. dN-1]` i.e. per-sample or per-timestep loss 

499 values; otherwise, it is a scalar. If the model has multiple outputs, 

500 you can use a different loss on each output by passing a dictionary 

501 or a list of losses. The loss value that will be minimized by the 

502 model will then be the sum of all individual losses, unless 

503 `loss_weights` is specified. 

504 metrics: List of metrics to be evaluated by the model during training 

505 and testing. Each of this can be a string (name of a built-in 

506 function), function or a `tf.keras.metrics.Metric` instance. See 

507 `tf.keras.metrics`. Typically you will use `metrics=['accuracy']`. A 

508 function is any callable with the signature `result = fn(y_true, 

509 y_pred)`. To specify different metrics for different outputs of a 

510 multi-output model, you could also pass a dictionary, such as 

511 `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`. 

512 You can also pass a list to specify a metric or a list of metrics 

513 for each output, such as `metrics=[['accuracy'], ['accuracy', 'mse']]` 

514 or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass the 

515 strings 'accuracy' or 'acc', we convert this to one of 

516 `tf.keras.metrics.BinaryAccuracy`, 

517 `tf.keras.metrics.CategoricalAccuracy`, 

518 `tf.keras.metrics.SparseCategoricalAccuracy` based on the loss 

519 function used and the model output shape. We do a similar 

520 conversion for the strings 'crossentropy' and 'ce' as well. 

521 loss_weights: Optional list or dictionary specifying scalar coefficients 

522 (Python floats) to weight the loss contributions of different model 

523 outputs. The loss value that will be minimized by the model will then 

524 be the *weighted sum* of all individual losses, weighted by the 

525 `loss_weights` coefficients. 

526 If a list, it is expected to have a 1:1 mapping to the model's 

527 outputs. If a dict, it is expected to map output names (strings) 

528 to scalar coefficients. 

529 weighted_metrics: List of metrics to be evaluated and weighted by 

530 `sample_weight` or `class_weight` during training and testing. 

531 run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s 

532 logic will not be wrapped in a `tf.function`. Recommended to leave 

533 this as `None` unless your `Model` cannot be run inside a 

534 `tf.function`. `run_eagerly=True` is not supported when using 

535 `tf.distribute.experimental.ParameterServerStrategy`. 

536 steps_per_execution: Int. Defaults to 1. The number of batches to 

537 run during each `tf.function` call. Running multiple batches 

538 inside a single `tf.function` call can greatly improve performance 

539 on TPUs or small models with a large Python overhead. 

540 At most, one full epoch will be run each 

541 execution. If a number larger than the size of the epoch is passed, 

542 the execution will be truncated to the size of the epoch. 

543 Note that if `steps_per_execution` is set to `N`, 

544 `Callback.on_batch_begin` and `Callback.on_batch_end` methods 

545 will only be called every `N` batches 

546 (i.e. before/after each `tf.function` execution). 

547 **kwargs: Arguments supported for backwards compatibility only. 

548 

549 Raises: 

550 ValueError: In case of invalid arguments for 

551 `optimizer`, `loss` or `metrics`. 

552 """ 

553 with self.distribute_strategy.scope(): 

554 if 'experimental_steps_per_execution' in kwargs: 

555 logging.warning('The argument `steps_per_execution` is no longer ' 

556 'experimental. Pass `steps_per_execution` instead of ' 

557 '`experimental_steps_per_execution`.') 

558 if not steps_per_execution: 

559 steps_per_execution = kwargs.pop('experimental_steps_per_execution') 

560 

561 # When compiling from an already-serialized model, we do not want to 

562 # reapply some processing steps (e.g. metric renaming for multi-output 

563 # models, which have prefixes added for each corresponding output name). 

564 from_serialized = kwargs.pop('from_serialized', False) 

565 

566 self._validate_compile(optimizer, metrics, **kwargs) 

567 self._run_eagerly = run_eagerly 

568 

569 self.optimizer = self._get_optimizer(optimizer) 

570 self.compiled_loss = compile_utils.LossesContainer( 

571 loss, loss_weights, output_names=self.output_names) 

572 self.compiled_metrics = compile_utils.MetricsContainer( 

573 metrics, weighted_metrics, output_names=self.output_names, 

574 from_serialized=from_serialized) 

575 

576 self._configure_steps_per_execution(steps_per_execution or 1) 

577 

578 # Initializes attrs that are reset each time `compile` is called. 

579 self._reset_compile_cache() 

580 self._is_compiled = True 

581 

582 self.loss = loss or {} # Backwards compat. 

583 

584 def _get_optimizer(self, optimizer): 

585 """Wraps `optimizer` in `LossScaleOptimizer` if necessary.""" 

586 # The deprecated PolicyV1 has a loss_scale, which we use for backwards 

587 # compatibility to match TF 2.3 behavior. The new Policy does not have a 

588 # loss_scale, so we use dynamic loss scaling if the mixed_float16 policy is 

589 # used. 

590 if isinstance(self._dtype_policy, policy.PolicyV1): 

591 loss_scale = self._dtype_policy.loss_scale 

592 elif self._dtype_policy.name == 'mixed_float16': 

593 loss_scale = 'dynamic' 

594 else: 

595 loss_scale = None 

596 

597 def _get_single_optimizer(opt): 

598 opt = optimizers.get(opt) 

599 if (loss_scale is not None and 

600 not isinstance(opt, lso.LossScaleOptimizer)): 

601 if loss_scale == 'dynamic': 

602 opt = lso.LossScaleOptimizer(opt) 

603 else: 

604 opt = lso.LossScaleOptimizerV1(opt, loss_scale) 

605 return opt 

606 

607 return nest.map_structure(_get_single_optimizer, optimizer) 

608 

609 @trackable.no_automatic_dependency_tracking 

610 def _reset_compile_cache(self): 

611 self.train_function = None 

612 self.test_function = None 

613 self.predict_function = None 

614 # Used to cache the `tf.function`'ed `train_function` to be logged in 

615 # TensorBoard, since the original `train_function` is not necessarily 

616 # a `tf.function` (e.g., with ParameterServerStrategy, the `train_function` 

617 # is a scheduling of the actual training function to a remote worker). 

618 self.train_tf_function = None 

619 

620 # Used to cache `trainable` attr of `Layer`s for `fit`. 

621 self._compiled_trainable_state = self._get_trainable_state() 

622 

623 @trackable.no_automatic_dependency_tracking 

624 def _configure_steps_per_execution(self, steps_per_execution): 

625 self._steps_per_execution = variables.Variable( 

626 steps_per_execution, 

627 dtype='int64', 

628 aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA) 

629 

630 @property 

631 def _should_compute_mask(self): 

632 return False 

633 

634 @property 

635 def metrics(self): 

636 """Returns the model's metrics added using `compile`, `add_metric` APIs. 

637 

638 Note: Metrics passed to `compile()` are available only after a `keras.Model` 

639 has been trained/evaluated on actual data. 

640 

641 Examples: 

642 

643 >>> inputs = tf.keras.layers.Input(shape=(3,)) 

644 >>> outputs = tf.keras.layers.Dense(2)(inputs) 

645 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 

646 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 

647 >>> [m.name for m in model.metrics] 

648 [] 

649 

650 >>> x = np.random.random((2, 3)) 

651 >>> y = np.random.randint(0, 2, (2, 2)) 

652 >>> model.fit(x, y) 

653 >>> [m.name for m in model.metrics] 

654 ['loss', 'mae'] 

655 

656 >>> inputs = tf.keras.layers.Input(shape=(3,)) 

657 >>> d = tf.keras.layers.Dense(2, name='out') 

658 >>> output_1 = d(inputs) 

659 >>> output_2 = d(inputs) 

660 >>> model = tf.keras.models.Model( 

661 ... inputs=inputs, outputs=[output_1, output_2]) 

662 >>> model.add_metric( 

663 ... tf.reduce_sum(output_2), name='mean', aggregation='mean') 

664 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) 

665 >>> model.fit(x, (y, y)) 

666 >>> [m.name for m in model.metrics] 

667 ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', 

668 'out_1_acc', 'mean'] 

669 

670 """ 

671 metrics = [] 

672 if self._is_compiled: 

673 # TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects 

674 # so that attr names are not load-bearing. 

675 if self.compiled_loss is not None: 

676 metrics += self.compiled_loss.metrics 

677 if self.compiled_metrics is not None: 

678 metrics += self.compiled_metrics.metrics 

679 

680 for l in self._flatten_layers(): 

681 metrics.extend(l._metrics) # pylint: disable=protected-access 

682 return metrics 

683 

684 @property 

685 def metrics_names(self): 

686 """Returns the model's display labels for all outputs. 

687 

688 Note: `metrics_names` are available only after a `keras.Model` has been 

689 trained/evaluated on actual data. 

690 

691 Examples: 

692 

693 >>> inputs = tf.keras.layers.Input(shape=(3,)) 

694 >>> outputs = tf.keras.layers.Dense(2)(inputs) 

695 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 

696 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 

697 >>> model.metrics_names 

698 [] 

699 

700 >>> x = np.random.random((2, 3)) 

701 >>> y = np.random.randint(0, 2, (2, 2)) 

702 >>> model.fit(x, y) 

703 >>> model.metrics_names 

704 ['loss', 'mae'] 

705 

706 >>> inputs = tf.keras.layers.Input(shape=(3,)) 

707 >>> d = tf.keras.layers.Dense(2, name='out') 

708 >>> output_1 = d(inputs) 

709 >>> output_2 = d(inputs) 

710 >>> model = tf.keras.models.Model( 

711 ... inputs=inputs, outputs=[output_1, output_2]) 

712 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) 

713 >>> model.fit(x, (y, y)) 

714 >>> model.metrics_names 

715 ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', 

716 'out_1_acc'] 

717 

718 """ 

719 

720 # This property includes all output names including `loss` and per-output 

721 # losses for backward compatibility. 

722 return [m.name for m in self.metrics] 

723 

724 @property 

725 def distribute_strategy(self): 

726 """The `tf.distribute.Strategy` this model was created under.""" 

727 return self._distribution_strategy or distribute_lib.get_strategy() 

728 

729 @property 

730 def run_eagerly(self): 

731 """Settable attribute indicating whether the model should run eagerly. 

732 

733 Running eagerly means that your model will be run step by step, 

734 like Python code. Your model might run slower, but it should become easier 

735 for you to debug it by stepping into individual layer calls. 

736 

737 By default, we will attempt to compile your model to a static graph to 

738 deliver the best execution performance. 

739 

740 Returns: 

741 Boolean, whether the model should run eagerly. 

742 """ 

743 if self.dynamic and self._run_eagerly is False: # pylint:disable=g-bool-id-comparison 

744 # TODO(fchollet): consider using py_func to enable this. 

745 raise ValueError('Your model contains layers that can only be ' 

746 'successfully run in eager execution (layers ' 

747 'constructed with `dynamic=True`). ' 

748 'You cannot set `run_eagerly=False`.') 

749 

750 if self._cluster_coordinator and self._run_eagerly: 

751 raise ValueError('When using `Model` with `ParameterServerStrategy`, ' 

752 '`run_eagerly` is not supported.') 

753 

754 # Run eagerly logic, by priority: 

755 # (1) Dynamic models must be run eagerly. 

756 # (2) Explicitly setting run_eagerly causes a Model to be run eagerly. 

757 # (3) Not explicitly setting run_eagerly defaults to TF's global setting. 

758 return (self.dynamic or self._run_eagerly or 

759 (def_function.functions_run_eagerly() and 

760 self._run_eagerly is None)) 

761 

762 @run_eagerly.setter 

763 def run_eagerly(self, value): 

764 self._run_eagerly = value 

765 

766 def train_step(self, data): 

767 """The logic for one training step. 

768 

769 This method can be overridden to support custom training logic. 

770 For concrete examples of how to override this method see 

771 [Customizing what happends in fit](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit). 

772 This method is called by `Model.make_train_function`. 

773 

774 This method should contain the mathematical logic for one step of training. 

775 This typically includes the forward pass, loss calculation, backpropagation, 

776 and metric updates. 

777 

778 Configuration details for *how* this logic is run (e.g. `tf.function` and 

779 `tf.distribute.Strategy` settings), should be left to 

780 `Model.make_train_function`, which can also be overridden. 

781 

782 Args: 

783 data: A nested structure of `Tensor`s. 

784 

785 Returns: 

786 A `dict` containing values that will be passed to 

787 `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the 

788 values of the `Model`'s metrics are returned. Example: 

789 `{'loss': 0.2, 'accuracy': 0.7}`. 

790 

791 """ 

792 # These are the only transformations `Model.fit` applies to user-input 

793 # data when a `tf.data.Dataset` is provided. 

794 data = data_adapter.expand_1d(data) 

795 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) 

796 # Run forward pass. 

797 with backprop.GradientTape() as tape: 

798 y_pred = self(x, training=True) 

799 loss = self.compiled_loss( 

800 y, y_pred, sample_weight, regularization_losses=self.losses) 

801 # Run backwards pass. 

802 self.optimizer.minimize(loss, self.trainable_variables, tape=tape) 

803 self.compiled_metrics.update_state(y, y_pred, sample_weight) 

804 # Collect metrics to return 

805 return_metrics = {} 

806 for metric in self.metrics: 

807 result = metric.result() 

808 if isinstance(result, dict): 

809 return_metrics.update(result) 

810 else: 

811 return_metrics[metric.name] = result 

812 return return_metrics 

813 

814 def make_train_function(self): 

815 """Creates a function that executes one step of training. 

816 

817 This method can be overridden to support custom training logic. 

818 This method is called by `Model.fit` and `Model.train_on_batch`. 

819 

820 Typically, this method directly controls `tf.function` and 

821 `tf.distribute.Strategy` settings, and delegates the actual training 

822 logic to `Model.train_step`. 

823 

824 This function is cached the first time `Model.fit` or 

825 `Model.train_on_batch` is called. The cache is cleared whenever 

826 `Model.compile` is called. 

827 

828 Returns: 

829 Function. The function created by this method should accept a 

830 `tf.data.Iterator`, and return a `dict` containing values that will 

831 be passed to `tf.keras.Callbacks.on_train_batch_end`, such as 

832 `{'loss': 0.2, 'accuracy': 0.7}`. 

833 """ 

834 if self.train_function is not None: 

835 return self.train_function 

836 

837 def step_function(model, iterator): 

838 """Runs a single training step.""" 

839 

840 def run_step(data): 

841 outputs = model.train_step(data) 

842 # Ensure counter is updated only if `train_step` succeeds. 

843 with ops.control_dependencies(_minimum_control_deps(outputs)): 

844 model._train_counter.assign_add(1) # pylint: disable=protected-access 

845 return outputs 

846 

847 data = next(iterator) 

848 outputs = model.distribute_strategy.run(run_step, args=(data,)) 

849 outputs = reduce_per_replica( 

850 outputs, self.distribute_strategy, reduction='first') 

851 write_scalar_summaries(outputs, step=model._train_counter) # pylint: disable=protected-access 

852 return outputs 

853 

854 if self._steps_per_execution.numpy().item() == 1: 

855 

856 def train_function(iterator): 

857 """Runs a training execution with one step.""" 

858 return step_function(self, iterator) 

859 

860 else: 

861 

862 def train_function(iterator): 

863 """Runs a training execution with multiple steps.""" 

864 for _ in math_ops.range(self._steps_per_execution): 

865 outputs = step_function(self, iterator) 

866 return outputs 

867 

868 if not self.run_eagerly: 

869 train_function = def_function.function( 

870 train_function, experimental_relax_shapes=True) 

871 self.train_tf_function = train_function 

872 

873 self.train_function = train_function 

874 

875 if self._cluster_coordinator: 

876 self.train_function = lambda iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda 

877 train_function, args=(iterator,)) 

878 

879 return self.train_function 

880 

881 def fit(self, 

882 x=None, 

883 y=None, 

884 batch_size=None, 

885 epochs=1, 

886 verbose='auto', 

887 callbacks=None, 

888 validation_split=0., 

889 validation_data=None, 

890 shuffle=True, 

891 class_weight=None, 

892 sample_weight=None, 

893 initial_epoch=0, 

894 steps_per_epoch=None, 

895 validation_steps=None, 

896 validation_batch_size=None, 

897 validation_freq=1, 

898 max_queue_size=10, 

899 workers=1, 

900 use_multiprocessing=False): 

901 """Trains the model for a fixed number of epochs (iterations on a dataset). 

902 

903 Args: 

904 x: Input data. It could be: 

905 - A Numpy array (or array-like), or a list of arrays 

906 (in case the model has multiple inputs). 

907 - A TensorFlow tensor, or a list of tensors 

908 (in case the model has multiple inputs). 

909 - A dict mapping input names to the corresponding array/tensors, 

910 if the model has named inputs. 

911 - A `tf.data` dataset. Should return a tuple 

912 of either `(inputs, targets)` or 

913 `(inputs, targets, sample_weights)`. 

914 - A generator or `keras.utils.Sequence` returning `(inputs, targets)` 

915 or `(inputs, targets, sample_weights)`. 

916 - A `tf.keras.utils.experimental.DatasetCreator`, which wraps a 

917 callable that takes a single argument of type 

918 `tf.distribute.InputContext`, and returns a `tf.data.Dataset`. 

919 `DatasetCreator` should be used when users prefer to specify the 

920 per-replica batching and sharding logic for the `Dataset`. 

921 See `tf.keras.utils.experimental.DatasetCreator` doc for more 

922 information. 

923 A more detailed description of unpacking behavior for iterator types 

924 (Dataset, generator, Sequence) is given below. If using 

925 `tf.distribute.experimental.ParameterServerStrategy`, only 

926 `DatasetCreator` type is supported for `x`. 

927 y: Target data. Like the input data `x`, 

928 it could be either Numpy array(s) or TensorFlow tensor(s). 

929 It should be consistent with `x` (you cannot have Numpy inputs and 

930 tensor targets, or inversely). If `x` is a dataset, generator, 

931 or `keras.utils.Sequence` instance, `y` should 

932 not be specified (since targets will be obtained from `x`). 

933 batch_size: Integer or `None`. 

934 Number of samples per gradient update. 

935 If unspecified, `batch_size` will default to 32. 

936 Do not specify the `batch_size` if your data is in the 

937 form of datasets, generators, or `keras.utils.Sequence` instances 

938 (since they generate batches). 

939 epochs: Integer. Number of epochs to train the model. 

940 An epoch is an iteration over the entire `x` and `y` 

941 data provided. 

942 Note that in conjunction with `initial_epoch`, 

943 `epochs` is to be understood as "final epoch". 

944 The model is not trained for a number of iterations 

945 given by `epochs`, but merely until the epoch 

946 of index `epochs` is reached. 

947 verbose: 'auto', 0, 1, or 2. Verbosity mode. 

948 0 = silent, 1 = progress bar, 2 = one line per epoch. 

949 'auto' defaults to 1 for most cases, but 2 when used with 

950 `ParameterServerStrategy`. Note that the progress bar is not 

951 particularly useful when logged to a file, so verbose=2 is 

952 recommended when not running interactively (eg, in a production 

953 environment). 

954 callbacks: List of `keras.callbacks.Callback` instances. 

955 List of callbacks to apply during training. 

956 See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger` 

957 and `tf.keras.callbacks.History` callbacks are created automatically 

958 and need not be passed into `model.fit`. 

959 `tf.keras.callbacks.ProgbarLogger` is created or not based on 

960 `verbose` argument to `model.fit`. 

961 Callbacks with batch-level calls are currently unsupported with 

962 `tf.distribute.experimental.ParameterServerStrategy`, and users are 

963 advised to implement epoch-level calls instead with an appropriate 

964 `steps_per_epoch` value. 

965 validation_split: Float between 0 and 1. 

966 Fraction of the training data to be used as validation data. 

967 The model will set apart this fraction of the training data, 

968 will not train on it, and will evaluate 

969 the loss and any model metrics 

970 on this data at the end of each epoch. 

971 The validation data is selected from the last samples 

972 in the `x` and `y` data provided, before shuffling. This argument is 

973 not supported when `x` is a dataset, generator or 

974 `keras.utils.Sequence` instance. 

975 `validation_split` is not yet supported with 

976 `tf.distribute.experimental.ParameterServerStrategy`. 

977 validation_data: Data on which to evaluate 

978 the loss and any model metrics at the end of each epoch. 

979 The model will not be trained on this data. Thus, note the fact 

980 that the validation loss of data provided using `validation_split` 

981 or `validation_data` is not affected by regularization layers like 

982 noise and dropout. 

983 `validation_data` will override `validation_split`. 

984 `validation_data` could be: 

985 - A tuple `(x_val, y_val)` of Numpy arrays or tensors. 

986 - A tuple `(x_val, y_val, val_sample_weights)` of NumPy arrays. 

987 - A `tf.data.Dataset`. 

988 - A Python generator or `keras.utils.Sequence` returning 

989 `(inputs, targets)` or `(inputs, targets, sample_weights)`. 

990 `validation_data` is not yet supported with 

991 `tf.distribute.experimental.ParameterServerStrategy`. 

992 shuffle: Boolean (whether to shuffle the training data 

993 before each epoch) or str (for 'batch'). This argument is ignored 

994 when `x` is a generator or an object of tf.data.Dataset. 

995 'batch' is a special option for dealing 

996 with the limitations of HDF5 data; it shuffles in batch-sized 

997 chunks. Has no effect when `steps_per_epoch` is not `None`. 

998 class_weight: Optional dictionary mapping class indices (integers) 

999 to a weight (float) value, used for weighting the loss function 

1000 (during training only). 

1001 This can be useful to tell the model to 

1002 "pay more attention" to samples from 

1003 an under-represented class. 

1004 sample_weight: Optional Numpy array of weights for 

1005 the training samples, used for weighting the loss function 

1006 (during training only). You can either pass a flat (1D) 

1007 Numpy array with the same length as the input samples 

1008 (1:1 mapping between weights and samples), 

1009 or in the case of temporal data, 

1010 you can pass a 2D array with shape 

1011 `(samples, sequence_length)`, 

1012 to apply a different weight to every timestep of every sample. This 

1013 argument is not supported when `x` is a dataset, generator, or 

1014 `keras.utils.Sequence` instance, instead provide the sample_weights 

1015 as the third element of `x`. 

1016 initial_epoch: Integer. 

1017 Epoch at which to start training 

1018 (useful for resuming a previous training run). 

1019 steps_per_epoch: Integer or `None`. 

1020 Total number of steps (batches of samples) 

1021 before declaring one epoch finished and starting the 

1022 next epoch. When training with input tensors such as 

1023 TensorFlow data tensors, the default `None` is equal to 

1024 the number of samples in your dataset divided by 

1025 the batch size, or 1 if that cannot be determined. If x is a 

1026 `tf.data` dataset, and 'steps_per_epoch' 

1027 is None, the epoch will run until the input dataset is exhausted. 

1028 When passing an infinitely repeating dataset, you must specify the 

1029 `steps_per_epoch` argument. If `steps_per_epoch=-1` the training 

1030 will run indefinitely with an infinitely repeating dataset. 

1031 This argument is not supported with array inputs. 

1032 When using `tf.distribute.experimental.ParameterServerStrategy`: 

1033 * `steps_per_epoch=None` is not supported. 

1034 validation_steps: Only relevant if `validation_data` is provided and 

1035 is a `tf.data` dataset. Total number of steps (batches of 

1036 samples) to draw before stopping when performing validation 

1037 at the end of every epoch. If 'validation_steps' is None, validation 

1038 will run until the `validation_data` dataset is exhausted. In the 

1039 case of an infinitely repeated dataset, it will run into an 

1040 infinite loop. If 'validation_steps' is specified and only part of 

1041 the dataset will be consumed, the evaluation will start from the 

1042 beginning of the dataset at each epoch. This ensures that the same 

1043 validation samples are used every time. 

1044 validation_batch_size: Integer or `None`. 

1045 Number of samples per validation batch. 

1046 If unspecified, will default to `batch_size`. 

1047 Do not specify the `validation_batch_size` if your data is in the 

1048 form of datasets, generators, or `keras.utils.Sequence` instances 

1049 (since they generate batches). 

1050 validation_freq: Only relevant if validation data is provided. Integer 

1051 or `collections.abc.Container` instance (e.g. list, tuple, etc.). 

1052 If an integer, specifies how many training epochs to run before a 

1053 new validation run is performed, e.g. `validation_freq=2` runs 

1054 validation every 2 epochs. If a Container, specifies the epochs on 

1055 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 

1056 validation at the end of the 1st, 2nd, and 10th epochs. 

1057 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 

1058 input only. Maximum size for the generator queue. 

1059 If unspecified, `max_queue_size` will default to 10. 

1060 workers: Integer. Used for generator or `keras.utils.Sequence` input 

1061 only. Maximum number of processes to spin up 

1062 when using process-based threading. If unspecified, `workers` 

1063 will default to 1. 

1064 use_multiprocessing: Boolean. Used for generator or 

1065 `keras.utils.Sequence` input only. If `True`, use process-based 

1066 threading. If unspecified, `use_multiprocessing` will default to 

1067 `False`. Note that because this implementation relies on 

1068 multiprocessing, you should not pass non-picklable arguments to 

1069 the generator as they can't be passed easily to children processes. 

1070 

1071 Unpacking behavior for iterator-like inputs: 

1072 A common pattern is to pass a tf.data.Dataset, generator, or 

1073 tf.keras.utils.Sequence to the `x` argument of fit, which will in fact 

1074 yield not only features (x) but optionally targets (y) and sample weights. 

1075 Keras requires that the output of such iterator-likes be unambiguous. The 

1076 iterator should return a tuple of length 1, 2, or 3, where the optional 

1077 second and third elements will be used for y and sample_weight 

1078 respectively. Any other type provided will be wrapped in a length one 

1079 tuple, effectively treating everything as 'x'. When yielding dicts, they 

1080 should still adhere to the top-level tuple structure. 

1081 e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate 

1082 features, targets, and weights from the keys of a single dict. 

1083 A notable unsupported data type is the namedtuple. The reason is that 

1084 it behaves like both an ordered datatype (tuple) and a mapping 

1085 datatype (dict). So given a namedtuple of the form: 

1086 `namedtuple("example_tuple", ["y", "x"])` 

1087 it is ambiguous whether to reverse the order of the elements when 

1088 interpreting the value. Even worse is a tuple of the form: 

1089 `namedtuple("other_tuple", ["x", "y", "z"])` 

1090 where it is unclear if the tuple was intended to be unpacked into x, y, 

1091 and sample_weight or passed through as a single element to `x`. As a 

1092 result the data processing code will simply raise a ValueError if it 

1093 encounters a namedtuple. (Along with instructions to remedy the issue.) 

1094 

1095 Returns: 

1096 A `History` object. Its `History.history` attribute is 

1097 a record of training loss values and metrics values 

1098 at successive epochs, as well as validation loss values 

1099 and validation metrics values (if applicable). 

1100 

1101 Raises: 

1102 RuntimeError: 1. If the model was never compiled or, 

1103 2. If `model.fit` is wrapped in `tf.function`. 

1104 

1105 ValueError: In case of mismatch between the provided input data 

1106 and what the model expects or when the input data is empty. 

1107 """ 

1108 # Legacy graph support is contained in `training_v1.Model`. 

1109 version_utils.disallow_legacy_graph('Model', 'fit') 

1110 self._assert_compile_was_called() 

1111 self._check_call_args('fit') 

1112 _disallow_inside_tf_function('fit') 

1113 

1114 if verbose == 'auto': 

1115 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 

1116 verbose = 2 # Default to epoch-level logging for PSStrategy. 

1117 else: 

1118 verbose = 1 # Default to batch-level logging otherwise. 

1119 

1120 if validation_split: 

1121 # Create the validation data using the training data. Only supported for 

1122 # `Tensor` and `NumPy` input. 

1123 (x, y, sample_weight), validation_data = ( 

1124 data_adapter.train_validation_split( 

1125 (x, y, sample_weight), validation_split=validation_split)) 

1126 

1127 if validation_data: 

1128 val_x, val_y, val_sample_weight = ( 

1129 data_adapter.unpack_x_y_sample_weight(validation_data)) 

1130 

1131 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 

1132 self._cluster_coordinator = cluster_coordinator.ClusterCoordinator( 

1133 self.distribute_strategy) 

1134 

1135 with self.distribute_strategy.scope(), \ 

1136 training_utils.RespectCompiledTrainableState(self): 

1137 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 

1138 data_handler = data_adapter.get_data_handler( 

1139 x=x, 

1140 y=y, 

1141 sample_weight=sample_weight, 

1142 batch_size=batch_size, 

1143 steps_per_epoch=steps_per_epoch, 

1144 initial_epoch=initial_epoch, 

1145 epochs=epochs, 

1146 shuffle=shuffle, 

1147 class_weight=class_weight, 

1148 max_queue_size=max_queue_size, 

1149 workers=workers, 

1150 use_multiprocessing=use_multiprocessing, 

1151 model=self, 

1152 steps_per_execution=self._steps_per_execution) 

1153 

1154 # Container that configures and calls `tf.keras.Callback`s. 

1155 if not isinstance(callbacks, callbacks_module.CallbackList): 

1156 callbacks = callbacks_module.CallbackList( 

1157 callbacks, 

1158 add_history=True, 

1159 add_progbar=verbose != 0, 

1160 model=self, 

1161 verbose=verbose, 

1162 epochs=epochs, 

1163 steps=data_handler.inferred_steps) 

1164 

1165 self.stop_training = False 

1166 self.train_function = self.make_train_function() 

1167 self._train_counter.assign(0) 

1168 callbacks.on_train_begin() 

1169 training_logs = None 

1170 # Handle fault-tolerance for multi-worker. 

1171 # TODO(omalleyt): Fix the ordering issues that mean this has to 

1172 # happen after `callbacks.on_train_begin`. 

1173 data_handler._initial_epoch = ( # pylint: disable=protected-access 

1174 self._maybe_load_initial_epoch_from_ckpt(initial_epoch)) 

1175 logs = None 

1176 for epoch, iterator in data_handler.enumerate_epochs(): 

1177 self.reset_metrics() 

1178 callbacks.on_epoch_begin(epoch) 

1179 with data_handler.catch_stop_iteration(): 

1180 for step in data_handler.steps(): 

1181 with trace.Trace( 

1182 'train', 

1183 epoch_num=epoch, 

1184 step_num=step, 

1185 batch_size=batch_size, 

1186 _r=1): 

1187 callbacks.on_train_batch_begin(step) 

1188 tmp_logs = self.train_function(iterator) 

1189 if data_handler.should_sync: 

1190 context.async_wait() 

1191 logs = tmp_logs # No error, now safe to assign to logs. 

1192 end_step = step + data_handler.step_increment 

1193 callbacks.on_train_batch_end(end_step, logs) 

1194 if self.stop_training: 

1195 break 

1196 

1197 logs = tf_utils.sync_to_numpy_or_python_type(logs) 

1198 if logs is None: 

1199 raise ValueError('Expect x to be a non-empty array or dataset.') 

1200 epoch_logs = copy.copy(logs) 

1201 

1202 # Run validation. 

1203 if validation_data and self._should_eval(epoch, validation_freq): 

1204 # Create data_handler for evaluation and cache it. 

1205 if getattr(self, '_eval_data_handler', None) is None: 

1206 self._eval_data_handler = data_adapter.get_data_handler( 

1207 x=val_x, 

1208 y=val_y, 

1209 sample_weight=val_sample_weight, 

1210 batch_size=validation_batch_size or batch_size, 

1211 steps_per_epoch=validation_steps, 

1212 initial_epoch=0, 

1213 epochs=1, 

1214 max_queue_size=max_queue_size, 

1215 workers=workers, 

1216 use_multiprocessing=use_multiprocessing, 

1217 model=self, 

1218 steps_per_execution=self._steps_per_execution) 

1219 val_logs = self.evaluate( 

1220 x=val_x, 

1221 y=val_y, 

1222 sample_weight=val_sample_weight, 

1223 batch_size=validation_batch_size or batch_size, 

1224 steps=validation_steps, 

1225 callbacks=callbacks, 

1226 max_queue_size=max_queue_size, 

1227 workers=workers, 

1228 use_multiprocessing=use_multiprocessing, 

1229 return_dict=True, 

1230 _use_cached_eval_dataset=True) 

1231 val_logs = {'val_' + name: val for name, val in val_logs.items()} 

1232 epoch_logs.update(val_logs) 

1233 

1234 callbacks.on_epoch_end(epoch, epoch_logs) 

1235 training_logs = epoch_logs 

1236 if self.stop_training: 

1237 break 

1238 

1239 # If eval data_hanlder exists, delete it after all epochs are done. 

1240 if getattr(self, '_eval_data_handler', None) is not None: 

1241 del self._eval_data_handler 

1242 callbacks.on_train_end(logs=training_logs) 

1243 return self.history 

1244 

1245 def test_step(self, data): 

1246 """The logic for one evaluation step. 

1247 

1248 This method can be overridden to support custom evaluation logic. 

1249 This method is called by `Model.make_test_function`. 

1250 

1251 This function should contain the mathematical logic for one step of 

1252 evaluation. 

1253 This typically includes the forward pass, loss calculation, and metrics 

1254 updates. 

1255 

1256 Configuration details for *how* this logic is run (e.g. `tf.function` and 

1257 `tf.distribute.Strategy` settings), should be left to 

1258 `Model.make_test_function`, which can also be overridden. 

1259 

1260 Args: 

1261 data: A nested structure of `Tensor`s. 

1262 

1263 Returns: 

1264 A `dict` containing values that will be passed to 

1265 `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the 

1266 values of the `Model`'s metrics are returned. 

1267 """ 

1268 data = data_adapter.expand_1d(data) 

1269 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) 

1270 

1271 y_pred = self(x, training=False) 

1272 # Updates stateful loss metrics. 

1273 self.compiled_loss( 

1274 y, y_pred, sample_weight, regularization_losses=self.losses) 

1275 self.compiled_metrics.update_state(y, y_pred, sample_weight) 

1276 # Collect metrics to return 

1277 return_metrics = {} 

1278 for metric in self.metrics: 

1279 result = metric.result() 

1280 if isinstance(result, dict): 

1281 return_metrics.update(result) 

1282 else: 

1283 return_metrics[metric.name] = result 

1284 return return_metrics 

1285 

1286 def make_test_function(self): 

1287 """Creates a function that executes one step of evaluation. 

1288 

1289 This method can be overridden to support custom evaluation logic. 

1290 This method is called by `Model.evaluate` and `Model.test_on_batch`. 

1291 

1292 Typically, this method directly controls `tf.function` and 

1293 `tf.distribute.Strategy` settings, and delegates the actual evaluation 

1294 logic to `Model.test_step`. 

1295 

1296 This function is cached the first time `Model.evaluate` or 

1297 `Model.test_on_batch` is called. The cache is cleared whenever 

1298 `Model.compile` is called. 

1299 

1300 Returns: 

1301 Function. The function created by this method should accept a 

1302 `tf.data.Iterator`, and return a `dict` containing values that will 

1303 be passed to `tf.keras.Callbacks.on_test_batch_end`. 

1304 """ 

1305 if self.test_function is not None: 

1306 return self.test_function 

1307 

1308 def step_function(model, iterator): 

1309 """Runs a single evaluation step.""" 

1310 

1311 def run_step(data): 

1312 outputs = model.test_step(data) 

1313 # Ensure counter is updated only if `test_step` succeeds. 

1314 with ops.control_dependencies(_minimum_control_deps(outputs)): 

1315 model._test_counter.assign_add(1) # pylint: disable=protected-access 

1316 return outputs 

1317 

1318 data = next(iterator) 

1319 outputs = model.distribute_strategy.run(run_step, args=(data,)) 

1320 outputs = reduce_per_replica( 

1321 outputs, self.distribute_strategy, reduction='first') 

1322 return outputs 

1323 

1324 if self._steps_per_execution.numpy().item() == 1: 

1325 

1326 def test_function(iterator): 

1327 """Runs an evaluation execution with one step.""" 

1328 return step_function(self, iterator) 

1329 

1330 else: 

1331 

1332 def test_function(iterator): 

1333 """Runs an evaluation execution with multiple steps.""" 

1334 for _ in math_ops.range(self._steps_per_execution): 

1335 outputs = step_function(self, iterator) 

1336 return outputs 

1337 

1338 if not self.run_eagerly: 

1339 test_function = def_function.function( 

1340 test_function, experimental_relax_shapes=True) 

1341 

1342 self.test_function = test_function 

1343 

1344 if self._cluster_coordinator: 

1345 self.test_function = lambda iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda 

1346 test_function, args=(iterator,)) 

1347 

1348 return self.test_function 

1349 

1350 def evaluate(self, 

1351 x=None, 

1352 y=None, 

1353 batch_size=None, 

1354 verbose=1, 

1355 sample_weight=None, 

1356 steps=None, 

1357 callbacks=None, 

1358 max_queue_size=10, 

1359 workers=1, 

1360 use_multiprocessing=False, 

1361 return_dict=False, 

1362 **kwargs): 

1363 """Returns the loss value & metrics values for the model in test mode. 

1364 

1365 Computation is done in batches (see the `batch_size` arg.) 

1366 

1367 Args: 

1368 x: Input data. It could be: 

1369 - A Numpy array (or array-like), or a list of arrays 

1370 (in case the model has multiple inputs). 

1371 - A TensorFlow tensor, or a list of tensors 

1372 (in case the model has multiple inputs). 

1373 - A dict mapping input names to the corresponding array/tensors, 

1374 if the model has named inputs. 

1375 - A `tf.data` dataset. Should return a tuple 

1376 of either `(inputs, targets)` or 

1377 `(inputs, targets, sample_weights)`. 

1378 - A generator or `keras.utils.Sequence` returning `(inputs, targets)` 

1379 or `(inputs, targets, sample_weights)`. 

1380 A more detailed description of unpacking behavior for iterator types 

1381 (Dataset, generator, Sequence) is given in the `Unpacking behavior 

1382 for iterator-like inputs` section of `Model.fit`. 

1383 y: Target data. Like the input data `x`, it could be either Numpy 

1384 array(s) or TensorFlow tensor(s). It should be consistent with `x` 

1385 (you cannot have Numpy inputs and tensor targets, or inversely). If 

1386 `x` is a dataset, generator or `keras.utils.Sequence` instance, `y` 

1387 should not be specified (since targets will be obtained from the 

1388 iterator/dataset). 

1389 batch_size: Integer or `None`. Number of samples per batch of 

1390 computation. If unspecified, `batch_size` will default to 32. Do not 

1391 specify the `batch_size` if your data is in the form of a dataset, 

1392 generators, or `keras.utils.Sequence` instances (since they generate 

1393 batches). 

1394 verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar. 

1395 sample_weight: Optional Numpy array of weights for the test samples, 

1396 used for weighting the loss function. You can either pass a flat (1D) 

1397 Numpy array with the same length as the input samples 

1398 (1:1 mapping between weights and samples), or in the case of 

1399 temporal data, you can pass a 2D array with shape `(samples, 

1400 sequence_length)`, to apply a different weight to every timestep 

1401 of every sample. This argument is not supported when `x` is a 

1402 dataset, instead pass sample weights as the third element of `x`. 

1403 steps: Integer or `None`. Total number of steps (batches of samples) 

1404 before declaring the evaluation round finished. Ignored with the 

1405 default value of `None`. If x is a `tf.data` dataset and `steps` is 

1406 None, 'evaluate' will run until the dataset is exhausted. This 

1407 argument is not supported with array inputs. 

1408 callbacks: List of `keras.callbacks.Callback` instances. List of 

1409 callbacks to apply during evaluation. See 

1410 [callbacks](/api_docs/python/tf/keras/callbacks). 

1411 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 

1412 input only. Maximum size for the generator queue. If unspecified, 

1413 `max_queue_size` will default to 10. 

1414 workers: Integer. Used for generator or `keras.utils.Sequence` input 

1415 only. Maximum number of processes to spin up when using process-based 

1416 threading. If unspecified, `workers` will default to 1. 

1417 use_multiprocessing: Boolean. Used for generator or 

1418 `keras.utils.Sequence` input only. If `True`, use process-based 

1419 threading. If unspecified, `use_multiprocessing` will default to 

1420 `False`. Note that because this implementation relies on 

1421 multiprocessing, you should not pass non-picklable arguments to the 

1422 generator as they can't be passed easily to children processes. 

1423 return_dict: If `True`, loss and metric results are returned as a dict, 

1424 with each key being the name of the metric. If `False`, they are 

1425 returned as a list. 

1426 **kwargs: Unused at this time. 

1427 

1428 See the discussion of `Unpacking behavior for iterator-like inputs` for 

1429 `Model.fit`. 

1430 

1431 `Model.evaluate` is not yet supported with 

1432 `tf.distribute.experimental.ParameterServerStrategy`. 

1433 

1434 Returns: 

1435 Scalar test loss (if the model has a single output and no metrics) 

1436 or list of scalars (if the model has multiple outputs 

1437 and/or metrics). The attribute `model.metrics_names` will give you 

1438 the display labels for the scalar outputs. 

1439 

1440 Raises: 

1441 RuntimeError: If `model.evaluate` is wrapped in `tf.function`. 

1442 ValueError: in case of invalid arguments. 

1443 """ 

1444 version_utils.disallow_legacy_graph('Model', 'evaluate') 

1445 self._assert_compile_was_called() 

1446 self._check_call_args('evaluate') 

1447 _disallow_inside_tf_function('evaluate') 

1448 use_cached_eval_dataset = kwargs.pop('_use_cached_eval_dataset', False) 

1449 if kwargs: 

1450 raise TypeError('Invalid keyword arguments: %s' % (kwargs,)) 

1451 

1452 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 

1453 self._cluster_coordinator = cluster_coordinator.ClusterCoordinator( 

1454 self.distribute_strategy) 

1455 

1456 with self.distribute_strategy.scope(): 

1457 # Use cached evaluation data only when it's called in `Model.fit` 

1458 if (use_cached_eval_dataset 

1459 and getattr(self, '_eval_data_handler', None) is not None): 

1460 data_handler = self._eval_data_handler 

1461 else: 

1462 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 

1463 data_handler = data_adapter.get_data_handler( 

1464 x=x, 

1465 y=y, 

1466 sample_weight=sample_weight, 

1467 batch_size=batch_size, 

1468 steps_per_epoch=steps, 

1469 initial_epoch=0, 

1470 epochs=1, 

1471 max_queue_size=max_queue_size, 

1472 workers=workers, 

1473 use_multiprocessing=use_multiprocessing, 

1474 model=self, 

1475 steps_per_execution=self._steps_per_execution) 

1476 

1477 # Container that configures and calls `tf.keras.Callback`s. 

1478 if not isinstance(callbacks, callbacks_module.CallbackList): 

1479 callbacks = callbacks_module.CallbackList( 

1480 callbacks, 

1481 add_history=True, 

1482 add_progbar=verbose != 0, 

1483 model=self, 

1484 verbose=verbose, 

1485 epochs=1, 

1486 steps=data_handler.inferred_steps) 

1487 

1488 logs = {} 

1489 self.test_function = self.make_test_function() 

1490 self._test_counter.assign(0) 

1491 callbacks.on_test_begin() 

1492 for _, iterator in data_handler.enumerate_epochs(): # Single epoch. 

1493 self.reset_metrics() 

1494 with data_handler.catch_stop_iteration(): 

1495 for step in data_handler.steps(): 

1496 with trace.Trace('test', step_num=step, _r=1): 

1497 callbacks.on_test_batch_begin(step) 

1498 tmp_logs = self.test_function(iterator) 

1499 if data_handler.should_sync: 

1500 context.async_wait() 

1501 logs = tmp_logs # No error, now safe to assign to logs. 

1502 end_step = step + data_handler.step_increment 

1503 callbacks.on_test_batch_end(end_step, logs) 

1504 logs = tf_utils.sync_to_numpy_or_python_type(logs) 

1505 callbacks.on_test_end(logs=logs) 

1506 

1507 if return_dict: 

1508 return logs 

1509 else: 

1510 return flatten_metrics_in_order(logs, self.metrics_names) 

1511 

1512 def predict_step(self, data): 

1513 """The logic for one inference step. 

1514 

1515 This method can be overridden to support custom inference logic. 

1516 This method is called by `Model.make_predict_function`. 

1517 

1518 This method should contain the mathematical logic for one step of inference. 

1519 This typically includes the forward pass. 

1520 

1521 Configuration details for *how* this logic is run (e.g. `tf.function` and 

1522 `tf.distribute.Strategy` settings), should be left to 

1523 `Model.make_predict_function`, which can also be overridden. 

1524 

1525 Args: 

1526 data: A nested structure of `Tensor`s. 

1527 

1528 Returns: 

1529 The result of one inference step, typically the output of calling the 

1530 `Model` on data. 

1531 """ 

1532 data = data_adapter.expand_1d(data) 

1533 x, _, _ = data_adapter.unpack_x_y_sample_weight(data) 

1534 return self(x, training=False) 

1535 

1536 def make_predict_function(self): 

1537 """Creates a function that executes one step of inference. 

1538 

1539 This method can be overridden to support custom inference logic. 

1540 This method is called by `Model.predict` and `Model.predict_on_batch`. 

1541 

1542 Typically, this method directly controls `tf.function` and 

1543 `tf.distribute.Strategy` settings, and delegates the actual evaluation 

1544 logic to `Model.predict_step`. 

1545 

1546 This function is cached the first time `Model.predict` or 

1547 `Model.predict_on_batch` is called. The cache is cleared whenever 

1548 `Model.compile` is called. 

1549 

1550 Returns: 

1551 Function. The function created by this method should accept a 

1552 `tf.data.Iterator`, and return the outputs of the `Model`. 

1553 """ 

1554 if self.predict_function is not None: 

1555 return self.predict_function 

1556 

1557 def step_function(model, iterator): 

1558 """Runs a single evaluation step.""" 

1559 

1560 def run_step(data): 

1561 outputs = model.predict_step(data) 

1562 # Ensure counter is updated only if `test_step` succeeds. 

1563 with ops.control_dependencies(_minimum_control_deps(outputs)): 

1564 model._predict_counter.assign_add(1) # pylint: disable=protected-access 

1565 return outputs 

1566 

1567 data = next(iterator) 

1568 outputs = model.distribute_strategy.run(run_step, args=(data,)) 

1569 outputs = reduce_per_replica( 

1570 outputs, self.distribute_strategy, reduction='concat') 

1571 return outputs 

1572 

1573 if (self._steps_per_execution is None or 

1574 self._steps_per_execution.numpy().item() == 1): 

1575 

1576 def predict_function(iterator): 

1577 """Runs an evaluation execution with one step.""" 

1578 return step_function(self, iterator) 

1579 

1580 else: 

1581 

1582 def predict_function(iterator): 

1583 """Runs an evaluation execution with multiple steps.""" 

1584 outputs = step_function(self, iterator) 

1585 for _ in math_ops.range(self._steps_per_execution - 1): 

1586 directives.set_loop_options( 

1587 shape_invariants=[( 

1588 t, tf_utils.get_tensor_spec(t, dynamic_batch=True).shape) 

1589 for t in nest.flatten(outputs)]) 

1590 step_outputs = step_function(self, iterator) 

1591 outputs = nest.map_structure(lambda t1, t2: concat([t1, t2]), outputs, 

1592 step_outputs) 

1593 return outputs 

1594 

1595 if not self.run_eagerly: 

1596 predict_function = def_function.function( 

1597 predict_function, experimental_relax_shapes=True) 

1598 

1599 self.predict_function = predict_function 

1600 return self.predict_function 

1601 

1602 def predict(self, 

1603 x, 

1604 batch_size=None, 

1605 verbose=0, 

1606 steps=None, 

1607 callbacks=None, 

1608 max_queue_size=10, 

1609 workers=1, 

1610 use_multiprocessing=False): 

1611 """Generates output predictions for the input samples. 

1612 

1613 Computation is done in batches. This method is designed for performance in 

1614 large scale inputs. For small amount of inputs that fit in one batch, 

1615 directly using `__call__` is recommended for faster execution, e.g., 

1616 `model(x)`, or `model(x, training=False)` if you have layers such as 

1617 `tf.keras.layers.BatchNormalization` that behaves differently during 

1618 inference. Also, note the fact that test loss is not affected by 

1619 regularization layers like noise and dropout. 

1620 

1621 Args: 

1622 x: Input samples. It could be: 

1623 - A Numpy array (or array-like), or a list of arrays 

1624 (in case the model has multiple inputs). 

1625 - A TensorFlow tensor, or a list of tensors 

1626 (in case the model has multiple inputs). 

1627 - A `tf.data` dataset. 

1628 - A generator or `keras.utils.Sequence` instance. 

1629 A more detailed description of unpacking behavior for iterator types 

1630 (Dataset, generator, Sequence) is given in the `Unpacking behavior 

1631 for iterator-like inputs` section of `Model.fit`. 

1632 batch_size: Integer or `None`. 

1633 Number of samples per batch. 

1634 If unspecified, `batch_size` will default to 32. 

1635 Do not specify the `batch_size` if your data is in the 

1636 form of dataset, generators, or `keras.utils.Sequence` instances 

1637 (since they generate batches). 

1638 verbose: Verbosity mode, 0 or 1. 

1639 steps: Total number of steps (batches of samples) 

1640 before declaring the prediction round finished. 

1641 Ignored with the default value of `None`. If x is a `tf.data` 

1642 dataset and `steps` is None, `predict` will 

1643 run until the input dataset is exhausted. 

1644 callbacks: List of `keras.callbacks.Callback` instances. 

1645 List of callbacks to apply during prediction. 

1646 See [callbacks](/api_docs/python/tf/keras/callbacks). 

1647 max_queue_size: Integer. Used for generator or `keras.utils.Sequence` 

1648 input only. Maximum size for the generator queue. 

1649 If unspecified, `max_queue_size` will default to 10. 

1650 workers: Integer. Used for generator or `keras.utils.Sequence` input 

1651 only. Maximum number of processes to spin up when using 

1652 process-based threading. If unspecified, `workers` will default 

1653 to 1. 

1654 use_multiprocessing: Boolean. Used for generator or 

1655 `keras.utils.Sequence` input only. If `True`, use process-based 

1656 threading. If unspecified, `use_multiprocessing` will default to 

1657 `False`. Note that because this implementation relies on 

1658 multiprocessing, you should not pass non-picklable arguments to 

1659 the generator as they can't be passed easily to children processes. 

1660 

1661 See the discussion of `Unpacking behavior for iterator-like inputs` for 

1662 `Model.fit`. Note that Model.predict uses the same interpretation rules as 

1663 `Model.fit` and `Model.evaluate`, so inputs must be unambiguous for all 

1664 three methods. 

1665 

1666 Returns: 

1667 Numpy array(s) of predictions. 

1668 

1669 Raises: 

1670 RuntimeError: If `model.predict` is wrapped in `tf.function`. 

1671 ValueError: In case of mismatch between the provided 

1672 input data and the model's expectations, 

1673 or in case a stateful model receives a number of samples 

1674 that is not a multiple of the batch size. 

1675 """ 

1676 version_utils.disallow_legacy_graph('Model', 'predict') 

1677 self._check_call_args('predict') 

1678 _disallow_inside_tf_function('predict') 

1679 

1680 # TODO(yashkatariya): Cache model on the coordinator for faster prediction. 

1681 # If running under PSS, then swap it with OneDeviceStrategy so that 

1682 # execution will run on the coordinator. 

1683 original_pss_strategy = None 

1684 if self.distribute_strategy._should_use_with_coordinator: # pylint: disable=protected-access 

1685 original_pss_strategy = self.distribute_strategy 

1686 self._distribution_strategy = None 

1687 

1688 # Cluster coordinator is set by `.fit()` and `.evaluate()` which is not 

1689 # needed in `.predict()` because all the predictions happen on the 

1690 # coordinator/locally. 

1691 if self._cluster_coordinator: 

1692 self._cluster_coordinator = None 

1693 

1694 outputs = None 

1695 with self.distribute_strategy.scope(): 

1696 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 

1697 dataset_types = (data_types.DatasetV1, data_types.DatasetV2) 

1698 if (self._in_multi_worker_mode() or _is_tpu_multi_host( 

1699 self.distribute_strategy)) and isinstance(x, dataset_types): 

1700 try: 

1701 options = options_lib.Options() 

1702 data_option = options_lib.AutoShardPolicy.DATA 

1703 options.experimental_distribute.auto_shard_policy = data_option 

1704 x = x.with_options(options) 

1705 except ValueError: 

1706 warnings.warn('Using Model.predict with ' 

1707 'MultiWorkerDistributionStrategy or TPUStrategy and ' 

1708 'AutoShardPolicy.FILE might lead to out-of-order result' 

1709 '. Consider setting it to AutoShardPolicy.DATA.') 

1710 

1711 data_handler = data_adapter.get_data_handler( 

1712 x=x, 

1713 batch_size=batch_size, 

1714 steps_per_epoch=steps, 

1715 initial_epoch=0, 

1716 epochs=1, 

1717 max_queue_size=max_queue_size, 

1718 workers=workers, 

1719 use_multiprocessing=use_multiprocessing, 

1720 model=self, 

1721 steps_per_execution=self._steps_per_execution) 

1722 

1723 # Container that configures and calls `tf.keras.Callback`s. 

1724 if not isinstance(callbacks, callbacks_module.CallbackList): 

1725 callbacks = callbacks_module.CallbackList( 

1726 callbacks, 

1727 add_history=True, 

1728 add_progbar=verbose != 0, 

1729 model=self, 

1730 verbose=verbose, 

1731 epochs=1, 

1732 steps=data_handler.inferred_steps) 

1733 

1734 self.predict_function = self.make_predict_function() 

1735 self._predict_counter.assign(0) 

1736 callbacks.on_predict_begin() 

1737 batch_outputs = None 

1738 for _, iterator in data_handler.enumerate_epochs(): # Single epoch. 

1739 with data_handler.catch_stop_iteration(): 

1740 for step in data_handler.steps(): 

1741 callbacks.on_predict_batch_begin(step) 

1742 tmp_batch_outputs = self.predict_function(iterator) 

1743 if data_handler.should_sync: 

1744 context.async_wait() 

1745 batch_outputs = tmp_batch_outputs # No error, now safe to assign. 

1746 if outputs is None: 

1747 outputs = nest.map_structure(lambda batch_output: [batch_output], 

1748 batch_outputs) 

1749 else: 

1750 nest.map_structure_up_to( 

1751 batch_outputs, 

1752 lambda output, batch_output: output.append(batch_output), 

1753 outputs, batch_outputs) 

1754 end_step = step + data_handler.step_increment 

1755 callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs}) 

1756 if batch_outputs is None: 

1757 raise ValueError('Expect x to be a non-empty array or dataset.') 

1758 callbacks.on_predict_end() 

1759 all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs) 

1760 

1761 # If originally PSS strategy was used, then replace it back since predict 

1762 # is running under `OneDeviceStrategy` after the swap and once its done 

1763 # we need to replace it back to PSS again. 

1764 if original_pss_strategy is not None: 

1765 self._distribution_strategy = original_pss_strategy 

1766 

1767 return tf_utils.sync_to_numpy_or_python_type(all_outputs) 

1768 

1769 def reset_metrics(self): 

1770 """Resets the state of all the metrics in the model. 

1771 

1772 Examples: 

1773 

1774 >>> inputs = tf.keras.layers.Input(shape=(3,)) 

1775 >>> outputs = tf.keras.layers.Dense(2)(inputs) 

1776 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 

1777 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 

1778 

1779 >>> x = np.random.random((2, 3)) 

1780 >>> y = np.random.randint(0, 2, (2, 2)) 

1781 >>> _ = model.fit(x, y, verbose=0) 

1782 >>> assert all(float(m.result()) for m in model.metrics) 

1783 

1784 >>> model.reset_metrics() 

1785 >>> assert all(float(m.result()) == 0 for m in model.metrics) 

1786 

1787 """ 

1788 for m in self.metrics: 

1789 m.reset_state() 

1790 

1791 def train_on_batch(self, 

1792 x, 

1793 y=None, 

1794 sample_weight=None, 

1795 class_weight=None, 

1796 reset_metrics=True, 

1797 return_dict=False): 

1798 """Runs a single gradient update on a single batch of data. 

1799 

1800 Args: 

1801 x: Input data. It could be: 

1802 - A Numpy array (or array-like), or a list of arrays 

1803 (in case the model has multiple inputs). 

1804 - A TensorFlow tensor, or a list of tensors 

1805 (in case the model has multiple inputs). 

1806 - A dict mapping input names to the corresponding array/tensors, 

1807 if the model has named inputs. 

1808 y: Target data. Like the input data `x`, it could be either Numpy 

1809 array(s) or TensorFlow tensor(s). It should be consistent with `x` 

1810 (you cannot have Numpy inputs and tensor targets, or inversely). 

1811 sample_weight: Optional array of the same length as x, containing 

1812 weights to apply to the model's loss for each sample. In the case of 

1813 temporal data, you can pass a 2D array with shape (samples, 

1814 sequence_length), to apply a different weight to every timestep of 

1815 every sample. 

1816 class_weight: Optional dictionary mapping class indices (integers) to a 

1817 weight (float) to apply to the model's loss for the samples from this 

1818 class during training. This can be useful to tell the model to "pay 

1819 more attention" to samples from an under-represented class. 

1820 reset_metrics: If `True`, the metrics returned will be only for this 

1821 batch. If `False`, the metrics will be statefully accumulated across 

1822 batches. 

1823 return_dict: If `True`, loss and metric results are returned as a dict, 

1824 with each key being the name of the metric. If `False`, they are 

1825 returned as a list. 

1826 

1827 Returns: 

1828 Scalar training loss 

1829 (if the model has a single output and no metrics) 

1830 or list of scalars (if the model has multiple outputs 

1831 and/or metrics). The attribute `model.metrics_names` will give you 

1832 the display labels for the scalar outputs. 

1833 

1834 Raises: 

1835 RuntimeError: If `model.train_on_batch` is wrapped in `tf.function`. 

1836 ValueError: In case of invalid user-provided arguments. 

1837 """ 

1838 self._assert_compile_was_called() 

1839 self._check_call_args('train_on_batch') 

1840 _disallow_inside_tf_function('train_on_batch') 

1841 with self.distribute_strategy.scope(), \ 

1842 training_utils.RespectCompiledTrainableState(self): 

1843 iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x, 

1844 y, sample_weight, 

1845 class_weight) 

1846 self.train_function = self.make_train_function() 

1847 logs = self.train_function(iterator) 

1848 

1849 if reset_metrics: 

1850 self.reset_metrics() 

1851 logs = tf_utils.sync_to_numpy_or_python_type(logs) 

1852 if return_dict: 

1853 return logs 

1854 else: 

1855 return flatten_metrics_in_order(logs, self.metrics_names) 

1856 

1857 def test_on_batch(self, 

1858 x, 

1859 y=None, 

1860 sample_weight=None, 

1861 reset_metrics=True, 

1862 return_dict=False): 

1863 """Test the model on a single batch of samples. 

1864 

1865 Args: 

1866 x: Input data. It could be: 

1867 - A Numpy array (or array-like), or a list of arrays (in case the 

1868 model has multiple inputs). 

1869 - A TensorFlow tensor, or a list of tensors (in case the model has 

1870 multiple inputs). 

1871 - A dict mapping input names to the corresponding array/tensors, if 

1872 the model has named inputs. 

1873 y: Target data. Like the input data `x`, it could be either Numpy 

1874 array(s) or TensorFlow tensor(s). It should be consistent with `x` 

1875 (you cannot have Numpy inputs and tensor targets, or inversely). 

1876 sample_weight: Optional array of the same length as x, containing 

1877 weights to apply to the model's loss for each sample. In the case of 

1878 temporal data, you can pass a 2D array with shape (samples, 

1879 sequence_length), to apply a different weight to every timestep of 

1880 every sample. 

1881 reset_metrics: If `True`, the metrics returned will be only for this 

1882 batch. If `False`, the metrics will be statefully accumulated across 

1883 batches. 

1884 return_dict: If `True`, loss and metric results are returned as a dict, 

1885 with each key being the name of the metric. If `False`, they are 

1886 returned as a list. 

1887 

1888 Returns: 

1889 Scalar test loss (if the model has a single output and no metrics) 

1890 or list of scalars (if the model has multiple outputs 

1891 and/or metrics). The attribute `model.metrics_names` will give you 

1892 the display labels for the scalar outputs. 

1893 

1894 Raises: 

1895 RuntimeError: If `model.test_on_batch` is wrapped in `tf.function`. 

1896 ValueError: In case of invalid user-provided arguments. 

1897 """ 

1898 self._assert_compile_was_called() 

1899 self._check_call_args('test_on_batch') 

1900 _disallow_inside_tf_function('test_on_batch') 

1901 with self.distribute_strategy.scope(): 

1902 iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x, 

1903 y, sample_weight) 

1904 self.test_function = self.make_test_function() 

1905 logs = self.test_function(iterator) 

1906 

1907 if reset_metrics: 

1908 self.reset_metrics() 

1909 logs = tf_utils.sync_to_numpy_or_python_type(logs) 

1910 if return_dict: 

1911 return logs 

1912 else: 

1913 return flatten_metrics_in_order(logs, self.metrics_names) 

1914 

1915 def predict_on_batch(self, x): 

1916 """Returns predictions for a single batch of samples. 

1917 

1918 Args: 

1919 x: Input data. It could be: 

1920 - A Numpy array (or array-like), or a list of arrays (in case the 

1921 model has multiple inputs). 

1922 - A TensorFlow tensor, or a list of tensors (in case the model has 

1923 multiple inputs). 

1924 

1925 Returns: 

1926 Numpy array(s) of predictions. 

1927 

1928 Raises: 

1929 RuntimeError: If `model.predict_on_batch` is wrapped in `tf.function`. 

1930 ValueError: In case of mismatch between given number of inputs and 

1931 expectations of the model. 

1932 """ 

1933 self._check_call_args('predict_on_batch') 

1934 _disallow_inside_tf_function('predict_on_batch') 

1935 with self.distribute_strategy.scope(): 

1936 iterator = data_adapter.single_batch_iterator(self.distribute_strategy, x) 

1937 self.predict_function = self.make_predict_function() 

1938 outputs = self.predict_function(iterator) 

1939 return tf_utils.sync_to_numpy_or_python_type(outputs) 

1940 

1941 def fit_generator(self, 

1942 generator, 

1943 steps_per_epoch=None, 

1944 epochs=1, 

1945 verbose=1, 

1946 callbacks=None, 

1947 validation_data=None, 

1948 validation_steps=None, 

1949 validation_freq=1, 

1950 class_weight=None, 

1951 max_queue_size=10, 

1952 workers=1, 

1953 use_multiprocessing=False, 

1954 shuffle=True, 

1955 initial_epoch=0): 

1956 """Fits the model on data yielded batch-by-batch by a Python generator. 

1957 

1958 DEPRECATED: 

1959 `Model.fit` now supports generators, so there is no longer any need to use 

1960 this endpoint. 

1961 """ 

1962 warnings.warn('`Model.fit_generator` is deprecated and ' 

1963 'will be removed in a future version. ' 

1964 'Please use `Model.fit`, which supports generators.') 

1965 return self.fit( 

1966 generator, 

1967 steps_per_epoch=steps_per_epoch, 

1968 epochs=epochs, 

1969 verbose=verbose, 

1970 callbacks=callbacks, 

1971 validation_data=validation_data, 

1972 validation_steps=validation_steps, 

1973 validation_freq=validation_freq, 

1974 class_weight=class_weight, 

1975 max_queue_size=max_queue_size, 

1976 workers=workers, 

1977 use_multiprocessing=use_multiprocessing, 

1978 shuffle=shuffle, 

1979 initial_epoch=initial_epoch) 

1980 

1981 def evaluate_generator(self, 

1982 generator, 

1983 steps=None, 

1984 callbacks=None, 

1985 max_queue_size=10, 

1986 workers=1, 

1987 use_multiprocessing=False, 

1988 verbose=0): 

1989 """Evaluates the model on a data generator. 

1990 

1991 DEPRECATED: 

1992 `Model.evaluate` now supports generators, so there is no longer any need 

1993 to use this endpoint. 

1994 """ 

1995 warnings.warn('`Model.evaluate_generator` is deprecated and ' 

1996 'will be removed in a future version. ' 

1997 'Please use `Model.evaluate`, which supports generators.') 

1998 self._check_call_args('evaluate_generator') 

1999 

2000 return self.evaluate( 

2001 generator, 

2002 steps=steps, 

2003 max_queue_size=max_queue_size, 

2004 workers=workers, 

2005 use_multiprocessing=use_multiprocessing, 

2006 verbose=verbose, 

2007 callbacks=callbacks) 

2008 

2009 def predict_generator(self, 

2010 generator, 

2011 steps=None, 

2012 callbacks=None, 

2013 max_queue_size=10, 

2014 workers=1, 

2015 use_multiprocessing=False, 

2016 verbose=0): 

2017 """Generates predictions for the input samples from a data generator. 

2018 

2019 DEPRECATED: 

2020 `Model.predict` now supports generators, so there is no longer any need 

2021 to use this endpoint. 

2022 """ 

2023 warnings.warn('`Model.predict_generator` is deprecated and ' 

2024 'will be removed in a future version. ' 

2025 'Please use `Model.predict`, which supports generators.') 

2026 return self.predict( 

2027 generator, 

2028 steps=steps, 

2029 max_queue_size=max_queue_size, 

2030 workers=workers, 

2031 use_multiprocessing=use_multiprocessing, 

2032 verbose=verbose, 

2033 callbacks=callbacks) 

2034 

2035 ###################################################################### 

2036 # Functions below are not training related. They are for model weights 

2037 # tracking, save/load, serialization, etc. 

2038 ###################################################################### 

2039 

2040 @property 

2041 def trainable_weights(self): 

2042 self._assert_weights_created() 

2043 if not self._trainable: 

2044 return [] 

2045 trainable_variables = [] 

2046 for trackable_obj in self._self_tracked_trackables: 

2047 trainable_variables += trackable_obj.trainable_variables 

2048 trainable_variables += self._trainable_weights 

2049 return self._dedup_weights(trainable_variables) 

2050 

2051 @property 

2052 def non_trainable_weights(self): 

2053 self._assert_weights_created() 

2054 non_trainable_variables = [] 

2055 for trackable_obj in self._self_tracked_trackables: 

2056 non_trainable_variables += trackable_obj.non_trainable_variables 

2057 

2058 if not self._trainable: 

2059 # Return order is all trainable vars, then all non-trainable vars. 

2060 trainable_variables = [] 

2061 for trackable_obj in self._self_tracked_trackables: 

2062 trainable_variables += trackable_obj.trainable_variables 

2063 

2064 non_trainable_variables = ( 

2065 trainable_variables + self._trainable_weights + 

2066 non_trainable_variables + self._non_trainable_weights) 

2067 else: 

2068 non_trainable_variables = ( 

2069 non_trainable_variables + self._non_trainable_weights) 

2070 

2071 return self._dedup_weights(non_trainable_variables) 

2072 

2073 def get_weights(self): 

2074 """Retrieves the weights of the model. 

2075 

2076 Returns: 

2077 A flat list of Numpy arrays. 

2078 """ 

2079 with self.distribute_strategy.scope(): 

2080 return super(Model, self).get_weights() 

2081 

2082 def save(self, 

2083 filepath, 

2084 overwrite=True, 

2085 include_optimizer=True, 

2086 save_format=None, 

2087 signatures=None, 

2088 options=None, 

2089 save_traces=True): 

2090 # pylint: disable=line-too-long 

2091 """Saves the model to Tensorflow SavedModel or a single HDF5 file. 

2092 

2093 Please see `tf.keras.models.save_model` or the 

2094 [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/) 

2095 for details. 

2096 

2097 Args: 

2098 filepath: String, PathLike, path to SavedModel or H5 file to save the 

2099 model. 

2100 overwrite: Whether to silently overwrite any existing file at the 

2101 target location, or provide the user with a manual prompt. 

2102 include_optimizer: If True, save optimizer's state together. 

2103 save_format: Either `'tf'` or `'h5'`, indicating whether to save the 

2104 model to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, 

2105 and 'h5' in TF 1.X. 

2106 signatures: Signatures to save with the SavedModel. Applicable to the 

2107 'tf' format only. Please see the `signatures` argument in 

2108 `tf.saved_model.save` for details. 

2109 options: (only applies to SavedModel format) 

2110 `tf.saved_model.SaveOptions` object that specifies options for 

2111 saving to SavedModel. 

2112 save_traces: (only applies to SavedModel format) When enabled, the 

2113 SavedModel will store the function traces for each layer. This 

2114 can be disabled, so that only the configs of each layer are stored. 

2115 Defaults to `True`. Disabling this will decrease serialization time 

2116 and reduce file size, but it requires that all custom layers/models 

2117 implement a `get_config()` method. 

2118 

2119 Example: 

2120 

2121 ```python 

2122 from keras.models import load_model 

2123 

2124 model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' 

2125 del model # deletes the existing model 

2126 

2127 # returns a compiled model 

2128 # identical to the previous one 

2129 model = load_model('my_model.h5') 

2130 ``` 

2131 """ 

2132 # pylint: enable=line-too-long 

2133 save.save_model(self, filepath, overwrite, include_optimizer, save_format, 

2134 signatures, options, save_traces) 

2135 

2136 def save_weights(self, 

2137 filepath, 

2138 overwrite=True, 

2139 save_format=None, 

2140 options=None): 

2141 """Saves all layer weights. 

2142 

2143 Either saves in HDF5 or in TensorFlow format based on the `save_format` 

2144 argument. 

2145 

2146 When saving in HDF5 format, the weight file has: 

2147 - `layer_names` (attribute), a list of strings 

2148 (ordered names of model layers). 

2149 - For every layer, a `group` named `layer.name` 

2150 - For every such layer group, a group attribute `weight_names`, 

2151 a list of strings 

2152 (ordered names of weights tensor of the layer). 

2153 - For every weight in the layer, a dataset 

2154 storing the weight value, named after the weight tensor. 

2155 

2156 When saving in TensorFlow format, all objects referenced by the network are 

2157 saved in the same format as `tf.train.Checkpoint`, including any `Layer` 

2158 instances or `Optimizer` instances assigned to object attributes. For 

2159 networks constructed from inputs and outputs using `tf.keras.Model(inputs, 

2160 outputs)`, `Layer` instances used by the network are tracked/saved 

2161 automatically. For user-defined classes which inherit from `tf.keras.Model`, 

2162 `Layer` instances must be assigned to object attributes, typically in the 

2163 constructor. See the documentation of `tf.train.Checkpoint` and 

2164 `tf.keras.Model` for details. 

2165 

2166 While the formats are the same, do not mix `save_weights` and 

2167 `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be 

2168 loaded using `Model.load_weights`. Checkpoints saved using 

2169 `tf.train.Checkpoint.save` should be restored using the corresponding 

2170 `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over 

2171 `save_weights` for training checkpoints. 

2172 

2173 The TensorFlow format matches objects and variables by starting at a root 

2174 object, `self` for `save_weights`, and greedily matching attribute 

2175 names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this 

2176 is the `Checkpoint` even if the `Checkpoint` has a model attached. This 

2177 means saving a `tf.keras.Model` using `save_weights` and loading into a 

2178 `tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match 

2179 the `Model`'s variables. See the [guide to training 

2180 checkpoints](https://www.tensorflow.org/guide/checkpoint) for details 

2181 on the TensorFlow format. 

2182 

2183 Args: 

2184 filepath: String or PathLike, path to the file to save the weights to. 

2185 When saving in TensorFlow format, this is the prefix used for 

2186 checkpoint files (multiple files are generated). Note that the '.h5' 

2187 suffix causes weights to be saved in HDF5 format. 

2188 overwrite: Whether to silently overwrite any existing file at the 

2189 target location, or provide the user with a manual prompt. 

2190 save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or 

2191 '.keras' will default to HDF5 if `save_format` is `None`. Otherwise 

2192 `None` defaults to 'tf'. 

2193 options: Optional `tf.train.CheckpointOptions` object that specifies 

2194 options for saving weights. 

2195 

2196 Raises: 

2197 ImportError: If h5py is not available when attempting to save in HDF5 

2198 format. 

2199 ValueError: For invalid/unknown format arguments. 

2200 """ 

2201 self._assert_weights_created() 

2202 filepath = path_to_string(filepath) 

2203 filepath_is_h5 = saving_utils.is_hdf5_filepath(filepath) 

2204 if save_format is None: 

2205 if filepath_is_h5: 

2206 save_format = 'h5' 

2207 else: 

2208 save_format = 'tf' 

2209 else: 

2210 user_format = save_format.lower().strip() 

2211 if user_format in ('tensorflow', 'tf'): 

2212 save_format = 'tf' 

2213 elif user_format in ('hdf5', 'h5', 'keras'): 

2214 save_format = 'h5' 

2215 else: 

2216 raise ValueError( 

2217 'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % ( 

2218 save_format,)) 

2219 if save_format == 'tf' and filepath_is_h5: 

2220 raise ValueError( 

2221 ('save_weights got save_format="tf"/"tensorflow", but the ' 

2222 'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" ' 

2223 'when saving in TensorFlow format.') 

2224 % filepath) 

2225 

2226 if save_format == 'h5' and h5py is None: 

2227 raise ImportError( 

2228 '`save_weights` requires h5py when saving in hdf5.') 

2229 if save_format == 'tf': 

2230 check_filepath = filepath + '.index' 

2231 else: 

2232 check_filepath = filepath 

2233 # If file exists and should not be overwritten: 

2234 if not overwrite and os.path.isfile(check_filepath): 

2235 proceed = ask_to_proceed_with_overwrite(check_filepath) 

2236 if not proceed: 

2237 return 

2238 if save_format == 'h5': 

2239 with h5py.File(filepath, 'w') as f: 

2240 hdf5_format.save_weights_to_hdf5_group(f, self.layers) 

2241 else: 

2242 if not context.executing_eagerly(): 

2243 # Call `get_session` to initialize any uninitialized variables. 

2244 backend.get_session() 

2245 self._checkpoint.write(filepath, options=options) 

2246 # Record this checkpoint so it's visible from tf.train.latest_checkpoint. 

2247 checkpoint_management.update_checkpoint_state_internal( 

2248 save_dir=os.path.dirname(filepath), 

2249 model_checkpoint_path=filepath, 

2250 save_relative_paths=True, 

2251 all_model_checkpoint_paths=[filepath]) 

2252 

2253 def load_weights(self, 

2254 filepath, 

2255 by_name=False, 

2256 skip_mismatch=False, 

2257 options=None): 

2258 """Loads all layer weights, either from a TensorFlow or an HDF5 weight file. 

2259 

2260 If `by_name` is False weights are loaded based on the network's 

2261 topology. This means the architecture should be the same as when the weights 

2262 were saved. Note that layers that don't have weights are not taken into 

2263 account in the topological ordering, so adding or removing layers is fine as 

2264 long as they don't have weights. 

2265 

2266 If `by_name` is True, weights are loaded into layers only if they share the 

2267 same name. This is useful for fine-tuning or transfer-learning models where 

2268 some of the layers have changed. 

2269 

2270 Only topological loading (`by_name=False`) is supported when loading weights 

2271 from the TensorFlow format. Note that topological loading differs slightly 

2272 between TensorFlow and HDF5 formats for user-defined classes inheriting from 

2273 `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the 

2274 TensorFlow format loads based on the object-local names of attributes to 

2275 which layers are assigned in the `Model`'s constructor. 

2276 

2277 Args: 

2278 filepath: String, path to the weights file to load. For weight files in 

2279 TensorFlow format, this is the file prefix (the same as was passed 

2280 to `save_weights`). This can also be a path to a SavedModel 

2281 saved from `model.save`. 

2282 by_name: Boolean, whether to load weights by name or by topological 

2283 order. Only topological loading is supported for weight files in 

2284 TensorFlow format. 

2285 skip_mismatch: Boolean, whether to skip loading of layers where there is 

2286 a mismatch in the number of weights, or a mismatch in the shape of 

2287 the weight (only valid when `by_name=True`). 

2288 options: Optional `tf.train.CheckpointOptions` object that specifies 

2289 options for loading weights. 

2290 

2291 Returns: 

2292 When loading a weight file in TensorFlow format, returns the same status 

2293 object as `tf.train.Checkpoint.restore`. When graph building, restore 

2294 ops are run automatically as soon as the network is built (on first call 

2295 for user-defined classes inheriting from `Model`, immediately if it is 

2296 already built). 

2297 

2298 When loading weights in HDF5 format, returns `None`. 

2299 

2300 Raises: 

2301 ImportError: If h5py is not available and the weight file is in HDF5 

2302 format. 

2303 ValueError: If `skip_mismatch` is set to `True` when `by_name` is 

2304 `False`. 

2305 """ 

2306 if backend.is_tpu_strategy(self._distribution_strategy): 

2307 if (self._distribution_strategy.extended.steps_per_run > 1 and 

2308 (not saving_utils.is_hdf5_filepath(filepath))): 

2309 raise ValueError('Load weights is not yet supported with TPUStrategy ' 

2310 'with steps_per_run greater than 1.') 

2311 if skip_mismatch and not by_name: 

2312 raise ValueError( 

2313 'When calling model.load_weights, skip_mismatch can only be set to ' 

2314 'True when by_name is True.') 

2315 

2316 filepath, save_format = _detect_save_format(filepath) 

2317 if save_format == 'tf': 

2318 status = self._checkpoint.read(filepath, options) 

2319 if by_name: 

2320 raise NotImplementedError( 

2321 'Weights may only be loaded based on topology into Models when ' 

2322 'loading TensorFlow-formatted weights (got by_name=True to ' 

2323 'load_weights).') 

2324 if not context.executing_eagerly(): 

2325 session = backend.get_session() 

2326 # Restore existing variables (if any) immediately, and set up a 

2327 # streaming restore for any variables created in the future. 

2328 trackable_utils.streaming_restore(status=status, session=session) 

2329 status.assert_nontrivial_match() 

2330 else: 

2331 status = None 

2332 if h5py is None: 

2333 raise ImportError( 

2334 '`load_weights` requires h5py when loading weights from HDF5.') 

2335 if not self._is_graph_network and not self.built: 

2336 raise ValueError( 

2337 'Unable to load weights saved in HDF5 format into a subclassed ' 

2338 'Model which has not created its variables yet. Call the Model ' 

2339 'first, then load the weights.') 

2340 self._assert_weights_created() 

2341 with h5py.File(filepath, 'r') as f: 

2342 if 'layer_names' not in f.attrs and 'model_weights' in f: 

2343 f = f['model_weights'] 

2344 if by_name: 

2345 hdf5_format.load_weights_from_hdf5_group_by_name( 

2346 f, self.layers, skip_mismatch=skip_mismatch) 

2347 else: 

2348 hdf5_format.load_weights_from_hdf5_group(f, self.layers) 

2349 

2350 # Perform any layer defined finalization of the layer state. 

2351 for layer in self.layers: 

2352 layer.finalize_state() 

2353 return status 

2354 

2355 def _updated_config(self): 

2356 """Util shared between different serialization methods. 

2357 

2358 Returns: 

2359 Model config with Keras version information added. 

2360 """ 

2361 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 

2362 

2363 config = self.get_config() 

2364 model_config = { 

2365 'class_name': self.__class__.__name__, 

2366 'config': config, 

2367 'keras_version': keras_version, 

2368 'backend': backend.backend() 

2369 } 

2370 return model_config 

2371 

2372 def get_config(self): 

2373 raise NotImplementedError 

2374 

2375 @classmethod 

2376 def from_config(cls, config, custom_objects=None): 

2377 # `from_config` assumes `cls` is either `Functional` or a child class of 

2378 # `Functional`. In the case that `cls` is meant to behave like a child class 

2379 # of `Functional` but only inherits from the `Model` class, we have to call 

2380 # `cls(...)` instead of `Functional.from_config`. 

2381 from tensorflow.python.keras.engine import functional # pylint: disable=g-import-not-at-top 

2382 with generic_utils.SharedObjectLoadingScope(): 

2383 input_tensors, output_tensors, created_layers = ( 

2384 functional.reconstruct_from_config(config, custom_objects)) 

2385 # Initialize a model belonging to `cls`, which can be user-defined or 

2386 # `Functional`. 

2387 model = cls(inputs=input_tensors, outputs=output_tensors, 

2388 name=config.get('name')) 

2389 functional.connect_ancillary_layers(model, created_layers) 

2390 return model 

2391 

2392 def to_json(self, **kwargs): 

2393 """Returns a JSON string containing the network configuration. 

2394 

2395 To load a network from a JSON save file, use 

2396 `keras.models.model_from_json(json_string, custom_objects={})`. 

2397 

2398 Args: 

2399 **kwargs: Additional keyword arguments 

2400 to be passed to `json.dumps()`. 

2401 

2402 Returns: 

2403 A JSON string. 

2404 """ 

2405 model_config = self._updated_config() 

2406 return json.dumps( 

2407 model_config, default=json_utils.get_json_type, **kwargs) 

2408 

2409 def to_yaml(self, **kwargs): 

2410 """Returns a yaml string containing the network configuration. 

2411 

2412 Note: Since TF 2.6, this method is no longer supported and will raise a 

2413 RuntimeError. 

2414 

2415 To load a network from a yaml save file, use 

2416 `keras.models.model_from_yaml(yaml_string, custom_objects={})`. 

2417 

2418 `custom_objects` should be a dictionary mapping 

2419 the names of custom losses / layers / etc to the corresponding 

2420 functions / classes. 

2421 

2422 Args: 

2423 **kwargs: Additional keyword arguments 

2424 to be passed to `yaml.dump()`. 

2425 

2426 Returns: 

2427 A YAML string. 

2428 

2429 Raises: 

2430 RuntimeError: announces that the method poses a security risk 

2431 """ 

2432 raise RuntimeError( 

2433 'Method `model.to_yaml()` has been removed due to security risk of ' 

2434 'arbitrary code execution. Please use `model.to_json()` instead.' 

2435 ) 

2436 

2437 def reset_states(self): 

2438 for layer in self.layers: 

2439 if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False): 

2440 layer.reset_states() 

2441 

2442 @property 

2443 @doc_controls.do_not_generate_docs 

2444 def state_updates(self): 

2445 """Deprecated, do NOT use! 

2446 

2447 Returns the `updates` from all layers that are stateful. 

2448 

2449 This is useful for separating training updates and 

2450 state updates, e.g. when we need to update a layer's internal state 

2451 during prediction. 

2452 

2453 Returns: 

2454 A list of update ops. 

2455 """ 

2456 warnings.warn('`Model.state_updates` will be removed in a future version. ' 

2457 'This property should not be used in TensorFlow 2.0, ' 

2458 'as `updates` are applied automatically.') 

2459 state_updates = [] 

2460 for layer in self.layers: 

2461 if getattr(layer, 'stateful', False): 

2462 if hasattr(layer, 'updates'): 

2463 state_updates += layer.updates 

2464 return state_updates 

2465 

2466 @property 

2467 def weights(self): 

2468 """Returns the list of all layer variables/weights. 

2469 

2470 Note: This will not track the weights of nested `tf.Modules` that are not 

2471 themselves Keras layers. 

2472 

2473 Returns: 

2474 A list of variables. 

2475 """ 

2476 return self._dedup_weights(self._undeduplicated_weights) 

2477 

2478 @property 

2479 def _undeduplicated_weights(self): 

2480 """Returns the undeduplicated list of all layer variables/weights.""" 

2481 self._assert_weights_created() 

2482 weights = [] 

2483 for layer in self._self_tracked_trackables: 

2484 weights += layer.variables 

2485 weights += (self._trainable_weights + self._non_trainable_weights) 

2486 return weights 

2487 

2488 def summary(self, line_length=None, positions=None, print_fn=None): 

2489 """Prints a string summary of the network. 

2490 

2491 Args: 

2492 line_length: Total length of printed lines 

2493 (e.g. set this to adapt the display to different 

2494 terminal window sizes). 

2495 positions: Relative or absolute positions of log elements 

2496 in each line. If not provided, 

2497 defaults to `[.33, .55, .67, 1.]`. 

2498 print_fn: Print function to use. Defaults to `print`. 

2499 It will be called on each line of the summary. 

2500 You can set it to a custom function 

2501 in order to capture the string summary. 

2502 

2503 Raises: 

2504 ValueError: if `summary()` is called before the model is built. 

2505 """ 

2506 if not self.built: 

2507 raise ValueError('This model has not yet been built. ' 

2508 'Build the model first by calling `build()` or calling ' 

2509 '`fit()` with some data, or specify ' 

2510 'an `input_shape` argument in the first layer(s) for ' 

2511 'automatic build.') 

2512 layer_utils.print_summary(self, 

2513 line_length=line_length, 

2514 positions=positions, 

2515 print_fn=print_fn) 

2516 

2517 @property 

2518 def layers(self): 

2519 return list(self._flatten_layers(include_self=False, recursive=False)) 

2520 

2521 def get_layer(self, name=None, index=None): 

2522 """Retrieves a layer based on either its name (unique) or index. 

2523 

2524 If `name` and `index` are both provided, `index` will take precedence. 

2525 Indices are based on order of horizontal graph traversal (bottom-up). 

2526 

2527 Args: 

2528 name: String, name of layer. 

2529 index: Integer, index of layer. 

2530 

2531 Returns: 

2532 A layer instance. 

2533 

2534 Raises: 

2535 ValueError: In case of invalid layer name or index. 

2536 """ 

2537 # TODO(fchollet): We could build a dictionary based on layer names 

2538 # since they are constant, but we have not done that yet. 

2539 if index is not None and name is not None: 

2540 raise ValueError('Provide only a layer name or a layer index.') 

2541 

2542 if index is not None: 

2543 if len(self.layers) <= index: 

2544 raise ValueError('Was asked to retrieve layer at index ' + str(index) + 

2545 ' but model only has ' + str(len(self.layers)) + 

2546 ' layers.') 

2547 else: 

2548 return self.layers[index] 

2549 

2550 if name is not None: 

2551 for layer in self.layers: 

2552 if layer.name == name: 

2553 return layer 

2554 raise ValueError('No such layer: ' + name + '.') 

2555 raise ValueError('Provide either a layer name or layer index.') 

2556 

2557 @trackable.no_automatic_dependency_tracking 

2558 def _set_save_spec(self, inputs): 

2559 if self._saved_model_inputs_spec is not None: 

2560 return # Already set. 

2561 

2562 input_names = self.input_names 

2563 if not input_names: 

2564 input_names = compile_utils.create_pseudo_input_names(inputs) 

2565 

2566 flat_inputs = nest.flatten(inputs) 

2567 specs = [] 

2568 for name, tensor in zip(input_names, flat_inputs): 

2569 specs.append( 

2570 tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name)) 

2571 specs = nest.pack_sequence_as(inputs, specs) 

2572 

2573 self._saved_model_inputs_spec = specs 

2574 

2575 # Store the input shapes 

2576 if (self.__class__.__name__ == 'Sequential' and 

2577 self._build_input_shape is None): 

2578 self._build_input_shape = nest.map_structure( 

2579 lambda x: None if x is None else x.shape, specs) 

2580 

2581 def _assert_weights_created(self): 

2582 """Asserts that all the weights for the model have been created. 

2583 

2584 For a non-dynamic model, the weights must already be created after the 

2585 layer has been called. For a dynamic model, the exact list of weights can 

2586 never be known for certain since it may change at any time during execution. 

2587 

2588 We run this check right before accessing weights or getting the Numpy value 

2589 for the current weights. Otherwise, if the layer has never been called, 

2590 the user would just get an empty list, which is misleading. 

2591 

2592 Raises: 

2593 ValueError: if the weights of the network has not yet been created. 

2594 """ 

2595 if self.dynamic: 

2596 return 

2597 

2598 if ('build' in self.__class__.__dict__ and 

2599 self.__class__ != Model and 

2600 not self.built): 

2601 # For any model that has customized build() method but hasn't 

2602 # been invoked yet, this will cover both sequential and subclass model. 

2603 # Also make sure to exclude Model class itself which has build() defined. 

2604 raise ValueError('Weights for model %s have not yet been created. ' 

2605 'Weights are created when the Model is first called on ' 

2606 'inputs or `build()` is called with an `input_shape`.' % 

2607 self.name) 

2608 

2609 def _check_call_args(self, method_name): 

2610 """Check that `call` has only one positional arg.""" 

2611 # Always allow first arg, regardless of arg name. 

2612 fullargspec = self._call_full_argspec 

2613 if fullargspec.defaults: 

2614 positional_args = fullargspec.args[:-len(fullargspec.defaults)] 

2615 else: 

2616 positional_args = fullargspec.args 

2617 if 'training' in positional_args: 

2618 positional_args.remove('training') 

2619 

2620 # self and first arg can be positional. 

2621 if len(positional_args) > 2: 

2622 extra_args = positional_args[2:] 

2623 raise ValueError( 

2624 'Models passed to `' + method_name + '` can only have `training` ' 

2625 'and the first argument in `call` as positional arguments, ' 

2626 'found: ' + str(extra_args) + '.') 

2627 

2628 def _validate_compile(self, optimizer, metrics, **kwargs): 

2629 """Performs validation checks for the default `compile`.""" 

2630 if any( 

2631 isinstance(opt, optimizer_v1.Optimizer) 

2632 for opt in nest.flatten(optimizer)): 

2633 raise ValueError( 

2634 '`tf.compat.v1.keras` Optimizer (', optimizer, ') is ' 

2635 'not supported when eager execution is enabled. Use a ' 

2636 '`tf.keras` Optimizer instead, or disable eager ' 

2637 'execution.') 

2638 

2639 kwargs.pop('cloning', None) # Legacy DistStrat argument, never used. 

2640 kwargs.pop('experimental_run_tf_function', None) # Always `True`. 

2641 if kwargs.pop('distribute', None) is not None: 

2642 raise ValueError( 

2643 'Distribute argument in compile is not available in TF 2.0 please ' 

2644 'create the model under the distribution strategy scope.') 

2645 if kwargs.pop('target_tensors', None) is not None: 

2646 raise ValueError( 

2647 'target_tensors argument is not supported when executing eagerly.') 

2648 invalid_kwargs = set(kwargs) - {'sample_weight_mode'} 

2649 if invalid_kwargs: 

2650 raise TypeError('Invalid keyword argument(s) in `compile`: %s' % 

2651 (invalid_kwargs,)) 

2652 

2653 # Model must be created and compiled with the same DistStrat. 

2654 if self.built and distribute_lib.has_strategy(): 

2655 strategy = distribute_lib.get_strategy() 

2656 for v in self.variables: 

2657 if not strategy.extended.variable_created_in_scope(v): 

2658 raise ValueError( 

2659 'Variable (%s) was not created in the distribution strategy ' 

2660 'scope of (%s). It is most likely due to not all layers or ' 

2661 'the model or optimizer being created outside the distribution ' 

2662 'strategy scope. Try to make sure your code looks similar ' 

2663 'to the following.\n' 

2664 'with strategy.scope():\n' 

2665 ' model=_create_model()\n' 

2666 ' model.compile(...)' % (v, strategy)) 

2667 

2668 # Model metrics must be created in the same distribution strategy scope 

2669 # as the model. 

2670 strategy = self.distribute_strategy 

2671 for metric in nest.flatten(metrics): 

2672 for v in getattr(metric, 'variables', []): 

2673 if not strategy.extended.variable_created_in_scope(v): 

2674 raise ValueError( 

2675 'Metric (%s) passed to model.compile was created inside of a ' 

2676 'different distribution strategy scope than the model. All ' 

2677 'metrics must be created in the same distribution strategy ' 

2678 'scope as the model (in this case %s). If you pass in a string ' 

2679 'identifier for a metric to compile the metric will ' 

2680 'automatically be created in the correct distribution ' 

2681 'strategy scope.' % (metric, strategy) 

2682 ) 

2683 

2684 # Model metrics must be created in the same distribution strategy scope 

2685 # as the model. 

2686 for opt in nest.flatten(optimizer): 

2687 for v in getattr(opt, '_weights', []): 

2688 if not strategy.extended.variable_created_in_scope(v): 

2689 raise ValueError( 

2690 'Optimizer (%s) passed to model.compile was created inside of a ' 

2691 'different distribution strategy scope than the model. All ' 

2692 'optimizers must be created in the same distribution strategy ' 

2693 'scope as the model (in this case %s). If you pass in a string ' 

2694 'identifier for an optimizer to compile the optimizer will ' 

2695 'automatically be created in the correct distribution ' 

2696 'strategy scope.' % (opt, strategy)) 

2697 

2698 def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch): 

2699 """Maybe load initial epoch from ckpt considering possible worker recovery. 

2700 

2701 Refer to tensorflow/python/keras/distribute/worker_training_state.py 

2702 for more information. 

2703 

2704 Args: 

2705 initial_epoch: The original initial_epoch user passes in in `fit()`. 

2706 

2707 Returns: 

2708 If the training is recovering from previous failure under multi-worker 

2709 training setting, return the epoch the training is supposed to continue 

2710 at. Otherwise, return the `initial_epoch` the user passes in. 

2711 """ 

2712 if self._training_state is not None: 

2713 return self._training_state.maybe_load_initial_epoch_from_ckpt( 

2714 initial_epoch, mode=ModeKeys.TRAIN) 

2715 return initial_epoch 

2716 

2717 def _assert_compile_was_called(self): 

2718 # Checks whether `compile` has been called. If it has been called, 

2719 # then the optimizer is set. This is different from whether the 

2720 # model is compiled 

2721 # (i.e. whether the model is built and its inputs/outputs are set). 

2722 if not self._is_compiled: 

2723 raise RuntimeError('You must compile your model before ' 

2724 'training/testing. ' 

2725 'Use `model.compile(optimizer, loss)`.') 

2726 

2727 def _set_inputs(self, inputs, outputs=None, training=None): 

2728 """This method is for compat with Modelv1. Only inputs are needed here.""" 

2729 self._set_save_spec(inputs) 

2730 

2731 @property 

2732 def _trackable_saved_model_saver(self): 

2733 return model_serialization.ModelSavedModelSaver(self) 

2734 

2735 def _trackable_children(self, save_type='checkpoint', **kwargs): 

2736 if save_type == 'savedmodel': 

2737 # SavedModel needs to ignore the execution functions. 

2738 train_function = self.train_function 

2739 test_function = self.test_function 

2740 predict_function = self.predict_function 

2741 train_tf_function = self.train_tf_function 

2742 self.train_function = None 

2743 self.test_function = None 

2744 self.predict_function = None 

2745 self.train_tf_function = None 

2746 

2747 children = super(Model, self)._trackable_children(save_type, **kwargs) 

2748 

2749 if save_type == 'savedmodel': 

2750 self.train_function = train_function 

2751 self.test_function = test_function 

2752 self.predict_function = predict_function 

2753 self.train_tf_function = train_tf_function 

2754 

2755 return children 

2756 

2757 def _should_eval(self, epoch, validation_freq): 

2758 epoch = epoch + 1 # one-index the user-facing epoch. 

2759 if isinstance(validation_freq, int): 

2760 return epoch % validation_freq == 0 

2761 elif isinstance(validation_freq, list): 

2762 return epoch in validation_freq 

2763 else: 

2764 raise ValueError('Expected `validation_freq` to be a list or int.') 

2765 

2766 ###################################################################### 

2767 # Functions below exist only as v1 / v2 compatibility shims. 

2768 ###################################################################### 

2769 

2770 def _get_compile_args(self, user_metrics=True): 

2771 """Used for saving or cloning a Model. 

2772 

2773 Args: 

2774 user_metrics: Whether to return user-supplied metrics or `Metric` objects. 

2775 Defaults to returning the user-supplied metrics. 

2776 

2777 Returns: 

2778 Dictionary of arguments that were used when compiling the model. 

2779 """ 

2780 self._assert_compile_was_called() 

2781 # pylint: disable=protected-access 

2782 

2783 saved_metrics = self.compiled_metrics._user_metrics 

2784 saved_weighted_metrics = self.compiled_metrics._user_weighted_metrics 

2785 

2786 if not user_metrics: 

2787 if saved_metrics is not None: 

2788 saved_metrics = self.compiled_metrics._metrics 

2789 if saved_weighted_metrics is not None: 

2790 saved_weighted_metrics = self.compiled_metrics._weighted_metrics 

2791 

2792 compile_args = { 

2793 'optimizer': self.optimizer, 

2794 'loss': self.compiled_loss._user_losses, 

2795 'metrics': saved_metrics, 

2796 'weighted_metrics': saved_weighted_metrics, 

2797 'loss_weights': self.compiled_loss._user_loss_weights, 

2798 } 

2799 # pylint: enable=protected-access 

2800 return compile_args 

2801 

2802 def _get_callback_model(self): 

2803 return self 

2804 

2805 def _in_multi_worker_mode(self): 

2806 return self.distribute_strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 

2807 

2808 @property 

2809 def _compile_was_called(self): 

2810 return self._is_compiled 

2811 

2812 

2813def reduce_per_replica(values, strategy, reduction='first'): 

2814 """Reduce PerReplica objects. 

2815 

2816 Args: 

2817 values: Structure of `PerReplica` objects or `Tensor`s. `Tensor`s are 

2818 returned as-is. 

2819 strategy: `tf.distribute.Strategy` object. 

2820 reduction: One of 'first', 'concat'. 

2821 

2822 Returns: 

2823 Structure of `Tensor`s. 

2824 """ 

2825 

2826 def _reduce(v): 

2827 """Reduce a single `PerReplica` object.""" 

2828 if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy): 

2829 return _multi_worker_concat(v, strategy) 

2830 if not _is_per_replica_instance(v): 

2831 return v 

2832 elif reduction == 'first': 

2833 return strategy.unwrap(v)[0] 

2834 elif reduction == 'concat': 

2835 if _is_tpu_multi_host(strategy): 

2836 return _tpu_multi_host_concat(v, strategy) 

2837 else: 

2838 return concat(strategy.unwrap(v)) 

2839 else: 

2840 raise ValueError('`reduction` must be "first" or "concat".') 

2841 

2842 return nest.map_structure(_reduce, values) 

2843 

2844 

2845def concat(tensors, axis=0): 

2846 """Concats `tensor`s along `axis`.""" 

2847 if isinstance(tensors[0], sparse_tensor.SparseTensor): 

2848 return sparse_ops.sparse_concat_v2(axis=axis, sp_inputs=tensors) 

2849 elif _is_scalar(tensors[0]): 

2850 return array_ops_stack.stack(tensors, axis=axis) 

2851 else: 

2852 return array_ops.concat(tensors, axis=axis) 

2853 

2854 

2855def _is_tpu_multi_host(strategy): 

2856 return (backend.is_tpu_strategy(strategy) and 

2857 strategy.extended.num_hosts > 1) 

2858 

2859 

2860def _tpu_multi_host_concat(v, strategy): 

2861 """Correctly order TPU PerReplica objects.""" 

2862 replicas = strategy.unwrap(v) 

2863 # When distributed datasets are created from Tensors / NumPy, 

2864 # TPUStrategy.experimental_distribute_dataset shards data in 

2865 # (Replica, Host) order, and TPUStrategy.unwrap returns it in 

2866 # (Host, Replica) order. 

2867 # TODO(b/150317897): Figure out long-term plan here. 

2868 num_replicas_per_host = strategy.extended.num_replicas_per_host 

2869 ordered_replicas = [] 

2870 for replica_id in range(num_replicas_per_host): 

2871 ordered_replicas += replicas[replica_id::num_replicas_per_host] 

2872 return concat(ordered_replicas) 

2873 

2874 

2875def _collective_all_reduce_multi_worker(strategy): 

2876 return (isinstance(strategy, 

2877 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 

2878 ) and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 

2879 

2880 

2881# TODO(wxinyi): merge this with _tpu_multi_host_concat once we have all_gather 

2882# for all strategies 

2883def _multi_worker_concat(v, strategy): 

2884 """Order PerReplica objects for CollectiveAllReduceStrategy and concat.""" 

2885 replicas = strategy.gather(v, axis=0) 

2886 # v might not have the same shape on different replicas 

2887 if _is_per_replica_instance(v): 

2888 shapes = array_ops.concat([ 

2889 array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0) 

2890 for single_value in v.values 

2891 ], 

2892 axis=0) 

2893 all_shapes = strategy.gather(shapes, axis=0) 

2894 else: 

2895 # v is a tensor. This may happen when, say, we have 2x1 multi-worker. 

2896 all_shapes = strategy.gather( 

2897 array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), axis=0) 

2898 

2899 replicas = array_ops.split( 

2900 replicas, 

2901 num_or_size_splits=all_shapes, 

2902 num=strategy.num_replicas_in_sync) 

2903 ordered_replicas = [] 

2904 num_replicas_per_worker = len(strategy.extended.worker_devices) 

2905 for replica_id in range(num_replicas_per_worker): 

2906 ordered_replicas += replicas[replica_id::num_replicas_per_worker] 

2907 return concat(ordered_replicas) 

2908 

2909 

2910def _is_scalar(x): 

2911 return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0 

2912 

2913 

2914def write_scalar_summaries(logs, step): 

2915 for name, value in logs.items(): 

2916 if _is_scalar(value): 

2917 summary_ops_v2.scalar('batch_' + name, value, step=step) 

2918 

2919 

2920def _minimum_control_deps(outputs): 

2921 """Returns the minimum control dependencies to ensure step succeeded.""" 

2922 if context.executing_eagerly(): 

2923 return [] # Control dependencies not needed. 

2924 outputs = nest.flatten(outputs, expand_composites=True) 

2925 for out in outputs: 

2926 # Variables can't be control dependencies. 

2927 if not isinstance(out, variables.Variable): 

2928 return [out] # Return first Tensor or Op from outputs. 

2929 return [] # No viable Tensor or Op to use for control deps. 

2930 

2931 

2932def _disallow_inside_tf_function(method_name): 

2933 if ops.inside_function(): 

2934 error_msg = ( 

2935 'Detected a call to `Model.{method_name}` inside a `tf.function`. ' 

2936 '`Model.{method_name} is a high-level endpoint that manages its own ' 

2937 '`tf.function`. Please move the call to `Model.{method_name}` outside ' 

2938 'of all enclosing `tf.function`s. Note that you can call a `Model` ' 

2939 'directly on `Tensor`s inside a `tf.function` like: `model(x)`.' 

2940 ).format(method_name=method_name) 

2941 raise RuntimeError(error_msg) 

2942 

2943 

2944def _detect_save_format(filepath): 

2945 """Returns path to weights file and save format.""" 

2946 

2947 filepath = path_to_string(filepath) 

2948 if saving_utils.is_hdf5_filepath(filepath): 

2949 return filepath, 'h5' 

2950 

2951 # Filepath could be a TensorFlow checkpoint file prefix or SavedModel 

2952 # directory. It's possible for filepath to be both a prefix and directory. 

2953 # Prioritize checkpoint over SavedModel. 

2954 if _is_readable_tf_checkpoint(filepath): 

2955 save_format = 'tf' 

2956 elif sm_loader.contains_saved_model(filepath): 

2957 ckpt_path = os.path.join(filepath, sm_constants.VARIABLES_DIRECTORY, 

2958 sm_constants.VARIABLES_FILENAME) 

2959 if _is_readable_tf_checkpoint(ckpt_path): 

2960 filepath = ckpt_path 

2961 save_format = 'tf' 

2962 else: 

2963 raise ValueError('Unable to load weights. filepath {} appears to be a ' 

2964 'SavedModel directory, but checkpoint either doesn\'t ' 

2965 'exist, or is incorrectly formatted.'.format(filepath)) 

2966 else: 

2967 # Not a TensorFlow checkpoint. This filepath is likely an H5 file that 

2968 # doesn't have the hdf5/keras extensions. 

2969 save_format = 'h5' 

2970 return filepath, save_format 

2971 

2972 

2973def _is_readable_tf_checkpoint(filepath): 

2974 try: 

2975 py_checkpoint_reader.NewCheckpointReader(filepath) 

2976 return True 

2977 except errors_impl.DataLossError: 

2978 # The checkpoint is not readable in TensorFlow format. 

2979 return False 

2980 

2981 

2982def flatten_metrics_in_order(logs, metrics_names): 

2983 """Turns the `logs` dict into a list as per key order of `metrics_names`.""" 

2984 results = [] 

2985 for name in metrics_names: 

2986 if name in logs: 

2987 results.append(logs[name]) 

2988 for key in sorted(logs.keys()): 

2989 if key not in metrics_names: 

2990 results.append(logs[key]) 

2991 if len(results) == 1: 

2992 return results[0] 

2993 return results 

2994 

2995 

2996def _is_per_replica_instance(obj): 

2997 return (isinstance(obj, ds_values.DistributedValues) and 

2998 isinstance(obj, composite_tensor.CompositeTensor))