Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training.py: 18%

1163 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Training-related part of the Keras engine.""" 

16 

17import copy 

18import itertools 

19import json 

20import warnings 

21import weakref 

22 

23import numpy as np 

24import tensorflow.compat.v2 as tf 

25 

26from keras.src import backend 

27from keras.src import callbacks as callbacks_module 

28from keras.src import optimizers 

29from keras.src.dtensor import layout_map as layout_map_lib 

30from keras.src.engine import base_layer 

31from keras.src.engine import base_layer_utils 

32from keras.src.engine import compile_utils 

33from keras.src.engine import data_adapter 

34from keras.src.engine import input_layer as input_layer_module 

35from keras.src.engine import training_utils 

36from keras.src.metrics import base_metric 

37from keras.src.mixed_precision import loss_scale_optimizer as lso 

38from keras.src.optimizers import optimizer 

39from keras.src.optimizers import optimizer_v1 

40from keras.src.saving import pickle_utils 

41from keras.src.saving import saving_api 

42from keras.src.saving import saving_lib 

43from keras.src.saving import serialization_lib 

44from keras.src.saving.legacy import serialization 

45from keras.src.saving.legacy.saved_model import json_utils 

46from keras.src.saving.legacy.saved_model import model_serialization 

47from keras.src.utils import generic_utils 

48from keras.src.utils import io_utils 

49from keras.src.utils import layer_utils 

50from keras.src.utils import tf_inspect 

51from keras.src.utils import tf_utils 

52from keras.src.utils import traceback_utils 

53from keras.src.utils import version_utils 

54from keras.src.utils.mode_keys import ModeKeys 

55 

56# isort: off 

57from tensorflow.python.eager import context 

58from tensorflow.python.platform import tf_logging as logging 

59from tensorflow.python.util.tf_export import keras_export 

60from tensorflow.python.distribute import distribute_utils 

61from tensorflow.python.distribute import input_ops 

62from tensorflow.tools.docs import doc_controls 

63 

64try: 

65 import h5py 

66except ImportError: 

67 h5py = None 

68 

69 

70@keras_export("keras.Model", "keras.models.Model") 

71class Model(base_layer.Layer, version_utils.ModelVersionSelector): 

72 """A model grouping layers into an object with training/inference features. 

73 

74 Args: 

75 inputs: The input(s) of the model: a `keras.Input` object or a 

76 combination of `keras.Input` objects in a dict, list or tuple. 

77 outputs: The output(s) of the model: a tensor that originated from 

78 `keras.Input` objects or a combination of such tensors in a dict, 

79 list or tuple. See Functional API example below. 

80 name: String, the name of the model. 

81 

82 There are two ways to instantiate a `Model`: 

83 

84 1 - With the "Functional API", where you start from `Input`, 

85 you chain layer calls to specify the model's forward pass, 

86 and finally you create your model from inputs and outputs: 

87 

88 ```python 

89 import tensorflow as tf 

90 

91 inputs = tf.keras.Input(shape=(3,)) 

92 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) 

93 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) 

94 model = tf.keras.Model(inputs=inputs, outputs=outputs) 

95 ``` 

96 

97 Note: Only dicts, lists, and tuples of input tensors are supported. Nested 

98 inputs are not supported (e.g. lists of list or dicts of dict). 

99 

100 A new Functional API model can also be created by using the 

101 intermediate tensors. This enables you to quickly extract sub-components 

102 of the model. 

103 

104 Example: 

105 

106 ```python 

107 inputs = keras.Input(shape=(None, None, 3)) 

108 processed = keras.layers.RandomCrop(width=32, height=32)(inputs) 

109 conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed) 

110 pooling = keras.layers.GlobalAveragePooling2D()(conv) 

111 feature = keras.layers.Dense(10)(pooling) 

112 

113 full_model = keras.Model(inputs, feature) 

114 backbone = keras.Model(processed, conv) 

115 activations = keras.Model(conv, feature) 

116 ``` 

117 

118 Note that the `backbone` and `activations` models are not 

119 created with `keras.Input` objects, but with the tensors that are originated 

120 from `keras.Input` objects. Under the hood, the layers and weights will 

121 be shared across these models, so that user can train the `full_model`, and 

122 use `backbone` or `activations` to do feature extraction. 

123 The inputs and outputs of the model can be nested structures of tensors as 

124 well, and the created models are standard Functional API models that support 

125 all the existing APIs. 

126 

127 2 - By subclassing the `Model` class: in that case, you should define your 

128 layers in `__init__()` and you should implement the model's forward pass 

129 in `call()`. 

130 

131 ```python 

132 import tensorflow as tf 

133 

134 class MyModel(tf.keras.Model): 

135 

136 def __init__(self): 

137 super().__init__() 

138 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 

139 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 

140 

141 def call(self, inputs): 

142 x = self.dense1(inputs) 

143 return self.dense2(x) 

144 

145 model = MyModel() 

146 ``` 

147 

148 If you subclass `Model`, you can optionally have 

149 a `training` argument (boolean) in `call()`, which you can use to specify 

150 a different behavior in training and inference: 

151 

152 ```python 

153 import tensorflow as tf 

154 

155 class MyModel(tf.keras.Model): 

156 

157 def __init__(self): 

158 super().__init__() 

159 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 

160 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 

161 self.dropout = tf.keras.layers.Dropout(0.5) 

162 

163 def call(self, inputs, training=False): 

164 x = self.dense1(inputs) 

165 if training: 

166 x = self.dropout(x, training=training) 

167 return self.dense2(x) 

168 

169 model = MyModel() 

170 ``` 

171 

172 Once the model is created, you can config the model with losses and metrics 

173 with `model.compile()`, train the model with `model.fit()`, or use the model 

174 to do prediction with `model.predict()`. 

175 """ 

176 

177 _TF_MODULE_IGNORED_PROPERTIES = frozenset( 

178 itertools.chain( 

179 ( 

180 "_train_counter", 

181 "_test_counter", 

182 "_predict_counter", 

183 "_steps_per_execution", 

184 ), 

185 base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES, 

186 ) 

187 ) 

188 _SCALAR_UPRANKING_ON = False 

189 

190 def __new__(cls, *args, **kwargs): 

191 # Signature detection 

192 if is_functional_model_init_params(args, kwargs) and cls == Model: 

193 # Functional model 

194 from keras.src.engine import functional 

195 

196 return functional.Functional(skip_init=True, *args, **kwargs) 

197 else: 

198 return super(Model, cls).__new__(cls, *args, **kwargs) 

199 

200 @tf.__internal__.tracking.no_automatic_dependency_tracking 

201 @traceback_utils.filter_traceback 

202 def __init__(self, *args, **kwargs): 

203 self._is_model_for_instrumentation = True 

204 base_layer.keras_api_gauge.get_cell("model").set(True) 

205 

206 # Special case for Subclassed Functional Model, which we couldn't detect 

207 # when __new__ is called. We only realize it is a functional model when 

208 # it calls super.__init__ with input and output tensor. 

209 from keras.src.engine import functional 

210 

211 if is_functional_model_init_params(args, kwargs) and not isinstance( 

212 self, functional.Functional 

213 ): 

214 # Filter the kwargs for multiple inheritance. 

215 supported_kwargs = [ 

216 "inputs", 

217 "outputs", 

218 "name", 

219 "trainable", 

220 "skip_init", 

221 ] 

222 model_kwargs = { 

223 k: kwargs[k] for k in kwargs if k in supported_kwargs 

224 } 

225 other_kwargs = { 

226 k: kwargs[k] for k in kwargs if k not in supported_kwargs 

227 } 

228 inject_functional_model_class(self.__class__) 

229 functional.Functional.__init__(self, *args, **model_kwargs) 

230 

231 # In case there is any multiple inheritance here, we need to call 

232 # the __init__ for any class that appears after the Functional 

233 # class. 

234 clz_to_init = [] 

235 found_functional_class = False 

236 for clz in self.__class__.__bases__: 

237 if issubclass(clz, functional.Functional): 

238 found_functional_class = True 

239 continue 

240 if found_functional_class: 

241 clz_to_init.append(clz) 

242 

243 if clz_to_init: 

244 for clz in clz_to_init: 

245 clz.__init__(self, *args, **other_kwargs) 

246 elif other_kwargs: 

247 # In case there are unused kwargs, we should raise an error to 

248 # user, in case they have a typo in the param name. 

249 raise TypeError( 

250 "The following keyword arguments passed to `Model` aren't " 

251 "supported: {}.".format(other_kwargs) 

252 ) 

253 return 

254 

255 base_layer.keras_api_gauge.get_cell("Model subclass").set(True) 

256 # The following are implemented as property functions: 

257 # self.trainable_weights 

258 # self.non_trainable_weights 

259 # `inputs` / `outputs` will only appear in kwargs if either are 

260 # misspelled. 

261 generic_utils.validate_kwargs( 

262 kwargs, 

263 { 

264 "trainable", 

265 "dtype", 

266 "dynamic", 

267 "name", 

268 "autocast", 

269 "inputs", 

270 "outputs", 

271 }, 

272 ) 

273 super().__init__(**kwargs) 

274 # By default, Model is a subclass model, which is not in graph network. 

275 self._is_graph_network = False 

276 

277 self.inputs = None 

278 self.outputs = None 

279 self.input_names = None 

280 self.output_names = None 

281 # stop_training is used by callback to stop training when error happens 

282 self.stop_training = False 

283 self.history = None 

284 # These objects are used in the default `Model.compile`. They are not 

285 # guaranteed to be set after `Model.compile` is called, as users can 

286 # override compile with custom logic. 

287 self.compiled_loss = None 

288 self.compiled_metrics = None 

289 

290 # This is True for Sequential networks and Functional networks. 

291 self._compute_output_and_mask_jointly = False 

292 

293 # Don't reset compilation if already done. This may occur if calling 

294 # `__init__` (or `_init_graph_network`) on an already-compiled model 

295 # such as a Sequential model. Sequential models may need to rebuild 

296 # themselves after compilation. 

297 self._maybe_create_attribute("_is_compiled", False) 

298 self._maybe_create_attribute("optimizer", None) 

299 

300 # Model must be created under scope of DistStrat it will be trained 

301 # with. 

302 if tf.distribute.has_strategy(): 

303 self._distribution_strategy = tf.distribute.get_strategy() 

304 else: 

305 self._distribution_strategy = None 

306 self._distribute_reduction_method = None 

307 

308 self._cluster_coordinator = None 

309 

310 # Defaults to value of `tf.config.experimental_functions_run_eagerly`. 

311 self._run_eagerly = None 

312 # Initialize cache attrs. 

313 self._reset_compile_cache() 

314 

315 # Fault-tolerance handler. Set in `ModelCheckpoint`. 

316 self._training_state = None 

317 self._saved_model_inputs_spec = None 

318 self._saved_model_arg_spec = None 

319 self._checkpoint = tf.train.Checkpoint(root=weakref.ref(self)) 

320 

321 self._steps_per_execution = None 

322 

323 self._init_batch_counters() 

324 self._base_model_initialized = True 

325 

326 # `jit_compile` starts off with None as default and gets overwritten by 

327 # the value specified in `Model.compile`, and this is effective for 

328 # `fit`, `evaluate`, and `predict`. 

329 self._jit_compile = None 

330 

331 self._layout_map = layout_map_lib.get_current_layout_map() 

332 

333 @tf.__internal__.tracking.no_automatic_dependency_tracking 

334 def _init_batch_counters(self): 

335 # Untracked Variables, used to keep track of mini-batches seen in `fit`, 

336 # `evaluate`, and `predict`. 

337 if not tf.inside_function(): 

338 # Creating variables inside tf.function is not allowed, hence 

339 # these would otherwise prevent users from creating Keras layers 

340 # inside tf.function. 

341 # These variables are not connected to outputs so they have no 

342 # effect on graph generation anyway. 

343 agg = tf.VariableAggregation.ONLY_FIRST_REPLICA 

344 self._train_counter = tf.Variable(0, dtype="int64", aggregation=agg) 

345 self._test_counter = tf.Variable(0, dtype="int64", aggregation=agg) 

346 self._predict_counter = tf.Variable( 

347 0, dtype="int64", aggregation=agg 

348 ) 

349 

350 def __setattr__(self, name, value): 

351 if not getattr(self, "_self_setattr_tracking", True): 

352 super().__setattr__(name, value) 

353 return 

354 

355 if all( 

356 isinstance(v, (base_layer.Layer, tf.Variable)) 

357 or base_layer_utils.has_weights(v) 

358 for v in tf.nest.flatten(value) 

359 ): 

360 try: 

361 self._base_model_initialized 

362 except AttributeError: 

363 raise RuntimeError( 

364 "It looks like you are subclassing `Model` and you " 

365 "forgot to call `super().__init__()`." 

366 " Always start with this line." 

367 ) 

368 

369 super().__setattr__(name, value) 

370 

371 def __reduce__(self): 

372 if self.built: 

373 return ( 

374 pickle_utils.deserialize_model_from_bytecode, 

375 (pickle_utils.serialize_model_as_bytecode(self),), 

376 ) 

377 else: 

378 # SavedModel (and hence serialize_model_as_bytecode) only support 

379 # built models, but if the model is not built, 

380 # it may be possible to serialize as a plain Python object, 

381 # as long as the constituent parts (layers, optimizers, losses, 

382 # etc.) can be serialized as plain Python objects. Thus we call up 

383 # the superclass hierarchy to get an implementation of __reduce__ 

384 # that can pickle this Model as a plain Python object. 

385 return super().__reduce__() 

386 

387 def __deepcopy__(self, memo): 

388 if self.built: 

389 new = pickle_utils.deserialize_model_from_bytecode( 

390 pickle_utils.serialize_model_as_bytecode(self) 

391 ) 

392 memo[id(self)] = new 

393 else: 

394 # See comment in __reduce__ for explanation 

395 deserializer, serialized, *rest = super().__reduce__() 

396 new = deserializer(*serialized) 

397 memo[id(self)] = new 

398 if rest: 

399 state = copy.deepcopy(rest[0], memo=memo) 

400 new.__setstate__(state) 

401 return new 

402 

403 def __copy__(self): 

404 return self.__deepcopy__({}) 

405 

406 @generic_utils.default 

407 def build(self, input_shape): 

408 """Builds the model based on input shapes received. 

409 

410 This is to be used for subclassed models, which do not know at 

411 instantiation time what their inputs look like. 

412 

413 This method only exists for users who want to call `model.build()` in a 

414 standalone way (as a substitute for calling the model on real data to 

415 build it). It will never be called by the framework (and thus it will 

416 never throw unexpected errors in an unrelated workflow). 

417 

418 Args: 

419 input_shape: Single tuple, `TensorShape` instance, or list/dict of 

420 shapes, where shapes are tuples, integers, or `TensorShape` 

421 instances. 

422 

423 Raises: 

424 ValueError: 

425 1. In case of invalid user-provided data (not of type tuple, 

426 list, `TensorShape`, or dict). 

427 2. If the model requires call arguments that are agnostic 

428 to the input shapes (positional or keyword arg in call 

429 signature). 

430 3. If not all layers were properly built. 

431 4. If float type inputs are not supported within the layers. 

432 

433 In each of these cases, the user should build their model by calling 

434 it on real tensor data. 

435 """ 

436 if self._is_graph_network: 

437 super().build(input_shape) 

438 return 

439 

440 if input_shape is None: 

441 raise ValueError( 

442 "Input shape must be defined when calling `build()` on " 

443 "a `Model` subclass." 

444 ) 

445 valid_types = (tuple, list, tf.TensorShape, dict) 

446 if not isinstance(input_shape, valid_types): 

447 raise ValueError( 

448 "Specified input shape is not one of the valid types. " 

449 "Please specify a batch input shape of type tuple or " 

450 "list of input shapes. User provided " 

451 "input type: {}.".format(type(input_shape)) 

452 ) 

453 

454 if input_shape and not self.inputs: 

455 # We create placeholders for the `None`s in the shape and build the 

456 # model in a Graph. Since tf.Variable is compatible with both eager 

457 # execution and graph building, the variables created after building 

458 # the model in a Graph are still valid when executing eagerly. 

459 if tf.executing_eagerly(): 

460 graph = tf.__internal__.FuncGraph("build_graph") 

461 else: 

462 graph = backend.get_graph() 

463 with graph.as_default(): 

464 if isinstance(input_shape, list) and all( 

465 d is None or isinstance(d, int) for d in input_shape 

466 ): 

467 input_shape = tuple(input_shape) 

468 if isinstance(input_shape, list): 

469 x = [ 

470 base_layer_utils.generate_placeholders_from_shape(shape) 

471 for shape in input_shape 

472 ] 

473 elif isinstance(input_shape, dict): 

474 x = { 

475 k: base_layer_utils.generate_placeholders_from_shape( 

476 shape 

477 ) 

478 for k, shape in input_shape.items() 

479 } 

480 else: 

481 x = base_layer_utils.generate_placeholders_from_shape( 

482 input_shape 

483 ) 

484 

485 kwargs = {} 

486 call_signature = self._call_spec.full_argspec 

487 call_args = call_signature.args 

488 # Exclude `self`, `inputs`, and any argument with a default 

489 # value. 

490 if len(call_args) > 2: 

491 if call_signature.defaults: 

492 call_args = call_args[2 : -len(call_signature.defaults)] 

493 else: 

494 call_args = call_args[2:] 

495 for arg in call_args: 

496 if arg == "training": 

497 # Case where `training` is a positional arg with no 

498 # default. 

499 kwargs["training"] = False 

500 else: 

501 # Has invalid call signature with unknown positional 

502 # arguments. 

503 raise ValueError( 

504 "Currently, you cannot build your model if it " 

505 "has positional or keyword arguments that are " 

506 "not inputs to the model, but are required for " 

507 "its `call()` method. Instead, in order to " 

508 "instantiate and build your model, `call()` " 

509 "your model on real tensor data with all " 

510 "expected call arguments. The argument " 

511 "for `call()` can be a single list/tuple that " 

512 "contains multiple inputs." 

513 ) 

514 elif len(call_args) < 2: 

515 # Signature without `inputs`. 

516 raise ValueError( 

517 "You can only call `build()` on a model if its " 

518 "`call()` method accepts an `inputs` argument." 

519 ) 

520 try: 

521 self.call(x, **kwargs) 

522 except (tf.errors.InvalidArgumentError, TypeError) as e: 

523 raise ValueError( 

524 "You cannot build your model by calling `build` " 

525 "if your layers do not support float type inputs. " 

526 "Instead, in order to instantiate and build your " 

527 "model, call your model on real tensor data (of " 

528 "the correct dtype).\n\nThe actual error from " 

529 f"`call` is: {e}." 

530 ) 

531 super().build(input_shape) 

532 

533 @traceback_utils.filter_traceback 

534 def __call__(self, *args, **kwargs): 

535 if self._layout_map is not None and not self.built: 

536 # Note that this method is only overridden for DTensor and layout 

537 # injection purpose. 

538 # Capture the inputs and create graph input as replacement for model 

539 # to initialize its weights first. 

540 copied_args = copy.copy(args) 

541 copied_kwargs = copy.copy(kwargs) 

542 

543 ( 

544 inputs, 

545 copied_args, 

546 copied_kwargs, 

547 ) = self._call_spec.split_out_first_arg(copied_args, copied_kwargs) 

548 

549 def _convert_to_graph_inputs(x): 

550 if isinstance(x, (tf.Tensor, np.ndarray, float, int)): 

551 x = tf.convert_to_tensor(x) 

552 return input_layer_module.Input(x.shape) 

553 

554 # TODO(scottzhu): maybe better handle mask and training flag. 

555 inputs = tf.nest.map_structure(_convert_to_graph_inputs, inputs) 

556 copied_args = tf.nest.map_structure( 

557 _convert_to_graph_inputs, copied_args 

558 ) 

559 copied_kwargs = tf.nest.map_structure( 

560 _convert_to_graph_inputs, copied_kwargs 

561 ) 

562 

563 with layout_map_lib.layout_map_scope(self._layout_map): 

564 # We ignore the result here. 

565 super().__call__(inputs, *copied_args, **copied_kwargs) 

566 

567 layout_map_lib._map_subclass_model_variable(self, self._layout_map) 

568 

569 return super().__call__(*args, **kwargs) 

570 

571 @doc_controls.doc_in_current_and_subclasses 

572 def call(self, inputs, training=None, mask=None): 

573 """Calls the model on new inputs and returns the outputs as tensors. 

574 

575 In this case `call()` just reapplies 

576 all ops in the graph to the new inputs 

577 (e.g. build a new computational graph from the provided inputs). 

578 

579 Note: This method should not be called directly. It is only meant to be 

580 overridden when subclassing `tf.keras.Model`. 

581 To call a model on an input, always use the `__call__()` method, 

582 i.e. `model(inputs)`, which relies on the underlying `call()` method. 

583 

584 Args: 

585 inputs: Input tensor, or dict/list/tuple of input tensors. 

586 training: Boolean or boolean scalar tensor, indicating whether to 

587 run the `Network` in training mode or inference mode. 

588 mask: A mask or list of masks. A mask can be either a boolean tensor 

589 or None (no mask). For more details, check the guide 

