Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training.py: 18%
1163 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related part of the Keras engine."""
17import copy
18import itertools
19import json
20import warnings
21import weakref
23import numpy as np
24import tensorflow.compat.v2 as tf
26from keras.src import backend
27from keras.src import callbacks as callbacks_module
28from keras.src import optimizers
29from keras.src.dtensor import layout_map as layout_map_lib
30from keras.src.engine import base_layer
31from keras.src.engine import base_layer_utils
32from keras.src.engine import compile_utils
33from keras.src.engine import data_adapter
34from keras.src.engine import input_layer as input_layer_module
35from keras.src.engine import training_utils
36from keras.src.metrics import base_metric
37from keras.src.mixed_precision import loss_scale_optimizer as lso
38from keras.src.optimizers import optimizer
39from keras.src.optimizers import optimizer_v1
40from keras.src.saving import pickle_utils
41from keras.src.saving import saving_api
42from keras.src.saving import saving_lib
43from keras.src.saving import serialization_lib
44from keras.src.saving.legacy import serialization
45from keras.src.saving.legacy.saved_model import json_utils
46from keras.src.saving.legacy.saved_model import model_serialization
47from keras.src.utils import generic_utils
48from keras.src.utils import io_utils
49from keras.src.utils import layer_utils
50from keras.src.utils import tf_inspect
51from keras.src.utils import tf_utils
52from keras.src.utils import traceback_utils
53from keras.src.utils import version_utils
54from keras.src.utils.mode_keys import ModeKeys
56# isort: off
57from tensorflow.python.eager import context
58from tensorflow.python.platform import tf_logging as logging
59from tensorflow.python.util.tf_export import keras_export
60from tensorflow.python.distribute import distribute_utils
61from tensorflow.python.distribute import input_ops
62from tensorflow.tools.docs import doc_controls
64try:
65 import h5py
66except ImportError:
67 h5py = None
70@keras_export("keras.Model", "keras.models.Model")
71class Model(base_layer.Layer, version_utils.ModelVersionSelector):
72 """A model grouping layers into an object with training/inference features.
74 Args:
75 inputs: The input(s) of the model: a `keras.Input` object or a
76 combination of `keras.Input` objects in a dict, list or tuple.
77 outputs: The output(s) of the model: a tensor that originated from
78 `keras.Input` objects or a combination of such tensors in a dict,
79 list or tuple. See Functional API example below.
80 name: String, the name of the model.
82 There are two ways to instantiate a `Model`:
84 1 - With the "Functional API", where you start from `Input`,
85 you chain layer calls to specify the model's forward pass,
86 and finally you create your model from inputs and outputs:
88 ```python
89 import tensorflow as tf
91 inputs = tf.keras.Input(shape=(3,))
92 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
93 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
94 model = tf.keras.Model(inputs=inputs, outputs=outputs)
95 ```
97 Note: Only dicts, lists, and tuples of input tensors are supported. Nested
98 inputs are not supported (e.g. lists of list or dicts of dict).
100 A new Functional API model can also be created by using the
101 intermediate tensors. This enables you to quickly extract sub-components
102 of the model.
104 Example:
106 ```python
107 inputs = keras.Input(shape=(None, None, 3))
108 processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
109 conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
110 pooling = keras.layers.GlobalAveragePooling2D()(conv)
111 feature = keras.layers.Dense(10)(pooling)
113 full_model = keras.Model(inputs, feature)
114 backbone = keras.Model(processed, conv)
115 activations = keras.Model(conv, feature)
116 ```
118 Note that the `backbone` and `activations` models are not
119 created with `keras.Input` objects, but with the tensors that are originated
120 from `keras.Input` objects. Under the hood, the layers and weights will
121 be shared across these models, so that user can train the `full_model`, and
122 use `backbone` or `activations` to do feature extraction.
123 The inputs and outputs of the model can be nested structures of tensors as
124 well, and the created models are standard Functional API models that support
125 all the existing APIs.
127 2 - By subclassing the `Model` class: in that case, you should define your
128 layers in `__init__()` and you should implement the model's forward pass
129 in `call()`.
131 ```python
132 import tensorflow as tf
134 class MyModel(tf.keras.Model):
136 def __init__(self):
137 super().__init__()
138 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
139 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
141 def call(self, inputs):
142 x = self.dense1(inputs)
143 return self.dense2(x)
145 model = MyModel()
146 ```
148 If you subclass `Model`, you can optionally have
149 a `training` argument (boolean) in `call()`, which you can use to specify
150 a different behavior in training and inference:
152 ```python
153 import tensorflow as tf
155 class MyModel(tf.keras.Model):
157 def __init__(self):
158 super().__init__()
159 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
160 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
161 self.dropout = tf.keras.layers.Dropout(0.5)
163 def call(self, inputs, training=False):
164 x = self.dense1(inputs)
165 if training:
166 x = self.dropout(x, training=training)
167 return self.dense2(x)
169 model = MyModel()
170 ```
172 Once the model is created, you can config the model with losses and metrics
173 with `model.compile()`, train the model with `model.fit()`, or use the model
174 to do prediction with `model.predict()`.
175 """
177 _TF_MODULE_IGNORED_PROPERTIES = frozenset(
178 itertools.chain(
179 (
180 "_train_counter",
181 "_test_counter",
182 "_predict_counter",
183 "_steps_per_execution",
184 ),
185 base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES,
186 )
187 )
188 _SCALAR_UPRANKING_ON = False
190 def __new__(cls, *args, **kwargs):
191 # Signature detection
192 if is_functional_model_init_params(args, kwargs) and cls == Model:
193 # Functional model
194 from keras.src.engine import functional
196 return functional.Functional(skip_init=True, *args, **kwargs)
197 else:
198 return super(Model, cls).__new__(cls, *args, **kwargs)
200 @tf.__internal__.tracking.no_automatic_dependency_tracking
201 @traceback_utils.filter_traceback
202 def __init__(self, *args, **kwargs):
203 self._is_model_for_instrumentation = True
204 base_layer.keras_api_gauge.get_cell("model").set(True)
206 # Special case for Subclassed Functional Model, which we couldn't detect
207 # when __new__ is called. We only realize it is a functional model when
208 # it calls super.__init__ with input and output tensor.
209 from keras.src.engine import functional
211 if is_functional_model_init_params(args, kwargs) and not isinstance(
212 self, functional.Functional
213 ):
214 # Filter the kwargs for multiple inheritance.
215 supported_kwargs = [
216 "inputs",
217 "outputs",
218 "name",
219 "trainable",
220 "skip_init",
221 ]
222 model_kwargs = {
223 k: kwargs[k] for k in kwargs if k in supported_kwargs
224 }
225 other_kwargs = {
226 k: kwargs[k] for k in kwargs if k not in supported_kwargs
227 }
228 inject_functional_model_class(self.__class__)
229 functional.Functional.__init__(self, *args, **model_kwargs)
231 # In case there is any multiple inheritance here, we need to call
232 # the __init__ for any class that appears after the Functional
233 # class.
234 clz_to_init = []
235 found_functional_class = False
236 for clz in self.__class__.__bases__:
237 if issubclass(clz, functional.Functional):
238 found_functional_class = True
239 continue
240 if found_functional_class:
241 clz_to_init.append(clz)
243 if clz_to_init:
244 for clz in clz_to_init:
245 clz.__init__(self, *args, **other_kwargs)
246 elif other_kwargs:
247 # In case there are unused kwargs, we should raise an error to
248 # user, in case they have a typo in the param name.
249 raise TypeError(
250 "The following keyword arguments passed to `Model` aren't "
251 "supported: {}.".format(other_kwargs)
252 )
253 return
255 base_layer.keras_api_gauge.get_cell("Model subclass").set(True)
256 # The following are implemented as property functions:
257 # self.trainable_weights
258 # self.non_trainable_weights
259 # `inputs` / `outputs` will only appear in kwargs if either are
260 # misspelled.
261 generic_utils.validate_kwargs(
262 kwargs,
263 {
264 "trainable",
265 "dtype",
266 "dynamic",
267 "name",
268 "autocast",
269 "inputs",
270 "outputs",
271 },
272 )
273 super().__init__(**kwargs)
274 # By default, Model is a subclass model, which is not in graph network.
275 self._is_graph_network = False
277 self.inputs = None
278 self.outputs = None
279 self.input_names = None
280 self.output_names = None
281 # stop_training is used by callback to stop training when error happens
282 self.stop_training = False
283 self.history = None
284 # These objects are used in the default `Model.compile`. They are not
285 # guaranteed to be set after `Model.compile` is called, as users can
286 # override compile with custom logic.
287 self.compiled_loss = None
288 self.compiled_metrics = None
290 # This is True for Sequential networks and Functional networks.
291 self._compute_output_and_mask_jointly = False
293 # Don't reset compilation if already done. This may occur if calling
294 # `__init__` (or `_init_graph_network`) on an already-compiled model
295 # such as a Sequential model. Sequential models may need to rebuild
296 # themselves after compilation.
297 self._maybe_create_attribute("_is_compiled", False)
298 self._maybe_create_attribute("optimizer", None)
300 # Model must be created under scope of DistStrat it will be trained
301 # with.
302 if tf.distribute.has_strategy():
303 self._distribution_strategy = tf.distribute.get_strategy()
304 else:
305 self._distribution_strategy = None
306 self._distribute_reduction_method = None
308 self._cluster_coordinator = None
310 # Defaults to value of `tf.config.experimental_functions_run_eagerly`.
311 self._run_eagerly = None
312 # Initialize cache attrs.
313 self._reset_compile_cache()
315 # Fault-tolerance handler. Set in `ModelCheckpoint`.
316 self._training_state = None
317 self._saved_model_inputs_spec = None
318 self._saved_model_arg_spec = None
319 self._checkpoint = tf.train.Checkpoint(root=weakref.ref(self))
321 self._steps_per_execution = None
323 self._init_batch_counters()
324 self._base_model_initialized = True
326 # `jit_compile` starts off with None as default and gets overwritten by
327 # the value specified in `Model.compile`, and this is effective for
328 # `fit`, `evaluate`, and `predict`.
329 self._jit_compile = None
331 self._layout_map = layout_map_lib.get_current_layout_map()
333 @tf.__internal__.tracking.no_automatic_dependency_tracking
334 def _init_batch_counters(self):
335 # Untracked Variables, used to keep track of mini-batches seen in `fit`,
336 # `evaluate`, and `predict`.
337 if not tf.inside_function():
338 # Creating variables inside tf.function is not allowed, hence
339 # these would otherwise prevent users from creating Keras layers
340 # inside tf.function.
341 # These variables are not connected to outputs so they have no
342 # effect on graph generation anyway.
343 agg = tf.VariableAggregation.ONLY_FIRST_REPLICA
344 self._train_counter = tf.Variable(0, dtype="int64", aggregation=agg)
345 self._test_counter = tf.Variable(0, dtype="int64", aggregation=agg)
346 self._predict_counter = tf.Variable(
347 0, dtype="int64", aggregation=agg
348 )
350 def __setattr__(self, name, value):
351 if not getattr(self, "_self_setattr_tracking", True):
352 super().__setattr__(name, value)
353 return
355 if all(
356 isinstance(v, (base_layer.Layer, tf.Variable))
357 or base_layer_utils.has_weights(v)
358 for v in tf.nest.flatten(value)
359 ):
360 try:
361 self._base_model_initialized
362 except AttributeError:
363 raise RuntimeError(
364 "It looks like you are subclassing `Model` and you "
365 "forgot to call `super().__init__()`."
366 " Always start with this line."
367 )
369 super().__setattr__(name, value)
371 def __reduce__(self):
372 if self.built:
373 return (
374 pickle_utils.deserialize_model_from_bytecode,
375 (pickle_utils.serialize_model_as_bytecode(self),),
376 )
377 else:
378 # SavedModel (and hence serialize_model_as_bytecode) only support
379 # built models, but if the model is not built,
380 # it may be possible to serialize as a plain Python object,
381 # as long as the constituent parts (layers, optimizers, losses,
382 # etc.) can be serialized as plain Python objects. Thus we call up
383 # the superclass hierarchy to get an implementation of __reduce__
384 # that can pickle this Model as a plain Python object.
385 return super().__reduce__()
387 def __deepcopy__(self, memo):
388 if self.built:
389 new = pickle_utils.deserialize_model_from_bytecode(
390 pickle_utils.serialize_model_as_bytecode(self)
391 )
392 memo[id(self)] = new
393 else:
394 # See comment in __reduce__ for explanation
395 deserializer, serialized, *rest = super().__reduce__()
396 new = deserializer(*serialized)
397 memo[id(self)] = new
398 if rest:
399 state = copy.deepcopy(rest[0], memo=memo)
400 new.__setstate__(state)
401 return new
403 def __copy__(self):
404 return self.__deepcopy__({})
406 @generic_utils.default
407 def build(self, input_shape):
408 """Builds the model based on input shapes received.
410 This is to be used for subclassed models, which do not know at
411 instantiation time what their inputs look like.
413 This method only exists for users who want to call `model.build()` in a
414 standalone way (as a substitute for calling the model on real data to
415 build it). It will never be called by the framework (and thus it will
416 never throw unexpected errors in an unrelated workflow).
418 Args:
419 input_shape: Single tuple, `TensorShape` instance, or list/dict of
420 shapes, where shapes are tuples, integers, or `TensorShape`
421 instances.
423 Raises:
424 ValueError:
425 1. In case of invalid user-provided data (not of type tuple,
426 list, `TensorShape`, or dict).
427 2. If the model requires call arguments that are agnostic
428 to the input shapes (positional or keyword arg in call
429 signature).
430 3. If not all layers were properly built.
431 4. If float type inputs are not supported within the layers.
433 In each of these cases, the user should build their model by calling
434 it on real tensor data.
435 """
436 if self._is_graph_network:
437 super().build(input_shape)
438 return
440 if input_shape is None:
441 raise ValueError(
442 "Input shape must be defined when calling `build()` on "
443 "a `Model` subclass."
444 )
445 valid_types = (tuple, list, tf.TensorShape, dict)
446 if not isinstance(input_shape, valid_types):
447 raise ValueError(
448 "Specified input shape is not one of the valid types. "
449 "Please specify a batch input shape of type tuple or "
450 "list of input shapes. User provided "
451 "input type: {}.".format(type(input_shape))
452 )
454 if input_shape and not self.inputs:
455 # We create placeholders for the `None`s in the shape and build the
456 # model in a Graph. Since tf.Variable is compatible with both eager
457 # execution and graph building, the variables created after building
458 # the model in a Graph are still valid when executing eagerly.
459 if tf.executing_eagerly():
460 graph = tf.__internal__.FuncGraph("build_graph")
461 else:
462 graph = backend.get_graph()
463 with graph.as_default():
464 if isinstance(input_shape, list) and all(
465 d is None or isinstance(d, int) for d in input_shape
466 ):
467 input_shape = tuple(input_shape)
468 if isinstance(input_shape, list):
469 x = [
470 base_layer_utils.generate_placeholders_from_shape(shape)
471 for shape in input_shape
472 ]
473 elif isinstance(input_shape, dict):
474 x = {
475 k: base_layer_utils.generate_placeholders_from_shape(
476 shape
477 )
478 for k, shape in input_shape.items()
479 }
480 else:
481 x = base_layer_utils.generate_placeholders_from_shape(
482 input_shape
483 )
485 kwargs = {}
486 call_signature = self._call_spec.full_argspec
487 call_args = call_signature.args
488 # Exclude `self`, `inputs`, and any argument with a default
489 # value.
490 if len(call_args) > 2:
491 if call_signature.defaults:
492 call_args = call_args[2 : -len(call_signature.defaults)]
493 else:
494 call_args = call_args[2:]
495 for arg in call_args:
496 if arg == "training":
497 # Case where `training` is a positional arg with no
498 # default.
499 kwargs["training"] = False
500 else:
501 # Has invalid call signature with unknown positional
502 # arguments.
503 raise ValueError(
504 "Currently, you cannot build your model if it "
505 "has positional or keyword arguments that are "
506 "not inputs to the model, but are required for "
507 "its `call()` method. Instead, in order to "
508 "instantiate and build your model, `call()` "
509 "your model on real tensor data with all "
510 "expected call arguments. The argument "
511 "for `call()` can be a single list/tuple that "
512 "contains multiple inputs."
513 )
514 elif len(call_args) < 2:
515 # Signature without `inputs`.
516 raise ValueError(
517 "You can only call `build()` on a model if its "
518 "`call()` method accepts an `inputs` argument."
519 )
520 try:
521 self.call(x, **kwargs)
522 except (tf.errors.InvalidArgumentError, TypeError) as e:
523 raise ValueError(
524 "You cannot build your model by calling `build` "
525 "if your layers do not support float type inputs. "
526 "Instead, in order to instantiate and build your "
527 "model, call your model on real tensor data (of "
528 "the correct dtype).\n\nThe actual error from "
529 f"`call` is: {e}."
530 )
531 super().build(input_shape)
533 @traceback_utils.filter_traceback
534 def __call__(self, *args, **kwargs):
535 if self._layout_map is not None and not self.built:
536 # Note that this method is only overridden for DTensor and layout
537 # injection purpose.
538 # Capture the inputs and create graph input as replacement for model
539 # to initialize its weights first.
540 copied_args = copy.copy(args)
541 copied_kwargs = copy.copy(kwargs)
543 (
544 inputs,
545 copied_args,
546 copied_kwargs,
547 ) = self._call_spec.split_out_first_arg(copied_args, copied_kwargs)
549 def _convert_to_graph_inputs(x):
550 if isinstance(x, (tf.Tensor, np.ndarray, float, int)):
551 x = tf.convert_to_tensor(x)
552 return input_layer_module.Input(x.shape)
554 # TODO(scottzhu): maybe better handle mask and training flag.
555 inputs = tf.nest.map_structure(_convert_to_graph_inputs, inputs)
556 copied_args = tf.nest.map_structure(
557 _convert_to_graph_inputs, copied_args
558 )
559 copied_kwargs = tf.nest.map_structure(
560 _convert_to_graph_inputs, copied_kwargs
561 )
563 with layout_map_lib.layout_map_scope(self._layout_map):
564 # We ignore the result here.
565 super().__call__(inputs, *copied_args, **copied_kwargs)
567 layout_map_lib._map_subclass_model_variable(self, self._layout_map)
569 return super().__call__(*args, **kwargs)
571 @doc_controls.doc_in_current_and_subclasses
572 def call(self, inputs, training=None, mask=None):
573 """Calls the model on new inputs and returns the outputs as tensors.
575 In this case `call()` just reapplies
576 all ops in the graph to the new inputs
577 (e.g. build a new computational graph from the provided inputs).
579 Note: This method should not be called directly. It is only meant to be
580 overridden when subclassing `tf.keras.Model`.
581 To call a model on an input, always use the `__call__()` method,
582 i.e. `model(inputs)`, which relies on the underlying `call()` method.
584 Args:
585 inputs: Input tensor, or dict/list/tuple of input tensors.
586 training: Boolean or boolean scalar tensor, indicating whether to
587 run the `Network` in training mode or inference mode.
588 mask: A mask or list of masks. A mask can be either a boolean tensor
589 or None (no mask). For more details, check the guide
590 [here](https://www.tensorflow.org/guide/keras/masking_and_padding).
592 Returns:
593 A tensor if there is a single output, or
594 a list of tensors if there are more than one outputs.
595 """
596 raise NotImplementedError(
597 "Unimplemented `tf.keras.Model.call()`: if you "
598 "intend to create a `Model` with the Functional "
599 "API, please provide `inputs` and `outputs` "
600 "arguments. Otherwise, subclass `Model` with an "
601 "overridden `call()` method."
602 )
604 @traceback_utils.filter_traceback
605 def compile(
606 self,
607 optimizer="rmsprop",
608 loss=None,
609 metrics=None,
610 loss_weights=None,
611 weighted_metrics=None,
612 run_eagerly=None,
613 steps_per_execution=None,
614 jit_compile=None,
615 pss_evaluation_shards=0,
616 **kwargs,
617 ):
618 """Configures the model for training.
620 Example:
622 ```python
623 model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
624 loss=tf.keras.losses.BinaryCrossentropy(),
625 metrics=[tf.keras.metrics.BinaryAccuracy(),
626 tf.keras.metrics.FalseNegatives()])
627 ```
629 Args:
630 optimizer: String (name of optimizer) or optimizer instance. See
631 `tf.keras.optimizers`.
632 loss: Loss function. May be a string (name of loss function), or
633 a `tf.keras.losses.Loss` instance. See `tf.keras.losses`. A loss
634 function is any callable with the signature `loss = fn(y_true,
635 y_pred)`, where `y_true` are the ground truth values, and
636 `y_pred` are the model's predictions.
637 `y_true` should have shape
638 `(batch_size, d0, .. dN)` (except in the case of
639 sparse loss functions such as
640 sparse categorical crossentropy which expects integer arrays of
641 shape `(batch_size, d0, .. dN-1)`).
642 `y_pred` should have shape `(batch_size, d0, .. dN)`.
643 The loss function should return a float tensor.
644 If a custom `Loss` instance is
645 used and reduction is set to `None`, return value has shape
646 `(batch_size, d0, .. dN-1)` i.e. per-sample or per-timestep loss
647 values; otherwise, it is a scalar. If the model has multiple
648 outputs, you can use a different loss on each output by passing a
649 dictionary or a list of losses. The loss value that will be
650 minimized by the model will then be the sum of all individual
651 losses, unless `loss_weights` is specified.
652 metrics: List of metrics to be evaluated by the model during
653 training and testing. Each of this can be a string (name of a
654 built-in function), function or a `tf.keras.metrics.Metric`
655 instance. See `tf.keras.metrics`. Typically you will use
656 `metrics=['accuracy']`.
657 A function is any callable with the signature `result = fn(y_true,
658 y_pred)`. To specify different metrics for different outputs of a
659 multi-output model, you could also pass a dictionary, such as
660 `metrics={'output_a':'accuracy', 'output_b':['accuracy', 'mse']}`.
661 You can also pass a list to specify a metric or a list of metrics
662 for each output, such as
663 `metrics=[['accuracy'], ['accuracy', 'mse']]`
664 or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass the
665 strings 'accuracy' or 'acc', we convert this to one of
666 `tf.keras.metrics.BinaryAccuracy`,
667 `tf.keras.metrics.CategoricalAccuracy`,
668 `tf.keras.metrics.SparseCategoricalAccuracy` based on the shapes
669 of the targets and of the model output. We do a similar
670 conversion for the strings 'crossentropy' and 'ce' as well.
671 The metrics passed here are evaluated without sample weighting; if
672 you would like sample weighting to apply, you can specify your
673 metrics via the `weighted_metrics` argument instead.
674 loss_weights: Optional list or dictionary specifying scalar
675 coefficients (Python floats) to weight the loss contributions of
676 different model outputs. The loss value that will be minimized by
677 the model will then be the *weighted sum* of all individual
678 losses, weighted by the `loss_weights` coefficients. If a list,
679 it is expected to have a 1:1 mapping to the model's outputs. If a
680 dict, it is expected to map output names (strings) to scalar
681 coefficients.
682 weighted_metrics: List of metrics to be evaluated and weighted by
683 `sample_weight` or `class_weight` during training and testing.
684 run_eagerly: Bool. If `True`, this `Model`'s logic will not be
685 wrapped in a `tf.function`. Recommended to leave this as `None`
686 unless your `Model` cannot be run inside a `tf.function`.
687 `run_eagerly=True` is not supported when using
688 `tf.distribute.experimental.ParameterServerStrategy`. Defaults to
689 `False`.
690 steps_per_execution: Int. The number of batches to
691 run during each `tf.function` call. Running multiple batches
692 inside a single `tf.function` call can greatly improve performance
693 on TPUs or small models with a large Python overhead. At most, one
694 full epoch will be run each execution. If a number larger than the
695 size of the epoch is passed, the execution will be truncated to
696 the size of the epoch. Note that if `steps_per_execution` is set
697 to `N`, `Callback.on_batch_begin` and `Callback.on_batch_end`
698 methods will only be called every `N` batches (i.e. before/after
699 each `tf.function` execution). Defaults to `1`.
700 jit_compile: If `True`, compile the model training step with XLA.
701 [XLA](https://www.tensorflow.org/xla) is an optimizing compiler
702 for machine learning.
703 `jit_compile` is not enabled for by default.
704 Note that `jit_compile=True`
705 may not necessarily work for all models.
706 For more information on supported operations please refer to the
707 [XLA documentation](https://www.tensorflow.org/xla).
708 Also refer to
709 [known XLA issues](https://www.tensorflow.org/xla/known_issues)
710 for more details.
711 pss_evaluation_shards: Integer or 'auto'. Used for
712 `tf.distribute.ParameterServerStrategy` training only. This arg
713 sets the number of shards to split the dataset into, to enable an
714 exact visitation guarantee for evaluation, meaning the model will
715 be applied to each dataset element exactly once, even if workers
716 fail. The dataset must be sharded to ensure separate workers do
717 not process the same data. The number of shards should be at least
718 the number of workers for good performance. A value of 'auto'
719 turns on exact evaluation and uses a heuristic for the number of
720 shards based on the number of workers. 0, meaning no
721 visitation guarantee is provided. NOTE: Custom implementations of
722 `Model.test_step` will be ignored when doing exact evaluation.
723 Defaults to `0`.
724 **kwargs: Arguments supported for backwards compatibility only.
725 """
726 if jit_compile and not tf_utils.can_jit_compile(warn=True):
727 jit_compile = False
728 base_layer.keras_api_gauge.get_cell("compile").set(True)
729 self._compile_config = serialization_lib.Config(
730 optimizer=optimizer,
731 loss=loss,
732 metrics=metrics,
733 loss_weights=loss_weights,
734 weighted_metrics=weighted_metrics,
735 run_eagerly=run_eagerly,
736 steps_per_execution=steps_per_execution,
737 jit_compile=jit_compile,
738 )
739 with self.distribute_strategy.scope():
740 if "experimental_steps_per_execution" in kwargs:
741 logging.warning(
742 "The argument `steps_per_execution` is no longer "
743 "experimental. Pass `steps_per_execution` instead of "
744 "`experimental_steps_per_execution`."
745 )
746 if not steps_per_execution:
747 steps_per_execution = kwargs.pop(
748 "experimental_steps_per_execution"
749 )
751 # When compiling from an already-serialized model, we do not want to
752 # reapply some processing steps (e.g. metric renaming for
753 # multi-output models, which have prefixes added for each
754 # corresponding output name).
755 from_serialized = kwargs.pop("from_serialized", False)
757 self._validate_compile(optimizer, metrics, **kwargs)
758 self._run_eagerly = run_eagerly
760 self.optimizer = self._get_optimizer(optimizer)
761 if isinstance(loss, compile_utils.LossesContainer):
762 self.compiled_loss = loss
763 else:
764 self.compiled_loss = compile_utils.LossesContainer(
765 loss, loss_weights, output_names=self.output_names
766 )
767 self.compiled_metrics = compile_utils.MetricsContainer(
768 metrics,
769 weighted_metrics,
770 output_names=self.output_names,
771 from_serialized=from_serialized,
772 )
774 self._configure_steps_per_execution(steps_per_execution or 1)
776 self._pss_evaluation_shards = self._infer_exact_eval_shards(
777 pss_evaluation_shards
778 )
780 # Initializes attrs that are reset each time `compile` is called.
781 self._reset_compile_cache()
782 self._is_compiled = True
783 self.loss = loss or {}
784 if (self._run_eagerly or self.dynamic) and jit_compile:
785 raise ValueError(
786 "You cannot enable `run_eagerly` and `jit_compile` "
787 "at the same time."
788 )
789 else:
790 self._jit_compile = jit_compile
792 def _get_optimizer(self, optimizer):
793 """Wraps `optimizer` in `LossScaleOptimizer` if necessary."""
795 def _get_single_optimizer(opt):
796 opt = optimizers.get(opt)
797 if self.dtype_policy.name == "mixed_float16" and not isinstance(
798 opt, lso.BaseLossScaleOptimizer
799 ):
800 # Loss scaling is necessary with mixed_float16 for models to
801 # converge to the same accuracy as with float32.
802 opt = lso.BaseLossScaleOptimizer(opt)
803 return opt
805 return tf.nest.map_structure(_get_single_optimizer, optimizer)
807 @tf.__internal__.tracking.no_automatic_dependency_tracking
808 def _reset_compile_cache(self):
809 self.train_function = None
810 self.test_function = None
811 self.predict_function = None
812 # Used to cache the `tf.function`'ed `train_function` to be logged in
813 # TensorBoard, since the original `train_function` is not necessarily
814 # a `tf.function` (e.g., with ParameterServerStrategy, the
815 # `train_function` is a scheduling of the actual training function to a
816 # remote worker).
817 self.train_tf_function = None
819 # Used to cache `trainable` attr of `Layer`s for `fit`.
820 self._compiled_trainable_state = self._get_trainable_state()
822 @tf.__internal__.tracking.no_automatic_dependency_tracking
823 def _configure_steps_per_execution(self, steps_per_execution):
824 self._steps_per_execution = tf.Variable(
825 steps_per_execution,
826 dtype="int64",
827 aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
828 )
830 @property
831 def _should_compute_mask(self):
832 return False
834 @property
835 def metrics(self):
836 """Return metrics added using `compile()` or `add_metric()`.
838 Note: Metrics passed to `compile()` are available only after a
839 `keras.Model` has been trained/evaluated on actual data.
841 Examples:
843 >>> inputs = tf.keras.layers.Input(shape=(3,))
844 >>> outputs = tf.keras.layers.Dense(2)(inputs)
845 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
846 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
847 >>> [m.name for m in model.metrics]
848 []
850 >>> x = np.random.random((2, 3))
851 >>> y = np.random.randint(0, 2, (2, 2))
852 >>> model.fit(x, y)
853 >>> [m.name for m in model.metrics]
854 ['loss', 'mae']
856 >>> inputs = tf.keras.layers.Input(shape=(3,))
857 >>> d = tf.keras.layers.Dense(2, name='out')
858 >>> output_1 = d(inputs)
859 >>> output_2 = d(inputs)
860 >>> model = tf.keras.models.Model(
861 ... inputs=inputs, outputs=[output_1, output_2])
862 >>> model.add_metric(
863 ... tf.reduce_sum(output_2), name='mean', aggregation='mean')
864 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
865 >>> model.fit(x, (y, y))
866 >>> [m.name for m in model.metrics]
867 ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',
868 'out_1_acc', 'mean']
870 """
871 metrics = []
872 if self._is_compiled:
873 if self.compiled_loss is not None:
874 metrics += self.compiled_loss.metrics
875 if self.compiled_metrics is not None:
876 metrics += self.compiled_metrics.metrics
878 for l in self._flatten_layers():
879 metrics.extend(l._metrics)
880 return metrics
882 @property
883 def metrics_names(self):
884 """Returns the model's display labels for all outputs.
886 Note: `metrics_names` are available only after a `keras.Model` has been
887 trained/evaluated on actual data.
889 Examples:
891 >>> inputs = tf.keras.layers.Input(shape=(3,))
892 >>> outputs = tf.keras.layers.Dense(2)(inputs)
893 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
894 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
895 >>> model.metrics_names
896 []
898 >>> x = np.random.random((2, 3))
899 >>> y = np.random.randint(0, 2, (2, 2))
900 >>> model.fit(x, y)
901 >>> model.metrics_names
902 ['loss', 'mae']
904 >>> inputs = tf.keras.layers.Input(shape=(3,))
905 >>> d = tf.keras.layers.Dense(2, name='out')
906 >>> output_1 = d(inputs)
907 >>> output_2 = d(inputs)
908 >>> model = tf.keras.models.Model(
909 ... inputs=inputs, outputs=[output_1, output_2])
910 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
911 >>> model.fit(x, (y, y))
912 >>> model.metrics_names
913 ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',
914 'out_1_acc']
916 """
918 # This property includes all output names including `loss` and
919 # per-output losses for backward compatibility.
920 return [m.name for m in self.metrics]
922 @property
923 def distribute_strategy(self):
924 """The `tf.distribute.Strategy` this model was created under."""
925 return self._distribution_strategy or tf.distribute.get_strategy()
927 @property
928 def run_eagerly(self):
929 """Settable attribute indicating whether the model should run eagerly.
931 Running eagerly means that your model will be run step by step,
932 like Python code. Your model might run slower, but it should become
933 easier for you to debug it by stepping into individual layer calls.
935 By default, we will attempt to compile your model to a static graph to
936 deliver the best execution performance.
938 Returns:
939 Boolean, whether the model should run eagerly.
940 """
941 if self.dynamic and self._run_eagerly == False:
942 # TODO(fchollet): consider using py_func to enable this.
943 raise ValueError(
944 "Your model contains layers that can only be "
945 "successfully run in eager execution (layers "
946 "constructed with `dynamic=True`). "
947 "You cannot set `run_eagerly=False`."
948 )
950 if self._cluster_coordinator and self._run_eagerly:
951 raise ValueError(
952 "When using `Model` with `ParameterServerStrategy`, "
953 "`run_eagerly` is not supported."
954 )
956 # Run eagerly logic, by priority:
957 # (1) Dynamic models must be run eagerly.
958 # (2) Explicitly setting run_eagerly causes a Model to be run eagerly.
959 # (3) Not explicitly setting run_eagerly defaults to TF's global
960 # setting.
961 return (
962 self.dynamic
963 or self._run_eagerly
964 or (tf.config.functions_run_eagerly() and self._run_eagerly is None)
965 )
967 @run_eagerly.setter
968 def run_eagerly(self, value):
969 self._run_eagerly = value
971 @property
972 def jit_compile(self):
973 """Specify whether to compile the model with XLA.
975 [XLA](https://www.tensorflow.org/xla) is an optimizing compiler
976 for machine learning. `jit_compile` is not enabled by default.
977 Note that `jit_compile=True` may not necessarily work for all models.
979 For more information on supported operations please refer to the
980 [XLA documentation](https://www.tensorflow.org/xla). Also refer to
981 [known XLA issues](https://www.tensorflow.org/xla/known_issues)
982 for more details.
983 """
984 return self._jit_compile
986 @jit_compile.setter
987 def jit_compile(self, value):
988 # Function remains cached with previous jit_compile settings
989 if self._jit_compile == value:
990 # Avoid resetting compiler cache if possible if the value is the
991 # same
992 return
993 # Check if TensorFlow is compiled with XLA before setting the value
994 if value and not tf_utils.can_jit_compile(warn=True):
995 self._jit_compile = False
996 return
998 self._jit_compile = value
999 # Setting `jit_compile` should invalidate previously cached functions.
1000 self._reset_compile_cache()
1002 @property
1003 def distribute_reduction_method(self):
1004 """The method employed to reduce per-replica values during training.
1006 Unless specified, the value "auto" will be assumed, indicating that
1007 the reduction strategy should be chosen based on the current
1008 running environment.
1009 See `reduce_per_replica` function for more details.
1011 """
1012 return self._distribute_reduction_method or "auto"
1014 @distribute_reduction_method.setter
1015 def distribute_reduction_method(self, value):
1016 self._distribute_reduction_method = value
1018 def _validate_target_and_loss(self, y, loss):
1019 """Raises error if target or loss is not found.
1021 This method verifies that the target and loss are properly populated
1022 when applicable, or raises errors.
1024 Args:
1025 y: the target for training.
1026 loss: the total loss tensor including loss added via `compile` and
1027 `add_loss`.
1028 """
1030 # `self.loss` references the loss added via `compile` call. If users
1031 # have provided such, the target must be provided; otherwise it's a user
1032 # error. Note that `self.loss` does not include losses added via
1033 # `add_loss`, and it is a valid use when such loss from `add_loss`
1034 # exists and target does not.
1035 if self.loss and y is None:
1036 raise ValueError(
1037 "Target data is missing. Your model was compiled with "
1038 f"loss={self.loss}, "
1039 "and therefore expects target data to be provided in `fit()`."
1040 )
1042 # For training, there must be compiled loss or regularization loss to
1043 # exist in order to apply the gradients. If one is not found, it means
1044 # no loss was supplied via `compile` or `add_loss`.
1045 elif loss is None:
1046 raise ValueError(
1047 "No loss found. You may have forgotten to provide a `loss` "
1048 "argument in the `compile()` method."
1049 )
1051 def train_step(self, data):
1052 """The logic for one training step.
1054 This method can be overridden to support custom training logic.
1055 For concrete examples of how to override this method see
1056 [Customizing what happens in fit](
1057 https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit).
1058 This method is called by `Model.make_train_function`.
1060 This method should contain the mathematical logic for one step of
1061 training. This typically includes the forward pass, loss calculation,
1062 backpropagation, and metric updates.
1064 Configuration details for *how* this logic is run (e.g. `tf.function`
1065 and `tf.distribute.Strategy` settings), should be left to
1066 `Model.make_train_function`, which can also be overridden.
1068 Args:
1069 data: A nested structure of `Tensor`s.
1071 Returns:
1072 A `dict` containing values that will be passed to
1073 `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the
1074 values of the `Model`'s metrics are returned. Example:
1075 `{'loss': 0.2, 'accuracy': 0.7}`.
1076 """
1077 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
1078 # Run forward pass.
1079 with tf.GradientTape() as tape:
1080 y_pred = self(x, training=True)
1081 loss = self.compute_loss(x, y, y_pred, sample_weight)
1082 self._validate_target_and_loss(y, loss)
1083 # Run backwards pass.
1084 self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
1085 return self.compute_metrics(x, y, y_pred, sample_weight)
1087 def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
1088 """Compute the total loss, validate it, and return it.
1090 Subclasses can optionally override this method to provide custom loss
1091 computation logic.
1093 Example:
1094 ```python
1095 class MyModel(tf.keras.Model):
1097 def __init__(self, *args, **kwargs):
1098 super(MyModel, self).__init__(*args, **kwargs)
1099 self.loss_tracker = tf.keras.metrics.Mean(name='loss')
1101 def compute_loss(self, x, y, y_pred, sample_weight):
1102 loss = tf.reduce_mean(tf.math.squared_difference(y_pred, y))
1103 loss += tf.add_n(self.losses)
1104 self.loss_tracker.update_state(loss)
1105 return loss
1107 def reset_metrics(self):
1108 self.loss_tracker.reset_states()
1110 @property
1111 def metrics(self):
1112 return [self.loss_tracker]
1114 tensors = tf.random.uniform((10, 10)), tf.random.uniform((10,))
1115 dataset = tf.data.Dataset.from_tensor_slices(tensors).repeat().batch(1)
1117 inputs = tf.keras.layers.Input(shape=(10,), name='my_input')
1118 outputs = tf.keras.layers.Dense(10)(inputs)
1119 model = MyModel(inputs, outputs)
1120 model.add_loss(tf.reduce_sum(outputs))
1122 optimizer = tf.keras.optimizers.SGD()
1123 model.compile(optimizer, loss='mse', steps_per_execution=10)
1124 model.fit(dataset, epochs=2, steps_per_epoch=10)
1125 print('My custom loss: ', model.loss_tracker.result().numpy())
1126 ```
1128 Args:
1129 x: Input data.
1130 y: Target data.
1131 y_pred: Predictions returned by the model (output of `model(x)`)
1132 sample_weight: Sample weights for weighting the loss function.
1134 Returns:
1135 The total loss as a `tf.Tensor`, or `None` if no loss results (which
1136 is the case when called by `Model.test_step`).
1137 """
1138 del x # The default implementation does not use `x`.
1139 return self.compiled_loss(
1140 y, y_pred, sample_weight, regularization_losses=self.losses
1141 )
1143 def compute_metrics(self, x, y, y_pred, sample_weight):
1144 """Update metric states and collect all metrics to be returned.
1146 Subclasses can optionally override this method to provide custom metric
1147 updating and collection logic.
1149 Example:
1150 ```python
1151 class MyModel(tf.keras.Sequential):
1153 def compute_metrics(self, x, y, y_pred, sample_weight):
1155 # This super call updates `self.compiled_metrics` and returns
1156 # results for all metrics listed in `self.metrics`.
1157 metric_results = super(MyModel, self).compute_metrics(
1158 x, y, y_pred, sample_weight)
1160 # Note that `self.custom_metric` is not listed in `self.metrics`.
1161 self.custom_metric.update_state(x, y, y_pred, sample_weight)
1162 metric_results['custom_metric_name'] = self.custom_metric.result()
1163 return metric_results
1164 ```
1166 Args:
1167 x: Input data.
1168 y: Target data.
1169 y_pred: Predictions returned by the model (output of `model.call(x)`)
1170 sample_weight: Sample weights for weighting the loss function.
1172 Returns:
1173 A `dict` containing values that will be passed to
1174 `tf.keras.callbacks.CallbackList.on_train_batch_end()`. Typically, the
1175 values of the metrics listed in `self.metrics` are returned. Example:
1176 `{'loss': 0.2, 'accuracy': 0.7}`.
1177 """
1178 del x # The default implementation does not use `x`.
1179 self.compiled_metrics.update_state(y, y_pred, sample_weight)
1180 return self.get_metrics_result()
1182 def get_metrics_result(self):
1183 """Returns the model's metrics values as a dict.
1185 If any of the metric result is a dict (containing multiple metrics),
1186 each of them gets added to the top level returned dict of this method.
1188 Returns:
1189 A `dict` containing values of the metrics listed in `self.metrics`.
1190 Example:
1191 `{'loss': 0.2, 'accuracy': 0.7}`.
1192 """
1193 # Collect metrics to return
1194 return_metrics = {}
1195 for metric in self.metrics:
1196 result = metric.result()
1197 if isinstance(result, dict):
1198 return_metrics.update(result)
1199 else:
1200 return_metrics[metric.name] = result
1201 return return_metrics
1203 def _validate_and_get_metrics_result(self, logs):
1204 """Returns model metrics as a dict if the keys match with input logs.
1206 When the training / evalution is performed with asynchronous steps, such
1207 as the case with `tf.distribute.ParameterServerStrategy`, the last
1208 scheduled `train / test_step` may not give the latest metrics because it
1209 is not guaranteed to be executed the last. This method gets metrics from
1210 the model directly instead of relying on the return from last step
1211 function.
1213 It logs a warning if the metric results could not be overridden when
1214 used with `tf.distribute.ParameterServerStrategy`.
1216 When the user has custom train / test step functions, the metrics
1217 returned may be different from `Model.metrics`. In those instances,
1218 this function will be no-op and return the logs.
1220 Args:
1221 logs: A `dict` of metrics returned by train / test step function.
1223 Returns:
1224 A `dict` containing values of the metrics listed in `self.metrics`
1225 when logs and model metrics keys match. Otherwise it returns input
1226 `logs`.
1227 """
1228 PSS_WARN_MSG = "Could not get Model metric results. \
1229 Using the results of last step function could lead to incorrect \
1230 results when used with ParameterServerStrategy"
1231 try:
1232 metric_logs = self.get_metrics_result()
1233 except TypeError:
1234 if self._cluster_coordinator:
1235 logging.warning(PSS_WARN_MSG)
1236 else:
1237 # Verify that train / test step logs passed and metric logs have
1238 # matching keys. Could be different when using custom step functions
1239 if isinstance(logs, dict) and set(logs.keys()) == set(
1240 metric_logs.keys()
1241 ):
1242 logs = tf_utils.sync_to_numpy_or_python_type(metric_logs)
1243 elif self._cluster_coordinator:
1244 logging.warning(PSS_WARN_MSG)
1245 return logs
1247 def _aggregate_exact_metrics(self, logs):
1248 # When doing exact evaluation, `logs` is a list of each data shard's
1249 # metric variables, which will be used to update the metrics.
1250 for shard_result in logs:
1251 for metric in self.metrics:
1252 if metric.name not in shard_result.keys():
1253 logging.log_first_n(
1254 logging.WARN,
1255 f"No matching result found for metric {metric.name}. "
1256 "This metric's computed result may be incorrect.",
1257 3,
1258 )
1259 continue
1260 metric_result = shard_result[metric.name]
1261 if len(metric_result) != len(metric.weights):
1262 raise ValueError(
1263 f"Expected {len(metric.weights)} variables in result "
1264 f"for metric {metric.name}, but found "
1265 f"{len(metric_result)}."
1266 )
1267 for weight, val in zip(metric.weights, metric_result):
1268 weight.assign_add(val)
1269 return self.get_metrics_result()
1271 def make_train_function(self, force=False):
1272 """Creates a function that executes one step of training.
1274 This method can be overridden to support custom training logic.
1275 This method is called by `Model.fit` and `Model.train_on_batch`.
1277 Typically, this method directly controls `tf.function` and
1278 `tf.distribute.Strategy` settings, and delegates the actual training
1279 logic to `Model.train_step`.
1281 This function is cached the first time `Model.fit` or
1282 `Model.train_on_batch` is called. The cache is cleared whenever
1283 `Model.compile` is called. You can skip the cache and generate again the
1284 function with `force=True`.
1286 Args:
1287 force: Whether to regenerate the train function and skip the cached
1288 function if available.
1290 Returns:
1291 Function. The function created by this method should accept a
1292 `tf.data.Iterator`, and return a `dict` containing values that will
1293 be passed to `tf.keras.Callbacks.on_train_batch_end`, such as
1294 `{'loss': 0.2, 'accuracy': 0.7}`.
1295 """
1296 if self.train_function is not None and not force:
1297 return self.train_function
1299 def step_function(model, iterator):
1300 """Runs a single training step."""
1302 def run_step(data):
1303 outputs = model.train_step(data)
1304 # Ensure counter is updated only if `train_step` succeeds.
1305 with tf.control_dependencies(_minimum_control_deps(outputs)):
1306 model._train_counter.assign_add(1)
1307 return outputs
1309 if self.jit_compile and not isinstance(
1310 model.distribute_strategy,
1311 (
1312 tf.compat.v1.distribute.experimental.TPUStrategy,
1313 tf.distribute.TPUStrategy,
1314 ),
1315 ):
1316 # TODO(b/258249546): Explicit `jit_compile=True` on TPU causes
1317 # unexpected behavior, so we skip TPU training now.
1318 run_step = tf.function(
1319 run_step, jit_compile=True, reduce_retracing=True
1320 )
1321 data = next(iterator)
1322 outputs = model.distribute_strategy.run(run_step, args=(data,))
1323 outputs = reduce_per_replica(
1324 outputs,
1325 self.distribute_strategy,
1326 reduction=self.distribute_reduction_method,
1327 )
1328 return outputs
1330 # Special case if steps_per_execution is one.
1331 if (
1332 self._steps_per_execution is None
1333 or self._steps_per_execution.numpy().item() == 1
1334 ):
1336 def train_function(iterator):
1337 """Runs a training execution with a single step."""
1338 return step_function(self, iterator)
1340 if not self.run_eagerly:
1341 train_function = tf.function(
1342 train_function, reduce_retracing=True
1343 )
1344 self.train_tf_function = train_function
1346 if self._cluster_coordinator:
1347 self.train_function = (
1348 lambda it: self._cluster_coordinator.schedule(
1349 train_function, args=(it,)
1350 )
1351 )
1352 else:
1353 self.train_function = train_function
1355 # If we're using a coordinator, use the value of
1356 # self._steps_per_execution at the time the function is
1357 # called/scheduled, and not when it is actually executed.
1358 elif self._cluster_coordinator:
1360 def train_function(iterator, steps_per_execution):
1361 """Runs a training execution with multiple steps."""
1362 for _ in tf.range(steps_per_execution):
1363 outputs = step_function(self, iterator)
1364 return outputs
1366 if not self.run_eagerly:
1367 train_function = tf.function(
1368 train_function, reduce_retracing=True
1369 )
1370 self.train_tf_function = train_function
1372 self.train_function = lambda it: self._cluster_coordinator.schedule(
1373 train_function, args=(it, self._steps_per_execution.value())
1374 )
1375 else:
1377 def train_function(iterator):
1378 """Runs a training execution with multiple steps."""
1379 for _ in tf.range(self._steps_per_execution):
1380 outputs = step_function(self, iterator)
1381 return outputs
1383 if not self.run_eagerly:
1384 train_function = tf.function(
1385 train_function, reduce_retracing=True
1386 )
1387 self.train_tf_function = train_function
1388 self.train_function = train_function
1390 return self.train_function
1392 @traceback_utils.filter_traceback
1393 def fit(
1394 self,
1395 x=None,
1396 y=None,
1397 batch_size=None,
1398 epochs=1,
1399 verbose="auto",
1400 callbacks=None,
1401 validation_split=0.0,
1402 validation_data=None,
1403 shuffle=True,
1404 class_weight=None,
1405 sample_weight=None,
1406 initial_epoch=0,
1407 steps_per_epoch=None,
1408 validation_steps=None,
1409 validation_batch_size=None,
1410 validation_freq=1,
1411 max_queue_size=10,
1412 workers=1,
1413 use_multiprocessing=False,
1414 ):
1415 """Trains the model for a fixed number of epochs (dataset iterations).
1417 Args:
1418 x: Input data. It could be:
1419 - A Numpy array (or array-like), or a list of arrays
1420 (in case the model has multiple inputs).
1421 - A TensorFlow tensor, or a list of tensors
1422 (in case the model has multiple inputs).
1423 - A dict mapping input names to the corresponding array/tensors,
1424 if the model has named inputs.
1425 - A `tf.data` dataset. Should return a tuple
1426 of either `(inputs, targets)` or
1427 `(inputs, targets, sample_weights)`.
1428 - A generator or `keras.utils.Sequence` returning `(inputs,
1429 targets)` or `(inputs, targets, sample_weights)`.
1430 - A `tf.keras.utils.experimental.DatasetCreator`, which wraps a
1431 callable that takes a single argument of type
1432 `tf.distribute.InputContext`, and returns a `tf.data.Dataset`.
1433 `DatasetCreator` should be used when users prefer to specify the
1434 per-replica batching and sharding logic for the `Dataset`.
1435 See `tf.keras.utils.experimental.DatasetCreator` doc for more
1436 information.
1437 A more detailed description of unpacking behavior for iterator
1438 types (Dataset, generator, Sequence) is given below. If these
1439 include `sample_weights` as a third component, note that sample
1440 weighting applies to the `weighted_metrics` argument but not the
1441 `metrics` argument in `compile()`. If using
1442 `tf.distribute.experimental.ParameterServerStrategy`, only
1443 `DatasetCreator` type is supported for `x`.
1444 y: Target data. Like the input data `x`,
1445 it could be either Numpy array(s) or TensorFlow tensor(s).
1446 It should be consistent with `x` (you cannot have Numpy inputs and
1447 tensor targets, or inversely). If `x` is a dataset, generator,
1448 or `keras.utils.Sequence` instance, `y` should
1449 not be specified (since targets will be obtained from `x`).
1450 batch_size: Integer or `None`.
1451 Number of samples per gradient update.
1452 If unspecified, `batch_size` will default to 32.
1453 Do not specify the `batch_size` if your data is in the
1454 form of datasets, generators, or `keras.utils.Sequence`
1455 instances (since they generate batches).
1456 epochs: Integer. Number of epochs to train the model.
1457 An epoch is an iteration over the entire `x` and `y`
1458 data provided
1459 (unless the `steps_per_epoch` flag is set to
1460 something other than None).
1461 Note that in conjunction with `initial_epoch`,
1462 `epochs` is to be understood as "final epoch".
1463 The model is not trained for a number of iterations
1464 given by `epochs`, but merely until the epoch
1465 of index `epochs` is reached.
1466 verbose: 'auto', 0, 1, or 2. Verbosity mode.
1467 0 = silent, 1 = progress bar, 2 = one line per epoch.
1468 'auto' becomes 1 for most cases, but 2 when used with
1469 `ParameterServerStrategy`. Note that the progress bar is not
1470 particularly useful when logged to a file, so verbose=2 is
1471 recommended when not running interactively (eg, in a production
1472 environment). Defaults to 'auto'.
1473 callbacks: List of `keras.callbacks.Callback` instances.
1474 List of callbacks to apply during training.
1475 See `tf.keras.callbacks`. Note
1476 `tf.keras.callbacks.ProgbarLogger` and
1477 `tf.keras.callbacks.History` callbacks are created automatically
1478 and need not be passed into `model.fit`.
1479 `tf.keras.callbacks.ProgbarLogger` is created or not based on
1480 `verbose` argument to `model.fit`.
1481 Callbacks with batch-level calls are currently unsupported with
1482 `tf.distribute.experimental.ParameterServerStrategy`, and users
1483 are advised to implement epoch-level calls instead with an
1484 appropriate `steps_per_epoch` value.
1485 validation_split: Float between 0 and 1.
1486 Fraction of the training data to be used as validation data.
1487 The model will set apart this fraction of the training data,
1488 will not train on it, and will evaluate
1489 the loss and any model metrics
1490 on this data at the end of each epoch.
1491 The validation data is selected from the last samples
1492 in the `x` and `y` data provided, before shuffling. This
1493 argument is not supported when `x` is a dataset, generator or
1494 `keras.utils.Sequence` instance.
1495 If both `validation_data` and `validation_split` are provided,
1496 `validation_data` will override `validation_split`.
1497 `validation_split` is not yet supported with
1498 `tf.distribute.experimental.ParameterServerStrategy`.
1499 validation_data: Data on which to evaluate
1500 the loss and any model metrics at the end of each epoch.
1501 The model will not be trained on this data. Thus, note the fact
1502 that the validation loss of data provided using
1503 `validation_split` or `validation_data` is not affected by
1504 regularization layers like noise and dropout.
1505 `validation_data` will override `validation_split`.
1506 `validation_data` could be:
1507 - A tuple `(x_val, y_val)` of Numpy arrays or tensors.
1508 - A tuple `(x_val, y_val, val_sample_weights)` of NumPy
1509 arrays.
1510 - A `tf.data.Dataset`.
1511 - A Python generator or `keras.utils.Sequence` returning
1512 `(inputs, targets)` or `(inputs, targets, sample_weights)`.
1513 `validation_data` is not yet supported with
1514 `tf.distribute.experimental.ParameterServerStrategy`.
1515 shuffle: Boolean (whether to shuffle the training data
1516 before each epoch) or str (for 'batch'). This argument is
1517 ignored when `x` is a generator or an object of tf.data.Dataset.
1518 'batch' is a special option for dealing
1519 with the limitations of HDF5 data; it shuffles in batch-sized
1520 chunks. Has no effect when `steps_per_epoch` is not `None`.
1521 class_weight: Optional dictionary mapping class indices (integers)
1522 to a weight (float) value, used for weighting the loss function
1523 (during training only).
1524 This can be useful to tell the model to
1525 "pay more attention" to samples from
1526 an under-represented class. When `class_weight` is specified
1527 and targets have a rank of 2 or greater, either `y` must be
1528 one-hot encoded, or an explicit final dimension of `1` must
1529 be included for sparse class labels.
1530 sample_weight: Optional Numpy array of weights for
1531 the training samples, used for weighting the loss function
1532 (during training only). You can either pass a flat (1D)
1533 Numpy array with the same length as the input samples
1534 (1:1 mapping between weights and samples),
1535 or in the case of temporal data,
1536 you can pass a 2D array with shape
1537 `(samples, sequence_length)`,
1538 to apply a different weight to every timestep of every sample.
1539 This argument is not supported when `x` is a dataset, generator,
1540 or `keras.utils.Sequence` instance, instead provide the
1541 sample_weights as the third element of `x`.
1542 Note that sample weighting does not apply to metrics specified
1543 via the `metrics` argument in `compile()`. To apply sample
1544 weighting to your metrics, you can specify them via the
1545 `weighted_metrics` in `compile()` instead.
1546 initial_epoch: Integer.
1547 Epoch at which to start training
1548 (useful for resuming a previous training run).
1549 steps_per_epoch: Integer or `None`.
1550 Total number of steps (batches of samples)
1551 before declaring one epoch finished and starting the
1552 next epoch. When training with input tensors such as
1553 TensorFlow data tensors, the default `None` is equal to
1554 the number of samples in your dataset divided by
1555 the batch size, or 1 if that cannot be determined. If x is a
1556 `tf.data` dataset, and 'steps_per_epoch'
1557 is None, the epoch will run until the input dataset is
1558 exhausted. When passing an infinitely repeating dataset, you
1559 must specify the `steps_per_epoch` argument. If
1560 `steps_per_epoch=-1` the training will run indefinitely with an
1561 infinitely repeating dataset. This argument is not supported
1562 with array inputs.
1563 When using `tf.distribute.experimental.ParameterServerStrategy`:
1564 * `steps_per_epoch=None` is not supported.
1565 validation_steps: Only relevant if `validation_data` is provided and
1566 is a `tf.data` dataset. Total number of steps (batches of
1567 samples) to draw before stopping when performing validation
1568 at the end of every epoch. If 'validation_steps' is None,
1569 validation will run until the `validation_data` dataset is
1570 exhausted. In the case of an infinitely repeated dataset, it
1571 will run into an infinite loop. If 'validation_steps' is
1572 specified and only part of the dataset will be consumed, the
1573 evaluation will start from the beginning of the dataset at each
1574 epoch. This ensures that the same validation samples are used
1575 every time.
1576 validation_batch_size: Integer or `None`.
1577 Number of samples per validation batch.
1578 If unspecified, will default to `batch_size`.
1579 Do not specify the `validation_batch_size` if your data is in
1580 the form of datasets, generators, or `keras.utils.Sequence`
1581 instances (since they generate batches).
1582 validation_freq: Only relevant if validation data is provided.
1583 Integer or `collections.abc.Container` instance (e.g. list, tuple,
1584 etc.). If an integer, specifies how many training epochs to run
1585 before a new validation run is performed, e.g. `validation_freq=2`
1586 runs validation every 2 epochs. If a Container, specifies the
1587 epochs on which to run validation, e.g.
1588 `validation_freq=[1, 2, 10]` runs validation at the end of the
1589 1st, 2nd, and 10th epochs.
1590 max_queue_size: Integer. Used for generator or
1591 `keras.utils.Sequence` input only. Maximum size for the generator
1592 queue. If unspecified, `max_queue_size` will default to 10.
1593 workers: Integer. Used for generator or `keras.utils.Sequence` input
1594 only. Maximum number of processes to spin up
1595 when using process-based threading. If unspecified, `workers`
1596 will default to 1.
1597 use_multiprocessing: Boolean. Used for generator or
1598 `keras.utils.Sequence` input only. If `True`, use process-based
1599 threading. If unspecified, `use_multiprocessing` will default to
1600 `False`. Note that because this implementation relies on
1601 multiprocessing, you should not pass non-picklable arguments to
1602 the generator as they can't be passed easily to children
1603 processes.
1605 Unpacking behavior for iterator-like inputs:
1606 A common pattern is to pass a tf.data.Dataset, generator, or
1607 tf.keras.utils.Sequence to the `x` argument of fit, which will in fact
1608 yield not only features (x) but optionally targets (y) and sample
1609 weights. Keras requires that the output of such iterator-likes be
1610 unambiguous. The iterator should return a tuple of length 1, 2, or 3,
1611 where the optional second and third elements will be used for y and
1612 sample_weight respectively. Any other type provided will be wrapped in
1613 a length one tuple, effectively treating everything as 'x'. When
1614 yielding dicts, they should still adhere to the top-level tuple
1615 structure.
1616 e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate
1617 features, targets, and weights from the keys of a single dict.
1618 A notable unsupported data type is the namedtuple. The reason is
1619 that it behaves like both an ordered datatype (tuple) and a mapping
1620 datatype (dict). So given a namedtuple of the form:
1621 `namedtuple("example_tuple", ["y", "x"])`
1622 it is ambiguous whether to reverse the order of the elements when
1623 interpreting the value. Even worse is a tuple of the form:
1624 `namedtuple("other_tuple", ["x", "y", "z"])`
1625 where it is unclear if the tuple was intended to be unpacked into x,
1626 y, and sample_weight or passed through as a single element to `x`. As
1627 a result the data processing code will simply raise a ValueError if it
1628 encounters a namedtuple. (Along with instructions to remedy the
1629 issue.)
1631 Returns:
1632 A `History` object. Its `History.history` attribute is
1633 a record of training loss values and metrics values
1634 at successive epochs, as well as validation loss values
1635 and validation metrics values (if applicable).
1637 Raises:
1638 RuntimeError: 1. If the model was never compiled or,
1639 2. If `model.fit` is wrapped in `tf.function`.
1641 ValueError: In case of mismatch between the provided input data
1642 and what the model expects or when the input data is empty.
1643 """
1644 base_layer.keras_api_gauge.get_cell("fit").set(True)
1645 # Legacy graph support is contained in `training_v1.Model`.
1646 version_utils.disallow_legacy_graph("Model", "fit")
1647 self._assert_compile_was_called()
1648 self._check_call_args("fit")
1649 _disallow_inside_tf_function("fit")
1651 verbose = _get_verbosity(verbose, self.distribute_strategy)
1653 if validation_split and validation_data is None:
1654 # Create the validation data using the training data. Only supported
1655 # for `Tensor` and `NumPy` input.
1656 (
1657 x,
1658 y,
1659 sample_weight,
1660 ), validation_data = data_adapter.train_validation_split(
1661 (x, y, sample_weight), validation_split=validation_split
1662 )
1664 if validation_data:
1665 (
1666 val_x,
1667 val_y,
1668 val_sample_weight,
1669 ) = data_adapter.unpack_x_y_sample_weight(validation_data)
1671 if self.distribute_strategy._should_use_with_coordinator:
1672 self._cluster_coordinator = (
1673 tf.distribute.experimental.coordinator.ClusterCoordinator(
1674 self.distribute_strategy
1675 )
1676 )
1678 with self.distribute_strategy.scope(), training_utils.RespectCompiledTrainableState( # noqa: E501
1679 self
1680 ):
1681 # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
1682 data_handler = data_adapter.get_data_handler(
1683 x=x,
1684 y=y,
1685 sample_weight=sample_weight,
1686 batch_size=batch_size,
1687 steps_per_epoch=steps_per_epoch,
1688 initial_epoch=initial_epoch,
1689 epochs=epochs,
1690 shuffle=shuffle,
1691 class_weight=class_weight,
1692 max_queue_size=max_queue_size,
1693 workers=workers,
1694 use_multiprocessing=use_multiprocessing,
1695 model=self,
1696 steps_per_execution=self._steps_per_execution,
1697 )
1699 # Container that configures and calls `tf.keras.Callback`s.
1700 if not isinstance(callbacks, callbacks_module.CallbackList):
1701 callbacks = callbacks_module.CallbackList(
1702 callbacks,
1703 add_history=True,
1704 add_progbar=verbose != 0,
1705 model=self,
1706 verbose=verbose,
1707 epochs=epochs,
1708 steps=data_handler.inferred_steps,
1709 )
1711 self.stop_training = False
1712 self.train_function = self.make_train_function()
1713 self._train_counter.assign(0)
1714 callbacks.on_train_begin()
1715 training_logs = None
1716 # Handle fault-tolerance for multi-worker.
1717 # TODO(omalleyt): Fix the ordering issues that mean this has to
1718 # happen after `callbacks.on_train_begin`.
1719 steps_per_epoch_inferred = (
1720 steps_per_epoch or data_handler.inferred_steps
1721 )
1722 (
1723 data_handler._initial_epoch,
1724 data_handler._initial_step,
1725 ) = self._maybe_load_initial_counters_from_ckpt(
1726 steps_per_epoch_inferred, initial_epoch
1727 )
1728 logs = None
1729 for epoch, iterator in data_handler.enumerate_epochs():
1730 self.reset_metrics()
1731 callbacks.on_epoch_begin(epoch)
1732 with data_handler.catch_stop_iteration():
1733 for step in data_handler.steps():
1734 with tf.profiler.experimental.Trace(
1735 "train",
1736 epoch_num=epoch,
1737 step_num=step,
1738 batch_size=batch_size,
1739 _r=1,
1740 ):
1741 callbacks.on_train_batch_begin(step)
1742 tmp_logs = self.train_function(iterator)
1743 if data_handler.should_sync:
1744 context.async_wait()
1745 # No error, now safe to assign to logs.
1746 logs = tmp_logs
1747 end_step = step + data_handler.step_increment
1748 callbacks.on_train_batch_end(end_step, logs)
1749 if self.stop_training:
1750 break
1752 logs = tf_utils.sync_to_numpy_or_python_type(logs)
1753 if logs is None:
1754 raise ValueError(
1755 "Unexpected result of `train_function` "
1756 "(Empty logs). This could be due to issues in input "
1757 "pipeline that resulted in an empty dataset. "
1758 "Otherwise, please use "
1759 "`Model.compile(..., run_eagerly=True)`, or "
1760 "`tf.config.run_functions_eagerly(True)` for more "
1761 "information of where went wrong, or file a "
1762 "issue/bug to `tf.keras`."
1763 )
1764 # Override with model metrics instead of last step logs
1765 logs = self._validate_and_get_metrics_result(logs)
1766 epoch_logs = copy.copy(logs)
1768 # Run validation.
1769 if validation_data and self._should_eval(
1770 epoch, validation_freq
1771 ):
1772 if self._pss_evaluation_shards:
1773 self._disallow_exact_eval_with_add_metrics()
1774 # Create data_handler for evaluation and cache it.
1775 if getattr(self, "_eval_data_handler", None) is None:
1776 self._eval_data_handler = data_adapter.get_data_handler(
1777 x=val_x,
1778 y=val_y,
1779 sample_weight=val_sample_weight,
1780 batch_size=validation_batch_size or batch_size,
1781 steps_per_epoch=validation_steps,
1782 initial_epoch=0,
1783 epochs=1,
1784 max_queue_size=max_queue_size,
1785 workers=workers,
1786 use_multiprocessing=use_multiprocessing,
1787 model=self,
1788 steps_per_execution=self._steps_per_execution,
1789 pss_evaluation_shards=self._pss_evaluation_shards,
1790 )
1791 val_logs = self.evaluate(
1792 x=val_x,
1793 y=val_y,
1794 sample_weight=val_sample_weight,
1795 batch_size=validation_batch_size or batch_size,
1796 steps=validation_steps,
1797 callbacks=callbacks,
1798 max_queue_size=max_queue_size,
1799 workers=workers,
1800 use_multiprocessing=use_multiprocessing,
1801 return_dict=True,
1802 _use_cached_eval_dataset=True,
1803 )
1804 val_logs = {
1805 "val_" + name: val for name, val in val_logs.items()
1806 }
1807 epoch_logs.update(val_logs)
1809 callbacks.on_epoch_end(epoch, epoch_logs)
1810 training_logs = epoch_logs
1811 if self.stop_training:
1812 break
1814 if isinstance(self.optimizer, optimizer.Optimizer) and epochs > 0:
1815 self.optimizer.finalize_variable_values(
1816 self.trainable_variables
1817 )
1819 # If eval data_handler exists, delete it after all epochs are done.
1820 if getattr(self, "_eval_data_handler", None) is not None:
1821 del self._eval_data_handler
1822 callbacks.on_train_end(logs=training_logs)
1823 return self.history
1825 def test_step(self, data):
1826 """The logic for one evaluation step.
1828 This method can be overridden to support custom evaluation logic.
1829 This method is called by `Model.make_test_function`.
1831 This function should contain the mathematical logic for one step of
1832 evaluation.
1833 This typically includes the forward pass, loss calculation, and metrics
1834 updates.
1836 Configuration details for *how* this logic is run (e.g. `tf.function`
1837 and `tf.distribute.Strategy` settings), should be left to
1838 `Model.make_test_function`, which can also be overridden.
1840 Args:
1841 data: A nested structure of `Tensor`s.
1843 Returns:
1844 A `dict` containing values that will be passed to
1845 `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the
1846 values of the `Model`'s metrics are returned.
1847 """
1848 x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
1850 y_pred = self(x, training=False)
1851 # Updates stateful loss metrics.
1852 self.compute_loss(x, y, y_pred, sample_weight)
1853 return self.compute_metrics(x, y, y_pred, sample_weight)
1855 def _make_test_function_exact(self):
1856 if getattr(self, "_shard_test_function", None):
1857 return self._shard_test_function
1859 def step_function(batch):
1860 def run_step(data):
1861 # TODO(b/272050910): Use sample_weight for weighted metrics.
1862 x, y, _ = data_adapter.unpack_x_y_sample_weight(data)
1863 y_pred = self(x, training=False)
1864 return x, y, y_pred
1866 if self._jit_compile:
1867 run_step = tf.function(
1868 run_step, jit_compile=True, reduce_retracing=True
1869 )
1871 outputs = self.distribute_strategy.run(run_step, args=(batch,))
1872 outputs = reduce_per_replica(
1873 outputs,
1874 self.distribute_strategy,
1875 reduction=self.distribute_reduction_method,
1876 )
1877 return outputs
1879 def shard_test_function(dataset, total_shards, shard_idx):
1880 local_metrics = []
1881 with tf_utils.with_metric_local_vars_scope():
1882 for metric in self.compiled_metrics.metrics:
1883 local_metrics.append(base_metric.clone_metric(metric))
1884 for metric in self.compiled_loss.metrics:
1885 local_metrics.append(base_metric.clone_metric(metric))
1886 dataset = input_ops.auto_shard_dataset(
1887 dataset, total_shards, shard_idx
1888 )
1889 iterator = iter(dataset)
1890 with distribute_utils.cache_variable_reads():
1891 for batch in iterator:
1892 x, y, y_pred = step_function(batch)
1893 for local_metric in local_metrics:
1894 local_metric.update_state(y, y_pred)
1895 outputs = {metric.name: metric.weights for metric in local_metrics}
1896 with tf.control_dependencies(_minimum_control_deps(outputs)):
1897 self._test_counter.assign_add(1)
1898 return outputs
1900 if not self.run_eagerly:
1901 shard_test_function = tf.function(
1902 shard_test_function, reduce_retracing=True
1903 )
1905 self._shard_test_function = (
1906 lambda *args: self._cluster_coordinator.schedule(
1907 shard_test_function,
1908 args=args,
1909 )
1910 )
1911 return self._shard_test_function
1913 def make_test_function(self, force=False):
1914 """Creates a function that executes one step of evaluation.
1916 This method can be overridden to support custom evaluation logic.
1917 This method is called by `Model.evaluate` and `Model.test_on_batch`.
1919 Typically, this method directly controls `tf.function` and
1920 `tf.distribute.Strategy` settings, and delegates the actual evaluation
1921 logic to `Model.test_step`.
1923 This function is cached the first time `Model.evaluate` or
1924 `Model.test_on_batch` is called. The cache is cleared whenever
1925 `Model.compile` is called. You can skip the cache and generate again the
1926 function with `force=True`.
1928 Args:
1929 force: Whether to regenerate the test function and skip the cached
1930 function if available.
1932 Returns:
1933 Function. The function created by this method should accept a
1934 `tf.data.Iterator`, and return a `dict` containing values that will
1935 be passed to `tf.keras.Callbacks.on_test_batch_end`.
1936 """
1937 if self.test_function is not None and not force:
1938 return self.test_function
1940 def step_function(model, iterator):
1941 """Runs a single evaluation step."""
1943 def run_step(data):
1944 outputs = model.test_step(data)
1945 # Ensure counter is updated only if `test_step` succeeds.
1946 with tf.control_dependencies(_minimum_control_deps(outputs)):
1947 model._test_counter.assign_add(1)
1948 return outputs
1950 if self.jit_compile:
1951 run_step = tf.function(
1952 run_step, jit_compile=True, reduce_retracing=True
1953 )
1955 data = next(iterator)
1956 outputs = model.distribute_strategy.run(run_step, args=(data,))
1957 outputs = reduce_per_replica(
1958 outputs,
1959 self.distribute_strategy,
1960 reduction=self.distribute_reduction_method,
1961 )
1962 return outputs
1964 # Special case if steps_per_execution is one.
1965 if (
1966 self._steps_per_execution is None
1967 or self._steps_per_execution.numpy().item() == 1
1968 ):
1970 def test_function(iterator):
1971 """Runs a test execution with a single step."""
1972 return step_function(self, iterator)
1974 if not self.run_eagerly:
1975 test_function = tf.function(
1976 test_function, reduce_retracing=True
1977 )
1979 if self._cluster_coordinator:
1980 self.test_function = (
1981 lambda it: self._cluster_coordinator.schedule(
1982 test_function, args=(it,)
1983 )
1984 )
1985 else:
1986 self.test_function = test_function
1988 # If we're using a coordinator, use the value of
1989 # self._steps_per_execution at the time the function is
1990 # called/scheduled, and not when it is actually executed.
1991 elif self._cluster_coordinator:
1993 def test_function(iterator, steps_per_execution):
1994 """Runs a test execution with multiple steps."""
1995 for _ in tf.range(steps_per_execution):
1996 outputs = step_function(self, iterator)
1997 return outputs
1999 if not self.run_eagerly:
2000 test_function = tf.function(
2001 test_function, reduce_retracing=True
2002 )
2004 self.test_function = lambda it: self._cluster_coordinator.schedule(
2005 test_function, args=(it, self._steps_per_execution.value())
2006 )
2007 else:
2009 def test_function(iterator):
2010 """Runs a test execution with multiple steps."""
2011 for _ in tf.range(self._steps_per_execution):
2012 outputs = step_function(self, iterator)
2013 return outputs
2015 if not self.run_eagerly:
2016 test_function = tf.function(
2017 test_function, reduce_retracing=True
2018 )
2019 self.test_function = test_function
2021 return self.test_function
2023 @traceback_utils.filter_traceback
2024 def evaluate(
2025 self,
2026 x=None,
2027 y=None,
2028 batch_size=None,
2029 verbose="auto",
2030 sample_weight=None,
2031 steps=None,
2032 callbacks=None,
2033 max_queue_size=10,
2034 workers=1,
2035 use_multiprocessing=False,
2036 return_dict=False,
2037 **kwargs,
2038 ):
2039 """Returns the loss value & metrics values for the model in test mode.
2041 Computation is done in batches (see the `batch_size` arg.)
2043 Args:
2044 x: Input data. It could be:
2045 - A Numpy array (or array-like), or a list of arrays
2046 (in case the model has multiple inputs).
2047 - A TensorFlow tensor, or a list of tensors
2048 (in case the model has multiple inputs).
2049 - A dict mapping input names to the corresponding array/tensors,
2050 if the model has named inputs.
2051 - A `tf.data` dataset. Should return a tuple
2052 of either `(inputs, targets)` or
2053 `(inputs, targets, sample_weights)`.
2054 - A generator or `keras.utils.Sequence` returning `(inputs,
2055 targets)` or `(inputs, targets, sample_weights)`.
2056 A more detailed description of unpacking behavior for iterator
2057 types (Dataset, generator, Sequence) is given in the `Unpacking
2058 behavior for iterator-like inputs` section of `Model.fit`.
2059 y: Target data. Like the input data `x`, it could be either Numpy
2060 array(s) or TensorFlow tensor(s). It should be consistent with `x`
2061 (you cannot have Numpy inputs and tensor targets, or inversely).
2062 If `x` is a dataset, generator or `keras.utils.Sequence` instance,
2063 `y` should not be specified (since targets will be obtained from
2064 the iterator/dataset).
2065 batch_size: Integer or `None`. Number of samples per batch of
2066 computation. If unspecified, `batch_size` will default to 32. Do
2067 not specify the `batch_size` if your data is in the form of a
2068 dataset, generators, or `keras.utils.Sequence` instances (since
2069 they generate batches).
2070 verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
2071 0 = silent, 1 = progress bar, 2 = single line.
2072 `"auto"` becomes 1 for most cases, and to 2 when used with
2073 `ParameterServerStrategy`. Note that the progress bar is not
2074 particularly useful when logged to a file, so `verbose=2` is
2075 recommended when not running interactively (e.g. in a production
2076 environment). Defaults to 'auto'.
2077 sample_weight: Optional Numpy array of weights for the test samples,
2078 used for weighting the loss function. You can either pass a flat
2079 (1D) Numpy array with the same length as the input samples
2080 (1:1 mapping between weights and samples), or in the case of
2081 temporal data, you can pass a 2D array with shape `(samples,
2082 sequence_length)`, to apply a different weight to every
2083 timestep of every sample. This argument is not supported when
2084 `x` is a dataset, instead pass sample weights as the third
2085 element of `x`.
2086 steps: Integer or `None`. Total number of steps (batches of samples)
2087 before declaring the evaluation round finished. Ignored with the
2088 default value of `None`. If x is a `tf.data` dataset and `steps`
2089 is None, 'evaluate' will run until the dataset is exhausted. This
2090 argument is not supported with array inputs.
2091 callbacks: List of `keras.callbacks.Callback` instances. List of
2092 callbacks to apply during evaluation. See
2093 [callbacks](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks).
2094 max_queue_size: Integer. Used for generator or
2095 `keras.utils.Sequence` input only. Maximum size for the generator
2096 queue. If unspecified, `max_queue_size` will default to 10.
2097 workers: Integer. Used for generator or `keras.utils.Sequence` input
2098 only. Maximum number of processes to spin up when using
2099 process-based threading. If unspecified, `workers` will default to
2100 1.
2101 use_multiprocessing: Boolean. Used for generator or
2102 `keras.utils.Sequence` input only. If `True`, use process-based
2103 threading. If unspecified, `use_multiprocessing` will default to
2104 `False`. Note that because this implementation relies on
2105 multiprocessing, you should not pass non-picklable arguments to
2106 the generator as they can't be passed easily to children
2107 processes.
2108 return_dict: If `True`, loss and metric results are returned as a
2109 dict, with each key being the name of the metric. If `False`, they
2110 are returned as a list.
2111 **kwargs: Unused at this time.
2113 See the discussion of `Unpacking behavior for iterator-like inputs` for
2114 `Model.fit`.
2116 Returns:
2117 Scalar test loss (if the model has a single output and no metrics)
2118 or list of scalars (if the model has multiple outputs
2119 and/or metrics). The attribute `model.metrics_names` will give you
2120 the display labels for the scalar outputs.
2122 Raises:
2123 RuntimeError: If `model.evaluate` is wrapped in a `tf.function`.
2124 """
2125 base_layer.keras_api_gauge.get_cell("evaluate").set(True)
2126 version_utils.disallow_legacy_graph("Model", "evaluate")
2127 self._assert_compile_was_called()
2128 self._check_call_args("evaluate")
2129 self._check_sample_weight_warning(x, sample_weight)
2130 _disallow_inside_tf_function("evaluate")
2131 use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
2132 if kwargs:
2133 raise TypeError(f"Invalid keyword arguments: {list(kwargs.keys())}")
2135 if self.distribute_strategy._should_use_with_coordinator:
2136 self._cluster_coordinator = (
2137 tf.distribute.experimental.coordinator.ClusterCoordinator(
2138 self.distribute_strategy
2139 )
2140 )
2142 verbose = _get_verbosity(verbose, self.distribute_strategy)
2143 if self._pss_evaluation_shards:
2144 self._disallow_exact_eval_with_add_metrics()
2145 with self.distribute_strategy.scope():
2146 # Use cached evaluation data only when it's called in `Model.fit`
2147 if (
2148 use_cached_eval_dataset
2149 and getattr(self, "_eval_data_handler", None) is not None
2150 ):
2151 data_handler = self._eval_data_handler
2152 else:
2153 # Creates a `tf.data.Dataset` and handles batch and epoch
2154 # iteration.
2155 data_handler = data_adapter.get_data_handler(
2156 x=x,
2157 y=y,
2158 sample_weight=sample_weight,
2159 batch_size=batch_size,
2160 steps_per_epoch=steps,
2161 initial_epoch=0,
2162 epochs=1,
2163 max_queue_size=max_queue_size,
2164 workers=workers,
2165 use_multiprocessing=use_multiprocessing,
2166 model=self,
2167 steps_per_execution=self._steps_per_execution,
2168 pss_evaluation_shards=self._pss_evaluation_shards,
2169 )
2171 # Container that configures and calls `tf.keras.Callback`s.
2172 if not isinstance(callbacks, callbacks_module.CallbackList):
2173 callbacks = callbacks_module.CallbackList(
2174 callbacks,
2175 add_history=True,
2176 add_progbar=verbose != 0,
2177 model=self,
2178 verbose=verbose,
2179 epochs=1,
2180 steps=data_handler.inferred_steps,
2181 )
2183 # Initialize to prevent errors if 0 epochs are evaluated.
2184 logs = {}
2186 test_function_runner = self._get_test_function_runner(callbacks)
2187 self._test_counter.assign(0)
2188 callbacks.on_test_begin()
2189 for (
2190 _,
2191 dataset_or_iterator,
2192 ) in data_handler.enumerate_epochs(): # Single epoch.
2193 self.reset_metrics()
2194 with data_handler.catch_stop_iteration():
2195 for step in data_handler.steps():
2196 with tf.profiler.experimental.Trace(
2197 "test", step_num=step, _r=1
2198 ):
2199 callbacks.on_test_batch_begin(step)
2200 logs = test_function_runner.run_step(
2201 dataset_or_iterator,
2202 data_handler,
2203 step,
2204 self._pss_evaluation_shards,
2205 )
2207 logs = tf_utils.sync_to_numpy_or_python_type(logs)
2208 # Override with model metrics instead of last step logs
2209 if self._pss_evaluation_shards:
2210 logs = self._aggregate_exact_metrics(logs)
2211 else:
2212 logs = self._validate_and_get_metrics_result(logs)
2213 callbacks.on_test_end(logs=logs)
2215 if return_dict:
2216 return logs
2217 else:
2218 return flatten_metrics_in_order(logs, self.metrics_names)
2220 def _disallow_exact_eval_with_add_metrics(self):
2221 metrics_from_add_metric = [
2222 metric
2223 for layer in self._flatten_layers()
2224 for metric in layer._metrics
2225 ]
2226 compiled_metrics = self.compiled_metrics.metrics
2227 if any(
2228 [
2229 metric not in compiled_metrics
2230 for metric in metrics_from_add_metric
2231 ]
2232 ):
2233 raise ValueError(
2234 "Detected that a metric was added to this model "
2235 "via `Model.add_metric`. This is not currently "
2236 "supported when using exact evaluation with "
2237 "`tf.distribute.ParameterServerStrategy`."
2238 )
2240 def _infer_exact_eval_shards(self, pss_evaluation_shards):
2241 if not self.distribute_strategy._should_use_with_coordinator:
2242 return 0
2243 if pss_evaluation_shards == "auto":
2244 # TODO(b/264265138) evaluate and improve this heuristic
2245 return self.distribute_strategy._num_workers * 5
2246 return pss_evaluation_shards
2248 def _get_test_function_runner(self, callbacks):
2249 if (
2250 self._pss_evaluation_shards
2251 and self.distribute_strategy._should_use_with_coordinator
2252 ):
2253 self.test_function = self._make_test_function_exact()
2254 test_function_runner = _ExactTestFunction(
2255 self.test_function, callbacks
2256 )
2257 else:
2258 self.test_function = self.make_test_function()
2259 test_function_runner = _TestFunction(self.test_function, callbacks)
2260 return test_function_runner
2262 def predict_step(self, data):
2263 """The logic for one inference step.
2265 This method can be overridden to support custom inference logic.
2266 This method is called by `Model.make_predict_function`.
2268 This method should contain the mathematical logic for one step of
2269 inference. This typically includes the forward pass.
2271 Configuration details for *how* this logic is run (e.g. `tf.function`
2272 and `tf.distribute.Strategy` settings), should be left to
2273 `Model.make_predict_function`, which can also be overridden.
2275 Args:
2276 data: A nested structure of `Tensor`s.
2278 Returns:
2279 The result of one inference step, typically the output of calling the
2280 `Model` on data.
2281 """
2282 x, _, _ = data_adapter.unpack_x_y_sample_weight(data)
2283 return self(x, training=False)
2285 def make_predict_function(self, force=False):
2286 """Creates a function that executes one step of inference.
2288 This method can be overridden to support custom inference logic.
2289 This method is called by `Model.predict` and `Model.predict_on_batch`.
2291 Typically, this method directly controls `tf.function` and
2292 `tf.distribute.Strategy` settings, and delegates the actual evaluation
2293 logic to `Model.predict_step`.
2295 This function is cached the first time `Model.predict` or
2296 `Model.predict_on_batch` is called. The cache is cleared whenever
2297 `Model.compile` is called. You can skip the cache and generate again the
2298 function with `force=True`.
2300 Args:
2301 force: Whether to regenerate the predict function and skip the cached
2302 function if available.
2304 Returns:
2305 Function. The function created by this method should accept a
2306 `tf.data.Iterator`, and return the outputs of the `Model`.
2307 """
2308 if self.predict_function is not None and not force:
2309 return self.predict_function
2311 def step_function(model, iterator):
2312 """Runs a single evaluation step."""
2314 def run_step(data):
2315 outputs = model.predict_step(data)
2316 # Ensure counter is updated only if `test_step` succeeds.
2317 with tf.control_dependencies(_minimum_control_deps(outputs)):
2318 model._predict_counter.assign_add(1)
2319 return outputs
2321 if self.jit_compile:
2322 run_step = tf.function(
2323 run_step, jit_compile=True, reduce_retracing=True
2324 )
2326 data = next(iterator)
2327 outputs = model.distribute_strategy.run(run_step, args=(data,))
2328 outputs = reduce_per_replica(
2329 outputs, self.distribute_strategy, reduction="concat"
2330 )
2331 return outputs
2333 # Special case if steps_per_execution is one.
2334 if (
2335 self._steps_per_execution is None
2336 or self._steps_per_execution.numpy().item() == 1
2337 ):
2339 def predict_function(iterator):
2340 """Runs an evaluation execution with a single step."""
2341 return step_function(self, iterator)
2343 else:
2345 def predict_function(iterator):
2346 """Runs an evaluation execution with multiple steps."""
2347 outputs = step_function(self, iterator)
2348 for _ in tf.range(self._steps_per_execution - 1):
2349 tf.autograph.experimental.set_loop_options(
2350 shape_invariants=[
2351 (
2352 outputs,
2353 tf.nest.map_structure(
2354 lambda t: tf_utils.get_tensor_spec(
2355 t, dynamic_batch=True
2356 ).shape,
2357 outputs,
2358 ),
2359 )
2360 ]
2361 )
2362 step_outputs = step_function(self, iterator)
2363 outputs = tf.nest.map_structure(
2364 lambda t1, t2: concat([t1, t2]), outputs, step_outputs
2365 )
2366 return outputs
2368 if not self.run_eagerly:
2369 predict_function = tf.function(
2370 predict_function, reduce_retracing=True
2371 )
2372 self.predict_function = predict_function
2374 return self.predict_function
2376 @traceback_utils.filter_traceback
2377 def predict(
2378 self,
2379 x,
2380 batch_size=None,
2381 verbose="auto",
2382 steps=None,
2383 callbacks=None,
2384 max_queue_size=10,
2385 workers=1,
2386 use_multiprocessing=False,
2387 ):
2388 """Generates output predictions for the input samples.
2390 Computation is done in batches. This method is designed for batch
2391 processing of large numbers of inputs. It is not intended for use inside
2392 of loops that iterate over your data and process small numbers of inputs
2393 at a time.
2395 For small numbers of inputs that fit in one batch,
2396 directly use `__call__()` for faster execution, e.g.,
2397 `model(x)`, or `model(x, training=False)` if you have layers such as
2398 `tf.keras.layers.BatchNormalization` that behave differently during
2399 inference. You may pair the individual model call with a `tf.function`
2400 for additional performance inside your inner loop.
2401 If you need access to numpy array values instead of tensors after your
2402 model call, you can use `tensor.numpy()` to get the numpy array value of
2403 an eager tensor.
2405 Also, note the fact that test loss is not affected by
2406 regularization layers like noise and dropout.
2408 Note: See [this FAQ entry](
2409 https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call)
2410 for more details about the difference between `Model` methods
2411 `predict()` and `__call__()`.
2413 Args:
2414 x: Input samples. It could be:
2415 - A Numpy array (or array-like), or a list of arrays
2416 (in case the model has multiple inputs).
2417 - A TensorFlow tensor, or a list of tensors
2418 (in case the model has multiple inputs).
2419 - A `tf.data` dataset.
2420 - A generator or `keras.utils.Sequence` instance.
2421 A more detailed description of unpacking behavior for iterator
2422 types (Dataset, generator, Sequence) is given in the `Unpacking
2423 behavior for iterator-like inputs` section of `Model.fit`.
2424 batch_size: Integer or `None`.
2425 Number of samples per batch.
2426 If unspecified, `batch_size` will default to 32.
2427 Do not specify the `batch_size` if your data is in the
2428 form of dataset, generators, or `keras.utils.Sequence` instances
2429 (since they generate batches).
2430 verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
2431 0 = silent, 1 = progress bar, 2 = single line.
2432 `"auto"` becomes 1 for most cases, and to 2 when used with
2433 `ParameterServerStrategy`. Note that the progress bar is not
2434 particularly useful when logged to a file, so `verbose=2` is
2435 recommended when not running interactively (e.g. in a production
2436 environment). Defaults to 'auto'.
2437 steps: Total number of steps (batches of samples)
2438 before declaring the prediction round finished.
2439 Ignored with the default value of `None`. If x is a `tf.data`
2440 dataset and `steps` is None, `predict()` will
2441 run until the input dataset is exhausted.
2442 callbacks: List of `keras.callbacks.Callback` instances.
2443 List of callbacks to apply during prediction.
2444 See [callbacks](
2445 https://www.tensorflow.org/api_docs/python/tf/keras/callbacks).
2446 max_queue_size: Integer. Used for generator or
2447 `keras.utils.Sequence` input only. Maximum size for the
2448 generator queue. If unspecified, `max_queue_size` will default
2449 to 10.
2450 workers: Integer. Used for generator or `keras.utils.Sequence` input
2451 only. Maximum number of processes to spin up when using
2452 process-based threading. If unspecified, `workers` will default
2453 to 1.
2454 use_multiprocessing: Boolean. Used for generator or
2455 `keras.utils.Sequence` input only. If `True`, use process-based
2456 threading. If unspecified, `use_multiprocessing` will default to
2457 `False`. Note that because this implementation relies on
2458 multiprocessing, you should not pass non-picklable arguments to
2459 the generator as they can't be passed easily to children
2460 processes.
2462 See the discussion of `Unpacking behavior for iterator-like inputs` for
2463 `Model.fit`. Note that Model.predict uses the same interpretation rules
2464 as `Model.fit` and `Model.evaluate`, so inputs must be unambiguous for
2465 all three methods.
2467 Returns:
2468 Numpy array(s) of predictions.
2470 Raises:
2471 RuntimeError: If `model.predict` is wrapped in a `tf.function`.
2472 ValueError: In case of mismatch between the provided
2473 input data and the model's expectations,
2474 or in case a stateful model receives a number of samples
2475 that is not a multiple of the batch size.
2476 """
2477 base_layer.keras_api_gauge.get_cell("predict").set(True)
2478 version_utils.disallow_legacy_graph("Model", "predict")
2479 self._check_call_args("predict")
2480 _disallow_inside_tf_function("predict")
2482 # TODO(yashkatariya): Cache model on the coordinator for faster
2483 # prediction. If running under PSS, then swap it with OneDeviceStrategy
2484 # so that execution will run on the coordinator.
2485 original_pss_strategy = None
2486 if self.distribute_strategy._should_use_with_coordinator:
2487 original_pss_strategy = self.distribute_strategy
2488 self._distribution_strategy = None
2490 # Cluster coordinator is set by `.fit()` and `.evaluate()` which is not
2491 # needed in `.predict()` because all the predictions happen on the
2492 # coordinator/locally.
2493 if self._cluster_coordinator:
2494 self._cluster_coordinator = None
2496 verbose = _get_verbosity(verbose, self.distribute_strategy)
2497 outputs = None
2498 with self.distribute_strategy.scope():
2499 # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
2500 dataset_types = (tf.compat.v1.data.Dataset, tf.data.Dataset)
2501 if (
2502 self._in_multi_worker_mode()
2503 or _is_tpu_multi_host(self.distribute_strategy)
2504 ) and isinstance(x, dataset_types):
2505 try:
2506 options = tf.data.Options()
2507 data_option = tf.data.experimental.AutoShardPolicy.DATA
2508 options.experimental_distribute.auto_shard_policy = (
2509 data_option
2510 )
2511 x = x.with_options(options)
2512 except ValueError:
2513 warnings.warn(
2514 "Using Model.predict with MultiWorkerMirroredStrategy "
2515 "or TPUStrategy and AutoShardPolicy.FILE might lead to "
2516 "out-of-order result. Consider setting it to "
2517 "AutoShardPolicy.DATA.",
2518 stacklevel=2,
2519 )
2521 data_handler = data_adapter.get_data_handler(
2522 x=x,
2523 batch_size=batch_size,
2524 steps_per_epoch=steps,
2525 initial_epoch=0,
2526 epochs=1,
2527 max_queue_size=max_queue_size,
2528 workers=workers,
2529 use_multiprocessing=use_multiprocessing,
2530 model=self,
2531 steps_per_execution=self._steps_per_execution,
2532 )
2534 # Container that configures and calls `tf.keras.Callback`s.
2535 if not isinstance(callbacks, callbacks_module.CallbackList):
2536 callbacks = callbacks_module.CallbackList(
2537 callbacks,
2538 add_history=True,
2539 add_progbar=verbose != 0,
2540 model=self,
2541 verbose=verbose,
2542 epochs=1,
2543 steps=data_handler.inferred_steps,
2544 )
2546 self.predict_function = self.make_predict_function()
2547 self._predict_counter.assign(0)
2548 callbacks.on_predict_begin()
2549 batch_outputs = None
2550 for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
2551 with data_handler.catch_stop_iteration():
2552 for step in data_handler.steps():
2553 callbacks.on_predict_batch_begin(step)
2554 tmp_batch_outputs = self.predict_function(iterator)
2555 if data_handler.should_sync:
2556 context.async_wait()
2557 batch_outputs = (
2558 tmp_batch_outputs # No error, now safe to assign.
2559 )
2560 if outputs is None:
2561 outputs = tf.nest.map_structure(
2562 lambda batch_output: [batch_output],
2563 batch_outputs,
2564 )
2565 else:
2566 tf.__internal__.nest.map_structure_up_to(
2567 batch_outputs,
2568 lambda output, batch_output: output.append(
2569 batch_output
2570 ),
2571 outputs,
2572 batch_outputs,
2573 )
2574 end_step = step + data_handler.step_increment
2575 callbacks.on_predict_batch_end(
2576 end_step, {"outputs": batch_outputs}
2577 )
2578 if batch_outputs is None:
2579 raise ValueError(
2580 "Unexpected result of `predict_function` "
2581 "(Empty batch_outputs). Please use "
2582 "`Model.compile(..., run_eagerly=True)`, or "
2583 "`tf.config.run_functions_eagerly(True)` for more "
2584 "information of where went wrong, or file a "
2585 "issue/bug to `tf.keras`."
2586 )
2587 callbacks.on_predict_end()
2588 all_outputs = tf.__internal__.nest.map_structure_up_to(
2589 batch_outputs, potentially_ragged_concat, outputs
2590 )
2592 # If originally PSS strategy was used, then replace it back since
2593 # predict is running under `OneDeviceStrategy` after the swap and once
2594 # its done we need to replace it back to PSS again.
2595 if original_pss_strategy is not None:
2596 self._distribution_strategy = original_pss_strategy
2598 return tf_utils.sync_to_numpy_or_python_type(all_outputs)
2600 def reset_metrics(self):
2601 """Resets the state of all the metrics in the model.
2603 Examples:
2605 >>> inputs = tf.keras.layers.Input(shape=(3,))
2606 >>> outputs = tf.keras.layers.Dense(2)(inputs)
2607 >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
2608 >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
2610 >>> x = np.random.random((2, 3))
2611 >>> y = np.random.randint(0, 2, (2, 2))
2612 >>> _ = model.fit(x, y, verbose=0)
2613 >>> assert all(float(m.result()) for m in model.metrics)
2615 >>> model.reset_metrics()
2616 >>> assert all(float(m.result()) == 0 for m in model.metrics)
2618 """
2619 for m in self.metrics:
2620 m.reset_state()
2622 def train_on_batch(
2623 self,
2624 x,
2625 y=None,
2626 sample_weight=None,
2627 class_weight=None,
2628 reset_metrics=True,
2629 return_dict=False,
2630 ):
2631 """Runs a single gradient update on a single batch of data.
2633 Args:
2634 x: Input data. It could be:
2635 - A Numpy array (or array-like), or a list of arrays
2636 (in case the model has multiple inputs).
2637 - A TensorFlow tensor, or a list of tensors
2638 (in case the model has multiple inputs).
2639 - A dict mapping input names to the corresponding array/tensors,
2640 if the model has named inputs.
2641 y: Target data. Like the input data `x`, it could be either Numpy
2642 array(s) or TensorFlow tensor(s).
2643 sample_weight: Optional array of the same length as x, containing
2644 weights to apply to the model's loss for each sample. In the case
2645 of temporal data, you can pass a 2D array with shape (samples,
2646 sequence_length), to apply a different weight to every timestep of
2647 every sample.
2648 class_weight: Optional dictionary mapping class indices (integers)
2649 to a weight (float) to apply to the model's loss for the samples
2650 from this class during training. This can be useful to tell the
2651 model to "pay more attention" to samples from an under-represented
2652 class. When `class_weight` is specified and targets have a rank of
2653 2 or greater, either `y` must be one-hot encoded, or an explicit
2654 final dimension of `1` must be included for sparse class labels.
2655 reset_metrics: If `True`, the metrics returned will be only for this
2656 batch. If `False`, the metrics will be statefully accumulated
2657 across batches.
2658 return_dict: If `True`, loss and metric results are returned as a
2659 dict, with each key being the name of the metric. If `False`, they
2660 are returned as a list.
2662 Returns:
2663 Scalar training loss
2664 (if the model has a single output and no metrics)
2665 or list of scalars (if the model has multiple outputs
2666 and/or metrics). The attribute `model.metrics_names` will give you
2667 the display labels for the scalar outputs.
2669 Raises:
2670 RuntimeError: If `model.train_on_batch` is wrapped in a `tf.function`.
2671 """
2672 self._assert_compile_was_called()
2673 self._check_call_args("train_on_batch")
2674 _disallow_inside_tf_function("train_on_batch")
2675 if reset_metrics:
2676 self.reset_metrics()
2677 with self.distribute_strategy.scope(), training_utils.RespectCompiledTrainableState( # noqa: E501
2678 self
2679 ):
2680 iterator = data_adapter.single_batch_iterator(
2681 self.distribute_strategy, x, y, sample_weight, class_weight
2682 )
2683 self.train_function = self.make_train_function()
2684 logs = self.train_function(iterator)
2686 logs = tf_utils.sync_to_numpy_or_python_type(logs)
2687 if return_dict:
2688 return logs
2689 else:
2690 return flatten_metrics_in_order(logs, self.metrics_names)
2692 def test_on_batch(
2693 self,
2694 x,
2695 y=None,
2696 sample_weight=None,
2697 reset_metrics=True,
2698 return_dict=False,
2699 ):
2700 """Test the model on a single batch of samples.
2702 Args:
2703 x: Input data. It could be:
2704 - A Numpy array (or array-like), or a list of arrays (in case the
2705 model has multiple inputs).
2706 - A TensorFlow tensor, or a list of tensors (in case the model has
2707 multiple inputs).
2708 - A dict mapping input names to the corresponding array/tensors,
2709 if the model has named inputs.
2710 y: Target data. Like the input data `x`, it could be either Numpy
2711 array(s) or TensorFlow tensor(s). It should be consistent with `x`
2712 (you cannot have Numpy inputs and tensor targets, or inversely).
2713 sample_weight: Optional array of the same length as x, containing
2714 weights to apply to the model's loss for each sample. In the case
2715 of temporal data, you can pass a 2D array with shape (samples,
2716 sequence_length), to apply a different weight to every timestep of
2717 every sample.
2718 reset_metrics: If `True`, the metrics returned will be only for this
2719 batch. If `False`, the metrics will be statefully accumulated
2720 across batches.
2721 return_dict: If `True`, loss and metric results are returned as a
2722 dict, with each key being the name of the metric. If `False`, they
2723 are returned as a list.
2725 Returns:
2726 Scalar test loss (if the model has a single output and no metrics)
2727 or list of scalars (if the model has multiple outputs
2728 and/or metrics). The attribute `model.metrics_names` will give you
2729 the display labels for the scalar outputs.
2731 Raises:
2732 RuntimeError: If `model.test_on_batch` is wrapped in a
2733 `tf.function`.
2734 """
2735 self._assert_compile_was_called()
2736 self._check_call_args("test_on_batch")
2737 _disallow_inside_tf_function("test_on_batch")
2738 if reset_metrics:
2739 self.reset_metrics()
2740 with self.distribute_strategy.scope():
2741 iterator = data_adapter.single_batch_iterator(
2742 self.distribute_strategy, x, y, sample_weight
2743 )
2744 self.test_function = self.make_test_function()
2745 logs = self.test_function(iterator)
2747 logs = tf_utils.sync_to_numpy_or_python_type(logs)
2748 if return_dict:
2749 return logs
2750 else:
2751 return flatten_metrics_in_order(logs, self.metrics_names)
2753 def predict_on_batch(self, x):
2754 """Returns predictions for a single batch of samples.
2756 Args:
2757 x: Input data. It could be:
2758 - A Numpy array (or array-like), or a list of arrays (in case the
2759 model has multiple inputs).
2760 - A TensorFlow tensor, or a list of tensors (in case the model has
2761 multiple inputs).
2763 Returns:
2764 Numpy array(s) of predictions.
2766 Raises:
2767 RuntimeError: If `model.predict_on_batch` is wrapped in a
2768 `tf.function`.
2769 """
2770 self._check_call_args("predict_on_batch")
2771 _disallow_inside_tf_function("predict_on_batch")
2772 with self.distribute_strategy.scope():
2773 iterator = data_adapter.single_batch_iterator(
2774 self.distribute_strategy, x
2775 )
2776 self.predict_function = self.make_predict_function()
2777 outputs = self.predict_function(iterator)
2778 return tf_utils.sync_to_numpy_or_python_type(outputs)
2780 @doc_controls.do_not_generate_docs
2781 def fit_generator(
2782 self,
2783 generator,
2784 steps_per_epoch=None,
2785 epochs=1,
2786 verbose=1,
2787 callbacks=None,
2788 validation_data=None,
2789 validation_steps=None,
2790 validation_freq=1,
2791 class_weight=None,
2792 max_queue_size=10,
2793 workers=1,
2794 use_multiprocessing=False,
2795 shuffle=True,
2796 initial_epoch=0,
2797 ):
2798 """Fits the model on data yielded batch-by-batch by a Python generator.
2800 DEPRECATED:
2801 `Model.fit` now supports generators, so there is no longer any need to
2802 use this endpoint.
2803 """
2804 warnings.warn(
2805 "`Model.fit_generator` is deprecated and "
2806 "will be removed in a future version. "
2807 "Please use `Model.fit`, which supports generators.",
2808 stacklevel=2,
2809 )
2810 return self.fit(
2811 generator,
2812 steps_per_epoch=steps_per_epoch,
2813 epochs=epochs,
2814 verbose=verbose,
2815 callbacks=callbacks,
2816 validation_data=validation_data,
2817 validation_steps=validation_steps,
2818 validation_freq=validation_freq,
2819 class_weight=class_weight,
2820 max_queue_size=max_queue_size,
2821 workers=workers,
2822 use_multiprocessing=use_multiprocessing,
2823 shuffle=shuffle,
2824 initial_epoch=initial_epoch,
2825 )
2827 @doc_controls.do_not_generate_docs
2828 def evaluate_generator(
2829 self,
2830 generator,
2831 steps=None,
2832 callbacks=None,
2833 max_queue_size=10,
2834 workers=1,
2835 use_multiprocessing=False,
2836 verbose=0,
2837 ):
2838 """Evaluates the model on a data generator.
2840 DEPRECATED:
2841 `Model.evaluate` now supports generators, so there is no longer any
2842 need to use this endpoint.
2843 """
2844 warnings.warn(
2845 "`Model.evaluate_generator` is deprecated and "
2846 "will be removed in a future version. "
2847 "Please use `Model.evaluate`, which supports generators.",
2848 stacklevel=2,
2849 )
2850 self._check_call_args("evaluate_generator")
2852 return self.evaluate(
2853 generator,
2854 steps=steps,
2855 max_queue_size=max_queue_size,
2856 workers=workers,
2857 use_multiprocessing=use_multiprocessing,
2858 verbose=verbose,
2859 callbacks=callbacks,
2860 )
2862 @doc_controls.do_not_generate_docs
2863 def predict_generator(
2864 self,
2865 generator,
2866 steps=None,
2867 callbacks=None,
2868 max_queue_size=10,
2869 workers=1,
2870 use_multiprocessing=False,
2871 verbose=0,
2872 ):
2873 """Generates predictions for the input samples from a data generator.
2875 DEPRECATED:
2876 `Model.predict` now supports generators, so there is no longer any
2877 need to use this endpoint.
2878 """
2879 warnings.warn(
2880 "`Model.predict_generator` is deprecated and "
2881 "will be removed in a future version. "
2882 "Please use `Model.predict`, which supports generators.",
2883 stacklevel=2,
2884 )
2885 return self.predict(
2886 generator,
2887 steps=steps,
2888 max_queue_size=max_queue_size,
2889 workers=workers,
2890 use_multiprocessing=use_multiprocessing,
2891 verbose=verbose,
2892 callbacks=callbacks,
2893 )
2895 ######################################################################
2896 # Functions below are not training related. They are for model weights
2897 # tracking, save/load, serialization, etc.
2898 ######################################################################
2900 @property
2901 def trainable_weights(self):
2902 self._assert_weights_created()
2903 if not self._trainable:
2904 return []
2905 trainable_variables = []
2906 for trackable_obj in self._self_tracked_trackables:
2907 trainable_variables += trackable_obj.trainable_variables
2908 trainable_variables += self._trainable_weights
2909 return self._dedup_weights(trainable_variables)
2911 @property
2912 def non_trainable_weights(self):
2913 self._assert_weights_created()
2914 non_trainable_variables = []
2915 for trackable_obj in self._self_tracked_trackables:
2916 non_trainable_variables += trackable_obj.non_trainable_variables
2918 if not self._trainable:
2919 # Return order is all trainable vars, then all non-trainable vars.
2920 trainable_variables = []
2921 for trackable_obj in self._self_tracked_trackables:
2922 trainable_variables += trackable_obj.trainable_variables
2924 non_trainable_variables = (
2925 trainable_variables
2926 + self._trainable_weights
2927 + non_trainable_variables
2928 + self._non_trainable_weights
2929 )
2930 else:
2931 non_trainable_variables = (
2932 non_trainable_variables + self._non_trainable_weights
2933 )
2935 return self._dedup_weights(non_trainable_variables)
2937 def get_weights(self):
2938 """Retrieves the weights of the model.
2940 Returns:
2941 A flat list of Numpy arrays.
2942 """
2943 with self.distribute_strategy.scope():
2944 return super().get_weights()
2946 @traceback_utils.filter_traceback
2947 def save(self, filepath, overwrite=True, save_format=None, **kwargs):
2948 """Saves a model as a TensorFlow SavedModel or HDF5 file.
2950 See the [Serialization and Saving guide](
2951 https://keras.io/guides/serialization_and_saving/) for details.
2953 Args:
2954 model: Keras model instance to be saved.
2955 filepath: `str` or `pathlib.Path` object. Path where to save the
2956 model.
2957 overwrite: Whether we should overwrite any existing model at the
2958 target location, or instead ask the user via an interactive
2959 prompt.
2960 save_format: Either `"keras"`, `"tf"`, `"h5"`,
2961 indicating whether to save the model
2962 in the native Keras format (`.keras`),
2963 in the TensorFlow SavedModel format
2964 (referred to as "SavedModel" below),
2965 or in the legacy HDF5 format (`.h5`).
2966 Defaults to `"tf"` in TF 2.X, and `"h5"` in TF 1.X.
2968 SavedModel format arguments:
2969 include_optimizer: Only applied to SavedModel and legacy HDF5
2970 formats. If False, do not save the optimizer state.
2971 Defaults to `True`.
2972 signatures: Only applies to SavedModel format. Signatures to save
2973 with the SavedModel. See the `signatures` argument in
2974 `tf.saved_model.save` for details.
2975 options: Only applies to SavedModel format.
2976 `tf.saved_model.SaveOptions` object that specifies SavedModel
2977 saving options.
2978 save_traces: Only applies to SavedModel format. When enabled, the
2979 SavedModel will store the function traces for each layer. This
2980 can be disabled, so that only the configs of each layer are
2981 stored. Defaults to `True`.
2982 Disabling this will decrease serialization time
2983 and reduce file size, but it requires that all custom
2984 layers/models implement a `get_config()` method.
2986 Example:
2988 ```python
2989 model = tf.keras.Sequential([
2990 tf.keras.layers.Dense(5, input_shape=(3,)),
2991 tf.keras.layers.Softmax()])
2992 model.save("model.keras")
2993 loaded_model = tf.keras.models.load_model("model.keras")
2994 x = tf.random.uniform((10, 3))
2995 assert np.allclose(model.predict(x), loaded_model.predict(x))
2996 ```
2998 Note that `model.save()` is an alias for `tf.keras.models.save_model()`.
2999 """
3000 saving_api.save_model(
3001 self,
3002 filepath=filepath,
3003 overwrite=overwrite,
3004 save_format=save_format,
3005 **kwargs,
3006 )
3008 @traceback_utils.filter_traceback
3009 def save_weights(
3010 self, filepath, overwrite=True, save_format=None, options=None
3011 ):
3012 """Saves all layer weights.
3014 Either saves in HDF5 or in TensorFlow format based on the `save_format`
3015 argument.
3017 When saving in HDF5 format, the weight file has:
3018 - `layer_names` (attribute), a list of strings
3019 (ordered names of model layers).
3020 - For every layer, a `group` named `layer.name`
3021 - For every such layer group, a group attribute `weight_names`,
3022 a list of strings
3023 (ordered names of weights tensor of the layer).
3024 - For every weight in the layer, a dataset
3025 storing the weight value, named after the weight tensor.
3027 When saving in TensorFlow format, all objects referenced by the network
3028 are saved in the same format as `tf.train.Checkpoint`, including any
3029 `Layer` instances or `Optimizer` instances assigned to object
3030 attributes. For networks constructed from inputs and outputs using
3031 `tf.keras.Model(inputs, outputs)`, `Layer` instances used by the network
3032 are tracked/saved automatically. For user-defined classes which inherit
3033 from `tf.keras.Model`, `Layer` instances must be assigned to object
3034 attributes, typically in the constructor. See the documentation of
3035 `tf.train.Checkpoint` and `tf.keras.Model` for details.
3037 While the formats are the same, do not mix `save_weights` and
3038 `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should
3039 be loaded using `Model.load_weights`. Checkpoints saved using
3040 `tf.train.Checkpoint.save` should be restored using the corresponding
3041 `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over
3042 `save_weights` for training checkpoints.
3044 The TensorFlow format matches objects and variables by starting at a
3045 root object, `self` for `save_weights`, and greedily matching attribute
3046 names. For `Model.save` this is the `Model`, and for `Checkpoint.save`
3047 this is the `Checkpoint` even if the `Checkpoint` has a model attached.
3048 This means saving a `tf.keras.Model` using `save_weights` and loading
3049 into a `tf.train.Checkpoint` with a `Model` attached (or vice versa)
3050 will not match the `Model`'s variables. See the
3051 [guide to training checkpoints](
3052 https://www.tensorflow.org/guide/checkpoint) for details on
3053 the TensorFlow format.
3055 Args:
3056 filepath: String or PathLike, path to the file to save the weights
3057 to. When saving in TensorFlow format, this is the prefix used
3058 for checkpoint files (multiple files are generated). Note that
3059 the '.h5' suffix causes weights to be saved in HDF5 format.
3060 overwrite: Whether to silently overwrite any existing file at the
3061 target location, or provide the user with a manual prompt.
3062 save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
3063 '.keras' will default to HDF5 if `save_format` is `None`.
3064 Otherwise, `None` becomes 'tf'. Defaults to `None`.
3065 options: Optional `tf.train.CheckpointOptions` object that specifies
3066 options for saving weights.
3068 Raises:
3069 ImportError: If `h5py` is not available when attempting to save in
3070 HDF5 format.
3071 """
3072 saving_api.save_weights(
3073 self,
3074 filepath=filepath,
3075 overwrite=overwrite,
3076 save_format=save_format,
3077 options=options,
3078 )
3080 @traceback_utils.filter_traceback
3081 def load_weights(
3082 self, filepath, skip_mismatch=False, by_name=False, options=None
3083 ):
3084 """Loads all layer weights from a saved files.
3086 The saved file could be a SavedModel file, a `.keras` file (v3 saving
3087 format), or a file created via `model.save_weights()`.
3089 By default, weights are loaded based on the network's
3090 topology. This means the architecture should be the same as when the
3091 weights were saved. Note that layers that don't have weights are not
3092 taken into account in the topological ordering, so adding or removing
3093 layers is fine as long as they don't have weights.
3095 **Partial weight loading**
3097 If you have modified your model, for instance by adding a new layer
3098 (with weights) or by changing the shape of the weights of a layer,
3099 you can choose to ignore errors and continue loading
3100 by setting `skip_mismatch=True`. In this case any layer with
3101 mismatching weights will be skipped. A warning will be displayed
3102 for each skipped layer.
3104 **Weight loading by name**
3106 If your weights are saved as a `.h5` file created
3107 via `model.save_weights()`, you can use the argument `by_name=True`.
3109 In this case, weights are loaded into layers only if they share
3110 the same name. This is useful for fine-tuning or transfer-learning
3111 models where some of the layers have changed.
3113 Note that only topological loading (`by_name=False`) is supported when
3114 loading weights from the `.keras` v3 format or from the TensorFlow
3115 SavedModel format.
3117 Args:
3118 filepath: String, path to the weights file to load. For weight files
3119 in TensorFlow format, this is the file prefix (the same as was
3120 passed to `save_weights()`). This can also be a path to a
3121 SavedModel or a `.keras` file (v3 saving format) saved
3122 via `model.save()`.
3123 skip_mismatch: Boolean, whether to skip loading of layers where
3124 there is a mismatch in the number of weights, or a mismatch in
3125 the shape of the weights.
3126 by_name: Boolean, whether to load weights by name or by topological
3127 order. Only topological loading is supported for weight files in
3128 the `.keras` v3 format or in the TensorFlow SavedModel format.
3129 options: Optional `tf.train.CheckpointOptions` object that specifies
3130 options for loading weights (only valid for a SavedModel file).
3131 """
3132 return saving_api.load_weights(
3133 self,
3134 filepath=filepath,
3135 by_name=by_name,
3136 skip_mismatch=skip_mismatch,
3137 options=options,
3138 )
3140 def _updated_config(self):
3141 """Util shared between different serialization methods.
3143 Returns:
3144 Model config with Keras version information added.
3145 """
3146 from keras.src import __version__ as keras_version
3148 config = self.get_config()
3149 model_config = {
3150 "class_name": self.__class__.__name__,
3151 "config": config,
3152 "keras_version": keras_version,
3153 "backend": backend.backend(),
3154 }
3155 return model_config
3157 @generic_utils.default
3158 def get_config(self):
3159 """Returns the config of the `Model`.
3161 Config is a Python dictionary (serializable) containing the
3162 configuration of an object, which in this case is a `Model`. This allows
3163 the `Model` to be be reinstantiated later (without its trained weights)
3164 from this configuration.
3166 Note that `get_config()` does not guarantee to return a fresh copy of
3167 dict every time it is called. The callers should make a copy of the
3168 returned dict if they want to modify it.
3170 Developers of subclassed `Model` are advised to override this method,
3171 and continue to update the dict from `super(MyModel, self).get_config()`
3172 to provide the proper configuration of this `Model`. The default config
3173 will return config dict for init parameters if they are basic types.
3174 Raises `NotImplementedError` when in cases where a custom
3175 `get_config()` implementation is required for the subclassed model.
3177 Returns:
3178 Python dictionary containing the configuration of this `Model`.
3179 """
3180 # If sublcass doesn't implement `get_config()` parse from init args
3181 # otherwise default to empty dict
3182 if generic_utils.is_default(self.get_config):
3183 try:
3184 config = base_layer.Layer.get_config(self)
3185 except NotImplementedError:
3186 config = {}
3187 logging.warning(
3188 "Model's `__init__()` arguments contain non-serializable "
3189 "objects. Please implement a `get_config()` method in the "
3190 "subclassed Model for proper saving and loading. "
3191 "Defaulting to empty config."
3192 )
3193 else:
3194 config = {}
3195 return config
3197 @classmethod
3198 def from_config(cls, config, custom_objects=None):
3199 # `from_config` assumes `cls` is either `Functional` or a child class of
3200 # `Functional`. In the case that `cls` is meant to behave like a child
3201 # class of `Functional` but only inherits from the `Model` class, we
3202 # have to call `cls(...)` instead of `Functional.from_config`.
3203 from keras.src.engine import functional
3205 with serialization.SharedObjectLoadingScope():
3206 functional_config_keys = [
3207 "name",
3208 "layers",
3209 "input_layers",
3210 "output_layers",
3211 ]
3212 is_functional_config = all(
3213 key in config for key in functional_config_keys
3214 )
3215 argspec = tf_inspect.getfullargspec(cls.__init__)
3216 functional_init_args = tf_inspect.getfullargspec(
3217 functional.Functional.__init__
3218 ).args[1:]
3219 revivable_as_functional = (
3220 cls in {functional.Functional, Model}
3221 or argspec.args[1:] == functional_init_args
3222 or (argspec.varargs == "args" and argspec.varkw == "kwargs")
3223 )
3224 if is_functional_config and revivable_as_functional:
3225 # Revive Functional model
3226 # (but not Functional subclasses with a custom __init__)
3227 inputs, outputs, layers = functional.reconstruct_from_config(
3228 config, custom_objects
3229 )
3230 model = cls(
3231 inputs=inputs, outputs=outputs, name=config.get("name")
3232 )
3233 functional.connect_ancillary_layers(model, layers)
3235 else:
3236 # Either the model has a custom __init__, or the config
3237 # does not contain all the information necessary to
3238 # revive a Functional model. This happens when the user creates
3239 # subclassed models where `get_config()` is returning
3240 # insufficient information to be considered a Functional model.
3241 # In this case, we fall back to provide all config into the
3242 # constructor of the class.
3243 try:
3244 model = cls(**config)
3245 except TypeError as e:
3246 raise TypeError(
3247 "Unable to revive model from config. When overriding "
3248 "the `get_config()` method, make sure that the "
3249 "returned config contains all items used as arguments "
3250 f"in the constructor to {cls}, "
3251 "which is the default behavior. "
3252 "You can override this default behavior by defining a "
3253 "`from_config(cls, config)` class method to specify "
3254 "how to create an "
3255 f"instance of {cls.__name__} from its config.\n\n"
3256 f"Received config={config}\n\n"
3257 f"Error encountered during deserialization: {e}"
3258 )
3259 return model
3261 def to_json(self, **kwargs):
3262 """Returns a JSON string containing the network configuration.
3264 To load a network from a JSON save file, use
3265 `keras.models.model_from_json(json_string, custom_objects={})`.
3267 Args:
3268 **kwargs: Additional keyword arguments to be passed to
3269 *`json.dumps()`.
3271 Returns:
3272 A JSON string.
3273 """
3274 model_config = self._updated_config()
3275 return json.dumps(
3276 model_config, default=json_utils.get_json_type, **kwargs
3277 )
3279 def to_yaml(self, **kwargs):
3280 """Returns a yaml string containing the network configuration.
3282 Note: Since TF 2.6, this method is no longer supported and will raise a
3283 RuntimeError.
3285 To load a network from a yaml save file, use
3286 `keras.models.model_from_yaml(yaml_string, custom_objects={})`.
3288 `custom_objects` should be a dictionary mapping
3289 the names of custom losses / layers / etc to the corresponding
3290 functions / classes.
3292 Args:
3293 **kwargs: Additional keyword arguments
3294 to be passed to `yaml.dump()`.
3296 Returns:
3297 A YAML string.
3299 Raises:
3300 RuntimeError: announces that the method poses a security risk
3301 """
3302 raise RuntimeError(
3303 "Method `model.to_yaml()` has been removed due to security risk of "
3304 "arbitrary code execution. Please use `model.to_json()` instead."
3305 )
3307 def reset_states(self):
3308 for layer in self.layers:
3309 if hasattr(layer, "reset_states") and getattr(
3310 layer, "stateful", False
3311 ):
3312 layer.reset_states()
3314 @property
3315 @doc_controls.do_not_generate_docs
3316 def state_updates(self):
3317 """Deprecated, do NOT use!
3319 Returns the `updates` from all layers that are stateful.
3321 This is useful for separating training updates and
3322 state updates, e.g. when we need to update a layer's internal state
3323 during prediction.
3325 Returns:
3326 A list of update ops.
3327 """
3328 warnings.warn(
3329 "`Model.state_updates` will be removed in a future version. "
3330 "This property should not be used in TensorFlow 2.0, "
3331 "as `updates` are applied automatically.",
3332 stacklevel=2,
3333 )
3334 state_updates = []
3335 for layer in self.layers:
3336 if getattr(layer, "stateful", False):
3337 if hasattr(layer, "updates"):
3338 state_updates += layer.updates
3339 return state_updates
3341 @property
3342 def weights(self):
3343 """Returns the list of all layer variables/weights.
3345 Note: This will not track the weights of nested `tf.Modules` that are
3346 not themselves Keras layers.
3348 Returns:
3349 A list of variables.
3350 """
3351 return self._dedup_weights(self._undeduplicated_weights)
3353 @property
3354 def _undeduplicated_weights(self):
3355 """Returns the undeduplicated list of all layer variables/weights."""
3356 self._assert_weights_created()
3357 weights = []
3358 for layer in self._self_tracked_trackables:
3359 weights += layer.variables
3360 weights += self._trainable_weights + self._non_trainable_weights
3361 return weights
3363 def summary(
3364 self,
3365 line_length=None,
3366 positions=None,
3367 print_fn=None,
3368 expand_nested=False,
3369 show_trainable=False,
3370 layer_range=None,
3371 ):
3372 """Prints a string summary of the network.
3374 Args:
3375 line_length: Total length of printed lines
3376 (e.g. set this to adapt the display to different
3377 terminal window sizes).
3378 positions: Relative or absolute positions of log elements
3379 in each line. If not provided, becomes
3380 `[0.3, 0.6, 0.70, 1.]`. Defaults to `None`.
3381 print_fn: Print function to use. By default, prints to `stdout`.
3382 If `stdout` doesn't work in your environment, change to `print`.
3383 It will be called on each line of the summary.
3384 You can set it to a custom function
3385 in order to capture the string summary.
3386 expand_nested: Whether to expand the nested models.
3387 Defaults to `False`.
3388 show_trainable: Whether to show if a layer is trainable.
3389 Defaults to `False`.
3390 layer_range: a list or tuple of 2 strings,
3391 which is the starting layer name and ending layer name
3392 (both inclusive) indicating the range of layers to be printed
3393 in summary. It also accepts regex patterns instead of exact
3394 name. In such case, start predicate will be the first element
3395 it matches to `layer_range[0]` and the end predicate will be
3396 the last element it matches to `layer_range[1]`.
3397 By default `None` which considers all layers of model.
3399 Raises:
3400 ValueError: if `summary()` is called before the model is built.
3401 """
3402 if not self.built:
3403 raise ValueError(
3404 "This model has not yet been built. "
3405 "Build the model first by calling `build()` or by calling "
3406 "the model on a batch of data."
3407 )
3408 layer_utils.print_summary(
3409 self,
3410 line_length=line_length,
3411 positions=positions,
3412 print_fn=print_fn,
3413 expand_nested=expand_nested,
3414 show_trainable=show_trainable,
3415 layer_range=layer_range,
3416 )
3418 @property
3419 def layers(self):
3420 return list(self._flatten_layers(include_self=False, recursive=False))
3422 @layers.setter
3423 def layers(self, _):
3424 raise AttributeError(
3425 "`Model.layers` attribute is reserved and should not be used. "
3426 "Please use another name."
3427 )
3429 def get_layer(self, name=None, index=None):
3430 """Retrieves a layer based on either its name (unique) or index.
3432 If `name` and `index` are both provided, `index` will take precedence.
3433 Indices are based on order of horizontal graph traversal (bottom-up).
3435 Args:
3436 name: String, name of layer.
3437 index: Integer, index of layer.
3439 Returns:
3440 A layer instance.
3441 """
3442 # TODO(fchollet): We could build a dictionary based on layer names
3443 # since they are constant, but we have not done that yet.
3444 if index is not None and name is not None:
3445 raise ValueError(
3446 "Provide only a layer name or a layer index. Received: "
3447 f"index={index}, name={name}."
3448 )
3450 if index is not None:
3451 if len(self.layers) <= index:
3452 raise ValueError(
3453 f"Was asked to retrieve layer at index {index}"
3454 f" but model only has {len(self.layers)}"
3455 " layers."
3456 )
3457 else:
3458 return self.layers[index]
3460 if name is not None:
3461 for layer in self.layers:
3462 if layer.name == name:
3463 return layer
3464 raise ValueError(
3465 f"No such layer: {name}. Existing layers are: "
3466 f"{list(layer.name for layer in self.layers)}."
3467 )
3468 raise ValueError(
3469 "Provide either a layer name or layer index at `get_layer`."
3470 )
3472 def get_weight_paths(self):
3473 """Retrieve all the variables and their paths for the model.
3475 The variable path (string) is a stable key to identify a `tf.Variable`
3476 instance owned by the model. It can be used to specify variable-specific
3477 configurations (e.g. DTensor, quantization) from a global view.
3479 This method returns a dict with weight object paths as keys
3480 and the corresponding `tf.Variable` instances as values.
3482 Note that if the model is a subclassed model and the weights haven't
3483 been initialized, an empty dict will be returned.
3485 Returns:
3486 A dict where keys are variable paths and values are `tf.Variable`
3487 instances.
3489 Example:
3491 ```python
3492 class SubclassModel(tf.keras.Model):
3494 def __init__(self, name=None):
3495 super().__init__(name=name)
3496 self.d1 = tf.keras.layers.Dense(10)
3497 self.d2 = tf.keras.layers.Dense(20)
3499 def call(self, inputs):
3500 x = self.d1(inputs)
3501 return self.d2(x)
3503 model = SubclassModel()
3504 model(tf.zeros((10, 10)))
3505 weight_paths = model.get_weight_paths()
3506 # weight_paths:
3507 # {
3508 # 'd1.kernel': model.d1.kernel,
3509 # 'd1.bias': model.d1.bias,
3510 # 'd2.kernel': model.d2.kernel,
3511 # 'd2.bias': model.d2.bias,
3512 # }
3514 # Functional model
3515 inputs = tf.keras.Input((10,), batch_size=10)
3516 x = tf.keras.layers.Dense(20, name='d1')(inputs)
3517 output = tf.keras.layers.Dense(30, name='d2')(x)
3518 model = tf.keras.Model(inputs, output)
3519 d1 = model.layers[1]
3520 d2 = model.layers[2]
3521 weight_paths = model.get_weight_paths()
3522 # weight_paths:
3523 # {
3524 # 'd1.kernel': d1.kernel,
3525 # 'd1.bias': d1.bias,
3526 # 'd2.kernel': d2.kernel,
3527 # 'd2.bias': d2.bias,
3528 # }
3529 ```
3530 """
3531 result = {}
3532 (
3533 descendants,
3534 object_paths_dict,
3535 ) = tf.__internal__.tracking.ObjectGraphView(
3536 self
3537 ).breadth_first_traversal()
3538 for descendant in descendants:
3539 if isinstance(descendant, tf.Variable):
3540 trackable_references = object_paths_dict[descendant]
3541 object_path = ".".join([t.name for t in trackable_references])
3542 result[object_path] = descendant
3543 return result
3545 def get_compile_config(self):
3546 """Returns a serialized config with information for compiling the model.
3548 This method returns a config dictionary containing all the information
3549 (optimizer, loss, metrics, etc.) with which the model was compiled.
3551 Returns:
3552 A dict containing information for compiling the model.
3553 """
3554 if self._is_compiled and hasattr(self, "_compile_config"):
3555 return self._compile_config.serialize()
3557 def compile_from_config(self, config):
3558 """Compiles the model with the information given in config.
3560 This method uses the information in the config (optimizer, loss,
3561 metrics, etc.) to compile the model.
3563 Args:
3564 config: Dict containing information for compiling the model.
3565 """
3566 has_overridden_compile = self.__class__.compile != Model.compile
3567 if has_overridden_compile:
3568 logging.warning(
3569 "`compile()` was not called as part of model loading "
3570 "because the model's `compile()` method is custom. "
3571 "All subclassed Models that have `compile()` "
3572 "overridden should also override "
3573 "`get_compile_config()` and `compile_from_config(config)`. "
3574 "Alternatively, you can "
3575 "call `compile()` manually after loading."
3576 )
3577 return
3578 config = saving_lib.deserialize_keras_object(config)
3579 self.compile(**config)
3580 if hasattr(self, "optimizer") and self.built:
3581 # Create optimizer variables.
3582 self.optimizer.build(self.trainable_variables)
3584 def export(self, filepath):
3585 """Create a SavedModel artifact for inference (e.g. via TF-Serving).
3587 This method lets you export a model to a lightweight SavedModel artifact
3588 that contains the model's forward pass only (its `call()` method)
3589 and can be served via e.g. TF-Serving. The forward pass is registered
3590 under the name `serve()` (see example below).
3592 The original code of the model (including any custom layers you may
3593 have used) is *no longer* necessary to reload the artifact -- it is
3594 entirely standalone.
3596 Args:
3597 filepath: `str` or `pathlib.Path` object. Path where to save
3598 the artifact.
3600 Example:
3602 ```python
3603 # Create the artifact
3604 model.export("path/to/location")
3606 # Later, in a different process / environment...
3607 reloaded_artifact = tf.saved_model.load("path/to/location")
3608 predictions = reloaded_artifact.serve(input_data)
3609 ```
3611 If you would like to customize your serving endpoints, you can
3612 use the lower-level `keras.export.ExportArchive` class. The `export()`
3613 method relies on `ExportArchive` internally.
3614 """
3615 from keras.src.export import export_lib
3617 export_lib.export_model(self, filepath)
3619 @tf.__internal__.tracking.no_automatic_dependency_tracking
3620 def _set_save_spec(self, inputs, args=None, kwargs=None):
3621 """Defines the save spec so that serialization can trace `call()`.
3623 The TensorSpecs of the call function `inputs`, `args`, and `kwargs` are
3624 saved into a tuple of `([inputs] + args, kwargs)`. The input
3625 `TensorSpec` names are updated to match the built `input_names`.
3627 The specs can be retrieved with the `save_spec` property.
3629 Args:
3630 inputs: possibly nested inputs passed into the call function.
3631 args: a list of positional arguments passed into call.
3632 kwargs: a dictionary of keyword arguments passed into call.
3633 """
3634 if self._saved_model_inputs_spec is not None:
3635 return # Already set.
3636 args = args or []
3637 kwargs = kwargs or {}
3639 input_names = self.input_names
3640 if not input_names:
3641 input_names = compile_utils.create_pseudo_input_names(inputs)
3643 flat_inputs = tf.nest.flatten(inputs)
3644 inputs_spec = []
3645 for name, tensor in zip(input_names, flat_inputs):
3646 inputs_spec.append(
3647 tf_utils.get_tensor_spec(tensor, dynamic_batch=False, name=name)
3648 )
3649 inputs_spec = tf.nest.pack_sequence_as(inputs, inputs_spec)
3650 super()._set_save_spec(inputs_spec, args, kwargs)
3652 # Store the input shapes
3653 if (
3654 self.__class__.__name__ == "Sequential"
3655 and self._build_input_shape is None
3656 ):
3657 self._build_input_shape = tf.nest.map_structure(
3658 lambda x: None if x is None else x.shape, inputs_spec
3659 )
3661 def save_spec(self, dynamic_batch=True):
3662 """Returns the `tf.TensorSpec` of call args as a tuple `(args, kwargs)`.
3664 This value is automatically defined after calling the model for the
3665 first time. Afterwards, you can use it when exporting the model for
3666 serving:
3668 ```python
3669 model = tf.keras.Model(...)
3671 @tf.function
3672 def serve(*args, **kwargs):
3673 outputs = model(*args, **kwargs)
3674 # Apply postprocessing steps, or add additional outputs.
3675 ...
3676 return outputs
3678 # arg_specs is `[tf.TensorSpec(...), ...]`. kwarg_specs, in this
3679 # example, is an empty dict since functional models do not use keyword
3680 # arguments.
3681 arg_specs, kwarg_specs = model.save_spec()
3683 model.save(path, signatures={
3684 'serving_default': serve.get_concrete_function(*arg_specs,
3685 **kwarg_specs)
3686 })
3687 ```
3689 Args:
3690 dynamic_batch: Whether to set the batch sizes of all the returned
3691 `tf.TensorSpec` to `None`. (Note that when defining functional or
3692 Sequential models with `tf.keras.Input([...], batch_size=X)`, the
3693 batch size will always be preserved). Defaults to `True`.
3694 Returns:
3695 If the model inputs are defined, returns a tuple `(args, kwargs)`. All
3696 elements in `args` and `kwargs` are `tf.TensorSpec`.
3697 If the model inputs are not defined, returns `None`.
3698 The model inputs are automatically set when calling the model,
3699 `model.fit`, `model.evaluate` or `model.predict`.
3700 """
3701 return self._get_save_spec(dynamic_batch, inputs_only=False)
3703 def _assert_weights_created(self):
3704 """Asserts that all the weights for the model have been created.
3706 For a non-dynamic model, the weights must already be created after the
3707 layer has been called. For a dynamic model, the exact list of weights
3708 can never be known for certain since it may change at any time during
3709 execution.
3711 We run this check right before accessing weights or getting the Numpy
3712 value for the current weights. Otherwise, if the layer has never been
3713 called, the user would just get an empty list, which is misleading.
3715 Raises:
3716 ValueError: if the weights of the network have not yet been created.
3717 """
3718 if self.dynamic:
3719 return
3721 if (
3722 "build" in self.__class__.__dict__
3723 and self.__class__ != Model
3724 and not self.built
3725 ):
3726 # For any model that has customized build() method but hasn't been
3727 # invoked yet, this will cover both sequential and subclass model.
3728 # Also make sure to exclude Model class itself which has build()
3729 # defined.
3730 raise ValueError(
3731 f"Weights for model '{self.name}' have not yet been "
3732 "created. "
3733 "Weights are created when the model is first called on "
3734 "inputs or `build()` is called with an `input_shape`."
3735 )
3737 def _check_call_args(self, method_name):
3738 """Check that `call()` has only one positional arg."""
3739 # Always allow first arg, regardless of arg name.
3740 fullargspec = self._call_spec.full_argspec
3741 if fullargspec.defaults:
3742 positional_args = fullargspec.args[: -len(fullargspec.defaults)]
3743 else:
3744 positional_args = fullargspec.args
3745 if "training" in positional_args:
3746 positional_args.remove("training")
3748 # self and first arg can be positional.
3749 if len(positional_args) > 2:
3750 extra_args = positional_args[2:]
3751 raise ValueError(
3752 f"Models passed to `{method_name}` can only have `training` "
3753 "and the first argument in `call()` as positional arguments, "
3754 f"found: {extra_args}."
3755 )
3757 def _validate_compile(self, optimizer, metrics, **kwargs):
3758 """Performs validation checks for the default `compile()`."""
3759 if any(
3760 isinstance(opt, optimizer_v1.Optimizer)
3761 for opt in tf.nest.flatten(optimizer)
3762 ):
3763 raise ValueError(
3764 f"`tf.compat.v1.keras` Optimizer ({optimizer}) is "
3765 "not supported when eager execution is enabled. Use a "
3766 "`tf.keras` Optimizer instead, or disable eager "
3767 "execution."
3768 )
3770 kwargs.pop("cloning", None) # Legacy DistStrat argument, never used.
3771 kwargs.pop("experimental_run_tf_function", None) # Always `True`.
3772 distribute_arg = kwargs.pop("distribute", None)
3773 if distribute_arg is not None:
3774 raise ValueError(
3775 "`distribute` argument in compile is not available in TF 2.0. "
3776 "Please create the model under the `strategy.scope()`. "
3777 f"Received: {distribute_arg}."
3778 )
3779 target_tensor_arg = kwargs.pop("target_tensors", None)
3780 if target_tensor_arg is not None:
3781 raise ValueError(
3782 "`target_tensors` argument is not supported when executing "
3783 f"eagerly. Received: {target_tensor_arg}."
3784 )
3785 invalid_kwargs = set(kwargs) - {"sample_weight_mode"}
3786 if invalid_kwargs:
3787 raise TypeError(
3788 "Invalid keyword argument(s) in `compile()`: "
3789 f"{(invalid_kwargs,)}. Valid keyword arguments include "
3790 '"cloning", "experimental_run_tf_function", "distribute",'
3791 ' "target_tensors", or "sample_weight_mode".'
3792 )
3794 # Model must be created and compiled with the same DistStrat.
3795 if self.built and tf.distribute.has_strategy():
3796 strategy = tf.distribute.get_strategy()
3797 for v in self.variables:
3798 if not strategy.extended.variable_created_in_scope(v):
3799 raise ValueError(
3800 f"Variable ({v}) was not created in the distribution "
3801 f"strategy scope of ({strategy}). It is most likely "
3802 "because some layers, model, or optimizer was being "
3803 "created outside the distribution strategy scope. Try "
3804 "to make sure your code looks similar "
3805 "to the following.\nwith strategy.scope():\n"
3806 " model=_create_model()\n"
3807 " model.compile(...)"
3808 )
3810 # Model metrics must be created in the same distribution strategy scope
3811 # as the model.
3812 strategy = self.distribute_strategy
3813 for metric in tf.nest.flatten(metrics):
3814 for v in getattr(metric, "variables", []):
3815 if not strategy.extended.variable_created_in_scope(v):
3816 raise ValueError(
3817 f"Metric ({metric}) passed to `model.compile` was "
3818 "created inside a different distribution strategy "
3819 "scope than the model. All metrics must be created "
3820 "in the same distribution strategy "
3821 f"scope as the model (in this case {strategy}). "
3822 "If you pass in a string identifier for a metric to "
3823 "compile, the metric will automatically be created "
3824 "in the correct distribution strategy scope."
3825 )
3827 # Model metrics must be created in the same distribution strategy scope
3828 # as the model.
3829 for opt in tf.nest.flatten(optimizer):
3830 for v in getattr(opt, "_weights", []):
3831 if not strategy.extended.variable_created_in_scope(v):
3832 raise ValueError(
3833 f"Optimizer ({optimizer}) passed to `model.compile` "
3834 "was created inside a different distribution strategy "
3835 "scope than the model. All optimizers must be created "
3836 "in the same distribution strategy scope as the model "
3837 f"(in this case {strategy}). If you pass in a string "
3838 "identifier for an optimizer to compile, the optimizer "
3839 "will automatically be created in the correct "
3840 "distribution strategy scope."
3841 )
3843 def _maybe_load_initial_counters_from_ckpt(
3844 self, steps_per_epoch, initial_epoch
3845 ):
3846 """Maybe load initial epoch from ckpt, considering worker recovery.
3848 Refer to tensorflow/python/keras/distribute/worker_training_state.py
3849 for more information.
3851 Args:
3852 steps_per_epoch: The number of step per epoch.
3853 initial_epoch: The original initial_epoch user passes in `fit()`.
3854 mode: The mode for running `model.fit()`.
3856 Returns:
3857 If the training is recovering from previous failure under multi-worker
3858 training setting, return the (epoch, step) the training is supposed to
3859 continue at. Otherwise, return the `initial_epoch, initial_step` the
3860 user passes in.
3861 """
3862 initial_step = 0
3863 if self._training_state is not None:
3864 return self._training_state.maybe_load_initial_counters_from_ckpt(
3865 steps_per_epoch, initial_epoch, mode=ModeKeys.TRAIN
3866 )
3867 return (initial_epoch, initial_step)
3869 def _assert_compile_was_called(self):
3870 # Checks whether `compile` has been called. If it has been called,
3871 # then the optimizer is set. This is different from whether the
3872 # model is compiled
3873 # (i.e. whether the model is built and its inputs/outputs are set).
3874 if not self._is_compiled:
3875 raise RuntimeError(
3876 "You must compile your model before "
3877 "training/testing. "
3878 "Use `model.compile(optimizer, loss)`."
3879 )
3881 def _check_sample_weight_warning(self, x, sample_weight):
3882 # Datasets can include sample weight, by returning a tuple with the
3883 # structure of `(x, y, sample_weight)`.
3884 sample_weight_present = sample_weight is not None or (
3885 isinstance(x, tf.data.Dataset)
3886 and isinstance(x.element_spec, tuple)
3887 and len(x.element_spec) == 3
3888 )
3890 if (
3891 sample_weight_present
3892 and self.compiled_metrics._user_weighted_metrics is None
3893 ):
3894 logging.warning(
3895 "`evaluate()` received a value for `sample_weight`, but "
3896 "`weighted_metrics` were not provided. Did you mean to pass "
3897 "metrics to `weighted_metrics` in `compile()`? If this is "
3898 "intentional you can pass `weighted_metrics=[]` to `compile()` "
3899 "in order to silence this warning."
3900 )
3902 def _set_inputs(self, inputs, outputs=None, training=None):
3903 """This method is for compat with Modelv1. Only inputs are needed
3904 here."""
3905 self._set_save_spec(inputs)
3907 @property
3908 def _trackable_saved_model_saver(self):
3909 return model_serialization.ModelSavedModelSaver(self)
3911 def _trackable_children(self, save_type="checkpoint", **kwargs):
3912 if save_type == "savedmodel":
3913 # SavedModel needs to ignore the execution functions.
3914 train_function = self.train_function
3915 test_function = self.test_function
3916 predict_function = self.predict_function
3917 train_tf_function = self.train_tf_function
3918 self.train_function = None
3919 self.test_function = None
3920 self.predict_function = None
3921 self.train_tf_function = None
3923 children = super()._trackable_children(save_type, **kwargs)
3925 if save_type == "savedmodel":
3926 self.train_function = train_function
3927 self.test_function = test_function
3928 self.predict_function = predict_function
3929 self.train_tf_function = train_tf_function
3931 return children
3933 def _should_eval(self, epoch, validation_freq):
3934 epoch = epoch + 1 # one-index the user-facing epoch.
3935 if isinstance(validation_freq, int):
3936 return epoch % validation_freq == 0
3937 elif isinstance(validation_freq, list):
3938 return epoch in validation_freq
3939 else:
3940 raise ValueError(
3941 "Expected `validation_freq` to be a list or int. "
3942 f"Received: validation_freq={validation_freq} of the "
3943 f"type {type(validation_freq)}."
3944 )
3946 ######################################################################
3947 # Functions below exist only as v1 / v2 compatibility shims.
3948 ######################################################################
3950 def _get_compile_args(self, user_metrics=True):
3951 """Used for saving or cloning a Model.
3953 Args:
3954 user_metrics: Whether to return user-supplied metrics or `Metric`
3955 objects. If True, returns the user-supplied metrics.
3956 Defaults to `True`.
3958 Returns:
3959 Dictionary of arguments that were used when compiling the model.
3960 """
3961 self._assert_compile_was_called()
3962 saved_metrics = self.compiled_metrics._user_metrics
3963 saved_weighted_metrics = self.compiled_metrics._user_weighted_metrics
3965 if not user_metrics:
3966 if saved_metrics is not None:
3967 saved_metrics = self.compiled_metrics._metrics
3968 if saved_weighted_metrics is not None:
3969 saved_weighted_metrics = self.compiled_metrics._weighted_metrics
3971 compile_args = {
3972 "optimizer": self.optimizer,
3973 "loss": self.compiled_loss._user_losses,
3974 "metrics": saved_metrics,
3975 "weighted_metrics": saved_weighted_metrics,
3976 "loss_weights": self.compiled_loss._user_loss_weights,
3977 }
3978 return compile_args
3980 def _get_callback_model(self):
3981 return self
3983 def _in_multi_worker_mode(self):
3984 return self.distribute_strategy.extended._in_multi_worker_mode()
3986 @property
3987 def _compile_was_called(self):
3988 return self._is_compiled
3990 def _save_experimental(self, filepath):
3991 return saving_lib.save_model(self, filepath)
3994class _TestFunction:
3995 def __init__(self, function, callbacks):
3996 self._function = function
3997 self._callbacks = callbacks
3999 def run_step(self, dataset_or_iterator, data_handler, step, unused_shards):
4000 tmp_logs = self._function(dataset_or_iterator)
4001 if data_handler.should_sync:
4002 context.async_wait()
4003 logs = tmp_logs
4004 end_step = step + data_handler.step_increment
4005 self._callbacks.on_test_batch_end(end_step, logs)
4006 return logs
4009class _ExactTestFunction(_TestFunction):
4010 def __init__(self, function, callbacks):
4011 super().__init__(function, callbacks)
4012 self._logs = []
4014 def run_step(self, dataset_or_iterator, data_handler, step, shards):
4015 tmp_logs = self._function(
4016 dataset_or_iterator,
4017 tf.constant(shards, dtype=tf.int64),
4018 tf.constant(step, dtype=tf.int64),
4019 )
4020 if data_handler.should_sync:
4021 context.async_wait()
4022 self._logs.append(tmp_logs)
4023 return self._logs
4026def reduce_per_replica(values, strategy, reduction):
4027 """Attempt to reduce the structure `values` to single values.
4029 Given `values` (a `tf.Tensor` or a `PerReplica` structure),
4030 which represents the values across all the replicas, `reduce_per_replica`
4031 attempts to "reduce" those values and returns the corresponding structure
4032 that represents only single values.
4034 Currently, `reduce_per_replica` is only used for reducing the metric results
4035 from `tf.distribute.Strategy.run()`. Depending on the underlying
4036 `Strategy` implementation, `values` may be a `PerReplica` object,
4037 which can be thought of as a collection of values across the replicas,
4038 or a `tf.Tensor`, if the strategy has already conducted the reduction
4039 for the downstream library.
4041 There are five possible outcomes of reduction:
4043 1) if the `values` is a structure of simple `tf.Tensor`s, meaning that
4044 reduction is not actually needed, `reduce_per_replica` returns the
4045 structure as-is.
4046 2) else, if `reduction="auto"`, then the best reduction strategy is
4047 chosen based on the current environment. This should only be used
4048 for training cases (`fit()`).
4049 3) else, if `reduction="first"`, then `reduce_per_replica`
4050 returns the values of the first replica. This is used in the case of
4051 training and evaluation, where `values` is expected to hold the same
4052 value across the replicas as a result of `Strategy`'s synchronization
4053 across the replicas.
4054 `reduce_per_replica` does not synchronize the values.
4055 4) else, if `reduction="sum"`, then `reduce_per_replica` returns the sum
4056 of values for all replicas. This may be used in the custom training loop
4057 case, where each replica contain different values which are not
4058 synchronized.
4059 5) else, if `reduction="concat"`, then `reduce_per_replica`
4060 returns the concatenation of the values across the replicas, along the
4061 axis of dimension 0. This is used in the inference case (`predict()`).
4063 Args:
4064 values: Structure of `PerReplica` objects or `tf.Tensor`s. `tf.Tensor`s
4065 are returned as-is.
4066 strategy: `tf.distribute.Strategy` object.
4067 reduction: One of `"auto"`, `"first"`, `"concat"`, or `"sum"`.
4068 `"auto"` will select `"first"` when used under a TPUStrategy, or
4069 `"sum"` otherwise.
4071 Returns:
4072 Structure of `Tensor`s, representing the result of reduction.
4074 Raises:
4075 ValueError: if the reduction method is not supported.
4076 """
4078 if reduction == "auto":
4079 reduction = "first" if backend.is_tpu_strategy(strategy) else "sum"
4081 def _reduce(v):
4082 """Reduce a single `PerReplica` object."""
4083 if _collective_all_reduce_multi_worker(strategy):
4084 if reduction == "concat":
4085 return _multi_worker_concat(v, strategy)
4086 elif reduction == "sum":
4087 return strategy.reduce("SUM", v, axis=None)
4089 if _is_dtensor_per_replica_instance(v):
4090 return _reduce_dtensor_per_replica(v, strategy, reduction)
4091 elif not _is_per_replica_instance(v):
4092 return v
4093 elif reduction == "first":
4094 return strategy.experimental_local_results(v)[0]
4095 elif reduction == "concat":
4096 if _is_tpu_multi_host(strategy):
4097 return _tpu_multi_host_concat(v, strategy)
4098 else:
4099 return concat(strategy.experimental_local_results(v))
4100 elif reduction == "sum":
4101 return tf.reduce_sum(strategy.experimental_local_results(v))
4102 else:
4103 raise ValueError(
4104 '`reduction` must be "first", "concat", "sum", or "auto". '
4105 f"Received: reduction={reduction}."
4106 )
4108 return tf.nest.map_structure(_reduce, values)
4111def concat(tensors, axis=0):
4112 """Concats `tensor`s along `axis`."""
4113 if isinstance(tensors[0], tf.SparseTensor):
4114 return tf.sparse.concat(axis=axis, sp_inputs=tensors)
4115 elif _is_scalar(tensors[0]):
4116 return tf.stack(tensors, axis=axis)
4117 else:
4118 return tf.concat(tensors, axis=axis)
4121def potentially_ragged_concat(tensors):
4122 """Concats `Tensor`s along their first dimension.
4124 Args:
4125 tensors: List of `Tensor`s.
4127 Returns:
4128 Concatenation of the inputs along the first dimension -- of type `Tensor`
4129 if all input shapes are compatible, or `RaggedTensor` if not.
4130 """
4131 if len(tensors) == 1:
4132 return tensors[0]
4133 if isinstance(tensors[0], tf.SparseTensor):
4134 return tf.sparse.concat(axis=0, sp_inputs=tensors)
4135 elif isinstance(tensors[0], tf.RaggedTensor):
4136 return tf.concat(tensors, axis=0)
4137 elif not tf.__internal__.tf2.enabled():
4138 return tf.concat(tensors, axis=0)
4140 non_batch_shapes = tf.stack([tf.shape(tensor)[1:] for tensor in tensors])
4141 constant_dims = tf.math.reduce_all(
4142 non_batch_shapes == non_batch_shapes[:1], axis=0
4143 )
4144 if tf.math.reduce_all(constant_dims).numpy().item():
4145 # All non-batch dims are constant
4146 if _is_scalar(tensors[0]):
4147 return tf.stack(tensors, axis=0)
4148 else:
4149 return tf.concat(tensors, axis=0)
4151 # First, identify constant inner dimensions by finding the
4152 # rightmost dimension that is not constant
4153 constant_inner_dimensions = (
4154 constant_dims.numpy().tolist()[::-1].index(False)
4155 )
4156 # If there are constant inner dimensions, define a constant inner shape
4157 if constant_inner_dimensions == 0:
4158 constant_inner_shape = None
4159 else:
4160 constant_inner_shape = tensors[0].shape[-constant_inner_dimensions:]
4161 return tf.ragged.constant(
4162 [tensor.numpy() for tensor in tensors], inner_shape=constant_inner_shape
4163 ).merge_dims(0, 1)
4166def _reduce_dtensor_per_replica(value, strategy, reduction):
4167 # Note that this function could happen in graph, so we can't just access
4168 # the per-replica.values(), which will trigger unpack in graph and result
4169 # into error.
4170 # For now we will perform ops on dtensor instance directly on a global
4171 # context.
4172 dtensor = value._dtensor
4173 if reduction == "first":
4174 num_replica = strategy.num_replicas_in_sync
4175 return tf.split(dtensor, num_replica, axis=0)[0]
4176 elif reduction == "concat":
4177 # Since dtensor is already in global context, the concat is a no-op
4178 return dtensor
4179 elif reduction == "sum":
4180 return tf.reduce_sum(dtensor)
4181 else:
4182 raise ValueError(
4183 '`reduction` must be one of "first", "concat", "sum", or "auto". '
4184 f"Received: reduction={reduction}."
4185 )
4188def _get_verbosity(verbose, distribute_strategy):
4189 """Find the right verbosity value for 'auto'."""
4190 if verbose == 1 and distribute_strategy._should_use_with_coordinator:
4191 raise ValueError(
4192 "`verbose=1` is not allowed with `ParameterServerStrategy` for "
4193 f"performance reasons. Received: verbose={verbose}"
4194 )
4195 if verbose == "auto":
4196 if (
4197 distribute_strategy._should_use_with_coordinator
4198 or not io_utils.is_interactive_logging_enabled()
4199 ):
4200 # Defaults to epoch-level logging for PSStrategy or using absl
4201 # logging.
4202 return 2
4203 else:
4204 return 1 # Defaults to batch-level logging otherwise.
4205 return verbose
4208def _is_tpu_multi_host(strategy):
4209 return backend.is_tpu_strategy(strategy) and strategy.extended.num_hosts > 1
4212def _tpu_multi_host_concat(v, strategy):
4213 """Correctly order TPU PerReplica objects."""
4214 replicas = strategy.experimental_local_results(v)
4215 # When distributed datasets are created from Tensors / NumPy,
4216 # TPUStrategy.experimental_distribute_dataset shards data in
4217 # (Replica, Host) order, and TPUStrategy.experimental_local_results returns
4218 # it in (Host, Replica) order.
4219 # TODO(b/150317897): Figure out long-term plan here.
4220 num_replicas_per_host = strategy.extended.num_replicas_per_host
4221 ordered_replicas = []
4222 for replica_id in range(num_replicas_per_host):
4223 ordered_replicas += replicas[replica_id::num_replicas_per_host]
4224 return concat(ordered_replicas)
4227def _collective_all_reduce_multi_worker(strategy):
4228 return (
4229 isinstance(strategy, tf.distribute.MultiWorkerMirroredStrategy)
4230 ) and strategy.extended._in_multi_worker_mode()
4233# TODO(wxinyi): merge this with _tpu_multi_host_concat once we have all_gather
4234# for all strategies
4235def _multi_worker_concat(v, strategy):
4236 """Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
4237 replicas = strategy.gather(v, axis=0)
4238 # v might not have the same shape on different replicas
4239 if _is_per_replica_instance(v):
4240 shapes = tf.concat(
4241 [
4242 tf.expand_dims(tf.shape(single_value)[0], axis=0)
4243 for single_value in v.values
4244 ],
4245 axis=0,
4246 )
4247 all_shapes = strategy.gather(shapes, axis=0)
4248 else:
4249 # v is a tensor. This may happen when, say, we have 2x1 multi-worker.
4250 all_shapes = strategy.gather(
4251 tf.expand_dims(tf.shape(v)[0], axis=0), axis=0
4252 )
4254 replicas = tf.split(
4255 replicas,
4256 num_or_size_splits=all_shapes,
4257 num=strategy.num_replicas_in_sync,
4258 )
4259 ordered_replicas = []
4260 num_replicas_per_worker = len(strategy.extended.worker_devices)
4261 for replica_id in range(num_replicas_per_worker):
4262 ordered_replicas += replicas[replica_id::num_replicas_per_worker]
4263 return concat(ordered_replicas)
4266def _is_scalar(x):
4267 return isinstance(x, (tf.Tensor, tf.Variable)) and x.shape.rank == 0
4270def _minimum_control_deps(outputs):
4271 """Returns the minimum control dependencies to ensure step succeeded."""
4272 if tf.executing_eagerly():
4273 return [] # Control dependencies not needed.
4274 outputs = tf.nest.flatten(outputs, expand_composites=True)
4275 for out in outputs:
4276 # Variables can't be control dependencies.
4277 if not isinstance(out, tf.Variable):
4278 return [out] # Return first Tensor or Op from outputs.
4279 return [] # No viable Tensor or Op to use for control deps.
4282def _disallow_inside_tf_function(method_name):
4283 if tf.inside_function():
4284 error_msg = (
4285 "Detected a call to `Model.{method_name}` inside a `tf.function`. "
4286 "`Model.{method_name} is a high-level endpoint that manages its "
4287 "own `tf.function`. Please move the call to `Model.{method_name}` "
4288 "outside of all enclosing `tf.function`s. Note that you can call a "
4289 "`Model` directly on `Tensor`s inside a `tf.function` like: "
4290 "`model(x)`."
4291 ).format(method_name=method_name)
4292 raise RuntimeError(error_msg)
4295def flatten_metrics_in_order(logs, metrics_names):
4296 """Turns the `logs` dict into a list as per key order of `metrics_names`."""
4297 results = []
4298 for name in metrics_names:
4299 if name in logs:
4300 results.append(logs[name])
4301 for key in sorted(logs.keys()):
4302 if key not in metrics_names:
4303 results.append(logs[key])
4304 if len(results) == 1:
4305 return results[0]
4306 return results
4309def _is_per_replica_instance(obj):
4310 return isinstance(obj, tf.distribute.DistributedValues) and isinstance(
4311 obj, tf.__internal__.CompositeTensor
4312 )
4315def _is_dtensor_per_replica_instance(obj):
4316 # This is a temp check for DTensorDistributedValue, which is not public API
4317 # yet.
4318 # TODO(scottzhu): Move to more stable API when dtensor based strategy is
4319 # ready.
4320 return isinstance(obj, tf.distribute.DistributedValues) and hasattr(
4321 obj, "_dtensor"
4322 )
4325def disable_multi_worker(method):
4326 """Decorator that disallows multi-worker use of `method`."""
4328 def _method_wrapper(self, *args, **kwargs):
4329 if self._in_multi_worker_mode():
4330 raise ValueError(
4331 f"{method.__name__} is not supported in multi-worker "
4332 "mode. Please use a non-multi-worker "
4333 "`tf.distribute.Strategy` such as "
4334 "`tf.distribute.MirroredStrategy`."
4335 )
4336 return method(self, *args, **kwargs)
4338 return tf.__internal__.decorator.make_decorator(
4339 target=method, decorator_func=_method_wrapper
4340 )
4343def inject_functional_model_class(cls):
4344 """Inject `Functional` into the hierarchy of this class if needed."""
4345 from keras.src.engine import functional
4346 from keras.src.engine import training_v1
4348 if cls == Model or cls == training_v1.Model:
4349 return functional.Functional
4350 # In case there is any multiple inheritance, we stop injecting the
4351 # class if keras model is not in its class hierarchy.
4352 if cls == object:
4353 return object
4355 cls.__bases__ = tuple(
4356 inject_functional_model_class(base) for base in cls.__bases__
4357 )
4358 # Trigger any `__new__` class swapping that needed to happen on `Functional`
4359 # but did not because functional was not in the class hierarchy.
4360 cls.__new__(cls)
4362 return cls
4365def is_functional_model_init_params(args, kwargs):
4366 # Both inputs and outputs in args
4367 if len(args) == 2:
4368 return True
4369 # Both inputs in args, outputs in kwargs
4370 if len(args) == 1 and "outputs" in kwargs:
4371 return True
4372 # Both in kwargs
4373 if "inputs" in kwargs and "outputs" in kwargs:
4374 return True
4375 return False