590 [here](https://www.tensorflow.org/guide/keras/masking_and_padding). 

591 

592 Returns: 

593 A tensor if there is a single output, or 

594 a list of tensors if there are more than one outputs. 

595 """ 

596 raise NotImplementedError( 

597 "Unimplemented `tf.keras.Model.call()`: if you " 

598 "intend to create a `Model` with the Functional " 

599 "API, please provide `inputs` and `outputs` " 

600 "arguments. Otherwise, subclass `Model` with an " 

601 "overridden `call()` method." 

602 ) 

603 

604 @traceback_utils.filter_traceback 

605 def compile( 

606 self, 

607 optimizer="rmsprop", 

608 loss=None, 

609 metrics=None, 

610 loss_weights=None, 

611 weighted_metrics=None, 

612 run_eagerly=None, 

613 steps_per_execution=None, 

614 jit_compile=None, 

615 pss_evaluation_shards=0, 

616 **kwargs, 

617 ): 

618 """Configures the model for training. 

619 

620 Example: 

621 

622 ```python 

623 model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), 

624 loss=tf.keras.losses.BinaryCrossentropy(), 

625 metrics=[tf.keras.metrics.BinaryAccuracy(), 

626 tf.keras.metrics.FalseNegatives()]) 

627 ``` 

628 

629 Args: 

630 optimizer: String (name of optimizer) or optimizer instance. See 

631 `tf.keras.optimizers`. 

632 loss: Loss function. May be a string (name of loss function), or 

633 a `tf.keras.losses.Loss` instance. See `tf.keras.losses`. A loss 

634 function is any callable with the signature `loss = fn(y_true, 

635 y_pred)`, where `y_true` are the ground truth values, and 

636 `y_pred` are the model's predictions. 

637 `y_true` should have shape 

638 `(batch_size, d0, .. dN)` (except in the case of 

639 sparse loss functions such as 

640 sparse categorical crossentropy which expects integer arrays of 

641 shape `(batch_size, d0, .. dN-1)`). 

642 `y_pred` should have shape `(batch_size, d0, .. dN)`. 

643 The loss function should return a float tensor. 

644 If a custom `Loss` instance is 

645 used and reduction is set to `None`, return value has shape 

646 `(batch_size, d0, .. dN-1)` i.e. per-sample or per-timestep loss 

647 values; otherwise, it is a scalar. If the model has multiple 

648 outputs, you can use a different loss on each output by passing a 

649 dictionary or a list of losses. The loss value that will be 

650 minimized by the model will then be the sum of all individual 

651 losses, unless `loss_weights` is specified. 

652 metrics: List of metrics to be evaluated by the model during 

653 training and testing. Each of this can be a string (name of a 

654 built-in function), function or a `tf.keras.metrics.Metric` 

655 instance. See `tf.keras.metrics`. Typically you will use 

656 `metrics=['accuracy']`. 

657 A function is any callable with the signature `result = fn(y_true, 

658 y_pred)`. To specify different metrics for different outputs of a 

659 multi-output model, you could also pass a dictionary, such as 

660 `metrics={'output_a':'accuracy', 'output_b':['accuracy', 'mse']}`. 

661 You can also pass a list to specify a metric or a list of metrics 

662 for each output, such as 

663 `metrics=[['accuracy'], ['accuracy', 'mse']]` 

664 or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass the 

665 strings 'accuracy' or 'acc', we convert this to one of 

666 `tf.keras.metrics.BinaryAccuracy`, 

667 `tf.keras.metrics.CategoricalAccuracy`, 

668 `tf.keras.metrics.SparseCategoricalAccuracy` based on the shapes 

669 of the targets and of the model output. We do a similar 

670 conversion for the strings 'crossentropy' and 'ce' as well. 

671 The metrics passed here are evaluated without sample weighting; if 

672 you would like sample weighting to apply, you can specify your 

673 metrics via the `weighted_metrics` argument instead. 

674 loss_weights: Optional list or dictionary specifying scalar 

675 coefficients (Python floats) to weight the loss contributions of 

676 different model outputs. The loss value that will be minimized by 

677 the model will then be the *weighted sum* of all individual 

678 losses, weighted by the `loss_weights` coefficients. If a list, 

679 it is expected to have a 1:1 mapping to the model's outputs. If a 

680 dict, it is expected to map output names (strings) to scalar 

681 coefficients. 

682 weighted_metrics: List of metrics to be evaluated and weighted by 

683 `sample_weight` or `class_weight` during training and testing. 

684 run_eagerly: Bool. If `True`, this `Model`'s logic will not be 

685 wrapped in a `tf.function`. Recommended to leave this as `None` 

686 unless your `Model` cannot be run inside a `tf.function`. 

687 `run_eagerly=True` is not supported when using 

688 `tf.distribute.experimental.ParameterServerStrategy`. Defaults to 

689 `False`. 

690 steps_per_execution: Int. The number of batches to 

691 run during each `tf.function` call. Running multiple batches 

692 inside a single `tf.function` call can greatly improve performance 

693 on TPUs or small models with a large Python overhead. At most, one 

694 full epoch will be run each execution. If a number larger than the 

695 size of the epoch is passed, the execution will be truncated to 

696 the size of the epoch. Note that if `steps_per_execution` is set 

697 to `N`, `Callback.on_batch_begin` and `Callback.on_batch_end` 

698 methods will only be called every `N` batches (i.e. before/after 

699 each `tf.function` execution). Defaults to `1`. 

700 jit_compile: If `True`, compile the model training step with XLA. 

701 [XLA](https://www.tensorflow.org/xla) is an optimizing compiler 

702 for machine learning. 

703 `jit_compile` is not enabled for by default. 

704 Note that `jit_compile=True` 

705 may not necessarily work for all models. 

706 For more information on supported operations please refer to the 

707 [XLA documentation](https://www.tensorflow.org/xla). 

708 Also refer to 

709 [known XLA issues](https://www.tensorflow.org/xla/known_issues) 

710 for more details. 

711 pss_evaluation_shards: Integer or 'auto'. Used for 

712 `tf.distribute.ParameterServerStrategy` training only. This arg 

713 sets the number of shards to split the dataset into, to enable an 

714 exact visitation guarantee for evaluation, meaning the model will 

715 be applied to each dataset element exactly once, even if workers 

716 fail. The dataset must be sharded to ensure separate workers do 

717 not process the same data. The number of shards should be at least 

718 the number of workers for good performance. A value of 'auto' 

719 turns on exact evaluation and uses a heuristic for the number of 

720 shards based on the number of workers. 0, meaning no 

721 visitation guarantee is provided. NOTE: Custom implementations of 

722 `Model.test_step` will be ignored when doing exact evaluation. 

723 Defaults to `0`. 

724 **kwargs: Arguments supported for backwards compatibility only. 

725 """ 

726 if jit_compile and not tf_utils.can_jit_compile(warn=True): 

727 jit_compile = False 

728 base_layer.keras_api_gauge.get_cell("compile").set(True) 

729 self._compile_config = serialization_lib.Config( 

730 optimizer=optimizer, 

731 loss=loss, 

732 metrics=metrics, 

733 loss_weights=loss_weights, 

734 weighted_metrics=weighted_metrics, 

735 run_eagerly=run_eagerly, 

736 steps_per_execution=steps_per_execution, 

737 jit_compile=jit_compile, 

738 ) 

739 with self.distribute_strategy.scope(): 

740 if "experimental_steps_per_execution" in kwargs: 

741 logging.warning( 

742 "The argument `steps_per_execution` is no longer " 

743 "experimental. Pass `steps_per_execution` instead of " 

744 "`experimental_steps_per_execution`." 

745 ) 

746 if not steps_per_execution: 

747 steps_per_execution = kwargs.pop( 

748 "experimental_steps_per_execution" 

749 ) 

750 

751 # When compiling from an already-serialized model, we do not want to 

752 # reapply some processing steps (e.g. metric renaming for 

753 # multi-output models, which have prefixes added for each 

754 # corresponding output name). 

755 from_serialized = kwargs.pop("from_serialized", False) 

756 

757 self._validate_compile(optimizer, metrics, **kwargs) 

758 self._run_eagerly = run_eagerly 

759 

760 self.optimizer = self._get_optimizer(optimizer) 

761 if isinstance(loss, compile_utils.LossesContainer): 

762 self.compiled_loss = loss 

763 else: 

764 self.compiled_loss = compile_utils.LossesContainer( 

765 loss, loss_weights, output_names=self.output_names 

766 ) 

767 self.compiled_metrics = compile_utils.MetricsContainer( 

768 metrics, 

769 weighted_metrics, 

770 output_names=self.output_names, 

771 from_serialized=from_serialized, 

772 ) 

773 

774 self._configure_steps_per_execution(steps_per_execution or 1) 

775 

776 self._pss_evaluation_shards = self._infer_exact_eval_shards( 

777 pss_evaluation_shards 

778 ) 

779 

780 # Initializes attrs that are reset each time `compile` is called. 

781 self._reset_compile_cache() 

782 self._is_compiled = True 

783 self.loss = loss or {} 

784 if (self._run_eagerly or self.dynamic) and jit_compile: 

785 raise ValueError( 

786 "You cannot enable `run_eagerly` and `jit_compile` " 

787 "at the same time." 

788 ) 

789 else: 

790 self._jit_compile = jit_compile 

791 

792 def _get_optimizer(self, optimizer): 

793 """Wraps `optimizer` in `LossScaleOptimizer` if necessary.""" 

794 

795 def _get_single_optimizer(opt): 

796 opt = optimizers.get(opt) 

797 if self.dtype_policy.name == "mixed_float16" and not isinstance( 

798 opt, lso.BaseLossScaleOptimizer 

799 ): 

800 # Loss scaling is necessary with mixed_float16 for models to 

801 # converge to the same accuracy as with float32. 

802 opt = lso.BaseLossScaleOptimizer(opt) 

803 return opt 

804 

805 return tf.nest.map_structure(_get_single_optimizer, optimizer) 

806 

807 @tf.__internal__.tracking.no_automatic_dependency_tracking 

808 def _reset_compile_cache(self): 

809 self.train_function = None 

810 self.test_function = None 

811 self.predict_function = None 

812 # Used to cache the `tf.function`'ed `train_function` to be logged in 

813 # TensorBoard, since the original `train_function` is not necessarily 

814 # a `tf.function` (e.g., with ParameterServerStrategy, the 

815 # `train_function` is a scheduling of the actual training function to a 

816 # remote worker). 

817 self.train_tf_function = None 

818 

819 # Used to cache `trainable` attr of `Layer`s for `fit`. 

820 self._compiled_trainable_state = self._get_trainable_state() 

821 

822 @tf.__internal__.tracking.no_automatic_dependency_tracking 

823 def _configure_steps_per_execution(self, steps_per_execution): 

824 self._steps_per_execution = tf.Variable( 

825 steps_per_execution, 

826 dtype="int64", 

827 aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, 

828 ) 

829 

830 @property 

831 def _should_compute_mask(self): 

832 return False 

833 

834 @property 

835 def metrics(self): 

836 """Return metrics added using `compile()` or `add_metric()`. 

837 

838 Note: Metrics passed to `compile()` are available only after a 

839 `keras.Model` has been trained/evaluated on actual data. 

840 

841 Examples: 

842 

843 >>> inputs = tf.keras.layers.Input(shape=(3,)) 

844 >>> outputs = tf.keras.layers.Dense(2)(inputs) 

845 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 

846 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 

847 >>> [m.name for m in model.metrics] 

848 [] 

849 

850 >>> x = np.random.random((2, 3)) 

851 >>> y = np.random.randint(0, 2, (2, 2)) 

852 >>> model.fit(x, y) 

853 >>> [m.name for m in model.metrics] 

854 ['loss', 'mae'] 

855 

856 >>> inputs = tf.keras.layers.Input(shape=(3,)) 

857 >>> d = tf.keras.layers.Dense(2, name='out') 

858 >>> output_1 = d(inputs) 

859 >>> output_2 = d(inputs) 

860 >>> model = tf.keras.models.Model( 

861 ... inputs=inputs, outputs=[output_1, output_2]) 

862 >>> model.add_metric( 

863 ... tf.reduce_sum(output_2), name='mean', aggregation='mean') 

864 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) 

865 >>> model.fit(x, (y, y)) 

866 >>> [m.name for m in model.metrics] 

867 ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', 

868 'out_1_acc', 'mean'] 

869 

870 """ 

871 metrics = [] 

872 if self._is_compiled: 

873 if self.compiled_loss is not None: 

874 metrics += self.compiled_loss.metrics 

875 if self.compiled_metrics is not None: 

876 metrics += self.compiled_metrics.metrics 

877 

878 for l in self._flatten_layers(): 

879 metrics.extend(l._metrics) 

880 return metrics 

881 

882 @property 

883 def metrics_names(self): 

884 """Returns the model's display labels for all outputs. 

885 

886 Note: `metrics_names` are available only after a `keras.Model` has been 

887 trained/evaluated on actual data. 

888 

889 Examples: 

890 

891 >>> inputs = tf.keras.layers.Input(shape=(3,)) 

892 >>> outputs = tf.keras.layers.Dense(2)(inputs) 

893 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 

894 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 

895 >>> model.metrics_names 

896 [] 

897 

898 >>> x = np.random.random((2, 3)) 

899 >>> y = np.random.randint(0, 2, (2, 2)) 

900 >>> model.fit(x, y) 

901 >>> model.metrics_names 

902 ['loss', 'mae'] 

903 

904 >>> inputs = tf.keras.layers.Input(shape=(3,)) 

905 >>> d = tf.keras.layers.Dense(2, name='out') 

906 >>> output_1 = d(inputs) 

907 >>> output_2 = d(inputs) 

908 >>> model = tf.keras.models.Model( 

909 ... inputs=inputs, outputs=[output_1, output_2]) 

910 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) 

911 >>> model.fit(x, (y, y)) 

912 >>> model.metrics_names 

913 ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', 

914 'out_1_acc'] 

915 

916 """ 

917 

918 # This property includes all output names including `loss` and 

919 # per-output losses for backward compatibility. 

920 return [m.name for m in self.metrics] 

921 

922 @property 

923 def distribute_strategy(self): 

924 """The `tf.distribute.Strategy` this model was created under.""" 

925 return self._distribution_strategy or tf.distribute.get_strategy() 

926 

927 @property 

928 def run_eagerly(self): 

929 """Settable attribute indicating whether the model should run eagerly. 

930 

931 Running eagerly means that your model will be run step by step, 

932 like Python code. Your model might run slower, but it should become 

933 easier for you to debug it by stepping into individual layer calls. 

934 

935 By default, we will attempt to compile your model to a static graph to 

936 deliver the best execution performance. 

937 

938 Returns: 

939 Boolean, whether the model should run eagerly. 

940 """ 

941 if self.dynamic and self._run_eagerly == False: 

942 # TODO(fchollet): consider using py_func to enable this. 

943 raise ValueError( 

944 "Your model contains layers that can only be " 

945 "successfully run in eager execution (layers " 

946 "constructed with `dynamic=True`). " 

947 "You cannot set `run_eagerly=False`." 

948 ) 

949 

950 if self._cluster_coordinator and self._run_eagerly: 

951 raise ValueError( 

952 "When using `Model` with `ParameterServerStrategy`, " 

953 "`run_eagerly` is not supported." 

954 ) 

955 

956 # Run eagerly logic, by priority: 

957 # (1) Dynamic models must be run eagerly. 

958 # (2) Explicitly setting run_eagerly causes a Model to be run eagerly. 

959 # (3) Not explicitly setting run_eagerly defaults to TF's global 

960 # setting. 

961 return ( 

962 self.dynamic 

963 or self._run_eagerly 

964 or (tf.config.functions_run_eagerly() and self._run_eagerly is None) 

965 ) 

966 

967 @run_eagerly.setter 

968 def run_eagerly(self, value): 

969 self._run_eagerly = value 

970 

971 @property 

972 def jit_compile(self): 

973 """Specify whether to compile the model with XLA. 

974 

975 [XLA](https://www.tensorflow.org/xla) is an optimizing compiler 

976 for machine learning. `jit_compile` is not enabled by default. 

977 Note that `jit_compile=True` may not necessarily work for all models. 

978 

979 For more information on supported operations please refer to the 

980 [XLA documentation](https://www.tensorflow.org/xla). Also refer to 

981 [known XLA issues](https://www.tensorflow.org/xla/known_issues) 

982 for more details. 

983 """ 

984 return self._jit_compile 

985 

986 @jit_compile.setter 

987 def jit_compile(self, value): 

988 # Function remains cached with previous jit_compile settings 

989 if self._jit_compile == value: 

990 # Avoid resetting compiler cache if possible if the value is the 

991 # same 

992 return 

993 # Check if TensorFlow is compiled with XLA before setting the value 

994 if value and not tf_utils.can_jit_compile(warn=True): 

995 self._jit_compile = False 

996 return 

997 

998 self._jit_compile = value 

999 # Setting `jit_compile` should invalidate previously cached functions. 

1000 self._reset_compile_cache() 

1001 

1002 @property 

1003 def distribute_reduction_method(self): 

1004 """The method employed to reduce per-replica values during training. 

1005 

1006 Unless specified, the value "auto" will be assumed, indicating that 

1007 the reduction strategy should be chosen based on the current 

1008 running environment. 

1009 See `reduce_per_replica` function for more details. 

1010 

1011 """ 

1012 return self._distribute_reduction_method or "auto" 

1013 

1014 @distribute_reduction_method.setter 

1015 def distribute_reduction_method(self, value): 

1016 self._distribute_reduction_method = value 

1017 

1018 def _validate_target_and_loss(self, y, loss): 

1019 """Raises error if target or loss is not found. 

1020 

1021 This method verifies that the target and loss are properly populated 

1022 when applicable, or raises errors. 

1023 

1024 Args: 

1025 y: the target for training. 

1026 loss: the total loss tensor including loss added via `compile` and 

1027 `add_loss`. 

1028 """ 

1029 

1030 # `self.loss` references the loss added via `compile` call. If users 

1031 # have provided such, the target must be provided; otherwise it's a user 

1032 # error. Note that `self.loss` does not include losses added via 

1033 # `add_loss`, and it is a valid use when such loss from `add_loss` 

1034 # exists and target does not. 

1035 if self.loss and y is None: 

1036 raise ValueError( 

1037 "Target data is missing. Your model was compiled with " 

1038 f"loss={self.loss}, " 

1039 "and therefore expects target data to be provided in `fit()`." 

1040 ) 

1041 

1042 # For training, there must be compiled loss or regularization loss to 

1043 # exist in order to apply the gradients. If one is not found, it means 

1044 # no loss was supplied via `compile` or `add_loss`. 

1045 elif loss is None: 

1046 raise ValueError( 

1047 "No loss found. You may have forgotten to provide a `loss` " 

1048 "argument in the `compile()` method." 

1049 ) 

1050 

1051 def train_step(self, data): 

1052 """The logic for one training step. 

1053 

1054 This method can be overridden to support custom training logic. 

1055 For concrete examples of how to override this method see 

1056 [Customizing what happens in fit]( 

1057 https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit). 

1058 This method is called by `Model.make_train_function`. 

1059 

1060 This method should contain the mathematical logic for one step of 

1061 training. This typically includes the forward pass, loss calculation, 

1062 backpropagation, and metric updates. 

1063 

1064 Configuration details for *how* this logic is run (e.g. `tf.function` 

1065 and `tf.distribute.Strategy` settings), should be left to 

1066 `Model.make_train_function`, which can also be overridden. 

1067 

1068 Args: 

1069 data: A nested structure of `Tensor`s. 

1070 

1071 Returns: 

1072 A `dict` containing values that will be passed to 

1073 `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the 

1074 values of the `Model`'s metrics are returned. Example: 

1075 `{'loss': 0.2, 'accuracy': 0.7}`. 

1076 """ 

1077 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) 

1078 # Run forward pass. 

1079 with tf.GradientTape() as tape: 

1080 y_pred = self(x, training=True) 

1081 loss = self.compute_loss(x, y, y_pred, sample_weight) 

1082 self._validate_target_and_loss(y, loss) 

1083 # Run backwards pass. 

1084 self.optimizer.minimize(loss, self.trainable_variables, tape=tape) 

1085 return self.compute_metrics(x, y, y_pred, sample_weight) 

1086 

1087 def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None): 

1088 """Compute the total loss, validate it, and return it. 

1089 

1090 Subclasses can optionally override this method to provide custom loss 

1091 computation logic. 

1092 

1093 Example: 

1094 ```python 

1095 class MyModel(tf.keras.Model): 

1096 

1097 def __init__(self, *args, **kwargs): 

1098 super(MyModel, self).__init__(*args, **kwargs) 

1099 self.loss_tracker = tf.keras.metrics.Mean(name='loss') 

1100 

1101 def compute_loss(self, x, y, y_pred, sample_weight): 

1102 loss = tf.reduce_mean(tf.math.squared_difference(y_pred, y)) 

1103 loss += tf.add_n(self.losses) 

1104 self.loss_tracker.update_state(loss) 

1105 return loss 

1106 

1107 def reset_metrics(self): 

1108 self.loss_tracker.reset_states() 

1109 

1110 @property 

1111 def metrics(self): 

1112 return [self.loss_tracker] 

1113 

1114 tensors = tf.random.uniform((10, 10)), tf.random.uniform((10,)) 

1115 dataset = tf.data.Dataset.from_tensor_slices(tensors).repeat().batch(1) 

1116 

1117 inputs = tf.keras.layers.Input(shape=(10,), name='my_input') 

1118 outputs = tf.keras.layers.Dense(10)(inputs) 

1119 model = MyModel(inputs, outputs) 

1120 model.add_loss(tf.reduce_sum(outputs)) 

1121 

1122 optimizer = tf.keras.optimizers.SGD() 

1123 model.compile(optimizer, loss='mse', steps_per_execution=10) 

1124 model.fit(dataset, epochs=2, steps_per_epoch=10) 

1125 print('My custom loss: ', model.loss_tracker.result().numpy()) 

1126 ``` 

1127 

1128 Args: 

1129 x: Input data. 

1130 y: Target data. 

1131 y_pred: Predictions returned by the model (output of `model(x)`) 

1132 sample_weight: Sample weights for weighting the loss function. 

1133 

1134 Returns: 

1135 The total loss as a `tf.Tensor`, or `None` if no loss results (which 

1136 is the case when called by `Model.test_step`). 

1137 """ 

1138 del x # The default implementation does not use `x`. 

1139 return self.compiled_loss( 

1140 y, y_pred, sample_weight, regularization_losses=self.losses 

1141 ) 

1142 

1143 def compute_metrics(self, x, y, y_pred, sample_weight): 

1144 """Update metric states and collect all metrics to be returned. 

1145 

1146 Subclasses can optionally override this method to provide custom metric 

1147 updating and collection logic. 

1148 

1149 Example: 

1150 ```python 

1151 class MyModel(tf.keras.Sequential): 

1152 

1153 def compute_metrics(self, x, y, y_pred, sample_weight): 

1154 

1155 # This super call updates `self.compiled_metrics` and returns 

1156 # results for all metrics listed in `self.metrics`. 

1157 metric_results = super(MyModel, self).compute_metrics( 

1158 x, y, y_pred, sample_weight) 

1159 

1160 # Note that `self.custom_metric` is not listed in `self.metrics`. 

1161 self.custom_metric.update_state(x, y, y_pred, sample_weight) 

1162 metric_results['custom_metric_name'] = self.custom_metric.result() 

1163 return metric_results 

1164 ``` 

1165 

1166 Args: 

1167 x: Input data. 

1168 y: Target data. 

1169 y_pred: Predictions returned by the model (output of `model.call(x)`) 

1170 sample_weight: Sample weights for weighting the loss function. 

1171 

1172 Returns: 

1173 A `dict` containing values that will be passed to 

1174 `tf.keras.callbacks.CallbackList.on_train_batch_end()`. Typically, the 

1175 values of the metrics listed in `self.metrics` are returned. Example: 

1176 `{'loss': 0.2, 'accuracy': 0.7}`. 

1177 """ 

1178 del x # The default implementation does not use `x`. 

1179 self.compiled_metrics.update_state(y, y_pred, sample_weight) 

1180 return self.get_metrics_result() 

1181 

1182 def get_metrics_result(self): 

1183 """Returns the model's metrics values as a dict. 

1184 

1185 If any of the metric result is a dict (containing multiple metrics), 

1186 each of them gets added to the top level returned dict of this method. 

1187 

1188 Returns: 

1189 A `dict` containing values of the metrics listed in `self.metrics`. 

1190 Example: 

1191 `{'loss': 0.2, 'accuracy': 0.7}`. 

1192 """ 

1193 # Collect metrics to return 

1194 return_metrics = {} 

1195 for metric in self.metrics: 

1196 result = metric.result() 

1197 if isinstance(result, dict): 

1198 return_metrics.update(result) 

1199 else: 

1200 return_metrics[metric.name] = result 

1201 return return_metrics 

1202 

1203 def _validate_and_get_metrics_result(self, logs): 

1204 """Returns model metrics as a dict if the keys match with input logs. 

1205 

1206 When the training / evalution is performed with asynchronous steps, such 

1207 as the case with `tf.distribute.ParameterServerStrategy`, the last 

1208 scheduled `train / test_step` may not give the latest metrics because it 

1209 is not guaranteed to be executed the last. This method gets metrics from 

1210 the model directly instead of relying on the return from last step 

1211 function. 

1212 

1213 It logs a warning if the metric results could not be overridden when 

1214 used with `tf.distribute.ParameterServerStrategy`. 

1215 

1216 When the user has custom train / test step functions, the metrics 

1217 returned may be different from `Model.metrics`. In those instances, 

1218 this function will be no-op and return the logs. 

1219 

1220 Args: 

1221 logs: A `dict` of metrics returned by train / test step function. 

1222 

1223 Returns: 

1224 A `dict` containing values of the metrics listed in `self.metrics` 

1225 when logs and model metrics keys match. Otherwise it returns input 

1226 `logs`. 

1227 """ 

1228 PSS_WARN_MSG = "Could not get Model metric results. \ 

1229 Using the results of last step function could lead to incorrect \ 

1230 results when used with ParameterServerStrategy" 

1231 try: 

1232 metric_logs = self.get_metrics_result() 

1233 except TypeError: 

1234 if self._cluster_coordinator: 

1235 logging.warning(PSS_WARN_MSG) 

1236 else: 

1237 # Verify that train / test step logs passed and metric logs have 

1238 # matching keys. Could be different when using custom step functions 

1239 if isinstance(logs, dict) and set(logs.keys()) == set( 

1240 metric_logs.keys() 

1241 ): 

1242 logs = tf_utils.sync_to_numpy_or_python_type(metric_logs) 

1243 elif self._cluster_coordinator: 

1244 logging.warning(PSS_WARN_MSG) 

1245 return logs 

1246 

1247 def _aggregate_exact_metrics(self, logs): 

1248 # When doing exact evaluation, `logs` is a list of each data shard's 

1249 # metric variables, which will be used to update the metrics. 

1250 for shard_result in logs: 

1251 for metric in self.metrics: 

1252 if metric.name not in shard_result.keys(): 

1253 logging.log_first_n( 

1254 logging.WARN, 

1255 f"No matching result found for metric {metric.name}. " 

1256 "This metric's computed result may be incorrect.", 

1257 3, 

1258 ) 

1259 continue 

1260 metric_result = shard_result[metric.name] 

1261 if len(metric_result) != len(metric.weights): 

1262 raise ValueError( 

1263 f"Expected {len(metric.weights)} variables in result " 

1264 f"for metric {metric.name}, but found " 

1265 f"{len(metric_result)}." 

1266 ) 

1267 for weight, val in zip(metric.weights, metric_result): 

1268 weight.assign_add(val) 

1269 return self.get_metrics_result() 

1270 

1271 def make_train_function(self, force=False): 

1272 """Creates a function that executes one step of training. 

1273 

1274 This method can be overridden to support custom training logic. 

1275 This method is called by `Model.fit` and `Model.train_on_batch`. 

1276 

1277 Typically, this method directly controls `tf.function` and 

1278 `tf.distribute.Strategy` settings, and delegates the actual training 

1279 logic to `Model.train_step`. 

1280 

1281 This function is cached the first time `Model.fit` or 

1282 `Model.train_on_batch` is called. The cache is cleared whenever 

1283 `Model.compile` is called. You can skip the cache and generate again the 

1284 function with `force=True`. 

1285 

1286 Args: 

1287 force: Whether to regenerate the train function and skip the cached 

1288 function if available. 

1289 

1290 Returns: 

1291 Function. The function created by this method should accept a 

1292 `tf.data.Iterator`, and return a `dict` containing values that will 

1293 be passed to `tf.keras.Callbacks.on_train_batch_end`, such as 

1294 `{'loss': 0.2, 'accuracy': 0.7}`. 

1295 """ 

1296 if self.train_function is not None and not force: 

1297 return self.train_function 

1298 

1299 def step_function(model, iterator): 

1300 """Runs a single training step.""" 

1301 

1302 def run_step(data): 

1303 outputs = model.train_step(data) 

1304 # Ensure counter is updated only if `train_step` succeeds. 

1305 with tf.control_dependencies(_minimum_control_deps(outputs)): 

1306 model._train_counter.assign_add(1) 

1307 return outputs 

1308 

1309 if self.jit_compile and not isinstance( 

1310 model.distribute_strategy, 

1311 ( 

1312 tf.compat.v1.distribute.experimental.TPUStrategy, 

1313 tf.distribute.TPUStrategy, 

1314 ), 

1315 ): 

1316 # TODO(b/258249546): Explicit `jit_compile=True` on TPU causes 

1317 # unexpected behavior, so we skip TPU training now. 

1318 run_step = tf.function( 

1319 run_step, jit_compile=True, reduce_retracing=True 

1320 ) 

1321 data = next(iterator) 

1322 outputs = model.distribute_strategy.run(run_step, args=(data,)) 

1323 outputs = reduce_per_replica( 

1324 outputs, 

1325 self.distribute_strategy, 

1326 reduction=self.distribute_reduction_method, 

1327 ) 

1328 return outputs 

1329 

1330 # Special case if steps_per_execution is one. 

1331 if ( 

1332 self._steps_per_execution is None 

1333 or self._steps_per_execution.numpy().item() == 1 

1334 ): 

1335 

1336 def train_function(iterator): 

1337 """Runs a training execution with a single step.""" 

1338 return step_function(self, iterator) 

1339 

1340 if not self.run_eagerly: 

1341 train_function = tf.function( 

1342 train_function, reduce_retracing=True 

1343 ) 

1344 self.train_tf_function = train_function 

1345 

1346 if self._cluster_coordinator: 

1347 self.train_function = ( 

1348 lambda it: self._cluster_coordinator.schedule( 

1349 train_function, args=(it,) 

1350 ) 

1351 ) 

1352 else: 

1353 self.train_function = train_function 

1354 

1355 # If we're using a coordinator, use the value of 

1356 # self._steps_per_execution at the time the function is 

1357 # called/scheduled, and not when it is actually executed. 

1358 elif self._cluster_coordinator: 

1359 

1360 def train_function(iterator, steps_per_execution): 

1361 """Runs a training execution with multiple steps.""" 

1362 for _ in tf.range(steps_per_execution): 

1363 outputs = step_function(self, iterator) 

1364 return outputs 

1365 

1366 if not self.run_eagerly: 

1367 train_function = tf.function( 

1368 train_function, reduce_retracing=True 

1369 ) 

1370 self.train_tf_function = train_function 

1371 

1372 self.train_function = lambda it: self._cluster_coordinator.schedule( 

1373 train_function, args=(it, self._steps_per_execution.value()) 

1374 ) 

1375 else: 

1376 

1377 def train_function(iterator): 

1378 """Runs a training execution with multiple steps.""" 

1379 for _ in tf.range(self._steps_per_execution): 

1380 outputs = step_function(self, iterator) 

1381 return outputs 

1382 

1383 if not self.run_eagerly: 

1384 train_function = tf.function( 

1385 train_function, reduce_retracing=True 

1386 ) 

1387 self.train_tf_function = train_function 

1388 self.train_function = train_function 

1389 

1390 return self.train_function 

1391 

1392 @traceback_utils.filter_traceback 

1393 def fit( 

1394 self, 

1395 x=None, 

1396 y=None, 

1397 batch_size=None, 

1398 epochs=1, 

1399 verbose="auto", 

1400 callbacks=None, 

1401 validation_split=0.0, 

1402 validation_data=None, 

1403 shuffle=True, 

1404 class_weight=None, 

1405 sample_weight=None, 

1406 initial_epoch=0, 

1407 steps_per_epoch=None, 

1408 validation_steps=None, 

1409 validation_batch_size=None, 

1410 validation_freq=1, 

1411 max_queue_size=10, 

1412 workers=1, 

1413 use_multiprocessing=False, 

1414 ): 

1415 """Trains the model for a fixed number of epochs (dataset iterations). 

1416 

1417 Args: 

1418 x: Input data. It could be: 

1419 - A Numpy array (or array-like), or a list of arrays 

1420 (in case the model has multiple inputs). 

1421 - A TensorFlow tensor, or a list of tensors 

1422 (in case the model has multiple inputs). 

1423 - A dict mapping input names to the corresponding array/tensors, 

1424 if the model has named inputs. 

1425 - A `tf.data` dataset. Should return a tuple 

1426 of either `(inputs, targets)` or 

1427 `(inputs, targets, sample_weights)`. 

1428 - A generator or `keras.utils.Sequence` returning `(inputs, 

1429 targets)` or `(inputs, targets, sample_weights)`. 

1430 - A `tf.keras.utils.experimental.DatasetCreator`, which wraps a 

1431 callable that takes a single argument of type 

1432 `tf.distribute.InputContext`, and returns a `tf.data.Dataset`. 

1433 `DatasetCreator` should be used when users prefer to specify the 

1434 per-replica batching and sharding logic for the `Dataset`. 

1435 See `tf.keras.utils.experimental.DatasetCreator` doc for more 

1436 information. 

1437 A more detailed description of unpacking behavior for iterator 

1438 types (Dataset, generator, Sequence) is given below. If these 

1439 include `sample_weights` as a third component, note that sample 

1440 weighting applies to the `weighted_metrics` argument but not the 

1441 `metrics` argument in `compile()`. If using 

1442 `tf.distribute.experimental.ParameterServerStrategy`, only 

1443 `DatasetCreator` type is supported for `x`. 

1444 y: Target data. Like the input data `x`, 

1445 it could be either Numpy array(s) or TensorFlow tensor(s). 

1446 It should be consistent with `x` (you cannot have Numpy inputs and 

1447 tensor targets, or inversely). If `x` is a dataset, generator, 

1448 or `keras.utils.Sequence` instance, `y` should 

1449 not be specified (since targets will be obtained from `x`). 

1450 batch_size: Integer or `None`. 

1451 Number of samples per gradient update. 

1452 If unspecified, `batch_size` will default to 32. 

1453 Do not specify the `batch_size` if your data is in the 

1454 form of datasets, generators, or `keras.utils.Sequence` 

1455 instances (since they generate batches). 

1456 epochs: Integer. Number of epochs to train the model. 

1457 An epoch is an iteration over the entire `x` and `y` 

1458 data provided 

1459 (unless the `steps_per_epoch` flag is set to 

1460 something other than None). 

1461 Note that in conjunction with `initial_epoch`, 

1462 `epochs` is to be understood as "final epoch". 

1463 The model is not trained for a number of iterations 

1464 given by `epochs`, but merely until the epoch 

1465 of index `epochs` is reached. 

1466 verbose: 'auto', 0, 1, or 2. Verbosity mode. 

1467 0 = silent, 1 = progress bar, 2 = one line per epoch. 

1468 'auto' becomes 1 for most cases, but 2 when used with 

1469 `ParameterServerStrategy`. Note that the progress bar is not 

1470 particularly useful when logged to a file, so verbose=2 is 

1471 recommended when not running interactively (eg, in a production 

1472 environment). Defaults to 'auto'. 

1473 callbacks: List of `keras.callbacks.Callback` instances. 

1474 List of callbacks to apply during training. 

1475 See `tf.keras.callbacks`. Note 

1476 `tf.keras.callbacks.ProgbarLogger` and 

1477 `tf.keras.callbacks.History` callbacks are created automatically 

1478 and need not be passed into `model.fit`. 

1479 `tf.keras.callbacks.ProgbarLogger` is created or not based on 

1480 `verbose` argument to `model.fit`. 

1481 Callbacks with batch-level calls are currently unsupported with 

1482 `tf.distribute.experimental.ParameterServerStrategy`, and users 

1483 are advised to implement epoch-level calls instead with an 

1484 appropriate `steps_per_epoch` value. 

1485 validation_split: Float between 0 and 1. 

1486 Fraction of the training data to be used as validation data. 

1487 The model will set apart this fraction of the training data, 

1488 will not train on it, and will evaluate 

1489 the loss and any model metrics 

1490 on this data at the end of each epoch. 

1491 The validation data is selected from the last samples 

1492 in the `x` and `y` data provided, before shuffling. This 

1493 argument is not supported when `x` is a dataset, generator or 

1494 `keras.utils.Sequence` instance. 

1495 If both `validation_data` and `validation_split` are provided, 

1496 `validation_data` will override `validation_split`. 

1497 `validation_split` is not yet supported with 

1498 `tf.distribute.experimental.ParameterServerStrategy`. 

1499 validation_data: Data on which to evaluate 

1500 the loss and any model metrics at the end of each epoch. 

1501 The model will not be trained on this data. Thus, note the fact 

1502 that the validation loss of data provided using 

1503 `validation_split` or `validation_data` is not affected by 

1504 regularization layers like noise and dropout. 

1505 `validation_data` will override `validation_split`. 

1506 `validation_data` could be: 

1507 - A tuple `(x_val, y_val)` of Numpy arrays or tensors. 

1508 - A tuple `(x_val, y_val, val_sample_weights)` of NumPy 

1509 arrays. 

1510 - A `tf.data.Dataset`. 

1511 - A Python generator or `keras.utils.Sequence` returning 

1512 `(inputs, targets)` or `(inputs, targets, sample_weights)`. 

1513 `validation_data` is not yet supported with 

1514 `tf.distribute.experimental.ParameterServerStrategy`. 

1515 shuffle: Boolean (whether to shuffle the training data 

1516 before each epoch) or str (for 'batch'). This argument is 

1517 ignored when `x` is a generator or an object of tf.data.Dataset. 

1518 'batch' is a special option for dealing 

1519 with the limitations of HDF5 data; it shuffles in batch-sized 

1520 chunks. Has no effect when `steps_per_epoch` is not `None`. 

1521 class_weight: Optional dictionary mapping class indices (integers) 

1522 to a weight (float) value, used for weighting the loss function 

1523 (during training only). 

1524 This can be useful to tell the model to 

1525 "pay more attention" to samples from 

1526 an under-represented class. When `class_weight` is specified 

1527 and targets have a rank of 2 or greater, either `y` must be 

1528 one-hot encoded, or an explicit final dimension of `1` must 

1529 be included for sparse class labels. 

1530 sample_weight: Optional Numpy array of weights for 

1531 the training samples, used for weighting the loss function 

1532 (during training only). You can either pass a flat (1D) 

1533 Numpy array with the same length as the input samples 

1534 (1:1 mapping between weights and samples), 

1535 or in the case of temporal data, 

1536 you can pass a 2D array with shape 

1537 `(samples, sequence_length)`, 

1538 to apply a different weight to every timestep of every sample. 

1539 This argument is not supported when `x` is a dataset, generator, 

1540 or `keras.utils.Sequence` instance, instead provide the 

1541 sample_weights as the third element of `x`. 

1542 Note that sample weighting does not apply to metrics specified 

1543 via the `metrics` argument in `compile()`. To apply sample 

1544 weighting to your metrics, you can specify them via the 

1545 `weighted_metrics` in `compile()` instead. 

1546 initial_epoch: Integer. 

1547 Epoch at which to start training 

1548 (useful for resuming a previous training run). 

1549 steps_per_epoch: Integer or `None`. 

1550 Total number of steps (batches of samples) 

1551 before declaring one epoch finished and starting the 

1552 next epoch. When training with input tensors such as 

1553 TensorFlow data tensors, the default `None` is equal to 

1554 the number of samples in your dataset divided by 

1555 the batch size, or 1 if that cannot be determined. If x is a 

1556 `tf.data` dataset, and 'steps_per_epoch' 

1557 is None, the epoch will run until the input dataset is 

1558 exhausted. When passing an infinitely repeating dataset, you 

1559 must specify the `steps_per_epoch` argument. If 

1560 `steps_per_epoch=-1` the training will run indefinitely with an 

1561 infinitely repeating dataset. This argument is not supported 

1562 with array inputs. 

1563 When using `tf.distribute.experimental.ParameterServerStrategy`: 

1564 * `steps_per_epoch=None` is not supported. 

1565 validation_steps: Only relevant if `validation_data` is provided and 

1566 is a `tf.data` dataset. Total number of steps (batches of 

1567 samples) to draw before stopping when performing validation 

1568 at the end of every epoch. If 'validation_steps' is None, 

1569 validation will run until the `validation_data` dataset is 

1570 exhausted. In the case of an infinitely repeated dataset, it 

1571 will run into an infinite loop. If 'validation_steps' is 

1572 specified and only part of the dataset will be consumed, the 

1573 evaluation will start from the beginning of the dataset at each 

1574 epoch. This ensures that the same validation samples are used 

1575 every time. 

1576 validation_batch_size: Integer or `None`. 

1577 Number of samples per validation batch. 

1578 If unspecified, will default to `batch_size`. 

1579 Do not specify the `validation_batch_size` if your data is in 

1580 the form of datasets, generators, or `keras.utils.Sequence` 

1581 instances (since they generate batches). 

1582 validation_freq: Only relevant if validation data is provided. 

1583 Integer or `collections.abc.Container` instance (e.g. list, tuple, 

1584 etc.). If an integer, specifies how many training epochs to run 

1585 before a new validation run is performed, e.g. `validation_freq=2` 

1586 runs validation every 2 epochs. If a Container, specifies the 

1587 epochs on which to run validation, e.g. 

1588 `validation_freq=[1, 2, 10]` runs validation at the end of the 

1589 1st, 2nd, and 10th epochs. 

1590 max_queue_size: Integer. Used for generator or 

1591 `keras.utils.Sequence` input only. Maximum size for the generator 

1592 queue. If unspecified, `max_queue_size` will default to 10. 

1593 workers: Integer. Used for generator or `keras.utils.Sequence` input 

1594 only. Maximum number of processes to spin up 

1595 when using process-based threading. If unspecified, `workers` 

1596 will default to 1. 

1597 use_multiprocessing: Boolean. Used for generator or 

1598 `keras.utils.Sequence` input only. If `True`, use process-based 

1599 threading. If unspecified, `use_multiprocessing` will default to 

1600 `False`. Note that because this implementation relies on 

1601 multiprocessing, you should not pass non-picklable arguments to 

1602 the generator as they can't be passed easily to children 

1603 processes. 

1604 

1605 Unpacking behavior for iterator-like inputs: 

1606 A common pattern is to pass a tf.data.Dataset, generator, or 

1607 tf.keras.utils.Sequence to the `x` argument of fit, which will in fact 

1608 yield not only features (x) but optionally targets (y) and sample 

1609 weights. Keras requires that the output of such iterator-likes be 

1610 unambiguous. The iterator should return a tuple of length 1, 2, or 3, 

1611 where the optional second and third elements will be used for y and 

1612 sample_weight respectively. Any other type provided will be wrapped in 

1613 a length one tuple, effectively treating everything as 'x'. When 

1614 yielding dicts, they should still adhere to the top-level tuple 

1615 structure. 

1616 e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate 

1617 features, targets, and weights from the keys of a single dict. 

1618 A notable unsupported data type is the namedtuple. The reason is 

1619 that it behaves like both an ordered datatype (tuple) and a mapping 

1620 datatype (dict). So given a namedtuple of the form: 

1621 `namedtuple("example_tuple", ["y", "x"])` 

1622 it is ambiguous whether to reverse the order of the elements when 

1623 interpreting the value. Even worse is a tuple of the form: 

1624 `namedtuple("other_tuple", ["x", "y", "z"])` 

1625 where it is unclear if the tuple was intended to be unpacked into x, 

1626 y, and sample_weight or passed through as a single element to `x`. As 

1627 a result the data processing code will simply raise a ValueError if it 

1628 encounters a namedtuple. (Along with instructions to remedy the 

1629 issue.) 

1630 

1631 Returns: 

1632 A `History` object. Its `History.history` attribute is 

1633 a record of training loss values and metrics values 

1634 at successive epochs, as well as validation loss values 

1635 and validation metrics values (if applicable). 

1636 

1637 Raises: 

1638 RuntimeError: 1. If the model was never compiled or, 

1639 2. If `model.fit` is wrapped in `tf.function`. 

1640 

1641 ValueError: In case of mismatch between the provided input data 

1642 and what the model expects or when the input data is empty. 

1643 """ 

1644 base_layer.keras_api_gauge.get_cell("fit").set(True) 

1645 # Legacy graph support is contained in `training_v1.Model`. 

1646 version_utils.disallow_legacy_graph("Model", "fit") 

1647 self._assert_compile_was_called() 

1648 self._check_call_args("fit") 

1649 _disallow_inside_tf_function("fit") 

1650 

1651 verbose = _get_verbosity(verbose, self.distribute_strategy) 

1652 

1653 if validation_split and validation_data is None: 

1654 # Create the validation data using the training data. Only supported 

1655 # for `Tensor` and `NumPy` input. 

1656 ( 

1657 x, 

1658 y, 

1659 sample_weight, 

1660 ), validation_data = data_adapter.train_validation_split( 

1661 (x, y, sample_weight), validation_split=validation_split 

1662 ) 

1663 

1664 if validation_data: 

1665 ( 

1666 val_x, 

1667 val_y, 

1668 val_sample_weight, 

1669 ) = data_adapter.unpack_x_y_sample_weight(validation_data) 

1670 

1671 if self.distribute_strategy._should_use_with_coordinator: 

1672 self._cluster_coordinator = ( 

1673 tf.distribute.experimental.coordinator.ClusterCoordinator( 

1674 self.distribute_strategy 

1675 ) 

1676 ) 

1677 

1678 with self.distribute_strategy.scope(), training_utils.RespectCompiledTrainableState( # noqa: E501 

1679 self 

1680 ): 

1681 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 

1682 data_handler = data_adapter.get_data_handler( 

1683 x=x, 

1684 y=y, 

1685 sample_weight=sample_weight, 

1686 batch_size=batch_size, 

1687 steps_per_epoch=steps_per_epoch, 

1688 initial_epoch=initial_epoch, 

1689 epochs=epochs, 

1690 shuffle=shuffle, 

1691 class_weight=class_weight, 

1692 max_queue_size=max_queue_size, 

1693 workers=workers, 

1694 use_multiprocessing=use_multiprocessing, 

1695 model=self, 

1696 steps_per_execution=self._steps_per_execution, 

1697 ) 

1698 

1699 # Container that configures and calls `tf.keras.Callback`s. 

1700 if not isinstance(callbacks, callbacks_module.CallbackList): 

1701 callbacks = callbacks_module.CallbackList( 

1702 callbacks, 

1703 add_history=True, 

1704 add_progbar=verbose != 0, 

1705 model=self, 

1706 verbose=verbose, 

1707 epochs=epochs, 

1708 steps=data_handler.inferred_steps, 

1709 ) 

1710 

1711 self.stop_training = False 

1712 self.train_function = self.make_train_function() 

1713 self._train_counter.assign(0) 

1714 callbacks.on_train_begin() 

1715 training_logs = None 

1716 # Handle fault-tolerance for multi-worker. 

1717 # TODO(omalleyt): Fix the ordering issues that mean this has to 

1718 # happen after `callbacks.on_train_begin`. 

1719 steps_per_epoch_inferred = ( 

1720 steps_per_epoch or data_handler.inferred_steps 

1721 ) 

1722 ( 

1723 data_handler._initial_epoch, 

1724 data_handler._initial_step, 

1725 ) = self._maybe_load_initial_counters_from_ckpt( 

1726 steps_per_epoch_inferred, initial_epoch 

1727 ) 

1728 logs = None 

1729 for epoch, iterator in data_handler.enumerate_epochs(): 

1730 self.reset_metrics() 

1731 callbacks.on_epoch_begin(epoch) 

1732 with data_handler.catch_stop_iteration(): 

1733 for step in data_handler.steps(): 

1734 with tf.profiler.experimental.Trace( 

1735 "train", 

1736 epoch_num=epoch, 

1737 step_num=step, 

1738 batch_size=batch_size, 

1739 _r=1, 

1740 ): 

1741 callbacks.on_train_batch_begin(step) 

1742 tmp_logs = self.train_function(iterator) 

1743 if data_handler.should_sync: 

1744 context.async_wait() 

1745 # No error, now safe to assign to logs. 

1746 logs = tmp_logs 

1747 end_step = step + data_handler.step_increment 

1748 callbacks.on_train_batch_end(end_step, logs) 

1749 if self.stop_training: 

1750 break 

1751 

1752 logs = tf_utils.sync_to_numpy_or_python_type(logs) 

1753 if logs is None: 

1754 raise ValueError( 

1755 "Unexpected result of `train_function` " 

1756 "(Empty logs). This could be due to issues in input " 

1757 "pipeline that resulted in an empty dataset. " 

1758 "Otherwise, please use " 

1759 "`Model.compile(..., run_eagerly=True)`, or " 

1760 "`tf.config.run_functions_eagerly(True)` for more " 

1761 "information of where went wrong, or file a " 

1762 "issue/bug to `tf.keras`." 

1763 ) 

1764 # Override with model metrics instead of last step logs 

1765 logs = self._validate_and_get_metrics_result(logs) 

1766 epoch_logs = copy.copy(logs) 

1767 

1768 # Run validation. 

1769 if validation_data and self._should_eval( 

1770 epoch, validation_freq 

1771 ): 

1772 if self._pss_evaluation_shards: 

1773 self._disallow_exact_eval_with_add_metrics() 

1774 # Create data_handler for evaluation and cache it. 

1775 if getattr(self, "_eval_data_handler", None) is None: 

1776 self._eval_data_handler = data_adapter.get_data_handler( 

1777 x=val_x, 

1778 y=val_y, 

1779 sample_weight=val_sample_weight, 

1780 batch_size=validation_batch_size or batch_size, 

1781 steps_per_epoch=validation_steps, 

1782 initial_epoch=0, 

1783 epochs=1, 

1784 max_queue_size=max_queue_size, 

1785 workers=workers, 

1786 use_multiprocessing=use_multiprocessing, 

1787 model=self, 

1788 steps_per_execution=self._steps_per_execution, 

1789 pss_evaluation_shards=self._pss_evaluation_shards, 

1790 ) 

1791 val_logs = self.evaluate( 

1792 x=val_x, 

1793 y=val_y, 

1794 sample_weight=val_sample_weight, 

1795 batch_size=validation_batch_size or batch_size, 

1796 steps=validation_steps, 

1797 callbacks=callbacks, 

1798 max_queue_size=max_queue_size, 

1799 workers=workers, 

1800 use_multiprocessing=use_multiprocessing, 

1801 return_dict=True, 

1802 _use_cached_eval_dataset=True, 

1803 ) 

1804 val_logs = { 

1805 "val_" + name: val for name, val in val_logs.items() 

1806 } 

1807 epoch_logs.update(val_logs) 

1808 

1809 callbacks.on_epoch_end(epoch, epoch_logs) 

1810 training_logs = epoch_logs 

1811 if self.stop_training: 

1812 break 

1813 

1814 if isinstance(self.optimizer, optimizer.Optimizer) and epochs > 0: 

1815 self.optimizer.finalize_variable_values( 

1816 self.trainable_variables 

1817 ) 

1818 

1819 # If eval data_handler exists, delete it after all epochs are done. 

1820 if getattr(self, "_eval_data_handler", None) is not None: 

1821 del self._eval_data_handler 

1822 callbacks.on_train_end(logs=training_logs) 

1823 return self.history 

1824 

1825 def test_step(self, data): 

1826 """The logic for one evaluation step. 

1827 

1828 This method can be overridden to support custom evaluation logic. 

1829 This method is called by `Model.make_test_function`. 

1830 

1831 This function should contain the mathematical logic for one step of 

1832 evaluation. 

1833 This typically includes the forward pass, loss calculation, and metrics 

1834 updates. 

1835 

1836 Configuration details for *how* this logic is run (e.g. `tf.function` 

1837 and `tf.distribute.Strategy` settings), should be left to 

1838 `Model.make_test_function`, which can also be overridden. 

1839 

1840 Args: 

1841 data: A nested structure of `Tensor`s. 

1842 

1843 Returns: 

1844 A `dict` containing values that will be passed to 

1845 `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the 

1846 values of the `Model`'s metrics are returned. 

1847 """ 

1848 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) 

1849 

1850 y_pred = self(x, training=False) 

1851 # Updates stateful loss metrics. 

1852 self.compute_loss(x, y, y_pred, sample_weight) 

1853 return self.compute_metrics(x, y, y_pred, sample_weight) 

1854 

1855 def _make_test_function_exact(self): 

1856 if getattr(self, "_shard_test_function", None): 

1857 return self._shard_test_function 

1858 

1859 def step_function(batch): 

1860 def run_step(data): 

1861 # TODO(b/272050910): Use sample_weight for weighted metrics. 

1862 x, y, _ = data_adapter.unpack_x_y_sample_weight(data) 

1863 y_pred = self(x, training=False) 

1864 return x, y, y_pred 

1865 

1866 if self._jit_compile: 

1867 run_step = tf.function( 

1868 run_step, jit_compile=True, reduce_retracing=True 

1869 ) 

1870 

1871 outputs = self.distribute_strategy.run(run_step, args=(batch,)) 

1872 outputs = reduce_per_replica( 

1873 outputs, 

1874 self.distribute_strategy, 

1875 reduction=self.distribute_reduction_method, 

1876 ) 

1877 return outputs 

1878 

1879 def shard_test_function(dataset, total_shards, shard_idx): 

1880 local_metrics = [] 

1881 with tf_utils.with_metric_local_vars_scope(): 

1882 for metric in self.compiled_metrics.metrics: 

1883 local_metrics.append(base_metric.clone_metric(metric)) 

1884 for metric in self.compiled_loss.metrics: 

1885 local_metrics.append(base_metric.clone_metric(metric)) 

1886 dataset = input_ops.auto_shard_dataset( 

1887 dataset, total_shards, shard_idx 

1888 ) 

1889 iterator = iter(dataset) 

1890 with distribute_utils.cache_variable_reads(): 

1891 for batch in iterator: 

1892 x, y, y_pred = step_function(batch) 

1893 for local_metric in local_metrics: 

1894 local_metric.update_state(y, y_pred) 

1895 outputs = {metric.name: metric.weights for metric in local_metrics} 

1896 with tf.control_dependencies(_minimum_control_deps(outputs)): 

1897 self._test_counter.assign_add(1) 

1898 return outputs 

1899 

1900 if not self.run_eagerly: 

1901 shard_test_function = tf.function( 

1902 shard_test_function, reduce_retracing=True 

1903 ) 

1904 

1905 self._shard_test_function = ( 

1906 lambda *args: self._cluster_coordinator.schedule( 

1907 shard_test_function, 

1908 args=args, 

1909 ) 

1910 ) 

1911 return self._shard_test_function 

1912 

1913 def make_test_function(self, force=False): 

1914 """Creates a function that executes one step of evaluation. 

1915 

1916 This method can be overridden to support custom evaluation logic. 

1917 This method is called by `Model.evaluate` and `Model.test_on_batch`. 

1918 

1919 Typically, this method directly controls `tf.function` and 

1920 `tf.distribute.Strategy` settings, and delegates the actual evaluation 

1921 logic to `Model.test_step`. 

1922 

1923 This function is cached the first time `Model.evaluate` or 

1924 `Model.test_on_batch` is called. The cache is cleared whenever 

1925 `Model.compile` is called. You can skip the cache and generate again the 

1926 function with `force=True`. 

1927 

1928 Args: 

1929 force: Whether to regenerate the test function and skip the cached 

1930 function if available. 

1931 

1932 Returns: 

1933 Function. The function created by this method should accept a 

1934 `tf.data.Iterator`, and return a `dict` containing values that will 

1935 be passed to `tf.keras.Callbacks.on_test_batch_end`. 

1936 """ 

1937 if self.test_function is not None and not force: 

1938 return self.test_function 

1939 

1940 def step_function(model, iterator): 

1941 """Runs a single evaluation step.""" 

1942 

1943 def run_step(data): 

1944 outputs = model.test_step(data) 

1945 # Ensure counter is updated only if `test_step` succeeds. 

1946 with tf.control_dependencies(_minimum_control_deps(outputs)): 

1947 model._test_counter.assign_add(1) 

1948 return outputs 

1949 

1950 if self.jit_compile: 

1951 run_step = tf.function( 

1952 run_step, jit_compile=True, reduce_retracing=True 

1953 ) 

1954 

1955 data = next(iterator) 

1956 outputs = model.distribute_strategy.run(run_step, args=(data,)) 

1957 outputs = reduce_per_replica( 

1958 outputs, 

1959 self.distribute_strategy, 

1960 reduction=self.distribute_reduction_method, 

1961 ) 

1962 return outputs 

1963 

1964 # Special case if steps_per_execution is one. 

1965 if ( 

1966 self._steps_per_execution is None 

1967 or self._steps_per_execution.numpy().item() == 1 

1968 ): 

1969 

1970 def test_function(iterator): 

1971 """Runs a test execution with a single step.""" 

1972 return step_function(self, iterator) 

1973 

1974 if not self.run_eagerly: 

1975 test_function = tf.function( 

1976 test_function, reduce_retracing=True 

1977 ) 

1978 

1979 if self._cluster_coordinator: 

1980 self.test_function = ( 

1981 lambda it: self._cluster_coordinator.schedule( 

1982 test_function, args=(it,) 

1983 ) 

1984 ) 

1985 else: 

1986 self.test_function = test_function 

1987 

1988 # If we're using a coordinator, use the value of 

1989 # self._steps_per_execution at the time the function is 

1990 # called/scheduled, and not when it is actually executed. 

1991 elif self._cluster_coordinator: 

1992 

1993 def test_function(iterator, steps_per_execution): 

1994 """Runs a test execution with multiple steps.""" 

1995 for _ in tf.range(steps_per_execution): 

1996 outputs = step_function(self, iterator) 

1997 return outputs 

1998 

1999 if not self.run_eagerly: 

2000 test_function = tf.function( 

2001 test_function, reduce_retracing=True 

2002 ) 

2003 

2004 self.test_function = lambda it: self._cluster_coordinator.schedule( 

2005 test_function, args=(it, self._steps_per_execution.value()) 

2006 ) 

2007 else: 

2008 

2009 def test_function(iterator): 

2010 """Runs a test execution with multiple steps.""" 

2011 for _ in tf.range(self._steps_per_execution): 

2012 outputs = step_function(self, iterator) 

2013 return outputs 

2014 

2015 if not self.run_eagerly: 

2016 test_function = tf.function( 

2017 test_function, reduce_retracing=True 

2018 ) 

2019 self.test_function = test_function 

2020 

2021 return self.test_function 

2022 

2023 @traceback_utils.filter_traceback 

2024 def evaluate( 

2025 self, 

2026 x=None, 

2027 y=None, 

2028 batch_size=None, 

2029 verbose="auto", 

2030 sample_weight=None, 

2031 steps=None, 

2032 callbacks=None, 

2033 max_queue_size=10, 

2034 workers=1, 

2035 use_multiprocessing=False, 

2036 return_dict=False, 

2037 **kwargs, 

2038 ): 

2039 """Returns the loss value & metrics values for the model in test mode. 

2040 

2041 Computation is done in batches (see the `batch_size` arg.) 

2042 

2043 Args: 

2044 x: Input data. It could be: 

2045 - A Numpy array (or array-like), or a list of arrays 

2046 (in case the model has multiple inputs). 

2047 - A TensorFlow tensor, or a list of tensors 

2048 (in case the model has multiple inputs). 

2049 - A dict mapping input names to the corresponding array/tensors, 

2050 if the model has named inputs. 

2051 - A `tf.data` dataset. Should return a tuple 

2052 of either `(inputs, targets)` or 

2053 `(inputs, targets, sample_weights)`. 

2054 - A generator or `keras.utils.Sequence` returning `(inputs, 

2055 targets)` or `(inputs, targets, sample_weights)`. 

2056 A more detailed description of unpacking behavior for iterator 

2057 types (Dataset, generator, Sequence) is given in the `Unpacking 

2058 behavior for iterator-like inputs` section of `Model.fit`. 

2059 y: Target data. Like the input data `x`, it could be either Numpy 

2060 array(s) or TensorFlow tensor(s). It should be consistent with `x` 

2061 (you cannot have Numpy inputs and tensor targets, or inversely). 

2062 If `x` is a dataset, generator or `keras.utils.Sequence` instance, 

2063 `y` should not be specified (since targets will be obtained from 

2064 the iterator/dataset). 

2065 batch_size: Integer or `None`. Number of samples per batch of 

2066 computation. If unspecified, `batch_size` will default to 32. Do 

2067 not specify the `batch_size` if your data is in the form of a 

2068 dataset, generators, or `keras.utils.Sequence` instances (since 

2069 they generate batches). 

2070 verbose: `"auto"`, 0, 1, or 2. Verbosity mode. 

2071 0 = silent, 1 = progress bar, 2 = single line. 

2072 `"auto"` becomes 1 for most cases, and to 2 when used with 

2073 `ParameterServerStrategy`. Note that the progress bar is not 

2074 particularly useful when logged to a file, so `verbose=2` is 

2075 recommended when not running interactively (e.g. in a production 

2076 environment). Defaults to 'auto'. 

2077 sample_weight: Optional Numpy array of weights for the test samples, 

2078 used for weighting the loss function. You can either pass a flat 

2079 (1D) Numpy array with the same length as the input samples 

2080 (1:1 mapping between weights and samples), or in the case of 

2081 temporal data, you can pass a 2D array with shape `(samples, 

2082 sequence_length)`, to apply a different weight to every 

2083 timestep of every sample. This argument is not supported when 

2084 `x` is a dataset, instead pass sample weights as the third 

2085 element of `x`. 

2086 steps: Integer or `None`. Total number of steps (batches of samples) 

2087 before declaring the evaluation round finished. Ignored with the 

2088 default value of `None`. If x is a `tf.data` dataset and `steps` 

2089 is None, 'evaluate' will run until the dataset is exhausted. This 

2090 argument is not supported with array inputs. 

2091 callbacks: List of `keras.callbacks.Callback` instances. List of 

2092 callbacks to apply during evaluation. See 

2093 [callbacks](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks). 

2094 max_queue_size: Integer. Used for generator or 

2095 `keras.utils.Sequence` input only. Maximum size for the generator 

2096 queue. If unspecified, `max_queue_size` will default to 10. 

2097 workers: Integer. Used for generator or `keras.utils.Sequence` input 

2098 only. Maximum number of processes to spin up when using 

2099 process-based threading. If unspecified, `workers` will default to 

2100 1. 

2101 use_multiprocessing: Boolean. Used for generator or 

2102 `keras.utils.Sequence` input only. If `True`, use process-based 

2103 threading. If unspecified, `use_multiprocessing` will default to 

2104 `False`. Note that because this implementation relies on 

2105 multiprocessing, you should not pass non-picklable arguments to 

2106 the generator as they can't be passed easily to children 

2107 processes. 

2108 return_dict: If `True`, loss and metric results are returned as a 

2109 dict, with each key being the name of the metric. If `False`, they 

2110 are returned as a list. 

2111 **kwargs: Unused at this time. 

2112 

2113 See the discussion of `Unpacking behavior for iterator-like inputs` for 

2114 `Model.fit`. 

2115 

2116 Returns: 

2117 Scalar test loss (if the model has a single output and no metrics) 

2118 or list of scalars (if the model has multiple outputs 

2119 and/or metrics). The attribute `model.metrics_names` will give you 

2120 the display labels for the scalar outputs. 

2121 

2122 Raises: 

2123 RuntimeError: If `model.evaluate` is wrapped in a `tf.function`. 

2124 """ 

2125 base_layer.keras_api_gauge.get_cell("evaluate").set(True) 

2126 version_utils.disallow_legacy_graph("Model", "evaluate") 

2127 self._assert_compile_was_called() 

2128 self._check_call_args("evaluate") 

2129 self._check_sample_weight_warning(x, sample_weight) 

2130 _disallow_inside_tf_function("evaluate") 

2131 use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) 

2132 if kwargs: 

2133 raise TypeError(f"Invalid keyword arguments: {list(kwargs.keys())}") 

2134 

2135 if self.distribute_strategy._should_use_with_coordinator: 

2136 self._cluster_coordinator = ( 

2137 tf.distribute.experimental.coordinator.ClusterCoordinator( 

2138 self.distribute_strategy 

2139 ) 

2140 ) 

2141 

2142 verbose = _get_verbosity(verbose, self.distribute_strategy) 

2143 if self._pss_evaluation_shards: 

2144 self._disallow_exact_eval_with_add_metrics() 

2145 with self.distribute_strategy.scope(): 

2146 # Use cached evaluation data only when it's called in `Model.fit` 

2147 if ( 

2148 use_cached_eval_dataset 

2149 and getattr(self, "_eval_data_handler", None) is not None 

2150 ): 

2151 data_handler = self._eval_data_handler 

2152 else: 

2153 # Creates a `tf.data.Dataset` and handles batch and epoch 

2154 # iteration. 

2155 data_handler = data_adapter.get_data_handler( 

2156 x=x, 

2157 y=y, 

2158 sample_weight=sample_weight, 

2159 batch_size=batch_size, 

2160 steps_per_epoch=steps, 

2161 initial_epoch=0, 

2162 epochs=1, 

2163 max_queue_size=max_queue_size, 

2164 workers=workers, 

2165 use_multiprocessing=use_multiprocessing, 

2166 model=self, 

2167 steps_per_execution=self._steps_per_execution, 

2168 pss_evaluation_shards=self._pss_evaluation_shards, 

2169 ) 

2170 

2171 # Container that configures and calls `tf.keras.Callback`s. 

2172 if not isinstance(callbacks, callbacks_module.CallbackList): 

2173 callbacks = callbacks_module.CallbackList( 

2174 callbacks, 

2175 add_history=True, 

2176 add_progbar=verbose != 0, 

2177 model=self, 

2178 verbose=verbose, 

2179 epochs=1, 

2180 steps=data_handler.inferred_steps, 

2181 ) 

2182 

2183 # Initialize to prevent errors if 0 epochs are evaluated. 

2184 logs = {} 

2185 

2186 test_function_runner = self._get_test_function_runner(callbacks) 

2187 self._test_counter.assign(0) 

2188 callbacks.on_test_begin() 

2189 for ( 

2190 _, 

2191 dataset_or_iterator, 

2192 ) in data_handler.enumerate_epochs(): # Single epoch. 

2193 self.reset_metrics() 

2194 with data_handler.catch_stop_iteration(): 

2195 for step in data_handler.steps(): 

2196 with tf.profiler.experimental.Trace( 

2197 "test", step_num=step, _r=1 

2198 ): 

2199 callbacks.on_test_batch_begin(step) 

2200 logs = test_function_runner.run_step( 

2201 dataset_or_iterator, 

2202 data_handler, 

2203 step, 

2204 self._pss_evaluation_shards, 

2205 ) 

2206 

2207 logs = tf_utils.sync_to_numpy_or_python_type(logs) 

2208 # Override with model metrics instead of last step logs 

2209 if self._pss_evaluation_shards: 

2210 logs = self._aggregate_exact_metrics(logs) 

2211 else: 

2212 logs = self._validate_and_get_metrics_result(logs) 

2213 callbacks.on_test_end(logs=logs) 

2214 

2215 if return_dict: 

2216 return logs 

2217 else: 

2218 return flatten_metrics_in_order(logs, self.metrics_names) 

2219 

2220 def _disallow_exact_eval_with_add_metrics(self): 

2221 metrics_from_add_metric = [ 

2222 metric 

2223 for layer in self._flatten_layers() 

2224 for metric in layer._metrics 

2225 ] 

2226 compiled_metrics = self.compiled_metrics.metrics 

2227 if any( 

2228 [ 

2229 metric not in compiled_metrics 

2230 for metric in metrics_from_add_metric 

2231 ] 

2232 ): 

2233 raise ValueError( 

2234 "Detected that a metric was added to this model " 

2235 "via `Model.add_metric`. This is not currently " 

2236 "supported when using exact evaluation with " 

2237 "`tf.distribute.ParameterServerStrategy`." 

2238 ) 

2239 

2240 def _infer_exact_eval_shards(self, pss_evaluation_shards): 

2241 if not self.distribute_strategy._should_use_with_coordinator: 

2242 return 0 

2243 if pss_evaluation_shards == "auto": 

2244 # TODO(b/264265138) evaluate and improve this heuristic 

2245 return self.distribute_strategy._num_workers * 5 

2246 return pss_evaluation_shards 

2247 

2248 def _get_test_function_runner(self, callbacks): 

2249 if ( 

2250 self._pss_evaluation_shards 

2251 and self.distribute_strategy._should_use_with_coordinator 

2252 ): 

2253 self.test_function = self._make_test_function_exact() 

2254 test_function_runner = _ExactTestFunction( 

2255 self.test_function, callbacks 

2256 ) 

2257 else: 

2258 self.test_function = self.make_test_function() 

2259 test_function_runner = _TestFunction(self.test_function, callbacks) 

2260 return test_function_runner 

2261 

2262 def predict_step(self, data): 

2263 """The logic for one inference step. 

2264 

2265 This method can be overridden to support custom inference logic. 

2266 This method is called by `Model.make_predict_function`. 

2267 

2268 This method should contain the mathematical logic for one step of 

2269 inference. This typically includes the forward pass. 

2270 

2271 Configuration details for *how* this logic is run (e.g. `tf.function` 

2272 and `tf.distribute.Strategy` settings), should be left to 

2273 `Model.make_predict_function`, which can also be overridden. 

2274 

2275 Args: 

2276 data: A nested structure of `Tensor`s. 

2277 

2278 Returns: 

2279 The result of one inference step, typically the output of calling the 

2280 `Model` on data. 

2281 """ 

2282 x, _, _ = data_adapter.unpack_x_y_sample_weight(data) 

2283 return self(x, training=False) 

2284 

2285 def make_predict_function(self, force=False): 

2286 """Creates a function that executes one step of inference. 

2287 

2288 This method can be overridden to support custom inference logic. 

2289 This method is called by `Model.predict` and `Model.predict_on_batch`. 

2290 

2291 Typically, this method directly controls `tf.function` and 

2292 `tf.distribute.Strategy` settings, and delegates the actual evaluation 

2293 logic to `Model.predict_step`. 

2294 

2295 This function is cached the first time `Model.predict` or 

2296 `Model.predict_on_batch` is called. The cache is cleared whenever 

2297 `Model.compile` is called. You can skip the cache and generate again the 

2298 function with `force=True`. 

2299 

2300 Args: 

2301 force: Whether to regenerate the predict function and skip the cached 

2302 function if available. 

2303 

2304 Returns: 

2305 Function. The function created by this method should accept a 

2306 `tf.data.Iterator`, and return the outputs of the `Model`. 

2307 """ 

2308 if self.predict_function is not None and not force: 

2309 return self.predict_function 

2310 

2311 def step_function(model, iterator): 

2312 """Runs a single evaluation step.""" 

2313 

2314 def run_step(data): 

2315 outputs = model.predict_step(data) 

2316 # Ensure counter is updated only if `test_step` succeeds. 

2317 with tf.control_dependencies(_minimum_control_deps(outputs)): 

2318 model._predict_counter.assign_add(1) 

2319 return outputs 

2320 

2321 if self.jit_compile: 

2322 run_step = tf.function( 

2323 run_step, jit_compile=True, reduce_retracing=True 

2324 ) 

2325 

2326 data = next(iterator) 

2327 outputs = model.distribute_strategy.run(run_step, args=(data,)) 

2328 outputs = reduce_per_replica( 

2329 outputs, self.distribute_strategy, reduction="concat" 

2330 ) 

2331 return outputs 

2332 

2333 # Special case if steps_per_execution is one. 

2334 if ( 

2335 self._steps_per_execution is None 

2336 or self._steps_per_execution.numpy().item() == 1 

2337 ): 

2338 

2339 def predict_function(iterator): 

2340 """Runs an evaluation execution with a single step.""" 

2341 return step_function(self, iterator) 

2342 

2343 else: 

2344 

2345 def predict_function(iterator): 

2346 """Runs an evaluation execution with multiple steps.""" 

2347 outputs = step_function(self, iterator) 

2348 for _ in tf.range(self._steps_per_execution - 1): 

2349 tf.autograph.experimental.set_loop_options( 

2350 shape_invariants=[ 

2351 ( 

2352 outputs, 

2353 tf.nest.map_structure( 

2354 lambda t: tf_utils.get_tensor_spec( 

2355 t, dynamic_batch=True 

2356 ).shape, 

2357 outputs, 

2358 ), 

2359 ) 

2360 ] 

2361 ) 

2362 step_outputs = step_function(self, iterator) 

2363 outputs = tf.nest.map_structure( 

2364 lambda t1, t2: concat([t1, t2]), outputs, step_outputs 

2365 ) 

2366 return outputs 

2367 

2368 if not self.run_eagerly: 

2369 predict_function = tf.function( 

2370 predict_function, reduce_retracing=True 

2371 ) 

2372 self.predict_function = predict_function 

2373 

2374 return self.predict_function 

2375 

2376 @traceback_utils.filter_traceback 

2377 def predict( 

2378 self, 

2379 x, 

2380 batch_size=None, 

2381 verbose="auto", 

2382 steps=None, 

2383 callbacks=None, 

2384 max_queue_size=10, 

2385 workers=1, 

2386 use_multiprocessing=False, 

2387 ): 

2388 """Generates output predictions for the input samples. 

2389 

2390 Computation is done in batches. This method is designed for batch 

2391 processing of large numbers of inputs. It is not intended for use inside 

2392 of loops that iterate over your data and process small numbers of inputs 

2393 at a time. 

2394 

2395 For small numbers of inputs that fit in one batch, 

2396 directly use `__call__()` for faster execution, e.g., 

2397 `model(x)`, or `model(x, training=False)` if you have layers such as 

2398 `tf.keras.layers.BatchNormalization` that behave differently during 

2399 inference. You may pair the individual model call with a `tf.function` 

2400 for additional performance inside your inner loop. 

2401 If you need access to numpy array values instead of tensors after your 

2402 model call, you can use `tensor.numpy()` to get the numpy array value of 

2403 an eager tensor. 

2404 

2405 Also, note the fact that test loss is not affected by 

2406 regularization layers like noise and dropout. 

2407 

2408 Note: See [this FAQ entry]( 

2409 https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call) 

2410 for more details about the difference between `Model` methods 

2411 `predict()` and `__call__()`. 

2412 

2413 Args: 

2414 x: Input samples. It could be: 

2415 - A Numpy array (or array-like), or a list of arrays 

2416 (in case the model has multiple inputs). 

2417 - A TensorFlow tensor, or a list of tensors 

2418 (in case the model has multiple inputs). 

2419 - A `tf.data` dataset. 

2420 - A generator or `keras.utils.Sequence` instance. 

2421 A more detailed description of unpacking behavior for iterator 

2422 types (Dataset, generator, Sequence) is given in the `Unpacking 

2423 behavior for iterator-like inputs` section of `Model.fit`. 

2424 batch_size: Integer or `None`. 

2425 Number of samples per batch. 

2426 If unspecified, `batch_size` will default to 32. 

2427 Do not specify the `batch_size` if your data is in the 

2428 form of dataset, generators, or `keras.utils.Sequence` instances 

2429 (since they generate batches). 

2430 verbose: `"auto"`, 0, 1, or 2. Verbosity mode. 

2431 0 = silent, 1 = progress bar, 2 = single line. 

2432 `"auto"` becomes 1 for most cases, and to 2 when used with 

2433 `ParameterServerStrategy`. Note that the progress bar is not 

2434 particularly useful when logged to a file, so `verbose=2` is 

2435 recommended when not running interactively (e.g. in a production 

2436 environment). Defaults to 'auto'. 

2437 steps: Total number of steps (batches of samples) 

2438 before declaring the prediction round finished. 

2439 Ignored with the default value of `None`. If x is a `tf.data` 

2440 dataset and `steps` is None, `predict()` will 

2441 run until the input dataset is exhausted. 

2442 callbacks: List of `keras.callbacks.Callback` instances. 

2443 List of callbacks to apply during prediction. 

2444 See [callbacks]( 

2445 https://www.tensorflow.org/api_docs/python/tf/keras/callbacks). 

2446 max_queue_size: Integer. Used for generator or 

2447 `keras.utils.Sequence` input only. Maximum size for the 

2448 generator queue. If unspecified, `max_queue_size` will default 

2449 to 10. 

2450 workers: Integer. Used for generator or `keras.utils.Sequence` input 

2451 only. Maximum number of processes to spin up when using 

2452 process-based threading. If unspecified, `workers` will default 

2453 to 1. 

2454 use_multiprocessing: Boolean. Used for generator or 

2455 `keras.utils.Sequence` input only. If `True`, use process-based 

2456 threading. If unspecified, `use_multiprocessing` will default to 

2457 `False`. Note that because this implementation relies on 

2458 multiprocessing, you should not pass non-picklable arguments to 

2459 the generator as they can't be passed easily to children 

2460 processes. 

2461 

2462 See the discussion of `Unpacking behavior for iterator-like inputs` for 

2463 `Model.fit`. Note that Model.predict uses the same interpretation rules 

2464 as `Model.fit` and `Model.evaluate`, so inputs must be unambiguous for 

2465 all three methods. 

2466 

2467 Returns: 

2468 Numpy array(s) of predictions. 

2469 

2470 Raises: 

2471 RuntimeError: If `model.predict` is wrapped in a `tf.function`. 

2472 ValueError: In case of mismatch between the provided 

2473 input data and the model's expectations, 

2474 or in case a stateful model receives a number of samples 

2475 that is not a multiple of the batch size. 

2476 """ 

2477 base_layer.keras_api_gauge.get_cell("predict").set(True) 

2478 version_utils.disallow_legacy_graph("Model", "predict") 

2479 self._check_call_args("predict") 

2480 _disallow_inside_tf_function("predict") 

2481 

2482 # TODO(yashkatariya): Cache model on the coordinator for faster 

2483 # prediction. If running under PSS, then swap it with OneDeviceStrategy 

2484 # so that execution will run on the coordinator. 

2485 original_pss_strategy = None 

2486 if self.distribute_strategy._should_use_with_coordinator: 

2487 original_pss_strategy = self.distribute_strategy 

2488 self._distribution_strategy = None 

2489 

2490 # Cluster coordinator is set by `.fit()` and `.evaluate()` which is not 

2491 # needed in `.predict()` because all the predictions happen on the 

2492 # coordinator/locally. 

2493 if self._cluster_coordinator: 

2494 self._cluster_coordinator = None 

2495 

2496 verbose = _get_verbosity(verbose, self.distribute_strategy) 

2497 outputs = None 

2498 with self.distribute_strategy.scope(): 

2499 # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 

2500 dataset_types = (tf.compat.v1.data.Dataset, tf.data.Dataset) 

2501 if ( 

2502 self._in_multi_worker_mode() 

2503 or _is_tpu_multi_host(self.distribute_strategy) 

2504 ) and isinstance(x, dataset_types): 

2505 try: 

2506 options = tf.data.Options() 

2507 data_option = tf.data.experimental.AutoShardPolicy.DATA 

2508 options.experimental_distribute.auto_shard_policy = ( 

2509 data_option 

2510 ) 

2511 x = x.with_options(options) 

2512 except ValueError: 

2513 warnings.warn( 

2514 "Using Model.predict with MultiWorkerMirroredStrategy " 

2515 "or TPUStrategy and AutoShardPolicy.FILE might lead to " 

2516 "out-of-order result. Consider setting it to " 

2517 "AutoShardPolicy.DATA.", 

2518 stacklevel=2, 

2519 ) 

2520 

2521 data_handler = data_adapter.get_data_handler( 

2522 x=x, 

2523 batch_size=batch_size, 

2524 steps_per_epoch=steps, 

2525 initial_epoch=0, 

2526 epochs=1, 

2527 max_queue_size=max_queue_size, 

2528 workers=workers, 

2529 use_multiprocessing=use_multiprocessing, 

2530 model=self, 

2531 steps_per_execution=self._steps_per_execution, 

2532 ) 

2533 

2534 # Container that configures and calls `tf.keras.Callback`s. 

2535 if not isinstance(callbacks, callbacks_module.CallbackList): 

2536 callbacks = callbacks_module.CallbackList( 

2537 callbacks, 

2538 add_history=True, 

2539 add_progbar=verbose != 0, 

2540 model=self, 

2541 verbose=verbose, 

2542 epochs=1, 

2543 steps=data_handler.inferred_steps, 

2544 ) 

2545 

2546 self.predict_function = self.make_predict_function() 

2547 self._predict_counter.assign(0) 

2548 callbacks.on_predict_begin() 

2549 batch_outputs = None 

2550 for _, iterator in data_handler.enumerate_epochs(): # Single epoch. 

2551 with data_handler.catch_stop_iteration(): 

2552 for step in data_handler.steps(): 

2553 callbacks.on_predict_batch_begin(step) 

2554 tmp_batch_outputs = self.predict_function(iterator) 

2555 if data_handler.should_sync: 

2556 context.async_wait() 

2557 batch_outputs = ( 

2558 tmp_batch_outputs # No error, now safe to assign. 

2559 ) 

2560 if outputs is None: 

2561 outputs = tf.nest.map_structure( 

2562 lambda batch_output: [batch_output], 

2563 batch_outputs, 

2564 ) 

2565 else: 

2566 tf.__internal__.nest.map_structure_up_to( 

2567 batch_outputs, 

2568 lambda output, batch_output: output.append( 

2569 batch_output 

2570 ), 

2571 outputs, 

2572 batch_outputs, 

2573 ) 

2574 end_step = step + data_handler.step_increment 

2575 callbacks.on_predict_batch_end( 

2576 end_step, {"outputs": batch_outputs} 

2577 ) 

2578 if batch_outputs is None: 

2579 raise ValueError( 

2580 "Unexpected result of `predict_function` " 

2581 "(Empty batch_outputs). Please use " 

2582 "`Model.compile(..., run_eagerly=True)`, or " 

2583 "`tf.config.run_functions_eagerly(True)` for more " 

2584 "information of where went wrong, or file a " 

2585 "issue/bug to `tf.keras`." 

2586 ) 

2587 callbacks.on_predict_end() 

2588 all_outputs = tf.__internal__.nest.map_structure_up_to( 

2589 batch_outputs, potentially_ragged_concat, outputs 

2590 ) 

2591 

2592 # If originally PSS strategy was used, then replace it back since 

2593 # predict is running under `OneDeviceStrategy` after the swap and once 

2594 # its done we need to replace it back to PSS again. 

2595 if original_pss_strategy is not None: 

2596 self._distribution_strategy = original_pss_strategy 

2597 

2598 return tf_utils.sync_to_numpy_or_python_type(all_outputs) 

2599 

2600 def reset_metrics(self): 

2601 """Resets the state of all the metrics in the model. 

2602 

2603 Examples: 

2604 

2605 >>> inputs = tf.keras.layers.Input(shape=(3,)) 

2606 >>> outputs = tf.keras.layers.Dense(2)(inputs) 

2607 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs) 

2608 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) 

2609 

2610 >>> x = np.random.random((2, 3)) 

2611 >>> y = np.random.randint(0, 2, (2, 2)) 

2612 >>> _ = model.fit(x, y, verbose=0) 

2613 >>> assert all(float(m.result()) for m in model.metrics) 

2614 

2615 >>> model.reset_metrics() 

2616 >>> assert all(float(m.result()) == 0 for m in model.metrics) 

2617 

2618 """ 

2619 for m in self.metrics: 

2620 m.reset_state() 

2621 

2622 def train_on_batch( 

2623 self, 

2624 x, 

2625 y=None, 

2626 sample_weight=None, 

2627 class_weight=None, 

2628 reset_metrics=True, 

2629 return_dict=False, 

2630 ): 

2631 """Runs a single gradient update on a single batch of data. 

2632 

2633 Args: 

2634 x: Input data. It could be: 

2635 - A Numpy array (or array-like), or a list of arrays 

2636 (in case the model has multiple inputs). 

2637 - A TensorFlow tensor, or a list of tensors 

2638 (in case the model has multiple inputs). 

2639 - A dict mapping input names to the corresponding array/tensors, 

2640 if the model has named inputs. 

2641 y: Target data. Like the input data `x`, it could be either Numpy 

2642 array(s) or TensorFlow tensor(s). 

2643 sample_weight: Optional array of the same length as x, containing 

2644 weights to apply to the model's loss for each sample. In the case 

2645 of temporal data, you can pass a 2D array with shape (samples, 

2646 sequence_length), to apply a different weight to every timestep of 

2647 every sample. 

2648 class_weight: Optional dictionary mapping class indices (integers) 

2649 to a weight (float) to apply to the model's loss for the samples 

2650 from this class during training. This can be useful to tell the 

2651 model to "pay more attention" to samples from an under-represented 

2652 class. When `class_weight` is specified and targets have a rank of 

2653 2 or greater, either `y` must be one-hot encoded, or an explicit 

2654 final dimension of `1` must be included for sparse class labels. 

2655 reset_metrics: If `True`, the metrics returned will be only for this 

2656 batch. If `False`, the metrics will be statefully accumulated 

2657 across batches. 

2658 return_dict: If `True`, loss and metric results are returned as a 

2659 dict, with each key being the name of the metric. If `False`, they 

2660 are returned as a list. 

2661 

2662 Returns: 

2663 Scalar training loss 

2664 (if the model has a single output and no metrics) 

2665 or list of scalars (if the model has multiple outputs 

2666 and/or metrics). The attribute `model.metrics_names` will give you 

2667 the display labels for the scalar outputs. 

2668 

2669 Raises: 

2670 RuntimeError: If `model.train_on_batch` is wrapped in a `tf.function`. 

2671 """ 

2672 self._assert_compile_was_called() 

2673 self._check_call_args("train_on_batch") 

2674 _disallow_inside_tf_function("train_on_batch") 

2675 if reset_metrics: 

2676 self.reset_metrics() 

2677 with self.distribute_strategy.scope(), training_utils.RespectCompiledTrainableState( # noqa: E501 

2678 self 

2679 ): 

2680 iterator = data_adapter.single_batch_iterator( 

2681 self.distribute_strategy, x, y, sample_weight, class_weight 

2682 ) 

2683 self.train_function = self.make_train_function() 

2684 logs = self.train_function(iterator) 

2685 

2686 logs = tf_utils.sync_to_numpy_or_python_type(logs) 

2687 if return_dict: 

2688 return logs 

2689 else: 

2690 return flatten_metrics_in_order(logs, self.metrics_names) 

2691 

2692 def test_on_batch( 

2693 self, 

2694 x, 

2695 y=None, 

2696 sample_weight=None, 

2697 reset_metrics=True, 

2698 return_dict=False, 

2699 ): 

2700 """Test the model on a single batch of samples. 

2701 

2702 Args: 

2703 x: Input data. It could be: 

2704 - A Numpy array (or array-like), or a list of arrays (in case the 

2705 model has multiple inputs). 

2706 - A TensorFlow tensor, or a list of tensors (in case the model has 

2707 multiple inputs). 

2708 - A dict mapping input names to the corresponding array/tensors, 

2709 if the model has named inputs. 

2710 y: Target data. Like the input data `x`, it could be either Numpy 

2711 array(s) or TensorFlow tensor(s). It should be consistent with `x` 

2712 (you cannot have Numpy inputs and tensor targets, or inversely). 

2713 sample_weight: Optional array of the same length as x, containing 

2714 weights to apply to the model's loss for each sample. In the case 

2715 of temporal data, you can pass a 2D array with shape (samples, 

2716 sequence_length), to apply a different weight to every timestep of 

2717 every sample. 

2718 reset_metrics: If `True`, the metrics returned will be only for this 

2719 batch. If `False`, the metrics will be statefully accumulated 

2720 across batches. 

2721 return_dict: If `True`, loss and metric results are returned as a 

2722 dict, with each key being the name of the metric. If `False`, they 

2723 are returned as a list. 

2724 

2725 Returns: 

2726 Scalar test loss (if the model has a single output and no metrics) 

2727 or list of scalars (if the model has multiple outputs 

2728 and/or metrics). The attribute `model.metrics_names` will give you 

2729 the display labels for the scalar outputs. 

2730 

2731 Raises: 

2732 RuntimeError: If `model.test_on_batch` is wrapped in a 

2733 `tf.function`. 

2734 """ 

2735 self._assert_compile_was_called() 

2736 self._check_call_args("test_on_batch") 

2737 _disallow_inside_tf_function("test_on_batch") 

2738 if reset_metrics: 

2739 self.reset_metrics() 

2740 with self.distribute_strategy.scope(): 

2741 iterator = data_adapter.single_batch_iterator( 

2742 self.distribute_strategy, x, y, sample_weight 

2743 ) 

2744 self.test_function = self.make_test_function() 

2745 logs = self.test_function(iterator) 

2746 

2747 logs = tf_utils.sync_to_numpy_or_python_type(logs) 

2748 if return_dict: 

2749 return logs 

2750 else: 

2751 return flatten_metrics_in_order(logs, self.metrics_names) 

2752 

2753 def predict_on_batch(self, x): 

2754 """Returns predictions for a single batch of samples. 

2755 

2756 Args: 

2757 x: Input data. It could be: 

2758 - A Numpy array (or array-like), or a list of arrays (in case the 

2759 model has multiple inputs). 

2760 - A TensorFlow tensor, or a list of tensors (in case the model has 

2761 multiple inputs). 

2762 

2763 Returns: 

2764 Numpy array(s) of predictions. 

2765 

2766 Raises: 

2767 RuntimeError: If `model.predict_on_batch` is wrapped in a 

2768 `tf.function`. 

2769 """ 

2770 self._check_call_args("predict_on_batch") 

2771 _disallow_inside_tf_function("predict_on_batch") 

2772 with self.distribute_strategy.scope(): 

2773 iterator = data_adapter.single_batch_iterator( 

2774 self.distribute_strategy, x 

2775 ) 

2776 self.predict_function = self.make_predict_function() 

2777 outputs = self.predict_function(iterator) 

2778 return tf_utils.sync_to_numpy_or_python_type(outputs) 

2779 

2780 @doc_controls.do_not_generate_docs 

2781 def fit_generator( 

2782 self, 

2783 generator, 

2784 steps_per_epoch=None, 

2785 epochs=1, 

2786 verbose=1, 

2787 callbacks=None, 

2788 validation_data=None, 

2789 validation_steps=None, 

2790 validation_freq=1, 

2791 class_weight=None, 

2792 max_queue_size=10, 

2793 workers=1, 

2794 use_multiprocessing=False, 

2795 shuffle=True, 

2796 initial_epoch=0, 

2797 ): 

2798 """Fits the model on data yielded batch-by-batch by a Python generator. 

2799 

2800 DEPRECATED: 

2801 `Model.fit` now supports generators, so there is no longer any need to 

2802 use this endpoint. 

2803 """ 

2804 warnings.warn( 

2805 "`Model.fit_generator` is deprecated and " 

2806 "will be removed in a future version. " 

2807 "Please use `Model.fit`, which supports generators.", 

2808 stacklevel=2, 

2809 ) 

2810 return self.fit( 

2811 generator, 

2812 steps_per_epoch=steps_per_epoch, 

2813 epochs=epochs, 

2814 verbose=verbose, 

2815 callbacks=callbacks, 

2816 validation_data=validation_data, 

2817 validation_steps=validation_steps, 

2818 validation_freq=validation_freq, 

2819 class_weight=class_weight, 

2820 max_queue_size=max_queue_size, 

2821 workers=workers, 

2822 use_multiprocessing=use_multiprocessing, 

2823 shuffle=shuffle, 

2824 initial_epoch=initial_epoch, 

2825 ) 

2826 

2827 @doc_controls.do_not_generate_docs 

2828 def evaluate_generator( 

2829 self, 

2830 generator, 

2831 steps=None, 

2832 callbacks=None, 

2833 max_queue_size=10, 

2834 workers=1, 

2835 use_multiprocessing=False, 

2836 verbose=0, 

2837 ): 

2838 """Evaluates the model on a data generator. 

2839 

2840 DEPRECATED: 

2841 `Model.evaluate` now supports generators, so there is no longer any 

2842 need to use this endpoint. 

2843 """ 

2844 warnings.warn( 

2845 "`Model.evaluate_generator` is deprecated and " 

2846 "will be removed in a future version. " 

2847 "Please use `Model.evaluate`, which supports generators.", 

2848 stacklevel=2, 

2849 ) 

2850 self._check_call_args("evaluate_generator") 

2851 

2852 return self.evaluate( 

2853 generator, 

2854 steps=steps, 

2855 max_queue_size=max_queue_size, 

2856 workers=workers, 

2857 use_multiprocessing=use_multiprocessing, 

2858 verbose=verbose, 

2859 callbacks=callbacks, 

2860 ) 

2861 

2862 @doc_controls.do_not_generate_docs 

2863 def predict_generator( 

2864 self, 

2865 generator, 

2866 steps=None, 

2867 callbacks=None, 

2868 max_queue_size=10, 

2869 workers=1, 

2870 use_multiprocessing=False, 

2871 verbose=0, 

2872 ): 

2873 """Generates predictions for the input samples from a data generator. 

2874 

2875 DEPRECATED: 

2876 `Model.predict` now supports generators, so there is no longer any 

2877 need to use this endpoint. 

2878 """ 

2879 warnings.warn( 

2880 "`Model.predict_generator` is deprecated and " 

2881 "will be removed in a future version. " 

2882 "Please use `Model.predict`, which supports generators.", 

2883 stacklevel=2, 

2884 ) 

2885 return self.predict( 

2886 generator, 

2887 steps=steps, 

2888 max_queue_size=max_queue_size, 

2889 workers=workers, 

2890 use_multiprocessing=use_multiprocessing, 

2891 verbose=verbose, 

2892 callbacks=callbacks, 

2893 ) 

2894 

2895 ###################################################################### 

2896 # Functions below are not training related. They are for model weights 

2897 # tracking, save/load, serialization, etc. 

2898 ###################################################################### 

2899 

2900 @property 

2901 def trainable_weights(self): 

2902 self._assert_weights_created() 

2903 if not self._trainable: 

2904 return [] 

2905 trainable_variables = [] 

2906 for trackable_obj in self._self_tracked_trackables: 

2907 trainable_variables += trackable_obj.trainable_variables 

2908 trainable_variables += self._trainable_weights 

2909 return self._dedup_weights(trainable_variables) 

2910 

2911 @property 

2912 def non_trainable_weights(self): 

2913 self._assert_weights_created() 

2914 non_trainable_variables = [] 

2915 for trackable_obj in self._self_tracked_trackables: 

2916 non_trainable_variables += trackable_obj.non_trainable_variables 

2917 

2918 if not self._trainable: 

2919 # Return order is all trainable vars, then all non-trainable vars. 

2920 trainable_variables = [] 

2921 for trackable_obj in self._self_tracked_trackables: 

2922 trainable_variables += trackable_obj.trainable_variables 

2923 

2924 non_trainable_variables = ( 

2925 trainable_variables 

2926 + self._trainable_weights 

2927 + non_trainable_variables 

2928 + self._non_trainable_weights 

2929 ) 

2930 else: 

2931 non_trainable_variables = ( 

2932 non_trainable_variables + self._non_trainable_weights 

2933 ) 

2934 

2935 return self._dedup_weights(non_trainable_variables) 

2936 

2937 def get_weights(self): 

2938 """Retrieves the weights of the model. 

2939 

2940 Returns: 

2941 A flat list of Numpy arrays. 

2942 """ 

2943 with self.distribute_strategy.scope(): 

2944 return super().get_weights() 

2945 

2946 @traceback_utils.filter_traceback 

2947 def save(self, filepath, overwrite=True, save_format=None, **kwargs): 

2948 """Saves a model as a TensorFlow SavedModel or HDF5 file. 

2949 

2950 See the [Serialization and Saving guide]( 

2951 https://keras.io/guides/serialization_and_saving/) for details. 

2952 

2953 Args: 

2954 model: Keras model instance to be saved. 

2955 filepath: `str` or `pathlib.Path` object. Path where to save the 

2956 model. 

2957 overwrite: Whether we should overwrite any existing model at the 

2958 target location, or instead ask the user via an interactive 

2959 prompt. 

2960 save_format: Either `"keras"`, `"tf"`, `"h5"`, 

2961 indicating whether to save the model 

2962 in the native Keras format (`.keras`), 

2963 in the TensorFlow SavedModel format 

2964 (referred to as "SavedModel" below), 

2965 or in the legacy HDF5 format (`.h5`). 

2966 Defaults to `"tf"` in TF 2.X, and `"h5"` in TF 1.X. 

2967 

2968 SavedModel format arguments: 

2969 include_optimizer: Only applied to SavedModel and legacy HDF5 

2970 formats. If False, do not save the optimizer state. 

2971 Defaults to `True`. 

2972 signatures: Only applies to SavedModel format. Signatures to save 

2973 with the SavedModel. See the `signatures` argument in 

2974 `tf.saved_model.save` for details. 

2975 options: Only applies to SavedModel format. 

2976 `tf.saved_model.SaveOptions` object that specifies SavedModel 

2977 saving options. 

2978 save_traces: Only applies to SavedModel format. When enabled, the 

2979 SavedModel will store the function traces for each layer. This 

2980 can be disabled, so that only the configs of each layer are 

2981 stored. Defaults to `True`. 

2982 Disabling this will decrease serialization time 

2983 and reduce file size, but it requires that all custom 

2984 layers/models implement a `get_config()` method. 

2985 

2986 Example: 

2987 

2988 ```python 

2989 model = tf.keras.Sequential([ 

2990 tf.keras.layers.Dense(5, input_shape=(3,)), 

2991 tf.keras.layers.Softmax()]) 

2992 model.save("model.keras") 

2993 loaded_model = tf.keras.models.load_model("model.keras") 

2994 x = tf.random.uniform((10, 3)) 

2995 assert np.allclose(model.predict(x), loaded_model.predict(x)) 

2996 ``` 

2997 

2998 Note that `model.save()` is an alias for `tf.keras.models.save_model()`. 

2999 """ 

3000 saving_api.save_model( 

3001 self, 

3002 filepath=filepath, 

3003 overwrite=overwrite, 

3004 save_format=save_format, 

3005 **kwargs, 

3006 ) 

3007 

3008 @traceback_utils.filter_traceback 

3009 def save_weights( 

3010 self, filepath, overwrite=True, save_format=None, options=None 

3011 ): 

3012 """Saves all layer weights. 

3013 

3014 Either saves in HDF5 or in TensorFlow format based on the `save_format` 

3015 argument. 

3016 

3017 When saving in HDF5 format, the weight file has: 

3018 - `layer_names` (attribute), a list of strings 

3019 (ordered names of model layers). 

3020 - For every layer, a `group` named `layer.name` 

3021 - For every such layer group, a group attribute `weight_names`, 

3022 a list of strings 

3023 (ordered names of weights tensor of the layer). 

3024 - For every weight in the layer, a dataset 

3025 storing the weight value, named after the weight tensor. 

3026 

3027 When saving in TensorFlow format, all objects referenced by the network 

3028 are saved in the same format as `tf.train.Checkpoint`, including any 

3029 `Layer` instances or `Optimizer` instances assigned to object 

3030 attributes. For networks constructed from inputs and outputs using 

3031 `tf.keras.Model(inputs, outputs)`, `Layer` instances used by the network 

3032 are tracked/saved automatically. For user-defined classes which inherit 

3033 from `tf.keras.Model`, `Layer` instances must be assigned to object 

3034 attributes, typically in the constructor. See the documentation of 

3035 `tf.train.Checkpoint` and `tf.keras.Model` for details. 

3036 

3037 While the formats are the same, do not mix `save_weights` and 

3038 `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should 

3039 be loaded using `Model.load_weights`. Checkpoints saved using 

3040 `tf.train.Checkpoint.save` should be restored using the corresponding 

3041 `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over 

3042 `save_weights` for training checkpoints. 

3043 

3044 The TensorFlow format matches objects and variables by starting at a 

3045 root object, `self` for `save_weights`, and greedily matching attribute 

3046 names. For `Model.save` this is the `Model`, and for `Checkpoint.save` 

3047 this is the `Checkpoint` even if the `Checkpoint` has a model attached. 

3048 This means saving a `tf.keras.Model` using `save_weights` and loading 

3049 into a `tf.train.Checkpoint` with a `Model` attached (or vice versa) 

3050 will not match the `Model`'s variables. See the 

3051 [guide to training checkpoints]( 

3052 https://www.tensorflow.org/guide/checkpoint) for details on 

3053 the TensorFlow format. 

3054 

3055 Args: 

3056 filepath: String or PathLike, path to the file to save the weights 

3057 to. When saving in TensorFlow format, this is the prefix used 

3058 for checkpoint files (multiple files are generated). Note that 

3059 the '.h5' suffix causes weights to be saved in HDF5 format. 

3060 overwrite: Whether to silently overwrite any existing file at the 

3061 target location, or provide the user with a manual prompt. 

3062 save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or 

3063 '.keras' will default to HDF5 if `save_format` is `None`. 

3064 Otherwise, `None` becomes 'tf'. Defaults to `None`. 

3065 options: Optional `tf.train.CheckpointOptions` object that specifies 

3066 options for saving weights. 

3067 

3068 Raises: 

3069 ImportError: If `h5py` is not available when attempting to save in 

3070 HDF5 format. 

3071 """ 

3072 saving_api.save_weights( 

3073 self, 

3074 filepath=filepath, 

3075 overwrite=overwrite, 

3076 save_format=save_format, 

3077 options=options, 

3078 ) 

3079 

3080 @traceback_utils.filter_traceback 

3081 def load_weights( 

3082 self, filepath, skip_mismatch=False, by_name=False, options=None 

3083 ): 

3084 """Loads all layer weights from a saved files. 

3085 

3086 The saved file could be a SavedModel file, a `.keras` file (v3 saving 

3087 format), or a file created via `model.save_weights()`. 

3088 

3089 By default, weights are loaded based on the network's 

3090 topology. This means the architecture should be the same as when the 

3091 weights were saved. Note that layers that don't have weights are not 

3092 taken into account in the topological ordering, so adding or removing 

3093 layers is fine as long as they don't have weights. 

3094 

3095 **Partial weight loading** 

3096 

3097 If you have modified your model, for instance by adding a new layer 

3098 (with weights) or by changing the shape of the weights of a layer, 

3099 you can choose to ignore errors and continue loading 

3100 by setting `skip_mismatch=True`. In this case any layer with 

3101 mismatching weights will be skipped. A warning will be displayed 

3102 for each skipped layer. 

3103 

3104 **Weight loading by name** 

3105 

3106 If your weights are saved as a `.h5` file created 

3107 via `model.save_weights()`, you can use the argument `by_name=True`. 

3108 

3109 In this case, weights are loaded into layers only if they share 

3110 the same name. This is useful for fine-tuning or transfer-learning 

3111 models where some of the layers have changed. 

3112 

3113 Note that only topological loading (`by_name=False`) is supported when 

3114 loading weights from the `.keras` v3 format or from the TensorFlow 

3115 SavedModel format. 

3116 

3117 Args: 

3118 filepath: String, path to the weights file to load. For weight files 

3119 in TensorFlow format, this is the file prefix (the same as was 

3120 passed to `save_weights()`). This can also be a path to a 

3121 SavedModel or a `.keras` file (v3 saving format) saved 

3122 via `model.save()`. 

3123 skip_mismatch: Boolean, whether to skip loading of layers where 

3124 there is a mismatch in the number of weights, or a mismatch in 

3125 the shape of the weights. 

3126 by_name: Boolean, whether to load weights by name or by topological 

3127 order. Only topological loading is supported for weight files in 

3128 the `.keras` v3 format or in the TensorFlow SavedModel format. 

3129 options: Optional `tf.train.CheckpointOptions` object that specifies 

3130 options for loading weights (only valid for a SavedModel file). 

3131 """ 

3132 return saving_api.load_weights( 

3133 self, 

3134 filepath=filepath, 

3135 by_name=by_name, 

3136 skip_mismatch=skip_mismatch, 

3137 options=options, 

3138 ) 

3139 

3140 def _updated_config(self): 

3141 """Util shared between different serialization methods. 

3142 

3143 Returns: 

3144 Model config with Keras version information added. 

3145 """ 

3146 from keras.src import __version__ as keras_version 

3147 

3148 config = self.get_config() 

3149 model_config = { 

3150 "class_name": self.__class__.__name__, 

3151 "config": config, 

3152 "keras_version": keras_version, 

3153 "backend": backend.backend(), 

3154 } 

3155 return model_config 

3156 

3157 @generic_utils.default 

3158 def get_config(self): 

3159 """Returns the config of the `Model`. 

3160 

3161 Config is a Python dictionary (serializable) containing the 

3162 configuration of an object, which in this case is a `Model`. This allows 

3163 the `Model` to be be reinstantiated later (without its trained weights) 

3164 from this configuration. 

3165 

3166 Note that `get_config()` does not guarantee to return a fresh copy of 

3167 dict every time it is called. The callers should make a copy of the 

3168 returned dict if they want to modify it. 

3169 

3170 Developers of subclassed `Model` are advised to override this method, 

3171 and continue to update the dict from `super(MyModel, self).get_config()` 

3172 to provide the proper configuration of this `Model`. The default config 

3173 will return config dict for init parameters if they are basic types. 

3174 Raises `NotImplementedError` when in cases where a custom 

3175 `get_config()` implementation is required for the subclassed model. 

3176 

3177 Returns: 

3178 Python dictionary containing the configuration of this `Model`. 

3179 """ 

3180 # If sublcass doesn't implement `get_config()` parse from init args 

3181 # otherwise default to empty dict 

3182 if generic_utils.is_default(self.get_config): 

3183 try: 

3184 config = base_layer.Layer.get_config(self) 

3185 except NotImplementedError: 

3186 config = {} 

3187 logging.warning( 

3188 "Model's `__init__()` arguments contain non-serializable " 

3189 "objects. Please implement a `get_config()` method in the " 

3190 "subclassed Model for proper saving and loading. " 

3191 "Defaulting to empty config." 

3192 ) 

3193 else: 

3194 config = {} 

3195 return config 

3196 

3197 @classmethod 

3198 def from_config(cls, config, custom_objects=None): 

3199 # `from_config` assumes `cls` is either `Functional` or a child class of 

3200 # `Functional`. In the case that `cls` is meant to behave like a child 

3201 # class of `Functional` but only inherits from the `Model` class, we 

3202 # have to call `cls(...)` instead of `Functional.from_config`. 

3203 from keras.src.engine import functional 

3204 

3205 with serialization.SharedObjectLoadingScope(): 

3206 functional_config_keys = [ 

3207 "name", 

3208 "layers", 

3209 "input_layers", 

3210 "output_layers", 

3211 ] 

3212 is_functional_config = all( 

3213 key in config for key in functional_config_keys 

3214 ) 

3215 argspec = tf_inspect.getfullargspec(cls.__init__) 

3216 functional_init_args = tf_inspect.getfullargspec( 

3217 functional.Functional.__init__ 

3218 ).args[1:] 

3219 revivable_as_functional = ( 

3220 cls in {functional.Functional, Model} 

3221 or argspec.args[1:] == functional_init_args 

3222 or (argspec.varargs == "args" and argspec.varkw == "kwargs") 

3223 ) 

3224 if is_functional_config and revivable_as_functional: 

3225 # Revive Functional model 

3226 # (but not Functional subclasses with a custom __init__) 

3227 inputs, outputs, layers = functional.reconstruct_from_config( 

3228 config, custom_objects 

3229 ) 

3230 model = cls( 

3231 inputs=inputs, outputs=outputs, name=config.get("name") 

3232 ) 

3233 functional.connect_ancillary_layers(model, layers) 

3234 

3235 else: 

3236 # Either the model has a custom __init__, or the config 

3237 # does not contain all the information necessary to 

3238 # revive a Functional model. This happens when the user creates 

3239 # subclassed models where `get_config()` is returning 

3240 # insufficient information to be considered a Functional model. 

3241 # In this case, we fall back to provide all config into the 

3242 # constructor of the class. 

3243 try: 

3244 model = cls(**config) 

3245 except TypeError as e: 

3246 raise TypeError( 

3247 "Unable to revive model from config. When overriding " 

3248 "the `get_config()` method, make sure that the " 

3249 "returned config contains all items used as arguments " 

3250 f"in the constructor to {cls}, " 

3251 "which is the default behavior. " 

3252 "You can override this default behavior by defining a " 

3253 "`from_config(cls, config)` class method to specify " 

3254 "how to create an " 

3255 f"instance of {cls.__name__} from its config.\n\n" 

3256 f"Received config={config}\n\n" 

3257 f"Error encountered during deserialization: {e}" 

3258 ) 

3259 return model 

3260 

3261 def to_json(self, **kwargs): 

3262 """Returns a JSON string containing the network configuration. 

3263 

3264 To load a network from a JSON save file, use 

3265 `keras.models.model_from_json(json_string, custom_objects={})`. 

3266 

3267 Args: 

3268 **kwargs: Additional keyword arguments to be passed to 

3269 *`json.dumps()`. 

3270 

3271 Returns: 

3272 A JSON string. 

3273 """ 

3274 model_config = self._updated_config() 

3275 return json.dumps( 

3276 model_config, default=json_utils.get_json_type, **kwargs 

3277 ) 

3278 

3279 def to_yaml(self, **kwargs): 

3280 """Returns a yaml string containing the network configuration. 

3281 

3282 Note: Since TF 2.6, this method is no longer supported and will raise a 

3283 RuntimeError. 

3284 

3285 To load a network from a yaml save file, use 

3286 `keras.models.model_from_yaml(yaml_string, custom_objects={})`. 

3287 

3288 `custom_objects` should be a dictionary mapping 

3289 the names of custom losses / layers / etc to the corresponding 

3290 functions / classes. 

3291 

3292 Args: 

3293 **kwargs: Additional keyword arguments 

3294 to be passed to `yaml.dump()`. 

3295 

3296 Returns: 

3297 A YAML string. 

3298 

3299 Raises: 

3300 RuntimeError: announces that the method poses a security risk 

3301 """ 

3302 raise RuntimeError( 

3303 "Method `model.to_yaml()` has been removed due to security risk of " 

3304 "arbitrary code execution. Please use `model.to_json()` instead." 

3305 ) 

3306 

3307 def reset_states(self): 

3308 for layer in self.layers: 

3309 if hasattr(layer, "reset_states") and getattr( 

3310 layer, "stateful", False 

3311 ): 

3312 layer.reset_states() 

3313 

3314 @property 

3315 @doc_controls.do_not_generate_docs 

3316 def state_updates(self): 

3317 """Deprecated, do NOT use! 

3318 

3319 Returns the `updates` from all layers that are stateful. 

3320 

3321 This is useful for separating training updates and 

3322 state updates, e.g. when we need to update a layer's internal state 

3323 during prediction. 

3324 

3325 Returns: 

3326 A list of update ops. 

3327 """ 

3328 warnings.warn( 

3329 "`Model.state_updates` will be removed in a future version. " 

3330 "This property should not be used in TensorFlow 2.0, " 

3331 "as `updates` are applied automatically.", 

3332 stacklevel=2, 

3333 ) 

3334 state_updates = [] 

3335 for layer in self.layers: 

3336 if getattr(layer, "stateful", False): 

3337 if hasattr(layer, "updates"): 

3338 state_updates += layer.updates 

3339 return state_updates 

3340 

3341 @property 

3342 def weights(self): 

3343 """Returns the list of all layer variables/weights. 

3344 

3345 Note: This will not track the weights of nested `tf.Modules` that are 

3346 not themselves Keras layers. 

3347 

3348 Returns: 

3349 A list of variables. 

3350 """ 

3351 return self._dedup_weights(self._undeduplicated_weights) 

3352 

3353 @property 

3354 def _undeduplicated_weights(self): 

3355 """Returns the undeduplicated list of all layer variables/weights.""" 

3356 self._assert_weights_created() 

3357 weights = [] 

3358 for layer in self._self_tracked_trackables: 

3359 weights += layer.variables 

3360 weights += self._trainable_weights + self._non_trainable_weights 

3361 return weights 

3362 

3363 def summary( 

3364 self, 

3365 line_length=None, 

3366 positions=None, 

3367 print_fn=None, 

3368 expand_nested=False, 

3369 show_trainable=False, 

3370 layer_range=None, 

3371 ): 

3372 """Prints a string summary of the network. 

3373 

3374 Args: 

3375 line_length: Total length of printed lines 

3376 (e.g. set this to adapt the display to different 

3377 terminal window sizes). 

3378 positions: Relative or absolute positions of log elements 

3379 in each line. If not provided, becomes 

3380 `[0.3, 0.6, 0.70, 1.]`. Defaults to `None`. 

3381 print_fn: Print function to use. By default, prints to `stdout`. 

3382 If `stdout` doesn't work in your environment, change to `print`. 

3383 It will be called on each line of the summary. 

3384 You can set it to a custom function 

3385 in order to capture the string summary. 

3386 expand_nested: Whether to expand the nested models. 

3387 Defaults to `False`. 

3388 show_trainable: Whether to show if a layer is trainable. 

3389 Defaults to `False`. 

3390 layer_range: a list or tuple of 2 strings, 

3391 which is the starting layer name and ending layer name 

3392 (both inclusive) indicating the range of layers to be printed 

3393 in summary. It also accepts regex patterns instead of exact 

3394 name. In such case, start predicate will be the first element 

3395 it matches to `layer_range[0]` and the end predicate will be 

3396 the last element it matches to `layer_range[1]`. 

3397 By default `None` which considers all layers of model. 

3398 

3399 Raises: 

3400 ValueError: if `summary()` is called before the model is built. 

3401 """ 

3402 if not self.built: 

3403 raise ValueError( 

3404 "This model has not yet been built. " 

3405 "Build the model first by calling `build()` or by calling " 

3406 "the model on a batch of data." 

3407 ) 

3408 layer_utils.print_summary( 

3409 self, 

3410 line_length=line_length, 

3411 positions=positions, 

3412 print_fn=print_fn, 

3413 expand_nested=expand_nested, 

3414 show_trainable=show_trainable, 

3415 layer_range=layer_range, 

3416 ) 

3417 

3418 @property 

3419 def layers(self): 

3420 return list(self._flatten_layers(include_self=False, recursive=False)) 

3421 

3422 @layers.setter 

3423 def layers(self, _): 

3424 raise AttributeError( 

3425 "`Model.layers` attribute is reserved and should not be used. " 

3426 "Please use another name." 

3427 ) 

3428 

3429 def get_layer(self, name=None, index=None): 

3430 """Retrieves a layer based on either its name (unique) or index. 

3431 

3432 If `name` and `index` are both provided, `index` will take precedence. 

3433 Indices are based on order of horizontal graph traversal (bottom-up). 

3434 

3435 Args: 

3436 name: String, name of layer. 

3437 index: Integer, index of layer. 

3438 

3439 Returns: 

3440 A layer instance. 

3441 """ 

3442 # TODO(fchollet): We could build a dictionary based on layer names 

3443 # since they are constant, but we have not done that yet. 

3444 if index is not None and name is not None: 

3445 raise ValueError( 

3446 "Provide only a layer name or a layer index. Received: " 

3447 f"index={index}, name={name}." 

3448 ) 

3449 

3450 if index is not None: 

3451 if len(self.layers) <= index: 

3452 raise ValueError( 

3453 f"Was asked to retrieve layer at index {index}" 

3454 f" but model only has {len(self.layers)}" 

3455 " layers." 

3456 ) 

3457 else: 

3458 return self.layers[index] 

3459 

3460 if name is not None: 

3461 for layer in self.layers: 

3462 if layer.name == name: 

3463 return layer 

3464 raise ValueError( 

3465 f"No such layer: {name}. Existing layers are: " 

3466 f"{list(layer.name for layer in self.layers)}." 

3467 ) 

3468 raise ValueError( 

3469 "Provide either a layer name or layer index at `get_layer`." 

3470 ) 

3471 

3472 def get_weight_paths(self): 

3473 """Retrieve all the variables and their paths for the model. 

3474 

3475 The variable path (string) is a stable key to identify a `tf.Variable` 

3476 instance owned by the model. It can be used to specify variable-specific 

3477 configurations (e.g. DTensor, quantization) from a global view. 

3478 

3479 This method returns a dict with weight object paths as keys 

3480 and the corresponding `tf.Variable` instances as values. 

3481 

3482 Note that if the model is a subclassed model and the weights haven't 

3483 been initialized, an empty dict will be returned. 

3484 

3485 Returns: 

3486 A dict where keys are variable paths and values are `tf.Variable` 

3487 instances. 

3488 

3489 Example: 

3490 

3491 ```python 

3492 class SubclassModel(tf.keras.Model): 

3493 

3494 def __init__(self, name=None): 

3495 super().__init__(name=name) 

3496 self.d1 = tf.keras.layers.Dense(10) 

3497 self.d2 = tf.keras.layers.Dense(20) 

3498 

3499 def call(self, inputs): 

3500 x = self.d1(inputs) 

3501 return self.d2(x) 

3502 

3503 model = SubclassModel() 

3504 model(tf.zeros((10, 10))) 

3505 weight_paths = model.get_weight_paths() 

3506 # weight_paths: 

3507 # { 

3508 # 'd1.kernel': model.d1.kernel, 

3509 # 'd1.bias': model.d1.bias, 

3510 # 'd2.kernel': model.d2.kernel, 

3511 # 'd2.bias': model.d2.bias, 

3512 # } 

3513 

3514 # Functional model 

3515 inputs = tf.keras.Input((10,), batch_size=10) 

3516 x = tf.keras.layers.Dense(20, name='d1')(inputs) 

3517 output = tf.keras.layers.Dense(30, name='d2')(x) 

3518 model = tf.keras.Model(inputs, output) 

3519 d1 = model.layers[1] 

3520 d2 = model.layers[2] 

3521 weight_paths = model.get_weight_paths() 

3522 # weight_paths: 

3523 # { 

3524 # 'd1.kernel': d1.kernel, 

3525 # 'd1.bias': d1.bias, 

3526 # 'd2.kernel': d2.kernel, 

3527 # 'd2.bias': d2.bias, 

3528 # } 

3529 ``` 

3530 """ 

3531 result = {} 

3532 ( 

3533 descendants, 

3534 object_paths_dict, 

3535 ) = tf.__internal__.tracking.ObjectGraphView( 

3536 self 

3537 ).breadth_first_traversal() 

3538 for descendant in descendants: 

3539 if isinstance(descendant, tf.Variable): 

3540 trackable_references = object_paths_dict[descendant] 

3541 object_path = ".".join([t.name for t in trackable_references]) 

3542 result[object_path] = descendant 

3543 return result 

3544 

3545 def get_compile_config(self): 

3546 """Returns a serialized config with information for compiling the model. 

3547 

3548 This method returns a config dictionary containing all the information 

3549 (optimizer, loss, metrics, etc.) with which the model was compiled. 

3550 

3551 Returns: 

3552 A dict containing information for compiling the model. 

3553 """ 

3554 if self._is_compiled and hasattr(self, "_compile_config"): 

3555 return self._compile_config.serialize() 

3556 

3557 def compile_from_config(self, config): 

3558 """Compiles the model with the information given in config. 

3559 

3560 This method uses the information in the config (optimizer, loss, 

3561 metrics, etc.) to compile the model. 

3562 

3563 Args: 

3564 config: Dict containing information for compiling the model. 

3565 """ 

3566 has_overridden_compile = self.__class__.compile != Model.compile 

3567 if has_overridden_compile: 

3568 logging.warning( 

3569 "`compile()` was not called as part of model loading " 

3570 "because the model's `compile()` method is custom. " 

3571 "All subclassed Models that have `compile()` " 

3572 "overridden should also override " 

3573 "`get_compile_config()` and `compile_from_config(config)`. " 

3574 "Alternatively, you can " 

3575 "call `compile()` manually after loading." 

3576 ) 

3577 return 

3578 config = saving_lib.deserialize_keras_object(config) 

3579 self.compile(**config) 

3580 if hasattr(self, "optimizer") and self.built: 

3581 # Create optimizer variables. 

3582 self.optimizer.build(self.trainable_variables) 

3583 

3584 def export(self, filepath): 

3585 """Create a SavedModel artifact for inference (e.g. via TF-Serving). 

3586 

3587 This method lets you export a model to a lightweight SavedModel artifact 

3588 that contains the model's forward pass only (its `call()` method) 

3589 and can be served via e.g. TF-Serving. The forward pass is registered 

3590 under the name `serve()` (see example below). 

3591 

3592 The original code of the model (including any custom layers you may 

3593 have used) is *no longer* necessary to reload the artifact -- it is 

3594 entirely standalone. 

3595 

3596 Args: 

3597 filepath: `str` or `pathlib.Path` object. Path where to save 

3598 the artifact. 

3599 

3600 Example: 

3601 

3602 ```python 

3603 # Create the artifact 

3604 model.export("path/to/location") 

3605 

3606 # Later, in a different process / environment... 

3607 reloaded_artifact = tf.saved_model.load("path/to/location") 

3608 predictions = reloaded_artifact.serve(input_data) 

3609 ``` 

3610 

3611 If you would like to customize your serving endpoints, you can 

3612 use the lower-level `keras.export.ExportArchive` class. The `export()` 

3613 method relies on `ExportArchive` internally. 

3614 """ 

3615 from keras.src.export import export_lib 

3616 

3617 export_lib.export_model(self, filepath) 

3618 

3619 @tf.__internal__.tracking.no_automatic_dependency_tracking 

3620 def _set_save_spec(self, inputs, args=None, kwargs=None): 

3621 """Defines the save spec so that serialization can trace `call()`. 

3622 

3623 The TensorSpecs of the call function `inputs`, `args`, and `kwargs` are 

3624 saved into a tuple of `([inputs] + args, kwargs)`. The input 

3625 `TensorSpec` names are updated to match the built `input_names`. 

3626 

3627 The specs can be retrieved with the `save_spec` property. 

3628 

3629 Args: 

3630 inputs: possibly nested inputs passed into the call function. 

3631 args: a list of positional arguments passed into call. 

3632 kwargs: a dictionary of keyword arguments passed into call. 

3633 """ 

3634 if self._saved_model_inputs_spec is not None: 

3635 return # Already set. 

3636 args = args or [] 

3637 kwargs = kwargs or {} 

3638 

3639 input_names = self.input_names 

3640 if not input_names: 

3641 input_names = compile_utils.create_pseudo_input_names(inputs) 

3642 

3643 flat_inputs = tf.nest.flatten(inputs) 

3644 inputs_spec = [] 

3645 for name, tensor in zip(input_names, flat_inputs): 

3646 inputs_spec.append( 

3647 tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name) 

3648 ) 

3649 inputs_spec = tf.nest.pack_sequence_as(inputs, inputs_spec) 

3650 super()._set_save_spec(inputs_spec, args, kwargs) 

3651 

3652 # Store the input shapes 

3653 if ( 

3654 self.__class__.__name__ == "Sequential" 

3655 and self._build_input_shape is None 

3656 ): 

3657 self._build_input_shape = tf.nest.map_structure( 

3658 lambda x: None if x is None else x.shape, inputs_spec 

3659 ) 

3660 

3661 def save_spec(self, dynamic_batch=True): 

3662 """Returns the `tf.TensorSpec` of call args as a tuple `(args, kwargs)`. 

3663 

3664 This value is automatically defined after calling the model for the 

3665 first time. Afterwards, you can use it when exporting the model for 

3666 serving: 

3667 

3668 ```python 

3669 model = tf.keras.Model(...) 

3670 

3671 @tf.function 

3672 def serve(*args, **kwargs): 

3673 outputs = model(*args, **kwargs) 

3674 # Apply postprocessing steps, or add additional outputs. 

3675 ... 

3676 return outputs 

3677 

3678 # arg_specs is `[tf.TensorSpec(...), ...]`. kwarg_specs, in this 

3679 # example, is an empty dict since functional models do not use keyword 

3680 # arguments. 

3681 arg_specs, kwarg_specs = model.save_spec() 

3682 

3683 model.save(path, signatures={ 

3684 'serving_default': serve.get_concrete_function(*arg_specs, 

3685 **kwarg_specs) 

3686 }) 

3687 ``` 

3688 

3689 Args: 

3690 dynamic_batch: Whether to set the batch sizes of all the returned 

3691 `tf.TensorSpec` to `None`. (Note that when defining functional or 

3692 Sequential models with `tf.keras.Input([...], batch_size=X)`, the 

3693 batch size will always be preserved). Defaults to `True`. 

3694 Returns: 

3695 If the model inputs are defined, returns a tuple `(args, kwargs)`. All 

3696 elements in `args` and `kwargs` are `tf.TensorSpec`. 

3697 If the model inputs are not defined, returns `None`. 

3698 The model inputs are automatically set when calling the model, 

3699 `model.fit`, `model.evaluate` or `model.predict`. 

3700 """ 

3701 return self._get_save_spec(dynamic_batch, inputs_only=False) 

3702 

3703 def _assert_weights_created(self): 

3704 """Asserts that all the weights for the model have been created. 

3705 

3706 For a non-dynamic model, the weights must already be created after the 

3707 layer has been called. For a dynamic model, the exact list of weights 

3708 can never be known for certain since it may change at any time during 

3709 execution. 

3710 

3711 We run this check right before accessing weights or getting the Numpy 

3712 value for the current weights. Otherwise, if the layer has never been 

3713 called, the user would just get an empty list, which is misleading. 

3714 

3715 Raises: 

3716 ValueError: if the weights of the network have not yet been created. 

3717 """ 

3718 if self.dynamic: 

3719 return 

3720 

3721 if ( 

3722 "build" in self.__class__.__dict__ 

3723 and self.__class__ != Model 

3724 and not self.built 

3725 ): 

3726 # For any model that has customized build() method but hasn't been 

3727 # invoked yet, this will cover both sequential and subclass model. 

3728 # Also make sure to exclude Model class itself which has build() 

3729 # defined. 

3730 raise ValueError( 

3731 f"Weights for model '{self.name}' have not yet been " 

3732 "created. " 

3733 "Weights are created when the model is first called on " 

3734 "inputs or `build()` is called with an `input_shape`." 

3735 ) 

3736 

3737 def _check_call_args(self, method_name): 

3738 """Check that `call()` has only one positional arg.""" 

3739 # Always allow first arg, regardless of arg name. 

3740 fullargspec = self._call_spec.full_argspec 

3741 if fullargspec.defaults: 

3742 positional_args = fullargspec.args[: -len(fullargspec.defaults)] 

3743 else: 

3744 positional_args = fullargspec.args 

3745 if "training" in positional_args: 

3746 positional_args.remove("training") 

3747 

3748 # self and first arg can be positional. 

3749 if len(positional_args) > 2: 

3750 extra_args = positional_args[2:] 

3751 raise ValueError( 

3752 f"Models passed to `{method_name}` can only have `training` " 

3753 "and the first argument in `call()` as positional arguments, " 

3754 f"found: {extra_args}." 

3755 ) 

3756 

3757 def _validate_compile(self, optimizer, metrics, **kwargs): 

3758 """Performs validation checks for the default `compile()`.""" 

3759 if any( 

3760 isinstance(opt, optimizer_v1.Optimizer) 

3761 for opt in tf.nest.flatten(optimizer) 

3762 ): 

3763 raise ValueError( 

3764 f"`tf.compat.v1.keras` Optimizer ({optimizer}) is " 

3765 "not supported when eager execution is enabled. Use a " 

3766 "`tf.keras` Optimizer instead, or disable eager " 

3767 "execution." 

3768 ) 

3769 

3770 kwargs.pop("cloning", None) # Legacy DistStrat argument, never used. 

3771 kwargs.pop("experimental_run_tf_function", None) # Always `True`. 

3772 distribute_arg = kwargs.pop("distribute", None) 

3773 if distribute_arg is not None: 

3774 raise ValueError( 

3775 "`distribute` argument in compile is not available in TF 2.0. " 

3776 "Please create the model under the `strategy.scope()`. " 

3777 f"Received: {distribute_arg}." 

3778 ) 

3779 target_tensor_arg = kwargs.pop("target_tensors", None) 

3780 if target_tensor_arg is not None: 

3781 raise ValueError( 

3782 "`target_tensors` argument is not supported when executing " 

3783 f"eagerly. Received: {target_tensor_arg}." 

3784 ) 

3785 invalid_kwargs = set(kwargs) - {"sample_weight_mode"} 

3786 if invalid_kwargs: 

3787 raise TypeError( 

3788 "Invalid keyword argument(s) in `compile()`: " 

3789 f"{(invalid_kwargs,)}. Valid keyword arguments include " 

3790 '"cloning", "experimental_run_tf_function", "distribute",' 

3791 ' "target_tensors", or "sample_weight_mode".' 

3792 ) 

3793 

3794 # Model must be created and compiled with the same DistStrat. 

3795 if self.built and tf.distribute.has_strategy(): 

3796 strategy = tf.distribute.get_strategy() 

3797 for v in self.variables: 

3798 if not strategy.extended.variable_created_in_scope(v): 

3799 raise ValueError( 

3800 f"Variable ({v}) was not created in the distribution " 

3801 f"strategy scope of ({strategy}). It is most likely " 

3802 "because some layers, model, or optimizer was being " 

3803 "created outside the distribution strategy scope. Try " 

3804 "to make sure your code looks similar " 

3805 "to the following.\nwith strategy.scope():\n" 

3806 " model=_create_model()\n" 

3807 " model.compile(...)" 

3808 ) 

3809 

3810 # Model metrics must be created in the same distribution strategy scope 

3811 # as the model. 

3812 strategy = self.distribute_strategy 

3813 for metric in tf.nest.flatten(metrics): 

3814 for v in getattr(metric, "variables", []): 

3815 if not strategy.extended.variable_created_in_scope(v): 

3816 raise ValueError( 

3817 f"Metric ({metric}) passed to `model.compile` was " 

3818 "created inside a different distribution strategy " 

3819 "scope than the model. All metrics must be created " 

3820 "in the same distribution strategy " 

3821 f"scope as the model (in this case {strategy}). " 

3822 "If you pass in a string identifier for a metric to " 

3823 "compile, the metric will automatically be created " 

3824 "in the correct distribution strategy scope." 

3825 ) 

3826 

3827 # Model metrics must be created in the same distribution strategy scope 

3828 # as the model. 

3829 for opt in tf.nest.flatten(optimizer): 

3830 for v in getattr(opt, "_weights", []): 

3831 if not strategy.extended.variable_created_in_scope(v): 

3832 raise ValueError( 

3833 f"Optimizer ({optimizer}) passed to `model.compile` " 

3834 "was created inside a different distribution strategy " 

3835 "scope than the model. All optimizers must be created " 

3836 "in the same distribution strategy scope as the model " 

3837 f"(in this case {strategy}). If you pass in a string " 

3838 "identifier for an optimizer to compile, the optimizer " 

3839 "will automatically be created in the correct " 

3840 "distribution strategy scope." 

3841 ) 

3842 

3843 def _maybe_load_initial_counters_from_ckpt( 

3844 self, steps_per_epoch, initial_epoch 

3845 ): 

3846 """Maybe load initial epoch from ckpt, considering worker recovery. 

3847 

3848 Refer to tensorflow/python/keras/distribute/worker_training_state.py 

3849 for more information. 

3850 

3851 Args: 

3852 steps_per_epoch: The number of step per epoch. 

3853 initial_epoch: The original initial_epoch user passes in `fit()`. 

3854 mode: The mode for running `model.fit()`. 

3855 

3856 Returns: 

3857 If the training is recovering from previous failure under multi-worker 

3858 training setting, return the (epoch, step) the training is supposed to 

3859 continue at. Otherwise, return the `initial_epoch, initial_step` the 

3860 user passes in. 

3861 """ 

3862 initial_step = 0 

3863 if self._training_state is not None: 

3864 return self._training_state.maybe_load_initial_counters_from_ckpt( 

3865 steps_per_epoch, initial_epoch, mode=ModeKeys.TRAIN 

3866 ) 

3867 return (initial_epoch, initial_step) 

3868 

3869 def _assert_compile_was_called(self): 

3870 # Checks whether `compile` has been called. If it has been called, 

3871 # then the optimizer is set. This is different from whether the 

3872 # model is compiled 

3873 # (i.e. whether the model is built and its inputs/outputs are set). 

3874 if not self._is_compiled: 

3875 raise RuntimeError( 

3876 "You must compile your model before " 

3877 "training/testing. " 

3878 "Use `model.compile(optimizer, loss)`." 

3879 ) 

3880 

3881 def _check_sample_weight_warning(self, x, sample_weight): 

3882 # Datasets can include sample weight, by returning a tuple with the 

3883 # structure of `(x, y, sample_weight)`. 

3884 sample_weight_present = sample_weight is not None or ( 

3885 isinstance(x, tf.data.Dataset) 

3886 and isinstance(x.element_spec, tuple) 

3887 and len(x.element_spec) == 3 

3888 ) 

3889 

3890 if ( 

3891 sample_weight_present 

3892 and self.compiled_metrics._user_weighted_metrics is None 

3893 ): 

3894 logging.warning( 

3895 "`evaluate()` received a value for `sample_weight`, but " 

3896 "`weighted_metrics` were not provided. Did you mean to pass " 

3897 "metrics to `weighted_metrics` in `compile()`? If this is " 

3898 "intentional you can pass `weighted_metrics=[]` to `compile()` " 

3899 "in order to silence this warning." 

3900 ) 

3901 

3902 def _set_inputs(self, inputs, outputs=None, training=None): 

3903 """This method is for compat with Modelv1. Only inputs are needed 

3904 here.""" 

3905 self._set_save_spec(inputs) 

3906 

3907 @property 

3908 def _trackable_saved_model_saver(self): 

3909 return model_serialization.ModelSavedModelSaver(self) 

3910 

3911 def _trackable_children(self, save_type="checkpoint", **kwargs): 

3912 if save_type == "savedmodel": 

3913 # SavedModel needs to ignore the execution functions. 

3914 train_function = self.train_function 

3915 test_function = self.test_function 

3916 predict_function = self.predict_function 

3917 train_tf_function = self.train_tf_function 

3918 self.train_function = None 

3919 self.test_function = None 

3920 self.predict_function = None 

3921 self.train_tf_function = None 

3922 

3923 children = super()._trackable_children(save_type, **kwargs) 

3924 

3925 if save_type == "savedmodel": 

3926 self.train_function = train_function 

3927 self.test_function = test_function 

3928 self.predict_function = predict_function 

3929 self.train_tf_function = train_tf_function 

3930 

3931 return children 

3932 

3933 def _should_eval(self, epoch, validation_freq): 

3934 epoch = epoch + 1 # one-index the user-facing epoch. 

3935 if isinstance(validation_freq, int): 

3936 return epoch % validation_freq == 0 

3937 elif isinstance(validation_freq, list): 

3938 return epoch in validation_freq 

3939 else: 

3940 raise ValueError( 

3941 "Expected `validation_freq` to be a list or int. " 

3942 f"Received: validation_freq={validation_freq} of the " 

3943 f"type {type(validation_freq)}." 

3944 ) 

3945 

3946 ###################################################################### 

3947 # Functions below exist only as v1 / v2 compatibility shims. 

3948 ###################################################################### 

3949 

3950 def _get_compile_args(self, user_metrics=True): 

3951 """Used for saving or cloning a Model. 

3952 

3953 Args: 

3954 user_metrics: Whether to return user-supplied metrics or `Metric` 

3955 objects. If True, returns the user-supplied metrics. 

3956 Defaults to `True`. 

3957 

3958 Returns: 

3959 Dictionary of arguments that were used when compiling the model. 

3960 """ 

3961 self._assert_compile_was_called() 

3962 saved_metrics = self.compiled_metrics._user_metrics 

3963 saved_weighted_metrics = self.compiled_metrics._user_weighted_metrics 

3964 

3965 if not user_metrics: 

3966 if saved_metrics is not None: 

3967 saved_metrics = self.compiled_metrics._metrics 

3968 if saved_weighted_metrics is not None: 

3969 saved_weighted_metrics = self.compiled_metrics._weighted_metrics 

3970 

3971 compile_args = { 

3972 "optimizer": self.optimizer, 

3973 "loss": self.compiled_loss._user_losses, 

3974 "metrics": saved_metrics, 

3975 "weighted_metrics": saved_weighted_metrics, 

3976 "loss_weights": self.compiled_loss._user_loss_weights, 

3977 } 

3978 return compile_args 

3979 

3980 def _get_callback_model(self): 

3981 return self 

3982 

3983 def _in_multi_worker_mode(self): 

3984 return self.distribute_strategy.extended._in_multi_worker_mode() 

3985 

3986 @property 

3987 def _compile_was_called(self): 

3988 return self._is_compiled 

3989 

3990 def _save_experimental(self, filepath): 

3991 return saving_lib.save_model(self, filepath) 

3992 

3993 

3994class _TestFunction: 

3995 def __init__(self, function, callbacks): 

3996 self._function = function 

3997 self._callbacks = callbacks 

3998 

3999 def run_step(self, dataset_or_iterator, data_handler, step, unused_shards): 

4000 tmp_logs = self._function(dataset_or_iterator) 

4001 if data_handler.should_sync: 

4002 context.async_wait() 

4003 logs = tmp_logs 

4004 end_step = step + data_handler.step_increment 

4005 self._callbacks.on_test_batch_end(end_step, logs) 

4006 return logs 

4007 

4008 

4009class _ExactTestFunction(_TestFunction): 

4010 def __init__(self, function, callbacks): 

4011 super().__init__(function, callbacks) 

4012 self._logs = [] 

4013 

4014 def run_step(self, dataset_or_iterator, data_handler, step, shards): 

4015 tmp_logs = self._function( 

4016 dataset_or_iterator, 

4017 tf.constant(shards, dtype=tf.int64), 

4018 tf.constant(step, dtype=tf.int64), 

4019 ) 

4020 if data_handler.should_sync: 

4021 context.async_wait() 

4022 self._logs.append(tmp_logs) 

4023 return self._logs 

4024 

4025 

4026def reduce_per_replica(values, strategy, reduction): 

4027 """Attempt to reduce the structure `values` to single values. 

4028 

4029 Given `values` (a `tf.Tensor` or a `PerReplica` structure), 

4030 which represents the values across all the replicas, `reduce_per_replica` 

4031 attempts to "reduce" those values and returns the corresponding structure 

4032 that represents only single values. 

4033 

4034 Currently, `reduce_per_replica` is only used for reducing the metric results 

4035 from `tf.distribute.Strategy.run()`. Depending on the underlying 

4036 `Strategy` implementation, `values` may be a `PerReplica` object, 

4037 which can be thought of as a collection of values across the replicas, 

4038 or a `tf.Tensor`, if the strategy has already conducted the reduction 

4039 for the downstream library. 

4040 

4041 There are five possible outcomes of reduction: 

4042 

4043 1) if the `values` is a structure of simple `tf.Tensor`s, meaning that 

4044 reduction is not actually needed, `reduce_per_replica` returns the 

4045 structure as-is. 

4046 2) else, if `reduction="auto"`, then the best reduction strategy is 

4047 chosen based on the current environment. This should only be used 

4048 for training cases (`fit()`). 

4049 3) else, if `reduction="first"`, then `reduce_per_replica` 

4050 returns the values of the first replica. This is used in the case of 

4051 training and evaluation, where `values` is expected to hold the same 

4052 value across the replicas as a result of `Strategy`'s synchronization 

4053 across the replicas. 

4054 `reduce_per_replica` does not synchronize the values. 

4055 4) else, if `reduction="sum"`, then `reduce_per_replica` returns the sum 

4056 of values for all replicas. This may be used in the custom training loop 

4057 case, where each replica contain different values which are not 

4058 synchronized. 

4059 5) else, if `reduction="concat"`, then `reduce_per_replica` 

4060 returns the concatenation of the values across the replicas, along the 

4061 axis of dimension 0. This is used in the inference case (`predict()`). 

4062 

4063 Args: 

4064 values: Structure of `PerReplica` objects or `tf.Tensor`s. `tf.Tensor`s 

4065 are returned as-is. 

4066 strategy: `tf.distribute.Strategy` object. 

4067 reduction: One of `"auto"`, `"first"`, `"concat"`, or `"sum"`. 

4068 `"auto"` will select `"first"` when used under a TPUStrategy, or 

4069 `"sum"` otherwise. 

4070 

4071 Returns: 

4072 Structure of `Tensor`s, representing the result of reduction. 

4073 

4074 Raises: 

4075 ValueError: if the reduction method is not supported. 

4076 """ 

4077 

4078 if reduction == "auto": 

4079 reduction = "first" if backend.is_tpu_strategy(strategy) else "sum" 

4080 

4081 def _reduce(v): 

4082 """Reduce a single `PerReplica` object.""" 

4083 if _collective_all_reduce_multi_worker(strategy): 

4084 if reduction == "concat": 

4085 return _multi_worker_concat(v, strategy) 

4086 elif reduction == "sum": 

4087 return strategy.reduce("SUM", v, axis=None) 

4088 

4089 if _is_dtensor_per_replica_instance(v): 

4090 return _reduce_dtensor_per_replica(v, strategy, reduction) 

4091 elif not _is_per_replica_instance(v): 

4092 return v 

4093 elif reduction == "first": 

4094 return strategy.experimental_local_results(v)[0] 

4095 elif reduction == "concat": 

4096 if _is_tpu_multi_host(strategy): 

4097 return _tpu_multi_host_concat(v, strategy) 

4098 else: 

4099 return concat(strategy.experimental_local_results(v)) 

4100 elif reduction == "sum": 

4101 return tf.reduce_sum(strategy.experimental_local_results(v)) 

4102 else: 

4103 raise ValueError( 

4104 '`reduction` must be "first", "concat", "sum", or "auto". ' 

4105 f"Received: reduction={reduction}." 

4106 ) 

4107 

4108 return tf.nest.map_structure(_reduce, values) 

4109 

4110 

4111def concat(tensors, axis=0): 

4112 """Concats `tensor`s along `axis`.""" 

4113 if isinstance(tensors[0], tf.SparseTensor): 

4114 return tf.sparse.concat(axis=axis, sp_inputs=tensors) 

4115 elif _is_scalar(tensors[0]): 

4116 return tf.stack(tensors, axis=axis) 

4117 else: 

4118 return tf.concat(tensors, axis=axis) 

4119 

4120 

4121def potentially_ragged_concat(tensors): 

4122 """Concats `Tensor`s along their first dimension. 

4123 

4124 Args: 

4125 tensors: List of `Tensor`s. 

4126 

4127 Returns: 

4128 Concatenation of the inputs along the first dimension -- of type `Tensor` 

4129 if all input shapes are compatible, or `RaggedTensor` if not. 

4130 """ 

4131 if len(tensors) == 1: 

4132 return tensors[0] 

4133 if isinstance(tensors[0], tf.SparseTensor): 

4134 return tf.sparse.concat(axis=0, sp_inputs=tensors) 

4135 elif isinstance(tensors[0], tf.RaggedTensor): 

4136 return tf.concat(tensors, axis=0) 

4137 elif not tf.__internal__.tf2.enabled(): 

4138 return tf.concat(tensors, axis=0) 

4139 

4140 non_batch_shapes = tf.stack([tf.shape(tensor)[1:] for tensor in tensors]) 

4141 constant_dims = tf.math.reduce_all( 

4142 non_batch_shapes == non_batch_shapes[:1], axis=0 

4143 ) 

4144 if tf.math.reduce_all(constant_dims).numpy().item(): 

4145 # All non-batch dims are constant 

4146 if _is_scalar(tensors[0]): 

4147 return tf.stack(tensors, axis=0) 

4148 else: 

4149 return tf.concat(tensors, axis=0) 

4150 

4151 # First, identify constant inner dimensions by finding the 

4152 # rightmost dimension that is not constant 

4153 constant_inner_dimensions = ( 

4154 constant_dims.numpy().tolist()[::-1].index(False) 

4155 ) 

4156 # If there are constant inner dimensions, define a constant inner shape 

4157 if constant_inner_dimensions == 0: 

4158 constant_inner_shape = None 

4159 else: 

4160 constant_inner_shape = tensors[0].shape[-constant_inner_dimensions:] 

4161 return tf.ragged.constant( 

4162 [tensor.numpy() for tensor in tensors], inner_shape=constant_inner_shape 

4163 ).merge_dims(0, 1) 

4164 

4165 

4166def _reduce_dtensor_per_replica(value, strategy, reduction): 

4167 # Note that this function could happen in graph, so we can't just access 

4168 # the per-replica.values(), which will trigger unpack in graph and result 

4169 # into error. 

4170 # For now we will perform ops on dtensor instance directly on a global 

4171 # context. 

4172 dtensor = value._dtensor 

4173 if reduction == "first": 

4174 num_replica = strategy.num_replicas_in_sync 

4175 return tf.split(dtensor, num_replica, axis=0)[0] 

4176 elif reduction == "concat": 

4177 # Since dtensor is already in global context, the concat is a no-op 

4178 return dtensor 

4179 elif reduction == "sum": 

4180 return tf.reduce_sum(dtensor) 

4181 else: 

4182 raise ValueError( 

4183 '`reduction` must be one of "first", "concat", "sum", or "auto". ' 

4184 f"Received: reduction={reduction}." 

4185 ) 

4186 

4187 

4188def _get_verbosity(verbose, distribute_strategy): 

4189 """Find the right verbosity value for 'auto'.""" 

4190 if verbose == 1 and distribute_strategy._should_use_with_coordinator: 

4191 raise ValueError( 

4192 "`verbose=1` is not allowed with `ParameterServerStrategy` for " 

4193 f"performance reasons. Received: verbose={verbose}" 

4194 ) 

4195 if verbose == "auto": 

4196 if ( 

4197 distribute_strategy._should_use_with_coordinator 

4198 or not io_utils.is_interactive_logging_enabled() 

4199 ): 

4200 # Defaults to epoch-level logging for PSStrategy or using absl 

4201 # logging. 

4202 return 2 

4203 else: 

4204 return 1 # Defaults to batch-level logging otherwise. 

4205 return verbose 

4206 

4207 

4208def _is_tpu_multi_host(strategy): 

4209 return backend.is_tpu_strategy(strategy) and strategy.extended.num_hosts > 1 

4210 

4211 

4212def _tpu_multi_host_concat(v, strategy): 

4213 """Correctly order TPU PerReplica objects.""" 

4214 replicas = strategy.experimental_local_results(v) 

4215 # When distributed datasets are created from Tensors / NumPy, 

4216 # TPUStrategy.experimental_distribute_dataset shards data in 

4217 # (Replica, Host) order, and TPUStrategy.experimental_local_results returns 

4218 # it in (Host, Replica) order. 

4219 # TODO(b/150317897): Figure out long-term plan here. 

4220 num_replicas_per_host = strategy.extended.num_replicas_per_host 

4221 ordered_replicas = [] 

4222 for replica_id in range(num_replicas_per_host): 

4223 ordered_replicas += replicas[replica_id::num_replicas_per_host] 

4224 return concat(ordered_replicas) 

4225 

4226 

4227def _collective_all_reduce_multi_worker(strategy): 

4228 return ( 

4229 isinstance(strategy, tf.distribute.MultiWorkerMirroredStrategy) 

4230 ) and strategy.extended._in_multi_worker_mode() 

4231 

4232 

4233# TODO(wxinyi): merge this with _tpu_multi_host_concat once we have all_gather 

4234# for all strategies 

4235def _multi_worker_concat(v, strategy): 

4236 """Order PerReplica objects for CollectiveAllReduceStrategy and concat.""" 

4237 replicas = strategy.gather(v, axis=0) 

4238 # v might not have the same shape on different replicas 

4239 if _is_per_replica_instance(v): 

4240 shapes = tf.concat( 

4241 [ 

4242 tf.expand_dims(tf.shape(single_value)[0], axis=0) 

4243 for single_value in v.values 

4244 ], 

4245 axis=0, 

4246 ) 

4247 all_shapes = strategy.gather(shapes, axis=0) 

4248 else: 

4249 # v is a tensor. This may happen when, say, we have 2x1 multi-worker. 

4250 all_shapes = strategy.gather( 

4251 tf.expand_dims(tf.shape(v)[0], axis=0), axis=0 

4252 ) 

4253 

4254 replicas = tf.split( 

4255 replicas, 

4256 num_or_size_splits=all_shapes, 

4257 num=strategy.num_replicas_in_sync, 

4258 ) 

4259 ordered_replicas = [] 

4260 num_replicas_per_worker = len(strategy.extended.worker_devices) 

4261 for replica_id in range(num_replicas_per_worker): 

4262 ordered_replicas += replicas[replica_id::num_replicas_per_worker] 

4263 return concat(ordered_replicas) 

4264 

4265 

4266def _is_scalar(x): 

4267 return isinstance(x, (tf.Tensor, tf.Variable)) and x.shape.rank == 0 

4268 

4269 

4270def _minimum_control_deps(outputs): 

4271 """Returns the minimum control dependencies to ensure step succeeded.""" 

4272 if tf.executing_eagerly(): 

4273 return [] # Control dependencies not needed. 

4274 outputs = tf.nest.flatten(outputs, expand_composites=True) 

4275 for out in outputs: 

4276 # Variables can't be control dependencies. 

4277 if not isinstance(out, tf.Variable): 

4278 return [out] # Return first Tensor or Op from outputs. 

4279 return [] # No viable Tensor or Op to use for control deps. 

4280 

4281 

4282def _disallow_inside_tf_function(method_name): 

4283 if tf.inside_function(): 

4284 error_msg = ( 

4285 "Detected a call to `Model.{method_name}` inside a `tf.function`. " 

4286 "`Model.{method_name} is a high-level endpoint that manages its " 

4287 "own `tf.function`. Please move the call to `Model.{method_name}` " 

4288 "outside of all enclosing `tf.function`s. Note that you can call a " 

4289 "`Model` directly on `Tensor`s inside a `tf.function` like: " 

4290 "`model(x)`." 

4291 ).format(method_name=method_name) 

4292 raise RuntimeError(error_msg) 

4293 

4294 

4295def flatten_metrics_in_order(logs, metrics_names): 

4296 """Turns the `logs` dict into a list as per key order of `metrics_names`.""" 

4297 results = [] 

4298 for name in metrics_names: 

4299 if name in logs: 

4300 results.append(logs[name]) 

4301 for key in sorted(logs.keys()): 

4302 if key not in metrics_names: 

4303 results.append(logs[key]) 

4304 if len(results) == 1: 

4305 return results[0] 

4306 return results 

4307 

4308 

4309def _is_per_replica_instance(obj): 

4310 return isinstance(obj, tf.distribute.DistributedValues) and isinstance( 

4311 obj, tf.__internal__.CompositeTensor 

4312 ) 

4313 

4314 

4315def _is_dtensor_per_replica_instance(obj): 

4316 # This is a temp check for DTensorDistributedValue, which is not public API 

4317 # yet. 

4318 # TODO(scottzhu): Move to more stable API when dtensor based strategy is 

4319 # ready. 

4320 return isinstance(obj, tf.distribute.DistributedValues) and hasattr( 

4321 obj, "_dtensor" 

4322 ) 

4323 

4324 

4325def disable_multi_worker(method): 

4326 """Decorator that disallows multi-worker use of `method`.""" 

4327 

4328 def _method_wrapper(self, *args, **kwargs): 

4329 if self._in_multi_worker_mode(): 

4330 raise ValueError( 

4331 f"{method.__name__} is not supported in multi-worker " 

4332 "mode. Please use a non-multi-worker " 

4333 "`tf.distribute.Strategy` such as " 

4334 "`tf.distribute.MirroredStrategy`." 

4335 ) 

4336 return method(self, *args, **kwargs) 

4337 

4338 return tf.__internal__.decorator.make_decorator( 

4339 target=method, decorator_func=_method_wrapper 

4340 ) 

4341 

4342 

4343def inject_functional_model_class(cls): 

4344 """Inject `Functional` into the hierarchy of this class if needed.""" 

4345 from keras.src.engine import functional 

4346 from keras.src.engine import training_v1 

4347 

4348 if cls == Model or cls == training_v1.Model: 

4349 return functional.Functional 

4350 # In case there is any multiple inheritance, we stop injecting the 

4351 # class if keras model is not in its class hierarchy. 

4352 if cls == object: 

4353 return object 

4354 

4355 cls.__bases__ = tuple( 

4356 inject_functional_model_class(base) for base in cls.__bases__ 

4357 ) 

4358 # Trigger any `__new__` class swapping that needed to happen on `Functional` 

4359 # but did not because functional was not in the class hierarchy. 

4360 cls.__new__(cls) 

4361 

4362 return cls 

4363 

4364 

4365def is_functional_model_init_params(args, kwargs): 

4366 # Both inputs and outputs in args 

4367 if len(args) == 2: 

4368 return True 

4369 # Both inputs in args, outputs in kwargs 

4370 if len(args) == 1 and "outputs" in kwargs: 

4371 return True 

4372 # Both in kwargs 

4373 if "inputs" in kwargs and "outputs" in kwargs: 

4374 return True 

4375 return False 

4376