Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/training_v1.py: 18%
1053 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""V1 Training-related part of the Keras engine."""
16import collections
17import warnings
19import numpy as np
20import tensorflow.compat.v2 as tf
22from keras.src import backend
23from keras.src import losses
24from keras.src import metrics as metrics_module
25from keras.src import optimizers
26from keras.src.distribute import distributed_training_utils
27from keras.src.distribute import distributed_training_utils_v1
28from keras.src.engine import base_layer
29from keras.src.engine import training as training_lib
30from keras.src.engine import training_arrays_v1
31from keras.src.engine import training_distributed_v1
32from keras.src.engine import training_eager_v1
33from keras.src.engine import training_generator_v1
34from keras.src.engine import training_utils
35from keras.src.engine import training_utils_v1
36from keras.src.mixed_precision import loss_scale_optimizer
37from keras.src.optimizers import optimizer_v1
38from keras.src.optimizers.legacy import optimizer_v2
39from keras.src.saving.legacy import saving_utils
40from keras.src.saving.legacy.saved_model import model_serialization
41from keras.src.utils import data_utils
42from keras.src.utils import layer_utils
43from keras.src.utils import losses_utils
44from keras.src.utils import tf_inspect
45from keras.src.utils import tf_utils
46from keras.src.utils.mode_keys import ModeKeys
48# isort: off
49from tensorflow.python.platform import tf_logging as logging
51try:
52 from scipy.sparse import issparse
53except ImportError:
54 issparse = None
57class Model(training_lib.Model):
58 """A model groups layers into an object with training & inference features.
60 There are two ways to instantiate a `Model`:
62 1 - With the "functional API", where you start from `Input`,
63 you chain layer calls to specify the model's forward pass,
64 and finally you create your model from inputs and outputs:
66 ```python
67 import tensorflow as tf
69 inputs = tf.keras.Input(shape=(3,))
70 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
71 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
72 model = tf.keras.Model(inputs=inputs, outputs=outputs)
73 ```
75 2 - By subclassing the `Model` class: in that case, you should define your
76 layers in `__init__` and you should implement the model's forward pass
77 in `call`.
79 ```python
80 import tensorflow as tf
82 class MyModel(tf.keras.Model):
84 def __init__(self):
85 super().__init__()
86 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
87 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
89 def call(self, inputs):
90 x = self.dense1(inputs)
91 return self.dense2(x)
93 model = MyModel()
94 ```
96 If you subclass `Model`, you can optionally have
97 a `training` argument (boolean) in `call`, which you can use to specify
98 a different behavior in training and inference:
100 ```python
101 import tensorflow as tf
103 class MyModel(tf.keras.Model):
105 def __init__(self):
106 super().__init__()
107 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
108 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
109 self.dropout = tf.keras.layers.Dropout(0.5)
111 def call(self, inputs, training=False):
112 x = self.dense1(inputs)
113 if training:
114 x = self.dropout(x, training=training)
115 return self.dense2(x)
117 model = MyModel()
118 ```
119 """
121 def __init__(self, *args, **kwargs):
122 super().__init__(*args, **kwargs)
123 # initializing _distribution_strategy here since it is possible to call
124 # predict on a model without compiling it.
125 self._distribution_strategy = None
126 self._compile_time_distribution_strategy = None
127 if (
128 tf.compat.v1.executing_eagerly_outside_functions()
129 and tf.distribute.has_strategy()
130 ):
131 self._set_strategy(tf.distribute.get_strategy())
133 # This flag is used to track if the user is using the deprecated path of
134 # passing distribution strategy to compile rather than creating the
135 # model under distribution strategy scope.
136 self._compile_distribution = False
138 self._run_eagerly = None
139 self._experimental_run_tf_function = (
140 tf.compat.v1.executing_eagerly_outside_functions()
141 )
143 self._v1_compile_was_called = False
145 def _init_batch_counters(self):
146 pass # Batch counters should not be created in legacy graph mode.
148 @tf.__internal__.tracking.no_automatic_dependency_tracking
149 def _set_strategy(self, strategy):
150 self._compile_time_distribution_strategy = strategy
152 def get_weights(self):
153 """Retrieves the weights of the model.
155 Returns:
156 A flat list of Numpy arrays.
157 """
158 strategy = (
159 self._distribution_strategy
160 or self._compile_time_distribution_strategy
161 )
162 if strategy:
163 with strategy.scope():
164 return base_layer.Layer.get_weights(self)
165 return base_layer.Layer.get_weights(self)
167 def load_weights(self, filepath, by_name=False, skip_mismatch=False):
168 """Loads all layer weights, either from a TensorFlow or an HDF5 file.
170 If `by_name` is False weights are loaded based on the network's
171 topology. This means the architecture should be the same as when the
172 weights were saved. Note that layers that don't have weights are not
173 taken into account in the topological ordering, so adding or removing
174 layers is fine as long as they don't have weights.
176 If `by_name` is True, weights are loaded into layers only if they share
177 the same name. This is useful for fine-tuning or transfer-learning
178 models where some of the layers have changed.
180 Only topological loading (`by_name=False`) is supported when loading
181 weights from the TensorFlow format. Note that topological loading
182 differs slightly between TensorFlow and HDF5 formats for user-defined
183 classes inheriting from `tf.keras.Model`: HDF5 loads based on a
184 flattened list of weights, while the TensorFlow format loads based on
185 the object-local names of attributes to which layers are assigned in the
186 `Model`'s constructor.
188 Args:
189 filepath: String, path to the weights file to load. For weight files
190 in TensorFlow format, this is the file prefix (the same as was
191 passed to `save_weights`).
192 by_name: Boolean, whether to load weights by name or by topological
193 order. Only topological loading is supported for weight files in
194 TensorFlow format.
195 skip_mismatch: Boolean, whether to skip loading of layers where
196 there is a mismatch in the number of weights, or a mismatch in
197 the shape of the weight (only valid when `by_name=True`).
199 Returns:
200 When loading a weight file in TensorFlow format, returns the same
201 status object as `tf.train.Checkpoint.restore`. When graph building,
202 restore ops are run automatically as soon as the network is built
203 (on first call for user-defined classes inheriting from `Model`,
204 immediately if it is already built).
206 When loading weights in HDF5 format, returns `None`.
208 Raises:
209 ImportError: If h5py is not available and the weight file is in HDF5
210 format.
211 ValueError: If `skip_mismatch` is set to `True` when `by_name` is
212 `False`.
213 """
214 if backend.is_tpu_strategy(self._distribution_strategy):
215 if self._distribution_strategy.extended.steps_per_run > 1 and (
216 not saving_utils.is_hdf5_filepath(filepath)
217 ):
218 raise ValueError(
219 "Load weights is not yet supported with TPUStrategy "
220 "with steps_per_run greater than 1."
221 )
222 return super().load_weights(
223 filepath, by_name=by_name, skip_mismatch=skip_mismatch
224 )
226 @tf.__internal__.tracking.no_automatic_dependency_tracking
227 def compile(
228 self,
229 optimizer="rmsprop",
230 loss=None,
231 metrics=None,
232 loss_weights=None,
233 sample_weight_mode=None,
234 weighted_metrics=None,
235 target_tensors=None,
236 distribute=None,
237 **kwargs,
238 ):
239 """Configures the model for training.
241 Args:
242 optimizer: String (name of optimizer) or optimizer instance.
243 See `tf.keras.optimizers`.
244 loss: String (name of objective function), objective function or
245 `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An
246 objective function is any callable with the signature
247 `scalar_loss = fn(y_true, y_pred)`. If the model has multiple
248 outputs, you can use a different loss on each output by passing
249 a dictionary or a list of losses. The loss value that will be
250 minimized by the model will then be the sum of all individual
251 losses.
252 metrics: List of metrics to be evaluated by the model during
253 training and testing. Typically you will use
254 `metrics=['accuracy']`. To specify different metrics for
255 different outputs of a multi-output model, you could also pass a
256 dictionary, such as `metrics={'output_a': 'accuracy',
257 'output_b': ['accuracy', 'mse']}`. You can also pass a list
258 (len = len(outputs)) of lists of metrics such as
259 `metrics=[['accuracy'], ['accuracy', 'mse']]` or
260 `metrics=['accuracy', ['accuracy', 'mse']]`.
261 loss_weights: Optional list or dictionary specifying scalar
262 coefficients (Python floats) to weight the loss contributions
263 of different model outputs.
264 The loss value that will be minimized by the model
265 will then be the *weighted sum* of all individual losses,
266 weighted by the `loss_weights` coefficients.
267 If a list, it is expected to have a 1:1 mapping
268 to the model's outputs. If a tensor, it is expected to map
269 output names (strings) to scalar coefficients.
270 sample_weight_mode: If you need to do timestep-wise
271 sample weighting (2D weights), set this to `"temporal"`.
272 `None` becomes sample-wise weights (1D).
273 If the model has multiple outputs, you can use a different
274 `sample_weight_mode` on each output by passing a
275 dictionary or a list of modes. Defaults to `None`.
276 weighted_metrics: List of metrics to be evaluated and weighted
277 by sample_weight or class_weight during training and testing.
278 target_tensors: By default, Keras will create placeholders for the
279 model's target, which will be fed with the target data during
280 training. If instead you would like to use your own
281 target tensors (in turn, Keras will not expect external
282 Numpy data for these targets at training time), you
283 can specify them via the `target_tensors` argument. It can be
284 a single tensor (for a single-output model), a list of tensors,
285 or a dict mapping output names to target tensors.
286 distribute: NOT SUPPORTED IN TF 2.0, please create and compile the
287 model under distribution strategy scope instead of passing it to
288 compile.
289 **kwargs: Any additional arguments.
291 Raises:
292 ValueError: In case of invalid arguments for
293 `optimizer`, `loss`, `metrics` or `sample_weight_mode`.
294 """
295 self._assert_built_as_v1()
296 self._run_eagerly = kwargs.pop("run_eagerly", None)
297 self._experimental_run_tf_function = kwargs.pop(
298 "experimental_run_tf_function", True
299 )
300 self._v1_compile_was_called = True
302 # Prepare Session arguments (legacy).
303 kwargs.pop("cloning", None) # Legacy DistStrat argument, never used.
304 self._from_serialized = kwargs.pop("from_serialized", False)
305 allowed_kwargs = {"feed_dict", "fetches", "options", "run_metadata"}
306 unknown_kwargs = set(kwargs.keys()) - allowed_kwargs
307 if unknown_kwargs:
308 raise TypeError(
309 f"Invalid keyword argument(s) in `compile`: {unknown_kwargs}"
310 )
311 self._function_kwargs = kwargs
312 if self._function_kwargs:
313 self._experimental_run_tf_function = False
314 if self.run_eagerly:
315 raise ValueError(
316 "Session keyword arguments are not supported "
317 "when `run_eagerly=True`. You passed the following "
318 "Session arguments: %s" % (self._function_kwargs,)
319 )
321 self._set_optimizer(optimizer)
322 is_any_keras_optimizer_v1 = any(
323 (
324 isinstance(opt, optimizer_v1.Optimizer)
325 and not isinstance(opt, optimizer_v1.TFOptimizer)
326 )
327 for opt in tf.nest.flatten(self.optimizer)
328 )
330 if (
331 is_any_keras_optimizer_v1
332 and tf.compat.v1.executing_eagerly_outside_functions()
333 ):
334 raise ValueError(
335 "`tf.compat.v1.keras` Optimizer (",
336 optimizer,
337 ") is "
338 "not supported when eager execution is enabled. Use a "
339 "`tf.keras` Optimizer instead, or disable eager "
340 "execution.",
341 )
343 if (
344 target_tensors is not None
345 ) or not tf.compat.v1.executing_eagerly_outside_functions():
346 # Fallback out of things that aren't supported with v2 loops
347 self._experimental_run_tf_function = False
349 if distribute is not None:
350 if (
351 tf.__internal__.tf2.enabled()
352 or self._experimental_run_tf_function
353 ):
354 raise ValueError(
355 "Distribute argument in compile is not available in TF 2.0 "
356 "please create the model under the distribution strategy "
357 "scope."
358 )
359 logging.warning(
360 "Distribute argument in compile is deprecated please "
361 "create the model under the distribution strategy scope."
362 )
363 self._distribution_strategy = distribute
364 self._compile_distribution = True
365 else:
366 if tf.distribute.has_strategy():
367 # When the user builds the model in the DS scope and cross
368 # replica context we want distribution strategy to be set but
369 # when building the replica copies of the models internally we
370 # should not be compiling with distribution strategy and use the
371 # default compilation path.
372 if tf.distribute.in_cross_replica_context():
373 self._distribution_strategy = tf.distribute.get_strategy()
375 if isinstance(
376 self._distribution_strategy,
377 tf.compat.v1.distribute.experimental.ParameterServerStrategy,
378 ):
379 raise NotImplementedError(
380 "`tf.compat.v1.distribute.experimental.ParameterServerStrategy`"
381 " currently only works with the tf.Estimator API"
382 )
384 if isinstance(
385 self._distribution_strategy,
386 tf.distribute.experimental.ParameterServerStrategy,
387 ):
388 raise NotImplementedError(
389 "`tf.distribute.experimental.ParameterServerStrategy` is only "
390 "supported in TF2."
391 )
393 if not self._experimental_run_tf_function:
394 self._validate_compile_param_for_distribution_strategy(
395 self.run_eagerly,
396 sample_weight_mode,
397 target_tensors,
398 weighted_metrics,
399 )
400 # We've disabled automatic dependency tracking for this method, but do
401 # want to add a checkpoint dependency on the optimizer if it's
402 # trackable.
403 if isinstance(self.optimizer, tf.__internal__.tracking.Trackable):
404 self._track_trackable(
405 self.optimizer, name="optimizer", overwrite=True
406 )
407 self.loss = loss or {}
408 self.loss_weights = loss_weights
409 self.sample_weight_mode = sample_weight_mode
410 self._compile_metrics = metrics or []
411 self._compile_weighted_metrics = weighted_metrics
412 if self.run_eagerly and target_tensors is not None:
413 raise ValueError(
414 "target_tensors argument is not supported when "
415 "running a model eagerly."
416 )
418 # _training_endpoints contains a list of _TrainingEndpoint object, which
419 # has all the model output/target/loss and related metadata.
420 self._training_endpoints = []
422 # Used to freeze the behavior of the Model once `compile` has been
423 # called.
424 self._compiled_trainable_state = self._get_trainable_state()
426 # Set tf.distribute.Strategy specific parameters.
427 self._distributed_model_cache = {}
428 self._distributed_function_cache = {}
430 # Clear any `_eager_losses` that was added.
431 self._clear_losses()
433 if (
434 not tf.executing_eagerly()
435 and self._distribution_strategy is not None
436 ):
437 # Ensures a Session is created and configured correctly for
438 # Distribution Strategy.
439 backend.configure_and_create_distributed_session(
440 self._distribution_strategy
441 )
442 # Initialize model metric attributes.
443 self._init_metric_attributes()
444 if not self.built or not self.inputs or not self.outputs:
445 # Model is not compilable because it does not know its number of
446 # inputs and outputs, nor their shapes and names. We will compile
447 # after the first time the model gets called on training data.
448 return
449 self._is_compiled = True
450 base_layer.keras_api_gauge.get_cell("compile").set(True)
452 # Prepare list of loss functions, same size of model outputs.
453 self.loss_functions = training_utils_v1.prepare_loss_functions(
454 self.loss, self.output_names
455 )
457 target_tensors = self._process_target_tensor_for_compile(target_tensors)
459 for o, n, l, t in zip(
460 self.outputs, self.output_names, self.loss_functions, target_tensors
461 ):
462 endpoint = _TrainingEndpoint(o, n, l)
463 endpoint.create_training_target(t, run_eagerly=self.run_eagerly)
464 self._training_endpoints.append(endpoint)
466 # Prepare list loss weights, same size of model outputs.
467 training_utils_v1.prepare_loss_weights(
468 self._training_endpoints, loss_weights
469 )
471 # Initialization for Eager mode execution.
472 if self.run_eagerly:
473 self._compile_eagerly(metrics, weighted_metrics, sample_weight_mode)
474 return
476 with backend.get_graph().as_default():
477 # Save all metric attributes per output of the model.
478 self._cache_output_metric_attributes(metrics, weighted_metrics)
480 # Set metric attributes on model.
481 self._set_metric_attributes()
483 # Invoke metric functions (unweighted) for all the outputs.
484 self._handle_metrics(
485 self.outputs,
486 targets=self._targets,
487 skip_target_masks=self._prepare_skip_target_masks(),
488 masks=self._prepare_output_masks(),
489 )
491 # Prepare sample weight modes. List with the same length as model
492 # outputs.
493 training_utils_v1.prepare_sample_weight_modes(
494 self._training_endpoints, sample_weight_mode
495 )
497 # Creates the model loss and weighted metrics sub-graphs.
498 self._compile_weights_loss_and_weighted_metrics()
500 # Functions for train, test and predict will
501 # be compiled lazily when required.
502 # This saves time when the user is not using all functions.
503 self.train_function = None
504 self.test_function = None
505 self.predict_function = None
507 # Collected trainable weights, sorted in topological order.
508 self._collected_trainable_weights = self.trainable_weights
510 # Validate all variables were correctly created in distribution
511 # scope.
512 if self._distribution_strategy and not self._compile_distribution:
513 for v in self.variables:
514 strategy = self._distribution_strategy
515 if not strategy.extended.variable_created_in_scope(v):
516 raise ValueError(
517 "Variable (%s) was not created in the distribution "
518 "strategy scope of (%s). It is most likely due to "
519 "not all layers or the model or optimizer being "
520 "created outside the distribution strategy scope. "
521 "Try to make sure your code looks similar "
522 "to the following.\n"
523 "with strategy.scope():\n"
524 " model=_create_model()\n"
525 " model.compile(...)" % (v, strategy)
526 )
528 @tf.__internal__.tracking.no_automatic_dependency_tracking
529 def _init_distributed_function_cache_if_not_compiled(self):
530 if not hasattr(self, "_distributed_function_cache"):
531 self._distributed_function_cache = {}
533 @property
534 def metrics(self):
535 """Returns the model's metrics added using `compile`, `add_metric`
536 APIs."""
537 metrics = []
538 if self._is_compiled:
539 if not hasattr(self, "_v1_compile_was_called"):
540 # See b/155687393 for more details, the model is created as a v2
541 # instance but converted to v1. Fallback to use base Model to
542 # retrieve the metrics.
543 return super().metrics
544 metrics += self._compile_metric_functions
545 metrics.extend(self._metrics)
546 metrics.extend(
547 _get_metrics_from_layers(
548 list(self._flatten_layers(include_self=False, recursive=False))
549 )
550 )
551 return metrics
553 @property
554 def metrics_names(self):
555 """Returns the model's display labels for all outputs."""
557 # This property includes all output names including `loss` and
558 # per-output losses for backward compatibility.
559 metrics_names = ["loss"]
560 if self._is_compiled:
561 if not hasattr(self, "_v1_compile_was_called"):
562 # See b/155687393 for more details, the model is created as a v2
563 # instance but converted to v1. Fallback to use base Model to
564 # retrieve the metrics name
565 return super().metrics_names
567 # Add output loss metric names to the metric names list.
568 if len(self._training_endpoints) > 1:
569 metrics_names.extend(
570 [
571 e.loss_name()
572 for e in self._training_endpoints
573 if not e.should_skip_target()
574 ]
575 )
577 # Add all metric names.
578 metrics_names += [m.name for m in self.metrics]
579 return metrics_names
581 @property
582 def run_eagerly(self):
583 """Settable attribute indicating whether the model should run eagerly.
585 Running eagerly means that your model will be run step by step,
586 like Python code. Your model might run slower, but it should become
587 easier for you to debug it by stepping into individual layer calls.
589 By default, we will attempt to compile your model to a static graph to
590 deliver the best execution performance.
592 Returns:
593 Boolean, whether the model should run eagerly.
594 """
595 if self._run_eagerly is True and not tf.executing_eagerly():
596 raise ValueError(
597 "You can only set `run_eagerly=True` if eager execution "
598 "is enabled."
599 )
600 if not self.dynamic:
601 if self._run_eagerly is None:
602 # Respect `tf.config.run_functions_eagerly` unless
603 # `run_eagerly` was explicitly passed to `compile`.
604 return tf.config.functions_run_eagerly()
605 else:
606 return self._run_eagerly
607 else:
608 if not tf.executing_eagerly():
609 raise ValueError(
610 "Your model contains layers that can only be "
611 "successfully run in eager execution (layers "
612 "constructed with `dynamic=True`). "
613 "You must enable eager execution with "
614 "`tf.enable_eager_execution()`."
615 )
616 if self._run_eagerly is False:
617 # TODO(fchollet): consider using py_func to enable this.
618 raise ValueError(
619 "Your model contains layers that can only be "
620 "successfully run in eager execution (layers "
621 "constructed with `dynamic=True`). "
622 "You cannot set `run_eagerly=False`."
623 )
624 return tf.executing_eagerly()
626 @run_eagerly.setter
627 def run_eagerly(self, value):
628 self._run_eagerly = value
630 def _select_training_loop(self, inputs):
631 """Select training loop for fit/eval/predict based on the inputs."""
632 # TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely
633 # integrated into the data adapters in the v2 loop. We can't do this yet
634 # because we currently have to fall back for unhandled data types.
635 if isinstance(inputs, (tf.compat.v1.data.Iterator, tf.data.Iterator)):
636 raise ValueError(
637 "For performance reasons Keras `fit`, `evaluate` and"
638 "`predict` accept tf.data `Datasets` as input but not "
639 "iterators that have been manually generated from "
640 "Datasets by users. Please directly pass in the "
641 "original `Dataset` object instead of passing in "
642 "`iter(dataset)`."
643 )
645 # Case 1: distribution strategy.
646 if self._distribution_strategy:
647 if self._in_multi_worker_mode():
648 return training_distributed_v1.DistributionMultiWorkerTrainingLoop( # noqa: E501
649 training_distributed_v1.DistributionSingleWorkerTrainingLoop() # noqa: E501
650 )
651 else:
652 return (
653 training_distributed_v1.DistributionSingleWorkerTrainingLoop() # noqa: E501
654 )
656 # Case 2: generator-like. Input is Python generator, or Sequence object,
657 # or a non-distributed Dataset or iterator in eager execution.
658 if data_utils.is_generator_or_sequence(inputs):
659 return training_generator_v1.GeneratorOrSequenceTrainingLoop()
660 if training_utils_v1.is_eager_dataset_or_iterator(inputs):
661 return training_generator_v1.EagerDatasetOrIteratorTrainingLoop()
663 # Case 3: Symbolic tensors or Numpy array-like.
664 # This includes Datasets and iterators in graph mode (since they
665 # generate symbolic tensors).
666 if self.run_eagerly:
667 return training_generator_v1.GeneratorLikeTrainingLoop()
668 else:
669 return training_arrays_v1.ArrayLikeTrainingLoop()
671 def fit(
672 self,
673 x=None,
674 y=None,
675 batch_size=None,
676 epochs=1,
677 verbose=1,
678 callbacks=None,
679 validation_split=0.0,
680 validation_data=None,
681 shuffle=True,
682 class_weight=None,
683 sample_weight=None,
684 initial_epoch=0,
685 steps_per_epoch=None,
686 validation_steps=None,
687 validation_freq=1,
688 max_queue_size=10,
689 workers=1,
690 use_multiprocessing=False,
691 **kwargs,
692 ):
693 """Trains the model for a fixed number of epochs (dataset iterations).
695 Args:
696 x: Input data. It could be:
697 - A Numpy array (or array-like), or a list of arrays
698 (in case the model has multiple inputs).
699 - A TensorFlow tensor, or a list of tensors
700 (in case the model has multiple inputs).
701 - A dict mapping input names to the corresponding array/tensors,
702 if the model has named inputs.
703 - A `tf.data` dataset. Should return a tuple
704 of either `(inputs, targets)` or
705 `(inputs, targets, sample_weights)`.
706 - A generator or `keras.utils.Sequence` returning `(inputs,
707 targets)` or `(inputs, targets, sample weights)`.
708 y: Target data. Like the input data `x`,
709 it could be either Numpy array(s) or TensorFlow tensor(s).
710 It should be consistent with `x` (you cannot have Numpy inputs and
711 tensor targets, or inversely). If `x` is a dataset, generator,
712 or `keras.utils.Sequence` instance, `y` should
713 not be specified (since targets will be obtained from `x`).
714 batch_size: Integer or `None`.
715 Number of samples per gradient update.
716 If unspecified, `batch_size` will default to 32.
717 Do not specify the `batch_size` if your data is in the
718 form of symbolic tensors, datasets,
719 generators, or `keras.utils.Sequence` instances (since they
720 generate batches).
721 epochs: Integer. Number of epochs to train the model.
722 An epoch is an iteration over the entire `x` and `y`
723 data provided.
724 Note that in conjunction with `initial_epoch`,
725 `epochs` is to be understood as "final epoch".
726 The model is not trained for a number of iterations
727 given by `epochs`, but merely until the epoch
728 of index `epochs` is reached.
729 verbose: 0, 1, or 2. Verbosity mode.
730 0 = silent, 1 = progress bar, 2 = one line per epoch.
731 Note that the progress bar is not particularly useful when
732 logged to a file, so verbose=2 is recommended when not running
733 interactively (eg, in a production environment).
734 callbacks: List of `keras.callbacks.Callback` instances.
735 List of callbacks to apply during training.
736 See `tf.keras.callbacks`.
737 validation_split: Float between 0 and 1.
738 Fraction of the training data to be used as validation data.
739 The model will set apart this fraction of the training data,
740 will not train on it, and will evaluate
741 the loss and any model metrics
742 on this data at the end of each epoch.
743 The validation data is selected from the last samples
744 in the `x` and `y` data provided, before shuffling. This
745 argument is not supported when `x` is a dataset, generator or
746 `keras.utils.Sequence` instance.
747 validation_data: Data on which to evaluate
748 the loss and any model metrics at the end of each epoch.
749 The model will not be trained on this data.
750 `validation_data` will override `validation_split`.
751 `validation_data` could be:
752 - tuple `(x_val, y_val)` of Numpy arrays or tensors
753 - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
754 - dataset
755 For the first two cases, `batch_size` must be provided.
756 For the last case, `validation_steps` could be provided.
757 shuffle: Boolean (whether to shuffle the training data
758 before each epoch) or str (for 'batch').
759 'batch' is a special option for dealing with the
760 limitations of HDF5 data; it shuffles in batch-sized chunks.
761 Has no effect when `steps_per_epoch` is not `None`.
762 class_weight: Optional dictionary mapping class indices (integers)
763 to a weight (float) value, used for weighting the loss function
764 (during training only).
765 This can be useful to tell the model to
766 "pay more attention" to samples from
767 an under-represented class.
768 sample_weight: Optional Numpy array of weights for
769 the training samples, used for weighting the loss function
770 (during training only). You can either pass a flat (1D)
771 Numpy array with the same length as the input samples
772 (1:1 mapping between weights and samples),
773 or in the case of temporal data,
774 you can pass a 2D array with shape
775 `(samples, sequence_length)`,
776 to apply a different weight to every timestep of every sample.
777 In this case you should make sure to specify
778 `sample_weight_mode="temporal"` in `compile()`. This argument is
779 not supported when `x` is a dataset, generator, or
780 `keras.utils.Sequence` instance, instead provide the
781 sample_weights as the third element of `x`.
782 initial_epoch: Integer.
783 Epoch at which to start training
784 (useful for resuming a previous training run).
785 steps_per_epoch: Integer or `None`.
786 Total number of steps (batches of samples)
787 before declaring one epoch finished and starting the
788 next epoch. When training with input tensors such as
789 TensorFlow data tensors, the default `None` is equal to
790 the number of samples in your dataset divided by
791 the batch size, or 1 if that cannot be determined. If x is a
792 `tf.data` dataset, and 'steps_per_epoch'
793 is None, the epoch will run until the input dataset is
794 exhausted. This argument is not supported with array inputs.
795 validation_steps: Only relevant if `validation_data` is provided and
796 is a `tf.data` dataset. Total number of steps (batches of
797 samples) to draw before stopping when performing validation at
798 the end of every epoch. If 'validation_steps' is None,
799 validation will run until the `validation_data` dataset is
800 exhausted. In the case of a infinite dataset, it will run into a
801 infinite loop. If 'validation_steps' is specified and only part
802 of the dataset will be consumed, the evaluation will start from
803 the beginning of the dataset at each epoch. This ensures that
804 the same validation samples are used every time.
805 validation_freq: Only relevant if validation data is provided.
806 Integer or `collections.abc.Container` instance (e.g. list,
807 tuple, etc.). If an integer, specifies how many training epochs
808 to run before a new validation run is performed, e.g.
809 `validation_freq=2` runs validation every 2 epochs. If a
810 Container, specifies the epochs on which to run validation, e.g.
811 `validation_freq=[1, 2, 10]` runs validation at the end of the
812 1st, 2nd, and 10th epochs.
813 max_queue_size: Integer. Used for generator or
814 `keras.utils.Sequence` input only. Maximum size for the
815 generator queue. If unspecified, `max_queue_size` will default
816 to 10.
817 workers: Integer. Used for generator or `keras.utils.Sequence` input
818 only. Maximum number of processes to spin up
819 when using process-based threading. If unspecified, `workers`
820 will default to 1. If 0, will execute the generator on the main
821 thread.
822 use_multiprocessing: Boolean. Used for generator or
823 `keras.utils.Sequence` input only. If `True`, use process-based
824 threading. If unspecified, `use_multiprocessing` will default to
825 `False`. Note that because this implementation relies on
826 multiprocessing, you should not pass non-picklable arguments to
827 the generator as they can't be passed easily to children
828 processes.
829 **kwargs: Used for backwards compatibility.
831 Returns:
832 A `History` object. Its `History.history` attribute is
833 a record of training loss values and metrics values
834 at successive epochs, as well as validation loss values
835 and validation metrics values (if applicable).
837 Raises:
838 RuntimeError: If the model was never compiled.
839 ValueError: In case of mismatch between the provided input data
840 and what the model expects.
841 """
842 self._assert_built_as_v1()
843 base_layer.keras_api_gauge.get_cell("fit").set(True)
844 # Legacy support
845 if "nb_epoch" in kwargs:
846 logging.warning(
847 "The `nb_epoch` argument in `fit` has been renamed `epochs`."
848 )
849 epochs = kwargs.pop("nb_epoch")
850 if kwargs:
851 raise TypeError("Unrecognized keyword arguments: " + str(kwargs))
852 self._assert_compile_was_called()
853 self._check_call_args("fit")
855 func = self._select_training_loop(x)
856 return func.fit(
857 self,
858 x=x,
859 y=y,
860 batch_size=batch_size,
861 epochs=epochs,
862 verbose=verbose,
863 callbacks=callbacks,
864 validation_split=validation_split,
865 validation_data=validation_data,
866 shuffle=shuffle,
867 class_weight=class_weight,
868 sample_weight=sample_weight,
869 initial_epoch=initial_epoch,
870 steps_per_epoch=steps_per_epoch,
871 validation_steps=validation_steps,
872 validation_freq=validation_freq,
873 max_queue_size=max_queue_size,
874 workers=workers,
875 use_multiprocessing=use_multiprocessing,
876 )
878 def evaluate(
879 self,
880 x=None,
881 y=None,
882 batch_size=None,
883 verbose=1,
884 sample_weight=None,
885 steps=None,
886 callbacks=None,
887 max_queue_size=10,
888 workers=1,
889 use_multiprocessing=False,
890 ):
891 """Returns the loss value & metrics values for the model in test mode.
893 Computation is done in batches (see the `batch_size` arg.)
895 Args:
896 x: Input data. It could be:
897 - A Numpy array (or array-like), or a list of arrays
898 (in case the model has multiple inputs).
899 - A TensorFlow tensor, or a list of tensors
900 (in case the model has multiple inputs).
901 - A dict mapping input names to the corresponding array/tensors,
902 if the model has named inputs.
903 - A `tf.data` dataset.
904 - A generator or `keras.utils.Sequence` instance.
905 y: Target data. Like the input data `x`,
906 it could be either Numpy array(s) or TensorFlow tensor(s).
907 It should be consistent with `x` (you cannot have Numpy inputs and
908 tensor targets, or inversely).
909 If `x` is a dataset, generator or
910 `keras.utils.Sequence` instance, `y` should not be specified
911 (since targets will be obtained from the iterator/dataset).
912 batch_size: Integer or `None`.
913 Number of samples per batch of computation.
914 If unspecified, `batch_size` will default to 32.
915 Do not specify the `batch_size` if your data is in the
916 form of symbolic tensors, dataset,
917 generators, or `keras.utils.Sequence` instances (since they
918 generate batches).
919 verbose: 0 or 1. Verbosity mode.
920 0 = silent, 1 = progress bar.
921 sample_weight: Optional Numpy array of weights for
922 the test samples, used for weighting the loss function.
923 You can either pass a flat (1D)
924 Numpy array with the same length as the input samples
925 (1:1 mapping between weights and samples),
926 or in the case of temporal data,
927 you can pass a 2D array with shape
928 `(samples, sequence_length)`,
929 to apply a different weight to every timestep of every sample.
930 In this case you should make sure to specify
931 `sample_weight_mode="temporal"` in `compile()`. This argument is
932 not supported when `x` is a dataset, instead pass sample weights
933 as the third element of `x`.
934 steps: Integer or `None`.
935 Total number of steps (batches of samples)
936 before declaring the evaluation round finished.
937 Ignored with the default value of `None`.
938 If x is a `tf.data` dataset and `steps` is
939 None, 'evaluate' will run until the dataset is exhausted.
940 This argument is not supported with array inputs.
941 callbacks: List of `keras.callbacks.Callback` instances.
942 List of callbacks to apply during evaluation.
943 See [callbacks](/api_docs/python/tf/keras/callbacks).
944 max_queue_size: Integer. Used for generator or
945 `keras.utils.Sequence` input only. Maximum size for the
946 generator queue. If unspecified, `max_queue_size` will default
947 to 10.
948 workers: Integer. Used for generator or `keras.utils.Sequence` input
949 only. Maximum number of processes to spin up when using
950 process-based threading. If unspecified, `workers` will default
951 to 1. If 0, will execute the generator on the main thread.
952 use_multiprocessing: Boolean. Used for generator or
953 `keras.utils.Sequence` input only. If `True`, use process-based
954 threading. If unspecified, `use_multiprocessing` will default to
955 `False`. Note that because this implementation relies on
956 multiprocessing, you should not pass non-picklable arguments to
957 the generator as they can't be passed easily to children
958 processes.
960 Returns:
961 Scalar test loss (if the model has a single output and no metrics)
962 or list of scalars (if the model has multiple outputs
963 and/or metrics). The attribute `model.metrics_names` will give you
964 the display labels for the scalar outputs.
966 Raises:
967 ValueError: in case of invalid arguments.
968 """
969 self._assert_built_as_v1()
970 base_layer.keras_api_gauge.get_cell("evaluate").set(True)
971 self._assert_compile_was_called()
972 self._check_call_args("evaluate")
974 func = self._select_training_loop(x)
975 return func.evaluate(
976 self,
977 x=x,
978 y=y,
979 batch_size=batch_size,
980 verbose=verbose,
981 sample_weight=sample_weight,
982 steps=steps,
983 callbacks=callbacks,
984 max_queue_size=max_queue_size,
985 workers=workers,
986 use_multiprocessing=use_multiprocessing,
987 )
989 def predict(
990 self,
991 x,
992 batch_size=None,
993 verbose=0,
994 steps=None,
995 callbacks=None,
996 max_queue_size=10,
997 workers=1,
998 use_multiprocessing=False,
999 ):
1000 """Generates output predictions for the input samples.
1002 Computation is done in batches (see the `batch_size` arg.)
1004 Args:
1005 x: Input samples. It could be:
1006 - A Numpy array (or array-like), or a list of arrays
1007 (in case the model has multiple inputs).
1008 - A TensorFlow tensor, or a list of tensors
1009 (in case the model has multiple inputs).
1010 - A `tf.data` dataset.
1011 - A generator or `keras.utils.Sequence` instance.
1012 batch_size: Integer or `None`.
1013 Number of samples per batch of computation.
1014 If unspecified, `batch_size` will default to 32.
1015 Do not specify the `batch_size` if your data is in the
1016 form of symbolic tensors, dataset,
1017 generators, or `keras.utils.Sequence` instances (since they
1018 generate batches).
1019 verbose: Verbosity mode, 0 or 1.
1020 steps: Total number of steps (batches of samples)
1021 before declaring the prediction round finished.
1022 Ignored with the default value of `None`. If x is a `tf.data`
1023 dataset and `steps` is None, `predict` will
1024 run until the input dataset is exhausted.
1025 callbacks: List of `keras.callbacks.Callback` instances.
1026 List of callbacks to apply during prediction.
1027 See [callbacks](/api_docs/python/tf/keras/callbacks).
1028 max_queue_size: Integer. Used for generator or
1029 `keras.utils.Sequence` input only. Maximum size for the
1030 generator queue. If unspecified, `max_queue_size` will default
1031 to 10.
1032 workers: Integer. Used for generator or `keras.utils.Sequence` input
1033 only. Maximum number of processes to spin up when using
1034 process-based threading. If unspecified, `workers` will default
1035 to 1. If 0, will execute the generator on the main thread.
1036 use_multiprocessing: Boolean. Used for generator or
1037 `keras.utils.Sequence` input only. If `True`, use process-based
1038 threading. If unspecified, `use_multiprocessing` will default to
1039 `False`. Note that because this implementation relies on
1040 multiprocessing, you should not pass non-picklable arguments to
1041 the generator as they can't be passed easily to children
1042 processes.
1045 Returns:
1046 Numpy array(s) of predictions.
1048 Raises:
1049 ValueError: In case of mismatch between the provided
1050 input data and the model's expectations,
1051 or in case a stateful model receives a number of samples
1052 that is not a multiple of the batch size.
1053 """
1054 self._assert_built_as_v1()
1055 base_layer.keras_api_gauge.get_cell("predict").set(True)
1056 self._check_call_args("predict")
1058 func = self._select_training_loop(x)
1059 return func.predict(
1060 self,
1061 x=x,
1062 batch_size=batch_size,
1063 verbose=verbose,
1064 steps=steps,
1065 callbacks=callbacks,
1066 max_queue_size=max_queue_size,
1067 workers=workers,
1068 use_multiprocessing=use_multiprocessing,
1069 )
1071 def reset_metrics(self):
1072 """Resets the state of metrics."""
1073 metrics = self._get_training_eval_metrics()
1074 for m in metrics:
1075 m.reset_state()
1077 # Reset metrics on all the distributed (cloned) models.
1078 if self._distribution_strategy:
1079 distributed_training_utils_v1._reset_metrics(self)
1081 def train_on_batch(
1082 self,
1083 x,
1084 y=None,
1085 sample_weight=None,
1086 class_weight=None,
1087 reset_metrics=True,
1088 ):
1089 """Runs a single gradient update on a single batch of data.
1091 Args:
1092 x: Input data. It could be:
1093 - A Numpy array (or array-like), or a list of arrays
1094 (in case the model has multiple inputs).
1095 - A TensorFlow tensor, or a list of tensors
1096 (in case the model has multiple inputs).
1097 - A dict mapping input names to the corresponding array/tensors,
1098 if the model has named inputs.
1099 - A `tf.data` dataset.
1100 y: Target data. Like the input data `x`, it could be either Numpy
1101 array(s) or TensorFlow tensor(s). It should be consistent with `x`
1102 (you cannot have Numpy inputs and tensor targets, or inversely).
1103 If `x` is a dataset, `y` should not be specified
1104 (since targets will be obtained from the iterator).
1105 sample_weight: Optional array of the same length as x, containing
1106 weights to apply to the model's loss for each sample. In the case
1107 of temporal data, you can pass a 2D array with shape (samples,
1108 sequence_length), to apply a different weight to every timestep of
1109 every sample. In this case you should make sure to specify
1110 sample_weight_mode="temporal" in compile(). This argument is not
1111 supported when `x` is a dataset.
1112 class_weight: Optional dictionary mapping class indices (integers)
1113 to a weight (float) to apply to the model's loss for the samples
1114 from this class during training. This can be useful to tell the
1115 model to "pay more attention" to samples from an under-represented
1116 class.
1117 reset_metrics: If `True`, the metrics returned will be only for this
1118 batch. If `False`, the metrics will be statefully accumulated
1119 across batches.
1121 Returns:
1122 Scalar training loss
1123 (if the model has a single output and no metrics)
1124 or list of scalars (if the model has multiple outputs
1125 and/or metrics). The attribute `model.metrics_names` will give you
1126 the display labels for the scalar outputs.
1128 Raises:
1129 ValueError: In case of invalid user-provided arguments.
1130 """
1131 self._assert_compile_was_called()
1132 self._check_call_args("train_on_batch")
1134 # If at this point we are in the replica context, then it is okay to
1135 # execute the Eager code path. The expected way to get here is to call
1136 # `fit` that calls `train_on_batch` on each replica.
1137 if (
1138 self._distribution_strategy
1139 and tf.distribute.in_cross_replica_context()
1140 ):
1141 raise NotImplementedError(
1142 "`train_on_batch` is not supported for models "
1143 "distributed with tf.distribute.Strategy."
1144 )
1145 # Validate and standardize user data.
1146 x, y, sample_weights = self._standardize_user_data(
1147 x,
1148 y,
1149 sample_weight=sample_weight,
1150 class_weight=class_weight,
1151 extract_tensors_from_dataset=True,
1152 )
1154 # If `self._distribution_strategy` is True, then we are in a replica
1155 # context at this point because of the check above. `train_on_batch` is
1156 # being run for each replica by `self._distribution_strategy` and the
1157 # same code path as Eager is expected to be taken.
1158 if self.run_eagerly or self._distribution_strategy:
1159 output_dict = training_eager_v1.train_on_batch(
1160 self,
1161 x,
1162 y,
1163 sample_weights=sample_weights,
1164 output_loss_metrics=self._output_loss_metrics,
1165 )
1166 outputs = (
1167 output_dict["total_loss"]
1168 + output_dict["output_losses"]
1169 + output_dict["metrics"]
1170 )
1171 outputs = [_non_none_constant_value(v) for v in outputs]
1172 else:
1173 x = training_utils_v1.ModelInputs(x).as_list()
1174 ins = x + list(y or []) + list(sample_weights or [])
1176 if not isinstance(backend.symbolic_learning_phase(), int):
1177 ins += [True] # Add learning phase value.
1179 self._update_sample_weight_modes(sample_weights=sample_weights)
1180 self._make_train_function()
1181 outputs = self.train_function(ins)
1183 if reset_metrics:
1184 self.reset_metrics()
1186 if len(outputs) == 1:
1187 return outputs[0]
1188 return outputs
1190 def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True):
1191 """Test the model on a single batch of samples.
1193 Args:
1194 x: Input data. It could be:
1195 - A Numpy array (or array-like), or a list of arrays
1196 (in case the model has multiple inputs).
1197 - A TensorFlow tensor, or a list of tensors
1198 (in case the model has multiple inputs).
1199 - A dict mapping input names to the corresponding array/tensors,
1200 if the model has named inputs.
1201 - A `tf.data` dataset.
1202 y: Target data. Like the input data `x`,
1203 it could be either Numpy array(s) or TensorFlow tensor(s).
1204 It should be consistent with `x` (you cannot have Numpy inputs and
1205 tensor targets, or inversely). If `x` is a dataset `y` should
1206 not be specified (since targets will be obtained from the
1207 iterator).
1208 sample_weight: Optional array of the same length as x, containing
1209 weights to apply to the model's loss for each sample.
1210 In the case of temporal data, you can pass a 2D array
1211 with shape (samples, sequence_length),
1212 to apply a different weight to every timestep of every sample.
1213 In this case you should make sure to specify
1214 sample_weight_mode="temporal" in compile(). This argument is not
1215 supported when `x` is a dataset.
1216 reset_metrics: If `True`, the metrics returned will be only for this
1217 batch. If `False`, the metrics will be statefully accumulated
1218 across batches.
1220 Returns:
1221 Scalar test loss (if the model has a single output and no metrics)
1222 or list of scalars (if the model has multiple outputs
1223 and/or metrics). The attribute `model.metrics_names` will give you
1224 the display labels for the scalar outputs.
1226 Raises:
1227 ValueError: In case of invalid user-provided arguments.
1228 """
1229 self._assert_compile_was_called()
1230 self._check_call_args("test_on_batch")
1232 if (
1233 self._distribution_strategy
1234 and tf.distribute.in_cross_replica_context()
1235 ):
1236 raise NotImplementedError(
1237 "`test_on_batch` is not supported for models "
1238 "distributed with tf.distribute.Strategy."
1239 )
1240 # Validate and standardize user data.
1241 x, y, sample_weights = self._standardize_user_data(
1242 x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True
1243 )
1245 # If `self._distribution_strategy` is True, then we are in a replica
1246 # context at this point.
1247 if self.run_eagerly or self._distribution_strategy:
1248 output_dict = training_eager_v1.test_on_batch(
1249 self,
1250 x,
1251 y,
1252 sample_weights=sample_weights,
1253 output_loss_metrics=self._output_loss_metrics,
1254 )
1255 outputs = (
1256 output_dict["total_loss"]
1257 + output_dict["output_losses"]
1258 + output_dict["metrics"]
1259 )
1260 outputs = [_non_none_constant_value(v) for v in outputs]
1261 else:
1262 x = training_utils_v1.ModelInputs(x).as_list()
1263 inputs = x + list(y or []) + list(sample_weights or [])
1265 self._update_sample_weight_modes(sample_weights=sample_weights)
1266 self._make_test_function()
1267 outputs = self.test_function(inputs)
1269 if reset_metrics:
1270 self.reset_metrics()
1272 if len(outputs) == 1:
1273 return outputs[0]
1274 return outputs
1276 def predict_on_batch(self, x):
1277 """Returns predictions for a single batch of samples.
1279 Args:
1280 x: Input data. It could be:
1281 - A Numpy array (or array-like), or a list of arrays
1282 (in case the model has multiple inputs).
1283 - A TensorFlow tensor, or a list of tensors
1284 (in case the model has multiple inputs).
1285 - A `tf.data` dataset.
1287 Returns:
1288 Numpy array(s) of predictions.
1290 Raises:
1291 ValueError: In case of mismatch between given number of inputs and
1292 expectations of the model.
1293 """
1294 self._check_call_args("predict_on_batch")
1296 if (
1297 self._distribution_strategy
1298 and tf.distribute.in_cross_replica_context()
1299 ):
1300 raise NotImplementedError(
1301 "`predict_on_batch` is not supported for models distributed "
1302 "with tf.distribute.Strategy."
1303 )
1304 # Validate and standardize user data.
1305 inputs, _, _ = self._standardize_user_data(
1306 x, extract_tensors_from_dataset=True
1307 )
1308 # If `self._distribution_strategy` is True, then we are in a replica
1309 # context at this point.
1310 if self.run_eagerly or self._distribution_strategy:
1311 inputs = training_utils_v1.cast_if_floating_dtype(inputs)
1312 if isinstance(inputs, collections.abc.Sequence):
1313 # Unwrap lists with only one input, as we do when training on
1314 # batch
1315 if len(inputs) == 1:
1316 inputs = inputs[0]
1318 return self(inputs)
1320 self._make_predict_function()
1321 outputs = self.predict_function(inputs)
1323 if len(outputs) == 1:
1324 return outputs[0]
1325 return outputs
1327 def fit_generator(
1328 self,
1329 generator,
1330 steps_per_epoch=None,
1331 epochs=1,
1332 verbose=1,
1333 callbacks=None,
1334 validation_data=None,
1335 validation_steps=None,
1336 validation_freq=1,
1337 class_weight=None,
1338 max_queue_size=10,
1339 workers=1,
1340 use_multiprocessing=False,
1341 shuffle=True,
1342 initial_epoch=0,
1343 ):
1344 """Fits the model on data yielded batch-by-batch by a Python generator.
1346 DEPRECATED:
1347 `Model.fit` now supports generators, so there is no longer any need to
1348 use this endpoint.
1349 """
1350 warnings.warn(
1351 "`model.fit_generator` is deprecated and "
1352 "will be removed in a future version. "
1353 "Please use `Model.fit`, which supports generators.",
1354 stacklevel=2,
1355 )
1356 return self.fit(
1357 generator,
1358 steps_per_epoch=steps_per_epoch,
1359 epochs=epochs,
1360 verbose=verbose,
1361 callbacks=callbacks,
1362 validation_data=validation_data,
1363 validation_steps=validation_steps,
1364 validation_freq=validation_freq,
1365 class_weight=class_weight,
1366 max_queue_size=max_queue_size,
1367 workers=workers,
1368 use_multiprocessing=use_multiprocessing,
1369 shuffle=shuffle,
1370 initial_epoch=initial_epoch,
1371 )
1373 def evaluate_generator(
1374 self,
1375 generator,
1376 steps=None,
1377 callbacks=None,
1378 max_queue_size=10,
1379 workers=1,
1380 use_multiprocessing=False,
1381 verbose=0,
1382 ):
1383 """Evaluates the model on a data generator.
1385 DEPRECATED:
1386 `Model.evaluate` now supports generators, so there is no longer any
1387 need to use this endpoint.
1388 """
1389 warnings.warn(
1390 "`Model.evaluate_generator` is deprecated and "
1391 "will be removed in a future version. "
1392 "Please use `Model.evaluate`, which supports generators.",
1393 stacklevel=2,
1394 )
1395 self._check_call_args("evaluate_generator")
1397 return self.evaluate(
1398 generator,
1399 steps=steps,
1400 max_queue_size=max_queue_size,
1401 workers=workers,
1402 use_multiprocessing=use_multiprocessing,
1403 verbose=verbose,
1404 callbacks=callbacks,
1405 )
1407 def predict_generator(
1408 self,
1409 generator,
1410 steps=None,
1411 callbacks=None,
1412 max_queue_size=10,
1413 workers=1,
1414 use_multiprocessing=False,
1415 verbose=0,
1416 ):
1417 """Generates predictions for the input samples from a data generator.
1419 DEPRECATED:
1420 `Model.predict` now supports generators, so there is no longer any
1421 need to use this endpoint.
1422 """
1423 warnings.warn(
1424 "`Model.predict_generator` is deprecated and "
1425 "will be removed in a future version. "
1426 "Please use `Model.predict`, which supports generators.",
1427 stacklevel=2,
1428 )
1429 return self.predict(
1430 generator,
1431 steps=steps,
1432 max_queue_size=max_queue_size,
1433 workers=workers,
1434 use_multiprocessing=use_multiprocessing,
1435 verbose=verbose,
1436 callbacks=callbacks,
1437 )
1439 def _check_call_args(self, method_name):
1440 """Check that `call` has only one positional arg."""
1441 # Always allow first arg, regardless of arg name.
1442 fullargspec = self._call_spec.full_argspec
1443 if fullargspec.defaults:
1444 positional_args = fullargspec.args[: -len(fullargspec.defaults)]
1445 else:
1446 positional_args = fullargspec.args
1447 if "training" in positional_args:
1448 positional_args.remove("training")
1450 # self and first arg can be positional.
1451 if len(positional_args) > 2:
1452 extra_args = positional_args[2:]
1453 raise ValueError(
1454 "Models passed to `"
1455 + method_name
1456 + "` can only have `training` "
1457 "and the first argument in `call` as positional arguments, "
1458 "found: " + str(extra_args) + "."
1459 )
1461 def _set_optimizer(self, optimizer):
1462 """Sets self.optimizer.
1464 Sets self.optimizer to `optimizer`, potentially wrapping it with a
1465 LossScaleOptimizer.
1467 Args:
1468 optimizer: The optimizer(s) to assign to self.optimizer.
1469 """
1470 if isinstance(optimizer, (list, tuple)):
1471 self.optimizer = [optimizers.get(opt) for opt in optimizer]
1472 else:
1473 self.optimizer = optimizers.get(optimizer)
1475 if self._dtype_policy.name == "mixed_float16" and not isinstance(
1476 self.optimizer, loss_scale_optimizer.LossScaleOptimizer
1477 ):
1478 if isinstance(self.optimizer, list):
1479 raise ValueError(
1480 'When the "mixed_float16" dtype policy is used, you '
1481 "can only pass a single optimizer. Using policy %s "
1482 "and got optimizers: %s" % self._dtype_policy,
1483 self.optimizer,
1484 )
1485 if not isinstance(self.optimizer, optimizer_v2.OptimizerV2):
1486 raise ValueError(
1487 '"optimizer" must be an instance of '
1488 "tf.keras.optimizers.legacy.Optimizer when a dype policy "
1489 "with a loss scale is used, but got: %s. Using policy: "
1490 "%s" % (self.optimizer, self._dtype_policy)
1491 )
1492 self.optimizer = loss_scale_optimizer.LossScaleOptimizer(
1493 self.optimizer
1494 )
1496 def _prepare_validation_data(
1497 self, validation_data, batch_size, validation_steps
1498 ):
1499 """Unpack and check the validation data."""
1500 (
1501 val_x,
1502 val_y,
1503 val_sample_weights,
1504 ) = training_utils_v1.unpack_validation_data(validation_data)
1505 return self._standardize_user_data(
1506 val_x,
1507 val_y,
1508 sample_weight=val_sample_weights,
1509 batch_size=batch_size,
1510 steps=validation_steps,
1511 steps_name="validation_steps",
1512 )
1514 def _validate_compile_param_for_distribution_strategy(
1515 self, run_eagerly, sample_weight_mode, target_tensors, weighted_metrics
1516 ):
1517 # Validate that arguments passed by the user to `compile` are supported
1518 # by tf.distribute.Strategy.
1519 if self._distribution_strategy:
1520 if sample_weight_mode:
1521 raise NotImplementedError(
1522 "sample_weight_mode is not supported with "
1523 "tf.distribute.Strategy."
1524 )
1525 if weighted_metrics:
1526 raise NotImplementedError(
1527 "weighted_metrics is not supported with "
1528 "tf.distribute.Strategy."
1529 )
1530 if target_tensors:
1531 raise ValueError(
1532 "target_tensors is not supported with "
1533 "tf.distribute.Strategy."
1534 )
1536 if run_eagerly:
1537 raise ValueError(
1538 "We currently do not support enabling `run_eagerly` with "
1539 "distribution strategy."
1540 )
1542 if distributed_training_utils_v1.is_distributing_by_cloning(
1543 self
1544 ) and (not self.built or not self.inputs or not self.outputs):
1545 raise ValueError(
1546 "We currently do not support distribution strategy with a "
1547 "`Sequential` model that is created without `input_shape`/"
1548 "`input_dim` set in its first layer or a subclassed model."
1549 )
1551 def _process_target_tensor_for_compile(self, target_tensors):
1552 if self.run_eagerly:
1553 # target tensor is not supported with run_eagerly. Create a list
1554 # with None as placeholder for each output.
1555 return [None for _ in self.output_names]
1557 if target_tensors is not None and not (
1558 isinstance(target_tensors, list) and target_tensors == []
1559 ):
1560 if isinstance(target_tensors, list):
1561 if len(target_tensors) != len(self.outputs):
1562 raise ValueError(
1563 "When passing a list as `target_tensors`, "
1564 "it should have one entry per model output. "
1565 "The model has %s outputs, "
1566 "but you passed target_tensors=%s"
1567 % (len(self.outputs), target_tensors)
1568 )
1569 elif isinstance(target_tensors, dict):
1570 unexpected_target_tensor_names = set(
1571 target_tensors.keys()
1572 ).difference(self.output_names)
1573 if unexpected_target_tensor_names:
1574 raise ValueError(
1575 "Unknown entry in `target_tensors` dictionary: "
1576 '"{name}". '
1577 "Only expected the following keys: {keys}".format(
1578 name=unexpected_target_tensor_names,
1579 keys=str(self.output_names),
1580 )
1581 )
1582 tmp_target_tensors = []
1583 for name in self.output_names:
1584 tmp_target_tensors.append(target_tensors.get(name, None))
1585 target_tensors = tmp_target_tensors
1586 elif tf.is_tensor(target_tensors):
1587 target_tensors = [target_tensors]
1588 else:
1589 raise TypeError(
1590 "Expected `target_tensors` to be a list or tuple or "
1591 "dict or a single tensor, but got:",
1592 target_tensors,
1593 )
1594 else:
1595 # In case target tensor is empty or None, create a list with Nones
1596 # that has same length as self.output_names. With that, the None
1597 # check of target tensor can be skipped downstream.
1598 target_tensors = [None for _ in self.output_names]
1599 return target_tensors
1601 def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode):
1602 # Prepare sample weight modes. List with the same length as model
1603 # outputs.
1604 training_utils_v1.prepare_sample_weight_modes(
1605 self._training_endpoints, sample_weight_mode
1606 )
1607 # Prepare sample weights.
1608 self._prepare_sample_weights()
1609 # Save all metric attributes per output of the model.
1610 self._cache_output_metric_attributes(metrics, weighted_metrics)
1611 self.total_loss = None
1612 # Set metric attributes on model.
1613 self._set_metric_attributes()
1615 self._collected_trainable_weights = self.trainable_weights
1617 def _update_sample_weight_modes(self, sample_weights=None):
1618 """Updates sample weight modes based on training/eval inputs.
1620 Sample weight placeholders will be created for all or no outputs
1621 based on whether sample_weight is provided for any output.
1623 If model contains `_sample_weight_modes` we check if the input
1624 `sample_weights` corresponds to the sample weight modes.
1625 1. Set sample weight mode to be 'temporal' for output i, if `compile`
1626 sample_weight_mode was set to `temporal` and sample weight inputs
1627 are given for one or more outputs.
1628 2. Set sample weight mode to be 'samplewise' for output i, if
1629 `compile` sample_weight_mode was not set and sample weight inputs
1630 are given for one or more outputs.
1631 3. Reset sample weight mode to None for output i if sample weight mode
1632 was set but there is no sample weight input.
1634 Args:
1635 sample_weights: List of sample weights of the same length as model
1636 outputs or None.
1637 """
1638 if not self._is_compiled:
1639 return
1640 if sample_weights and any(s is not None for s in sample_weights):
1641 for endpoint in self._training_endpoints:
1642 endpoint.sample_weight_mode = (
1643 endpoint.sample_weight_mode or "samplewise"
1644 )
1645 else:
1646 for endpoint in self._training_endpoints:
1647 endpoint.sample_weight_mode = None
1649 def _recompile_weights_loss_and_weighted_metrics(self):
1650 if not self._is_compiled:
1651 return False
1652 recompile = any(
1653 e.sample_weights_mismatch() for e in self._training_endpoints
1654 )
1656 if recompile:
1657 self._compile_weights_loss_and_weighted_metrics()
1658 return recompile
1660 @tf.__internal__.tracking.no_automatic_dependency_tracking
1661 def _compile_weights_loss_and_weighted_metrics(self, sample_weights=None):
1662 """Compiles the model loss and weighted metric sub-graphs.
1664 This may be used to set graph tensors as sample weights (instead of
1665 creating placeholders). This functionality is necessary for
1666 `tf.keras.estimator.model_to_estimator`, which calls Keras models in a
1667 v1 graph, and creates iterator tensors for inputs, targets, and sample
1668 weights.
1670 Args:
1671 sample_weights: List of tensors to use as the sample weights. Must be
1672 the same length as the number of outputs. If left as `None`,
1673 placeholders are used instead.
1674 """
1675 with backend.get_graph().as_default():
1676 if sample_weights is not None:
1677 self._update_sample_weight_modes(sample_weights)
1678 self._prepare_sample_weights(sample_weights)
1680 masks = self._prepare_output_masks()
1682 # Compute weighted metrics.
1683 self._handle_metrics(
1684 self.outputs,
1685 targets=self._targets,
1686 skip_target_masks=self._prepare_skip_target_masks(),
1687 sample_weights=self.sample_weights,
1688 masks=masks,
1689 return_weighted_metrics=True,
1690 )
1692 # Compute total loss.
1693 # Used to keep track of the total loss value (stateless).
1694 # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
1695 # loss_weight_2 * output_2_loss_fn(...) +
1696 # layer losses.
1697 self.total_loss = self._prepare_total_loss(masks)
1699 def _prepare_skip_target_masks(self):
1700 """Boolean mask for whether target in output list should be skipped.
1702 If the loss function corresponding to a model output is None, then this
1703 output will be skipped during total loss calculation and feed targets
1704 preparation.
1706 Returns:
1707 A boolean list for whether the corresponding target in the output list
1708 should be skipped during loss calculation.
1709 """
1710 return [l is None for l in self.loss_functions]
1712 def _prepare_output_masks(self):
1713 """Returns masks corresponding to model outputs."""
1714 return [getattr(x, "_keras_mask", None) for x in self.outputs]
1716 def _prepare_total_loss(self, masks):
1717 """Computes total loss from loss functions.
1719 Args:
1720 masks: List of mask values corresponding to each model output.
1722 Returns:
1723 A list of loss weights of python floats.
1725 Raises:
1726 TypeError: If model run_eagerly is True.
1727 """
1728 if self.run_eagerly:
1729 raise TypeError(
1730 "total loss can not be computed when compiled with "
1731 "run_eagerly = True."
1732 )
1733 loss_list = []
1734 with backend.name_scope("loss"):
1735 for endpoint, mask in zip(self._training_endpoints, masks):
1736 if endpoint.should_skip_target():
1737 continue
1738 y_true = endpoint.training_target.target
1739 y_pred = endpoint.output
1740 loss_fn = endpoint.loss_fn
1741 loss_weight = endpoint.loss_weight
1742 loss_name = endpoint.loss_name()
1743 sample_weight = endpoint.sample_weight
1745 with backend.name_scope(loss_name):
1746 if mask is not None:
1747 mask = tf.cast(mask, y_pred.dtype)
1748 # Update weights with mask.
1749 if sample_weight is None:
1750 sample_weight = mask
1751 else:
1752 # Update dimensions of weights to match with mask if
1753 # possible.
1754 (
1755 mask,
1756 _,
1757 sample_weight,
1758 ) = losses_utils.squeeze_or_expand_dimensions(
1759 mask, sample_weight=sample_weight
1760 )
1762 if hasattr(loss_fn, "reduction"):
1763 per_sample_losses = loss_fn.call(y_true, y_pred)
1764 sample_weight = losses_utils.apply_valid_mask(
1765 per_sample_losses,
1766 sample_weight,
1767 mask,
1768 loss_fn.reduction,
1769 )
1770 weighted_losses = losses_utils.compute_weighted_loss(
1771 per_sample_losses,
1772 sample_weight=sample_weight,
1773 reduction=losses_utils.ReductionV2.NONE,
1774 )
1775 loss_reduction = loss_fn.reduction
1777 # `AUTO` loss reduction defaults to
1778 # `SUM_OVER_BATCH_SIZE` for all compile use cases.
1779 if loss_reduction == losses_utils.ReductionV2.AUTO:
1780 loss_reduction = (
1781 losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
1782 )
1784 # Compute the stateless loss value.
1785 output_loss = losses_utils.reduce_weighted_loss(
1786 weighted_losses, reduction=loss_reduction
1787 )
1788 else:
1789 # Compute the stateless loss value for a custom loss
1790 # class. Here we assume that the class takes care of
1791 # loss reduction because if this class returns a vector
1792 # value we cannot differentiate between use case where a
1793 # custom optimizer expects a vector loss value vs
1794 # unreduced per-sample loss value.
1795 output_loss = loss_fn(
1796 y_true, y_pred, sample_weight=sample_weight
1797 )
1798 loss_reduction = (
1799 losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
1800 )
1802 if len(self.outputs) > 1:
1803 # Keep track of stateful result tensor for the loss.
1804 endpoint.output_loss_metric(output_loss)
1806 # Scale output loss for distribution. For custom losses we
1807 # assume reduction was mean.
1808 if (
1809 loss_reduction
1810 == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
1811 ):
1812 output_loss = losses_utils.scale_loss_for_distribution(
1813 output_loss
1814 )
1816 loss_list.append(loss_weight * output_loss)
1817 if not loss_list and not self.losses:
1818 raise ValueError(
1819 "The model cannot be compiled "
1820 "because it has no loss to optimize."
1821 )
1823 # Add regularization penalties and other layer-specific losses.
1824 custom_losses = self.get_losses_for(None) + self.get_losses_for(
1825 self.inputs
1826 )
1827 if custom_losses:
1828 total_custom_loss = tf.add_n(
1829 losses_utils.cast_losses_to_common_dtype(custom_losses)
1830 )
1831 loss_list.append(
1832 losses_utils.scale_loss_for_distribution(total_custom_loss)
1833 )
1835 loss_list = losses_utils.cast_losses_to_common_dtype(loss_list)
1836 if loss_list:
1837 total_loss = tf.add_n(loss_list)
1838 else:
1839 total_loss = 0.0
1840 return total_loss
1842 def _get_callback_model(self):
1843 """Returns the Callback Model for this Model."""
1845 if hasattr(self, "_replicated_model") and self._replicated_model:
1846 # When using training_distributed, we set the callback model
1847 # to an instance of the `DistributedModel` that we create in
1848 # the `compile` call. The `DistributedModel` is initialized
1849 # with the first replicated model. We need to set the callback
1850 # model to a DistributedModel to allow us to override saving
1851 # and loading weights when we checkpoint the model during training.
1852 return self._replicated_model
1853 if hasattr(self, "callback_model") and self.callback_model:
1854 return self.callback_model
1855 return self
1857 @tf.__internal__.tracking.no_automatic_dependency_tracking
1858 def _make_callback_model(self, grouped_model):
1859 first_replicated_model = self._distribution_strategy.unwrap(
1860 grouped_model
1861 )[0]
1862 # We initialize the callback model with the first replicated model.
1863 self._replicated_model = DistributedCallbackModel(
1864 first_replicated_model
1865 )
1866 self._replicated_model.set_original_model(self)
1868 def _validate_or_infer_batch_size(self, batch_size, steps, x):
1869 """Validates that `batch_size` provided is consistent with InputLayer.
1871 It's possible that the user specified a static batch size in their
1872 InputLayer. If so, this method checks the provided `batch_size` and `x`
1873 arguments are consistent with this static batch size. Also, if
1874 `batch_size` is `None`, this method will attempt to infer the batch size
1875 from the static batch size of the InputLayer. Lastly, ValueError will be
1876 raised if `x` is a tf.data.Dataset and `batch_size` is specified as we
1877 expect users to provide batched datasets.
1879 Args:
1880 batch_size: The batch_size provided as an argument to
1881 fit/evaluate/predict.
1882 steps: The steps provided as an argument to fit/evaluate/predict.
1883 x: The data passed as `x` to fit/evaluate/predict.
1885 Returns:
1886 The validated batch_size, auto-inferred from the first layer if not
1887 provided.
1888 """
1889 if isinstance(
1890 x, (tf.compat.v1.data.Dataset, tf.data.Dataset, data_utils.Sequence)
1891 ) or tf_inspect.isgenerator(x):
1892 if batch_size is not None:
1893 raise ValueError(
1894 "The `batch_size` argument must not be specified for the "
1895 "given input type. Received input: "
1896 "{}, batch_size: {}".format(x, batch_size)
1897 )
1898 return
1900 # Avoids the override in Sequential.layers which filters Input layers.
1901 # (Which are often the very layers that we're after.)
1902 layers = self._flatten_layers(include_self=False, recursive=False)
1903 first_layer = next(layers, None)
1904 if first_layer:
1905 # The per-replica static batch size.
1906 static_batch_size = training_utils.get_static_batch_size(
1907 first_layer
1908 )
1909 if static_batch_size is not None:
1911 # Determine number of times the user-supplied batch size will be
1912 # split.
1913 if (
1914 self._distribution_strategy
1915 and distributed_training_utils.global_batch_size_supported(
1916 self._distribution_strategy
1917 )
1918 ):
1919 num_splits_for_ds = (
1920 self._distribution_strategy.num_replicas_in_sync
1921 )
1922 else:
1923 num_splits_for_ds = 1
1925 # Check `batch_size` argument is consistent with InputLayer.
1926 if batch_size is not None:
1927 if batch_size % num_splits_for_ds != 0:
1928 raise ValueError(
1929 "The `batch_size` argument ({}) must be divisible "
1930 "the by number of replicas ({})".format(
1931 batch_size, num_splits_for_ds
1932 )
1933 )
1934 per_replica_batch_size = batch_size // num_splits_for_ds
1936 if per_replica_batch_size != static_batch_size:
1937 raise ValueError(
1938 "The `batch_size` argument value {} is "
1939 "incompatible with the specified batch size of "
1940 "your Input Layer: {}".format(
1941 per_replica_batch_size, static_batch_size
1942 )
1943 )
1945 # Check Dataset/Iterator batch size is consistent with
1946 # InputLayer.
1947 if isinstance(
1948 x,
1949 (
1950 tf.data.Dataset,
1951 tf.compat.v1.data.Iterator,
1952 tf.data.Iterator,
1953 ),
1954 ):
1955 ds_batch_size = tf.compat.v1.Dimension(
1956 tf.nest.flatten(tf.compat.v1.data.get_output_shapes(x))[
1957 0
1958 ][0]
1959 ).value
1960 if ds_batch_size is not None:
1961 if ds_batch_size % num_splits_for_ds != 0:
1962 raise ValueError(
1963 "The batch output shape of your `Dataset` {} "
1964 "cannot be divisible by number of "
1965 "replicas {}".format(
1966 ds_batch_size, num_splits_for_ds
1967 )
1968 )
1970 ds_per_replica_batch_size = (
1971 ds_batch_size // num_splits_for_ds
1972 )
1973 if ds_per_replica_batch_size != static_batch_size:
1974 raise ValueError(
1975 "The batch output shape of your `Dataset` is "
1976 "{}, which is incompatible with the specified "
1977 "batch size of your Input Layer: {}".format(
1978 ds_per_replica_batch_size, static_batch_size
1979 )
1980 )
1982 # Set inferred batch size from the InputLayer.
1983 if steps is None:
1984 batch_size = static_batch_size * num_splits_for_ds
1986 if batch_size is None and steps is None:
1987 # Backwards compatibility
1988 batch_size = 32
1989 return batch_size
1991 def _prepare_sample_weights(self, sample_weights=None):
1992 """Sets sample weight attribute on the model."""
1993 # List with the same length as model outputs.
1994 if sample_weights is not None:
1995 if len(sample_weights) != len(self._training_endpoints):
1996 raise ValueError(
1997 "Provided sample weights must have same length as the "
1998 "number of outputs. Expected: {}, got: {}.".format(
1999 len(self._training_endpoints), len(sample_weights)
2000 )
2001 )
2002 else:
2003 sample_weights = [None] * len(self._training_endpoints)
2004 for endpoint, weight in zip(self._training_endpoints, sample_weights):
2005 endpoint.populate_sample_weight(weight, endpoint.sample_weight_mode)
2007 def _cache_output_metric_attributes(self, metrics, weighted_metrics):
2008 """Caches metric name and function attributes for every model output."""
2009 output_shapes = []
2010 for output in self.outputs:
2011 if output is None or output.shape.rank is None:
2012 output_shapes.append(None)
2013 else:
2014 output_shapes.append(output.shape.as_list())
2015 self._per_output_metrics = (
2016 training_utils_v1.collect_per_output_metric_info(
2017 metrics,
2018 self.output_names,
2019 output_shapes,
2020 self.loss_functions,
2021 from_serialized=self._from_serialized,
2022 )
2023 )
2024 self._per_output_weighted_metrics = (
2025 training_utils_v1.collect_per_output_metric_info(
2026 weighted_metrics,
2027 self.output_names,
2028 output_shapes,
2029 self.loss_functions,
2030 from_serialized=self._from_serialized,
2031 is_weighted=True,
2032 )
2033 )
2035 def _add_unique_metric_name(self, metric_name, metric_fn, output_index):
2036 """Makes the metric name unique.
2038 If there are multiple outputs for which the metrics are calculated,
2039 the metric names have to be made unique by appending an integer.
2041 Args:
2042 metric_name: Metric name that corresponds to the metric specified by
2043 the user. For example: 'acc'.
2044 metric_fn: The Metric object.
2045 output_index: The index of the model output for which the metric name
2046 is being added.
2048 Returns:
2049 string, name of the model's unique metric name
2050 """
2051 # For multi-output models, prepend the output names to the metric name.
2052 if len(self.output_names) > 1:
2053 # If we're loading from an already-serialized model, we've already
2054 # prepended the output name, and we don't want to do it again.
2055 #
2056 # Alternatively, we may be receiving a stateless metric (e.g. the
2057 # string "accuracy") rather than a `Metric` object, in which case we
2058 # want to prepend the output name even if we are loading a
2059 # serialized model.
2060 if not getattr(metric_fn, "_from_serialized", False):
2061 metric_name = f"{self.output_names[output_index]}_{metric_name}"
2063 j = 1
2064 base_metric_name = metric_name
2065 while metric_name in self.metrics_names:
2066 metric_name = "%s_%d" % (base_metric_name, j)
2067 j += 1
2069 return metric_name
2071 def _init_metric_attributes(self):
2072 """Initialized model metric attributes."""
2073 # List of stateful metric functions. Used for resetting metric state
2074 # during training/eval.
2075 self._compile_metric_functions = []
2077 def _set_per_output_metric_attributes(self, metrics_dict, output_index):
2078 """Sets the metric attributes on the model for the given output.
2080 Args:
2081 metrics_dict: A dict with metric names as keys and metric fns as
2082 values.
2083 output_index: The index of the model output for which the metric
2084 attributes are added.
2086 Returns:
2087 Metrics dict updated with unique metric names as keys.
2088 """
2089 updated_metrics_dict = collections.OrderedDict()
2090 for metric_name, metric_fn in metrics_dict.items():
2091 metric_name = self._add_unique_metric_name(
2092 metric_name, metric_fn, output_index
2093 )
2095 # Update the name on the metric class to be the unique generated
2096 # name.
2097 metric_fn._name = metric_name
2098 updated_metrics_dict[metric_name] = metric_fn
2099 # Keep track of metric name and function.
2100 self._compile_metric_functions.append(metric_fn)
2101 return updated_metrics_dict
2103 def _set_metric_attributes(self):
2104 """Sets the metric attributes on the model for all the model outputs."""
2105 updated_per_output_metrics = []
2106 updated_per_output_weighted_metrics = []
2107 for i, endpoint in enumerate(self._training_endpoints):
2108 if endpoint.should_skip_target():
2109 updated_per_output_metrics.append(self._per_output_metrics[i])
2110 updated_per_output_weighted_metrics.append(
2111 self._per_output_weighted_metrics[i]
2112 )
2113 continue
2114 updated_per_output_metrics.append(
2115 self._set_per_output_metric_attributes(
2116 self._per_output_metrics[i], i
2117 )
2118 )
2119 updated_per_output_weighted_metrics.append(
2120 self._set_per_output_metric_attributes(
2121 self._per_output_weighted_metrics[i], i
2122 )
2123 )
2125 # Create a metric wrapper for each output loss. This computes mean of an
2126 # output loss across mini-batches (irrespective of how we reduce within
2127 # a batch).
2128 if len(self._training_endpoints) > 1:
2129 for endpoint in self._training_endpoints:
2130 if not endpoint.should_skip_target():
2131 endpoint.output_loss_metric = metrics_module.Mean(
2132 name=endpoint.loss_name()
2133 )
2135 self._per_output_metrics = updated_per_output_metrics
2136 self._per_output_weighted_metrics = updated_per_output_weighted_metrics
2138 def _handle_per_output_metrics(
2139 self, metrics_dict, y_true, y_pred, mask, weights=None
2140 ):
2141 """Calls metric functions for a single output.
2143 Args:
2144 metrics_dict: A dict with metric names as keys and metric fns as
2145 values.
2146 y_true: Target output.
2147 y_pred: Predicted output.
2148 mask: Computed mask value for the current output.
2149 weights: Weights to be applied on the current output.
2151 Returns:
2152 A list of metric result tensors.
2153 """
2154 metric_results = []
2155 for metric_name, metric_fn in metrics_dict.items():
2156 with backend.name_scope(metric_name):
2157 metric_result = training_utils_v1.call_metric_function(
2158 metric_fn, y_true, y_pred, weights=weights, mask=mask
2159 )
2160 metric_results.append(metric_result)
2161 return metric_results
2163 def _handle_metrics(
2164 self,
2165 outputs,
2166 targets=None,
2167 skip_target_masks=None,
2168 sample_weights=None,
2169 masks=None,
2170 return_weighted_metrics=False,
2171 return_weighted_and_unweighted_metrics=False,
2172 ):
2173 """Handles calling metric functions.
2175 Args:
2176 outputs: List of outputs (predictions).
2177 targets: List of targets.
2178 skip_target_masks: Optional. List of boolean for whether the
2179 corresponding target should be ignored or not.
2180 sample_weights: Optional list of sample weight arrays.
2181 masks: List of computed output mask values.
2182 return_weighted_metrics: Flag that indicates whether weighted metrics
2183 should be computed instead of unweighted metrics. This flag is
2184 ignored when `return_weighted_and_unweighted_metrics` is enabled.
2185 return_weighted_and_unweighted_metrics: Flag that is used to indicate
2186 whether both weighted and unweighted metrics should be computed.
2187 When this is not enabled, we use `return_weighted_metrics` param to
2188 indicate whether weighted or unweighted metrics should be returned.
2190 Returns:
2191 A list of metric result tensors.
2192 """
2193 # TODO(scottzhu): Update this to use the new training_endpoints.
2194 # Currently the eager and graph logic is bit different.
2195 skip_target_masks = skip_target_masks or [False] * len(outputs)
2196 metric_results = []
2197 with backend.name_scope("metrics"):
2198 # Invoke all metrics added using `compile`.
2199 for i in range(len(outputs)):
2200 if skip_target_masks[i]:
2201 continue
2202 output = outputs[i] if outputs else None
2203 target = targets[i] if targets else None
2204 output_mask = masks[i] if masks else None
2206 if (
2207 return_weighted_and_unweighted_metrics
2208 or not return_weighted_metrics
2209 ):
2210 metric_results.extend(
2211 self._handle_per_output_metrics(
2212 self._per_output_metrics[i],
2213 target,
2214 output,
2215 output_mask,
2216 )
2217 )
2218 if (
2219 return_weighted_and_unweighted_metrics
2220 or return_weighted_metrics
2221 ):
2222 metric_results.extend(
2223 self._handle_per_output_metrics(
2224 self._per_output_weighted_metrics[i],
2225 target,
2226 output,
2227 output_mask,
2228 weights=sample_weights[i]
2229 if sample_weights
2230 else None,
2231 )
2232 )
2233 return metric_results
2235 def _check_trainable_weights_consistency(self):
2236 """Check trainable weights count consistency.
2238 This will raise a warning if `trainable_weights` and
2239 `_collected_trainable_weights` are inconsistent (i.e. have different
2240 number of parameters).
2241 Inconsistency will typically arise when one modifies `model.trainable`
2242 without calling `model.compile` again.
2243 """
2244 if not hasattr(self, "_collected_trainable_weights"):
2245 return
2247 if len(self.trainable_weights) != len(
2248 self._collected_trainable_weights
2249 ):
2250 logging.log_first_n(
2251 logging.WARN,
2252 "Discrepancy between trainable weights and collected"
2253 " trainable weights, did you set `model.trainable`"
2254 " without calling `model.compile` after ?",
2255 1,
2256 )
2258 def _make_train_function(self):
2259 has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
2260 self._check_trainable_weights_consistency()
2261 if isinstance(self.optimizer, list):
2262 raise ValueError(
2263 "The `optimizer` in `compile` should be a single optimizer."
2264 )
2265 # If we have re-compiled the loss/weighted metric sub-graphs then create
2266 # train function even if one exists already. This is because
2267 # `_feed_sample_weights` list has been updated on re-compile.
2268 if getattr(self, "train_function", None) is None or has_recompiled:
2269 # Restore the compiled trainable state.
2270 current_trainable_state = self._get_trainable_state()
2271 self._set_trainable_state(self._compiled_trainable_state)
2273 inputs = (
2274 self._feed_inputs
2275 + self._feed_targets
2276 + self._feed_sample_weights
2277 )
2278 if not isinstance(backend.symbolic_learning_phase(), int):
2279 inputs += [backend.symbolic_learning_phase()]
2281 with backend.get_graph().as_default():
2282 with backend.name_scope("training"):
2283 # Training updates
2284 updates = self.optimizer.get_updates(
2285 params=self._collected_trainable_weights,
2286 loss=self.total_loss,
2287 )
2288 # Unconditional updates
2289 updates += self.get_updates_for(None)
2290 # Conditional updates relevant to this model
2291 updates += self.get_updates_for(self.inputs)
2293 metrics = self._get_training_eval_metrics()
2294 metrics_tensors = [
2295 m._call_result
2296 for m in metrics
2297 if hasattr(m, "_call_result")
2298 ]
2300 with backend.name_scope("training"):
2301 # Gets loss and metrics. Updates weights at each call.
2302 fn = backend.function(
2303 inputs,
2304 [self.total_loss] + metrics_tensors,
2305 updates=updates,
2306 name="train_function",
2307 **self._function_kwargs,
2308 )
2309 setattr(self, "train_function", fn)
2311 # Restore the current trainable state
2312 self._set_trainable_state(current_trainable_state)
2314 def _make_test_function(self):
2315 has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
2316 # If we have re-compiled the loss/weighted metric sub-graphs then create
2317 # test function even if one exists already. This is because
2318 # `_feed_sample_weights` list has been updated on re-compile.
2319 if getattr(self, "test_function", None) is None or has_recompiled:
2320 inputs = (
2321 self._feed_inputs
2322 + self._feed_targets
2323 + self._feed_sample_weights
2324 )
2326 with backend.get_graph().as_default():
2327 metrics = self._get_training_eval_metrics()
2328 metrics_tensors = [
2329 m._call_result
2330 for m in metrics
2331 if hasattr(m, "_call_result")
2332 ]
2334 with backend.name_scope("evaluation"):
2335 updates = self.state_updates
2336 # Return loss and metrics, no gradient updates.
2337 # Does update the network states.
2338 fn = backend.function(
2339 inputs,
2340 [self.total_loss] + metrics_tensors,
2341 updates=updates,
2342 name="test_function",
2343 **self._function_kwargs,
2344 )
2345 setattr(self, "test_function", fn)
2347 def _make_predict_function(self):
2348 if not hasattr(self, "predict_function"):
2349 self.predict_function = None
2350 if self.predict_function is None:
2351 inputs = self._feed_inputs
2352 # Gets network outputs. Does not update weights.
2353 # Does update the network states.
2354 kwargs = getattr(self, "_function_kwargs", {})
2355 with backend.name_scope(ModeKeys.PREDICT):
2356 self.predict_function = backend.function(
2357 inputs,
2358 self.outputs,
2359 updates=self.state_updates,
2360 name="predict_function",
2361 **kwargs,
2362 )
2364 def _make_execution_function(self, mode):
2365 if mode == ModeKeys.TRAIN:
2366 self._make_train_function()
2367 return self.train_function
2368 if mode == ModeKeys.TEST:
2369 self._make_test_function()
2370 return self.test_function
2371 if mode == ModeKeys.PREDICT:
2372 self._make_predict_function()
2373 return self.predict_function
2375 def _distribution_standardize_user_data(
2376 self,
2377 x,
2378 y=None,
2379 sample_weight=None,
2380 class_weight=None,
2381 batch_size=None,
2382 validation_split=0.0,
2383 shuffle=False,
2384 epochs=1,
2385 allow_partial_batch=False,
2386 ):
2387 """Runs validation checks on input and target data passed by the user.
2389 This is called when using tf.distribute.Strategy to train, evaluate or
2390 serve the model.
2392 Args:
2393 x: Input data. A numpy array or `tf.data` dataset.
2394 y: Target data. A numpy array or None if x is a `tf.data` dataset.
2395 sample_weight: An optional sample-weight array passed by the user to
2396 weight the importance of each sample in `x`.
2397 class_weight: An optional class-weight array by the user to
2398 weight the importance of samples in `x` based on the class they
2399 belong to, as conveyed by `y`.
2400 batch_size: Integer batch size. If provided, it is used to run
2401 additional validation checks on stateful models.
2402 validation_split: Float between 0 and 1.
2403 Fraction of the training data to be used as validation data.
2404 shuffle: Boolean whether to shuffle the training data before each
2405 epoch.
2406 epochs: Integer epochs. If > 1, repeat the numpy training data epochs
2407 times when converting to training dataset.
2408 allow_partial_batch: Boolean whether to enforce that all batches have
2409 the same size.
2411 Returns:
2412 Dataset instance.
2414 Raises:
2415 ValueError: In case of invalid user-provided data.
2416 RuntimeError: If the model was never compiled.
2417 """
2418 if class_weight:
2419 raise NotImplementedError(
2420 "`class_weight` is currently not supported "
2421 "when using tf.distribute.Strategy."
2422 )
2424 if (
2425 sample_weight is not None
2426 and sample_weight.all()
2427 and backend.is_tpu_strategy(self._distribution_strategy)
2428 ):
2429 raise NotImplementedError(
2430 "`sample_weight` is currently not supported "
2431 "when using TPUStrategy."
2432 )
2434 # Validates `steps` and `shuffle` arguments right at the beginning
2435 # since we use it to construct the dataset object.
2436 # TODO(anjalisridhar): Remove this check once we refactor the
2437 # _standardize_user_data code path. This check is already present
2438 # elsewhere in the codebase.
2439 if isinstance(x, tf.data.Dataset):
2440 if shuffle:
2441 training_utils_v1.verify_dataset_shuffled(x)
2443 strategy = self._distribution_strategy
2444 with strategy.scope():
2445 # We should be sure to call get_session() inside the
2446 # strategy.scope() so the strategy can affect the session options.
2447 if tf.compat.v1.executing_eagerly_outside_functions():
2448 session = None
2449 else:
2450 session = backend.get_session()
2452 first_x_value = tf.nest.flatten(x)[0]
2453 if isinstance(first_x_value, np.ndarray):
2454 x = training_utils.list_to_tuple(x)
2455 if y is not None:
2456 y = training_utils.list_to_tuple(y)
2457 if sample_weight is not None:
2458 sample_weight = training_utils.list_to_tuple(
2459 sample_weight
2460 )
2461 in_tuple = (x, y, sample_weight)
2462 else:
2463 in_tuple = (x, y)
2464 else:
2465 in_tuple = x
2467 ds = strategy.extended.experimental_make_numpy_dataset(
2468 in_tuple, session=session
2469 )
2470 if shuffle:
2471 # We want a buffer size that is larger than the batch size
2472 # provided by the user and provides sufficient randomness.
2473 # Note that larger numbers introduce more memory usage based
2474 # on the size of each sample.
2475 ds = ds.shuffle(max(1024, batch_size * 8))
2476 if epochs > 1:
2477 ds = ds.repeat(epochs)
2479 # We need to use the drop_remainder argument to get a known
2480 # static input shape which is required for TPUs.
2481 drop_remainder = (
2482 not allow_partial_batch
2483 and strategy.extended.experimental_require_static_shapes
2484 )
2486 # TODO(b/131720208): We still drop remainder here if number of
2487 # examples is divisible by batch size, as sometimes dynamic
2488 # padder will time out with keras.metrics.CategoricalAccuracy()
2489 # metric.
2490 if backend.is_tpu_strategy(strategy) and not drop_remainder:
2491 dataset_size = first_x_value.shape[0]
2492 if dataset_size % batch_size == 0:
2493 drop_remainder = True
2495 x = ds.batch(batch_size, drop_remainder=drop_remainder)
2496 else:
2497 assert isinstance(x, tf.data.Dataset)
2498 training_utils_v1.validate_dataset_input(
2499 x, y, sample_weight, validation_split
2500 )
2501 return x
2503 def _standardize_user_data(
2504 self,
2505 x,
2506 y=None,
2507 sample_weight=None,
2508 class_weight=None,
2509 batch_size=None,
2510 check_steps=False,
2511 steps_name="steps",
2512 steps=None,
2513 validation_split=0.0,
2514 shuffle=False,
2515 extract_tensors_from_dataset=False,
2516 ):
2517 """Runs validation checks on input and target data passed by the user.
2519 Also standardizes the data to lists of arrays, in order.
2521 Also builds and compiles the model on the fly if it is a subclassed
2522 model that has never been called before (and thus has no
2523 inputs/outputs).
2525 This is a purely internal method, subject to refactoring at any time.
2527 Args:
2528 x: Input data. It could be:
2529 - A Numpy array (or array-like), or a list of arrays
2530 (in case the model has multiple inputs).
2531 - A TensorFlow tensor, or a list of tensors
2532 (in case the model has multiple inputs).
2533 - A dict mapping input names to the corresponding array/tensors,
2534 if the model has named inputs.
2535 - A `tf.data` dataset.
2536 y: Target data. Like the input data `x`,
2537 it could be either Numpy array(s) or TensorFlow tensor(s).
2538 It should be consistent with `x` (you cannot have Numpy inputs and
2539 tensor targets, or inversely). If `x` is a dataset, `y` should not
2540 be specified (since targets will be obtained from the iterator).
2541 sample_weight: An optional sample-weight array passed by the user to
2542 weight the importance of each sample in `x`.
2543 class_weight: An optional class-weight array by the user to
2544 weight the importance of samples in `x` based on the class they
2545 belong to, as conveyed by `y`. If both `sample_weight` and
2546 `class_weight` are provided, the weights are multiplied.
2547 batch_size: Integer batch size. If provided, it is used to run
2548 additional validation checks on stateful models.
2549 check_steps: boolean, True if we want to check for validity of `steps`
2550 and False, otherwise. For example, when we are standardizing one
2551 batch of data for train_on_batch/predict_on_batch/test_on_batch
2552 APIs, `steps` value is not required and we should not check for its
2553 validity in these cases.
2554 steps_name: The public API's parameter name for `steps`.
2555 steps: Integer or `None`. Total number of steps (batches of samples)
2556 to execute.
2557 validation_split: Float between 0 and 1.
2558 Fraction of the training data to be used as validation data.
2559 shuffle: Boolean whether to shuffle the training data before each
2560 epoch.
2561 extract_tensors_from_dataset: Boolean. When `x` is a dataset instance,
2562 this indicates whether to extract actual tensors from the dataset or
2563 instead output the dataset instance itself.
2564 Set to True when calling from `train_on_batch`/etc.
2566 Returns:
2567 A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a
2568 dict or not), target arrays, sample-weight arrays. If the model's
2569 input and targets are symbolic, these lists are empty (since the model
2570 takes no user-provided data, instead the data comes from the symbolic
2571 inputs/targets).
2573 Raises:
2574 ValueError: In case of invalid user-provided data.
2575 RuntimeError: If the model was never compiled.
2576 """
2577 if isinstance(x, (tf.compat.v1.data.Dataset, tf.data.Dataset)):
2578 # Graph mode dataset. We'll pass the dataset as-is (unless
2579 # `extract_tensors_from_dataset` is True, in which case we extract
2580 # the tensors from the dataset and we output them.
2581 training_utils_v1.validate_dataset_input(
2582 x, y, sample_weight, validation_split
2583 )
2584 if shuffle:
2585 training_utils_v1.verify_dataset_shuffled(x)
2587 is_dataset = True
2588 if extract_tensors_from_dataset:
2589 # We do this for `train_on_batch`/etc.
2590 (
2591 x,
2592 y,
2593 sample_weight,
2594 ) = training_utils_v1.extract_tensors_from_dataset(x)
2595 elif isinstance(x, tf.compat.v1.data.Iterator):
2596 # Graph mode iterator. We extract the symbolic tensors.
2597 training_utils_v1.validate_dataset_input(
2598 x, y, sample_weight, validation_split
2599 )
2600 iterator = x
2601 x, y, sample_weight = training_utils_v1.unpack_iterator_input(
2602 iterator
2603 )
2604 is_dataset = True
2605 else:
2606 is_dataset = False
2608 # Validates `steps` argument based on x's type.
2609 if check_steps:
2610 training_utils_v1.check_steps_argument(x, steps, steps_name)
2612 # First, we build the model on the fly if necessary.
2613 if not self.inputs:
2614 all_inputs, y_input, dict_inputs = self._build_model_with_inputs(
2615 x, y
2616 )
2617 is_build_called = True
2618 else:
2619 all_inputs = []
2620 # Whether this is a subclassed model that expects dictionary inputs
2621 # rather than list inputs (e.g. FeatureColumn-based models).
2622 dict_inputs = isinstance(self.inputs, dict)
2623 is_build_called = False
2624 y_input = y
2626 # Second, we compile the model on the fly if necessary, mostly for
2627 # subclass models.
2628 is_compile_called = False
2629 if not self._is_compiled and self.optimizer:
2630 self._compile_from_inputs(all_inputs, y_input, x, y)
2631 is_compile_called = True
2633 # In graph mode, if we had just set inputs and targets as symbolic
2634 # tensors by invoking build and compile on the model respectively, we do
2635 # not have to feed anything to the model. Model already has input and
2636 # target data as part of the graph. Note: in this case, `any` and `all`
2637 # are equivalent since we disallow mixed symbolic/value inputs.
2639 # self.run_eagerly is not free to compute, so we want to reuse the
2640 # value.
2641 run_eagerly = self.run_eagerly
2643 if (
2644 not run_eagerly
2645 and is_build_called
2646 and is_compile_called
2647 and not is_dataset
2648 and any(_is_symbolic_tensor(v) for v in all_inputs)
2649 ):
2650 return [], [], None
2652 return self._standardize_tensors(
2653 x,
2654 y,
2655 sample_weight,
2656 run_eagerly=run_eagerly,
2657 dict_inputs=dict_inputs,
2658 is_dataset=is_dataset,
2659 class_weight=class_weight,
2660 batch_size=batch_size,
2661 )
2663 def _standardize_tensors(
2664 self,
2665 x,
2666 y,
2667 sample_weight,
2668 run_eagerly,
2669 dict_inputs,
2670 is_dataset,
2671 class_weight=None,
2672 batch_size=None,
2673 ):
2674 if run_eagerly:
2675 # In eager mode, do not do shape validation
2676 # since the network has no input nodes (placeholders) to be fed.
2677 feed_input_names = self.input_names
2678 feed_input_shapes = None
2679 elif not self._is_graph_network:
2680 # Case: symbolic-mode subclassed network. Do not do shape
2681 # validation.
2682 feed_input_names = self._feed_input_names
2683 feed_input_shapes = None
2684 else:
2685 # Case: symbolic-mode graph network.
2686 # In this case, we run extensive shape validation checks.
2687 feed_input_names = self._feed_input_names
2688 feed_input_shapes = self._feed_input_shapes
2690 # Standardize the inputs.
2691 if not isinstance(x, (tf.compat.v1.data.Dataset, tf.data.Dataset)):
2692 # TODO(fchollet): run static checks with dataset output shape(s).
2693 x = training_utils_v1.standardize_input_data(
2694 x,
2695 feed_input_names,
2696 feed_input_shapes,
2697 check_batch_axis=False, # Don't enforce the batch size.
2698 exception_prefix="input",
2699 )
2701 # Get typespecs for the input data and sanitize it if necessary.
2702 # TODO(momernick): This should be capable of doing full input validation
2703 # at all times - validate that this is so and refactor the
2704 # standardization code.
2705 if isinstance(x, tf.data.Dataset):
2706 x_shapes = tf.data.experimental.get_structure(x)
2707 if isinstance(x_shapes, tuple):
2708 # If the output of a Dataset is a tuple, we assume it's either
2709 # of the form (x_data, y_data) or (x_data, y_data,
2710 # sample_weights). In either case, we only care about x_data
2711 # here.
2712 x_shapes = x_shapes[0]
2713 else:
2714 flat_inputs = tf.nest.flatten(x)
2715 flat_expected_inputs = tf.nest.flatten(self.inputs)
2716 converted_x = []
2717 for a, b in zip(flat_inputs, flat_expected_inputs):
2718 converted_x.append(_convert_scipy_sparse_tensor(a, b))
2719 x = tf.nest.pack_sequence_as(x, converted_x)
2721 # Convert ResourceVariables to tensors so nest.assert_same_structure
2722 # below won't fail with Variable and Tensor.
2723 x_tensors = tf_utils.convert_variables_to_tensors(x)
2724 x_shapes = tf.nest.map_structure(
2725 tf_utils.type_spec_from_value, x_tensors
2726 )
2728 flat_inputs = tf.nest.flatten(x_shapes)
2729 # Convert ResourceVariables to tensors so nest.assert_same_structure
2730 # below won't fail with Variable and Tensor.
2731 flat_expected_inputs = tf.nest.flatten(
2732 tf_utils.convert_variables_to_tensors(self.inputs)
2733 )
2734 for a, b in zip(flat_inputs, flat_expected_inputs):
2735 tf.nest.assert_same_structure(a, b, expand_composites=True)
2737 if y is not None:
2738 # Prepare self._sample_weight_modes. List with the same length as
2739 # model outputs.
2740 training_utils_v1.prepare_sample_weight_modes(
2741 self._training_endpoints, self.sample_weight_mode
2742 )
2743 feed_output_names = self._feed_output_names
2744 feed_sample_weight_modes = self._sample_weight_modes
2745 if not self._is_graph_network:
2746 feed_output_shapes = None
2747 else:
2748 feed_output_shapes = self._feed_output_shapes
2750 # Standardize the outputs.
2751 y = training_utils_v1.standardize_input_data(
2752 y,
2753 feed_output_names,
2754 # Don't enforce target shapes to match output shapes.
2755 # Precise checks will be run in
2756 # `check_loss_and_target_compatibility`.
2757 shapes=None,
2758 check_batch_axis=False, # Don't enforce the batch size.
2759 exception_prefix="target",
2760 )
2762 # Generate sample-wise weight values given the `sample_weight` and
2763 # `class_weight` arguments.
2764 sample_weights = training_utils_v1.standardize_sample_weights(
2765 sample_weight, feed_output_names
2766 )
2767 class_weights = training_utils_v1.standardize_class_weights(
2768 class_weight, feed_output_names
2769 )
2771 sample_weights = [
2772 training_utils_v1.standardize_weights(ref, sw, cw, mode)
2773 for (ref, sw, cw, mode) in zip(
2774 y, sample_weights, class_weights, feed_sample_weight_modes
2775 )
2776 ]
2777 # Check that all arrays have the same length.
2778 if not self._distribution_strategy:
2779 training_utils_v1.check_array_lengths(x, y, sample_weights)
2780 if self._is_graph_network and not run_eagerly:
2781 # Additional checks to avoid users mistakenly using improper
2782 # loss fns.
2783 training_utils_v1.check_loss_and_target_compatibility(
2784 y, self._feed_loss_fns, feed_output_shapes
2785 )
2787 sample_weights, _, _ = training_utils.handle_partial_sample_weights(
2788 y, sample_weights, feed_sample_weight_modes, check_all_flat=True
2789 )
2790 else:
2791 y = []
2792 sample_weights = None
2794 if self.stateful and batch_size and not is_dataset:
2795 # Check that for stateful networks, number of samples is a multiple
2796 # of the static batch size.
2797 if x[0].shape[0] % batch_size != 0:
2798 raise ValueError(
2799 "In a stateful network, "
2800 "you should only pass inputs with "
2801 "a number of samples that can be "
2802 "divided by the batch size. Found: "
2803 + str(x[0].shape[0])
2804 + " samples"
2805 )
2807 # If dictionary inputs were provided, we return a dictionary as well.
2808 if dict_inputs and not isinstance(
2809 x, (tf.compat.v1.data.Dataset, tf.data.Dataset)
2810 ):
2811 x = dict(zip(feed_input_names, x))
2812 return x, y, sample_weights
2814 def _build_model_with_inputs(self, inputs, targets):
2815 """Build the model (set model inputs/outputs), mainly for subclass
2816 model."""
2817 processed_inputs = []
2818 is_dict_inputs = False
2819 orig_inputs = inputs
2820 # We need to use `inputs` to set the model inputs.
2821 # If input data is a dataset iterator in graph mode or if it is an eager
2822 # iterator and only one batch of samples is required, we fetch the data
2823 # tensors from the iterator and then standardize them.
2824 if isinstance(inputs, (tf.compat.v1.data.Dataset, tf.data.Dataset)):
2825 inputs, targets, _ = training_utils_v1.extract_tensors_from_dataset(
2826 inputs
2827 )
2828 # We type-check that `inputs` and `targets` are either single arrays
2829 # or lists of arrays, and extract a flat list of inputs from the passed
2830 # structure.
2831 training_utils_v1.validate_input_types(inputs, orig_inputs)
2833 if isinstance(inputs, (list, tuple)):
2834 processed_inputs += list(inputs)
2835 elif isinstance(inputs, dict):
2836 is_dict_inputs = True
2837 keys = sorted(inputs.keys())
2838 processed_inputs = [inputs[k] for k in keys]
2839 else:
2840 processed_inputs.append(inputs)
2841 # Now that we have a flat set of inputs, we make sure that none of them
2842 # are CompositeTensors or CompositeTensorValues of any type (or scipy
2843 # sparse arrays, which we treat as SparseTensor values). We cannot
2844 # safely infer input data from an arbitrary composite tensor, so we
2845 # don't try - users should explicitly add composite tensor inputs to
2846 # their subclassed models.
2847 for input_tensor in processed_inputs:
2848 if training_utils_v1.is_composite_or_composite_value(
2849 input_tensor
2850 ) and not isinstance(input_tensor, tf.Variable):
2851 # TODO(b/132691975): Document subclass-model CT input handling.
2852 raise ValueError(
2853 "All SparseTensor and RaggedTensor inputs must be "
2854 "explicitly declared using a keras.Input() with "
2855 "sparse=True or ragged=True. We found an undeclared "
2856 "input %s. For Sequential models, please add a "
2857 "keras.Input() as your first Layer. For subclassed models, "
2858 "please call self._set_inputs() on your input set, which "
2859 "you can create using keras.Input() for each input to your "
2860 "model." % (input_tensor,)
2861 )
2862 # Build the model using the retrieved inputs (value or symbolic).
2863 # If values are generated from a dataset, then in symbolic-mode
2864 # placeholders will be created to match the value shapes.
2865 if isinstance(
2866 orig_inputs,
2867 (
2868 tf.compat.v1.data.Dataset,
2869 tf.data.Dataset,
2870 tf.compat.v1.data.Iterator,
2871 ),
2872 ):
2873 if not self.inputs:
2874 # For subclassed models, a robust input spec is not available so
2875 # we must cast to the model dtype.
2876 inputs = training_utils_v1.cast_if_floating_dtype(
2877 inputs, self.dtype
2878 )
2880 def create_tensor_spec(t):
2881 return tf.TensorSpec(t.shape, t.dtype)
2883 cast_inputs = tf.nest.map_structure(create_tensor_spec, inputs)
2884 elif training_utils_v1.has_tensors(inputs):
2885 cast_inputs = training_utils_v1.cast_if_floating_dtype(inputs)
2886 else:
2887 cast_inputs = inputs
2888 self._set_inputs(cast_inputs)
2889 return processed_inputs, targets, is_dict_inputs
2891 def _compile_from_inputs(
2892 self, all_inputs, target, orig_inputs, orig_target
2893 ):
2894 if target is not None:
2895 # We need to use `y` to set the model targets.
2896 if training_utils_v1.has_tensors(target):
2897 target = training_utils_v1.cast_if_floating_dtype_and_mismatch(
2898 target, self.outputs
2899 )
2900 training_utils_v1.validate_input_types(
2901 target, orig_target, allow_dict=False, field_name="target"
2902 )
2903 if isinstance(target, (list, tuple)):
2904 all_inputs += list(target)
2905 else:
2906 all_inputs.append(target)
2907 # Type check that all inputs are *either* value *or* symbolic.
2908 # TODO(fchollet): this check could be removed in Eager mode?
2909 if any(tf.is_tensor(v) for v in all_inputs):
2910 if not all(tf.is_tensor(v) for v in all_inputs):
2911 raise ValueError(
2912 "Do not pass inputs that mix Numpy arrays and "
2913 "TensorFlow tensors. "
2914 "You passed: x="
2915 + str(orig_inputs)
2916 + "; y="
2917 + str(orig_target)
2918 )
2919 is_dataset = isinstance(
2920 orig_inputs,
2921 (
2922 tf.compat.v1.data.Dataset,
2923 tf.data.Dataset,
2924 tf.compat.v1.data.Iterator,
2925 ),
2926 )
2927 if is_dataset or tf.executing_eagerly():
2928 target_tensors = None
2929 else:
2930 # Handle target tensors if any passed.
2931 if target is not None:
2932 if not isinstance(target, (list, tuple)):
2933 target = [target]
2934 target_tensors = [v for v in target if _is_symbolic_tensor(v)]
2935 else:
2936 target_tensors = None
2938 self.compile(
2939 optimizer=self.optimizer,
2940 loss=self.loss,
2941 metrics=self._compile_metrics,
2942 weighted_metrics=self._compile_weighted_metrics,
2943 loss_weights=self.loss_weights,
2944 target_tensors=target_tensors,
2945 sample_weight_mode=self.sample_weight_mode,
2946 run_eagerly=self.run_eagerly,
2947 experimental_run_tf_function=self._experimental_run_tf_function,
2948 )
2950 # TODO(omalleyt): Consider changing to a more descriptive function name.
2951 def _set_inputs(self, inputs, outputs=None, training=None):
2952 """Set model's input and output specs based on the input data received.
2954 This is to be used for Model subclasses, which do not know at
2955 instantiation time what their inputs look like.
2957 Args:
2958 inputs: Single array, or list of arrays. The arrays could be
2959 placeholders, Numpy arrays, data tensors, or TensorSpecs.
2960 - if placeholders: the model is built on top of these placeholders,
2961 and we expect Numpy data to be fed for them when calling
2962 `fit`/etc.
2963 - if Numpy data or TensorShapes: we create placeholders matching the
2964 TensorShapes or shapes of the Numpy arrays. We expect Numpy data
2965 to be fed for these placeholders when calling `fit`/etc.
2966 - if data tensors: the model is built on top of these tensors.
2967 We do not expect any Numpy data to be provided when calling
2968 `fit`/etc.
2969 outputs: None, a data tensor, or a list of tensors. If None, the
2970 outputs will be determined by invoking `self.call()`, otherwise the
2971 provided value will be used.
2972 training: Boolean or None. Only relevant in symbolic mode. Specifies
2973 whether to build the model's graph in inference mode (False),
2974 training mode (True), or using the Keras learning phase (None).
2975 Raises:
2976 ValueError: If dict inputs are passed to a Sequential Model where the
2977 first layer isn't FeatureLayer.
2978 """
2979 self._set_save_spec(inputs)
2980 inputs = self._set_input_attrs(inputs)
2982 if outputs is None:
2983 kwargs = {}
2984 if self._expects_training_arg:
2985 # In V2 mode, feeding `training=None` is not allowed because any
2986 # value explicitly passed by the user is respected, even
2987 # `None`.`
2988 if (
2989 training is None
2990 and not tf.compat.v1.executing_eagerly_outside_functions()
2991 ):
2992 training = backend.learning_phase()
2993 if training is not None:
2994 kwargs["training"] = training
2995 try:
2996 outputs = self(inputs, **kwargs)
2997 except NotImplementedError:
2998 # This Model or a submodel is dynamic and hasn't overridden
2999 # `compute_output_shape`.
3000 outputs = None
3002 self._set_output_attrs(outputs)
3004 @tf.__internal__.tracking.no_automatic_dependency_tracking
3005 def _set_input_attrs(self, inputs):
3006 """Sets attributes related to the inputs of the Model."""
3007 if self.inputs:
3008 raise ValueError("Model inputs are already set.")
3010 if self.__class__.__name__ == "Sequential" and not self.built:
3011 if tf.is_tensor(inputs):
3012 input_shape = (None,) + tuple(inputs.shape.as_list()[1:])
3013 elif isinstance(inputs, tf.TensorShape):
3014 input_shape = (None,) + tuple(inputs.as_list()[1:])
3015 elif isinstance(inputs, dict):
3016 # We assert that the first layer is a FeatureLayer.
3017 if not training_utils_v1.is_feature_layer(self.layers[0]):
3018 raise ValueError(
3019 "Passing a dictionary input to a Sequential Model "
3020 "which doesn't have FeatureLayer as the first layer"
3021 " is an error."
3022 )
3023 input_shape = (None,)
3024 else:
3025 input_shape = (None,) + tuple(inputs.shape[1:])
3026 self._build_input_shape = input_shape
3028 # Cast inputs to the compute dtype. This is primarily used
3029 # when saving to determine the correct dtype in the input signature.
3030 inputs = self._maybe_cast_inputs(inputs)
3032 # On-the-fly setting of symbolic model inputs (either by using the
3033 # tensor provided, or by creating a placeholder if Numpy data was
3034 # provided).
3035 model_inputs = training_utils_v1.ModelInputs(inputs)
3036 inputs = model_inputs.get_symbolic_inputs()
3037 self.inputs = model_inputs.get_symbolic_inputs(
3038 return_single_as_list=True
3039 )
3040 self.input_names = model_inputs.get_input_names()
3042 self._feed_inputs = []
3043 self._feed_input_names = []
3044 self._feed_input_shapes = []
3046 for k, v in model_inputs.as_dict():
3047 if backend.is_placeholder(v):
3048 self._feed_input_names.append(k)
3049 self._feed_inputs.append(v)
3050 self._feed_input_shapes.append(backend.int_shape(v))
3052 return inputs
3054 @tf.__internal__.tracking.no_automatic_dependency_tracking
3055 def _set_output_attrs(self, outputs):
3056 """Sets attributes related to the outputs of the Model."""
3057 # NOTE(taylorrobie): This convention cannot be changed without updating
3058 # the data adapter since it assumes nest.flatten ordering.
3059 outputs = tf.nest.flatten(outputs)
3060 self.outputs = outputs
3061 self.output_names = training_utils_v1.generic_output_names(outputs)
3062 # TODO(scottzhu): Should we cleanup the self._training_endpoints here?
3063 self.built = True
3065 @property
3066 def _targets(self):
3067 """The output target tensors for the model."""
3068 return [
3069 e.training_target.target
3070 for e in self._training_endpoints
3071 if e.has_training_target()
3072 ]
3074 @property
3075 def _feed_targets(self):
3076 return [
3077 e.training_target.target
3078 for e in self._training_endpoints
3079 if e.has_feedable_training_target()
3080 ]
3082 @property
3083 def _feed_output_names(self):
3084 return [
3085 e.output_name
3086 for e in self._training_endpoints
3087 if e.has_feedable_training_target()
3088 ]
3090 @property
3091 def _feed_output_shapes(self):
3092 return [
3093 e.feed_output_shape
3094 for e in self._training_endpoints
3095 if e.has_feedable_training_target()
3096 ]
3098 @property
3099 def _feed_loss_fns(self):
3100 return [
3101 e.loss_fn
3102 for e in self._training_endpoints
3103 if e.has_feedable_training_target()
3104 ]
3106 @property
3107 def _loss_weights_list(self):
3108 return [e.loss_weight for e in self._training_endpoints]
3110 @property
3111 def _output_loss_metrics(self):
3112 if hasattr(self, "_training_endpoints"):
3113 return [
3114 e.output_loss_metric
3115 for e in self._training_endpoints
3116 if e.output_loss_metric is not None
3117 ]
3118 return None
3120 @property
3121 def sample_weights(self):
3122 return [e.sample_weight for e in self._training_endpoints]
3124 @property
3125 def _sample_weight_modes(self):
3126 return [e.sample_weight_mode for e in self._training_endpoints]
3128 @property
3129 def _feed_sample_weights(self):
3130 return [
3131 e.sample_weight
3132 for e in self._training_endpoints
3133 if e.sample_weight is not None
3134 ]
3136 def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
3137 """Maybe load 1st epoch from checkpoint, considering worker recovery.
3139 Refer to tensorflow/python/keras/distribute/worker_training_state.py
3140 for more information.
3142 Args:
3143 initial_epoch: The original initial_epoch user passes in in `fit()`.
3144 mode: The mode for running `model.fit()`.
3146 Returns:
3147 If the training is recovering from previous failure under multi-worker
3148 training setting, return the epoch the training is supposed to
3149 continue at. Otherwise, return the `initial_epoch` the user passes in.
3150 """
3151 if self._training_state is not None:
3152 return self._training_state.maybe_load_initial_epoch_from_ckpt(
3153 initial_epoch, mode
3154 )
3155 return initial_epoch
3157 def _get_training_eval_metrics(self):
3158 """Returns all the metrics that are to be reported.
3160 This includes the output loss metrics, compile metrics/weighted metrics,
3161 add_metric metrics.
3162 """
3163 metrics = []
3164 metrics.extend(getattr(self, "_output_loss_metrics", None) or [])
3165 metrics.extend(getattr(self, "metrics", None) or [])
3166 return metrics
3168 def _assert_compile_was_called(self):
3169 # Checks whether `compile` has been called. If it has been called,
3170 # then the optimizer is set. This is different from whether the
3171 # model is compiled
3172 # (i.e. whether the model is built and its inputs/outputs are set).
3173 if not self._compile_was_called:
3174 raise RuntimeError(
3175 "You must compile your model before "
3176 "training/testing. "
3177 "Use `model.compile(optimizer, loss)`."
3178 )
3180 def _in_multi_worker_mode(self):
3181 """Method to infer if this `Model` is working in multi-worker settings.
3183 Multi-worker training refers to the setup where the training is
3184 distributed across multiple workers, as opposed to the case where
3185 only a local process performs the training. This function is
3186 used to infer for example whether or not a distribute coordinator
3187 should be run, and thus TensorFlow servers should be started for
3188 communication with other servers in the cluster, or whether or not
3189 saving/restoring checkpoints is relevant for preemption fault tolerance.
3191 Experimental. Signature and implementation are subject to change.
3193 Returns:
3194 Whether this model indicates it's working in multi-worker settings.
3195 """
3196 strategy = self._distribution_strategy
3198 # Otherwise, use the strategy whose scope this is in.
3199 if not strategy and tf.distribute.has_strategy():
3200 strategy = tf.distribute.get_strategy()
3201 return strategy and strategy.extended._in_multi_worker_mode()
3203 @property
3204 def _trackable_saved_model_saver(self):
3205 return model_serialization.ModelSavedModelSaver(self)
3207 def _get_compile_args(self, user_metrics=True):
3208 del user_metrics
3209 self._assert_compile_was_called()
3210 kwargs = {
3211 "loss": self.loss,
3212 "metrics": self._compile_metrics,
3213 "loss_weights": self.loss_weights,
3214 "sample_weight_mode": self.sample_weight_mode,
3215 "weighted_metrics": self._compile_weighted_metrics,
3216 }
3217 return kwargs
3219 @property
3220 def _compile_was_called(self):
3221 return self._v1_compile_was_called
3224class DistributedCallbackModel(Model):
3225 """Model that is used for callbacks with tf.distribute.Strategy."""
3227 def __init__(self, model):
3228 super().__init__()
3229 self.optimizer = model.optimizer
3231 def set_original_model(self, orig_model):
3232 self._original_model = orig_model
3234 def save_weights(self, filepath, overwrite=True, save_format=None):
3235 self._replicated_model.save_weights(
3236 filepath, overwrite=overwrite, save_format=save_format
3237 )
3239 def save(self, filepath, overwrite=True, include_optimizer=True):
3240 # save weights from the distributed model to the original model
3241 distributed_model_weights = self.get_weights()
3242 self._original_model.set_weights(distributed_model_weights)
3243 # TODO(anjalisridhar): Do we need to save the original model here?
3244 # Saving the first replicated model works as well.
3245 self._original_model.save(
3246 filepath, overwrite=True, include_optimizer=False
3247 )
3249 def load_weights(self, filepath, by_name=False):
3250 self._original_model.load_weights(filepath, by_name=False)
3251 # Copy the weights from the original model to each of the replicated
3252 # models.
3253 orig_model_weights = self._original_model.get_weights()
3254 distributed_training_utils_v1.set_weights(
3255 self._original_model._distribution_strategy,
3256 self,
3257 orig_model_weights,
3258 )
3260 def __getattr__(self, item):
3261 # Allowed attributes of the model that can be accessed by the user
3262 # during a callback.
3263 if item not in ("_setattr_tracking", "_layers"):
3264 logging.warning(
3265 "You are accessing attribute " + item + " of the "
3266 "DistributedCallbackModel that may not have been set "
3267 "correctly."
3268 )
3269 return super().__getattr__(item)
3272class _TrainingEndpoint:
3273 """A container for the training output/target and related entities.
3275 In the case of model with multiple outputs, there is a one-to-one mapping
3276 between model output (y_pred), model target (y_true), loss, metrics etc.
3277 By unifying these entities into one class, different entity can access
3278 information between each other, rather than currently access different list
3279 of attributes of the model.
3280 """
3282 def __init__(
3283 self,
3284 output,
3285 output_name,
3286 loss_fn,
3287 loss_weight=None,
3288 training_target=None,
3289 output_loss_metric=None,
3290 sample_weight=None,
3291 sample_weight_mode=None,
3292 ):
3293 """Initialize the _TrainingEndpoint.
3295 Note that the output and output_name should be stable as long as the
3296 model structure doesn't change. The training_target suppose to be
3297 mutable since the information is provided via `compile()`
3299 Args:
3300 output: the output tensor of the model.
3301 output_name: the unique name of the output tensor.
3302 loss_fn: the loss function for the output tensor.
3303 loss_weight: float, the weights for the loss.
3304 training_target: the _TrainingTarget for the model.
3305 output_loss_metric: the metric object for the loss function.
3306 sample_weight: the weights for how a sample is weighted during metric
3307 and loss calculation. Could be None.
3308 sample_weight_mode: string, 'temporal', 'samplewise' or None. The mode
3309 for how the sample_weight is populated.
3310 """
3311 self._output = output
3312 self._output_name = output_name
3313 self._loss_fn = loss_fn
3314 self._loss_weight = loss_weight
3315 self._training_target = training_target
3316 self._output_loss_metric = output_loss_metric
3317 self._sample_weight = sample_weight
3318 self._sample_weight_mode = sample_weight_mode
3320 @property
3321 def output(self):
3322 return self._output
3324 @property
3325 def output_name(self):
3326 return self._output_name
3328 @property
3329 def shape(self):
3330 return backend.int_shape(self.output)
3332 @property
3333 def loss_fn(self):
3334 return self._loss_fn
3336 @property
3337 def loss_weight(self):
3338 return self._loss_weight
3340 @loss_weight.setter
3341 def loss_weight(self, value):
3342 self._loss_weight = value
3344 @property
3345 def training_target(self):
3346 return self._training_target
3348 @training_target.setter
3349 def training_target(self, value):
3350 self._training_target = value
3352 def create_training_target(self, target, run_eagerly=False):
3353 """Create training_target instance and update the self.training_target.
3355 Note that the input target should just be a tensor or None, and
3356 corresponding training target will be created based on the output and
3357 loss_fn.
3359 Args:
3360 target: the target tensor for the current output. Could be None.
3361 run_eagerly: boolean, whether the model is in run_eagerly mode.
3363 Raises:
3364 ValueError if the training_target field for the current instance has
3365 already been populated.
3366 """
3367 if self.has_training_target():
3368 raise ValueError(
3369 "The training_target field for the _TrainingEndpoint "
3370 "instance has already been populated"
3371 )
3372 if run_eagerly:
3373 # When run_eagerly, the target tensor is ignored, and the None
3374 # placeholder is created instead.
3375 self.training_target = _TrainingTarget(
3376 None, feedable=True, skip_target_weights=False
3377 )
3378 return
3380 if self.should_skip_target():
3381 self.training_target = _TrainingTarget(None)
3382 else:
3383 if target is not None and not backend.is_placeholder(target):
3384 feedable = False
3385 skip_target_weights = True
3386 else:
3387 feedable = True
3388 skip_target_weights = False
3390 if target is None:
3391 target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get(
3392 self.loss_fn, backend.dtype(self.output)
3393 )
3395 target = backend.placeholder(
3396 ndim=len(self.shape),
3397 name=self.output_name + "_target",
3398 sparse=backend.is_sparse(self.output),
3399 dtype=target_dtype,
3400 )
3402 self.training_target = _TrainingTarget(
3403 target,
3404 feedable=feedable,
3405 skip_target_weights=skip_target_weights,
3406 )
3408 @property
3409 def output_loss_metric(self):
3410 return self._output_loss_metric
3412 @output_loss_metric.setter
3413 def output_loss_metric(self, value):
3414 self._output_loss_metric = value
3416 @property
3417 def sample_weight(self):
3418 return self._sample_weight
3420 @sample_weight.setter
3421 def sample_weight(self, value):
3422 self._sample_weight = value
3424 @property
3425 def sample_weight_mode(self):
3426 return self._sample_weight_mode
3428 @sample_weight_mode.setter
3429 def sample_weight_mode(self, value):
3430 self._sample_weight_mode = value
3432 def should_skip_target(self):
3433 return self._loss_fn is None
3435 def should_skip_target_weights(self):
3436 return (
3437 self.should_skip_target()
3438 or self.training_target is None
3439 or self.training_target.skip_target_weights
3440 )
3442 def has_training_target(self):
3443 return self.training_target is not None
3445 def has_feedable_training_target(self):
3446 return (
3447 not self.should_skip_target()
3448 and self.training_target is not None
3449 and self.training_target.feedable
3450 )
3452 def loss_name(self):
3453 if self._loss_fn is not None:
3454 return self._output_name + "_loss"
3455 return None
3457 @property
3458 def feed_output_shape(self):
3459 """The output shape for the feedable target."""
3460 if not self.has_feedable_training_target():
3461 return None
3463 if (
3464 (
3465 isinstance(self.loss_fn, losses.LossFunctionWrapper)
3466 and self.loss_fn.fn == losses.sparse_categorical_crossentropy
3467 )
3468 ) or (isinstance(self.loss_fn, losses.SparseCategoricalCrossentropy)):
3469 if backend.image_data_format() == "channels_first":
3470 return (self.shape[0], 1) + self.shape[2:]
3471 else:
3472 return self.shape[:-1] + (1,)
3473 elif not isinstance(self.loss_fn, losses.Loss) or (
3474 isinstance(self.loss_fn, losses.LossFunctionWrapper)
3475 and (getattr(losses, self.loss_fn.fn.__name__, None) is None)
3476 ):
3477 # If the given loss is not an instance of the `Loss` class (custom
3478 # class) or if the loss function that is wrapped is not in the
3479 # `losses` module, then it is a user-defined loss and we make no
3480 # assumptions about it.
3481 return None
3482 else:
3483 return self.shape
3485 def sample_weights_mismatch(self):
3486 """Check if the sample weight and the mode match or not."""
3487 # If there is a mismatch between sample weight mode and the placeholders
3488 # created, then recompile the sub-graphs that depend on sample weights.
3489 return (
3490 self.sample_weight_mode is not None and self.sample_weight is None
3491 ) or (
3492 self.sample_weight_mode is None and self.sample_weight is not None
3493 )
3495 def populate_sample_weight(self, sample_weight, sample_weight_mode):
3496 """Populate the sample weight and based on the sample weight mode."""
3497 if sample_weight is None and (
3498 self.should_skip_target_weights()
3499 or sample_weight_mode is None
3500 or tf.executing_eagerly()
3501 ):
3502 self._sample_weight = None
3503 return
3505 assert sample_weight_mode in ["temporal", "samplewise"]
3506 if sample_weight_mode == "temporal":
3507 default_value = [[1.0]]
3508 shape = [None, None]
3509 else:
3510 # sample_weight_mode == 'samplewise'
3511 default_value = [1.0]
3512 shape = [None]
3514 if sample_weight is not None:
3515 if not sample_weight.shape.is_compatible_with(shape):
3516 raise ValueError(
3517 "Received sample weight with shape {}. Expected shape "
3518 "{}.".format(sample_weight.shape, shape)
3519 )
3520 self._sample_weight = sample_weight
3521 else:
3522 self._sample_weight = tf.compat.v1.placeholder_with_default(
3523 tf.constant(default_value, dtype=backend.floatx()),
3524 shape=shape,
3525 name=self.output_name + "_sample_weights",
3526 )
3529class _TrainingTarget:
3530 """Container for a target tensor (y_true) and its metadata (shape, loss...).
3532 Args:
3533 target: A target tensor for the model. It may be `None` if the
3534 output is excluded from loss computation. It is still kept as None
3535 since each output of the model should have a corresponding target. If
3536 the target is None, the rest of the attributes will be None as well.
3537 feedable: Boolean, whether the target is feedable (requires data to be
3538 passed in `fit` or `train_on_batch`), or not (model compiled with
3539 `target_tensors` argument).
3540 skip_target_weights: Boolean, whether the target should be skipped during
3541 weights calculation.
3542 """
3544 def __init__(self, target, feedable=False, skip_target_weights=True):
3545 self._target = target
3546 self._feedable = feedable
3547 self._skip_target_weights = skip_target_weights
3549 @property
3550 def target(self):
3551 return self._target
3553 @property
3554 def feedable(self):
3555 return self._feedable
3557 @property
3558 def skip_target_weights(self):
3559 return self._skip_target_weights
3562def _is_symbolic_tensor(x):
3563 return tf.is_tensor(x)
3566def _convert_scipy_sparse_tensor(value, expected_input):
3567 """Handle scipy sparse tensor conversions.
3569 This method takes a value 'value' and returns the proper conversion. If
3570 value is a scipy sparse tensor and the expected input is a dense tensor,
3571 we densify 'value'. If value is a scipy sparse tensor and the expected input
3572 is a TF SparseTensor, we convert 'value' to a SparseTensor. If 'value' is
3573 not a scipy sparse tensor, or scipy is not imported, we pass it through
3574 unchanged.
3576 Args:
3577 value: An object that may be a scipy sparse tensor
3578 expected_input: The expected input placeholder.
3580 Returns:
3581 The possibly-converted 'value'.
3582 """
3583 if issparse is not None and issparse(value):
3584 if backend.is_sparse(expected_input):
3585 sparse_coo = value.tocoo()
3586 row, col = sparse_coo.row, sparse_coo.col
3587 data, shape = sparse_coo.data, sparse_coo.shape
3588 indices = np.concatenate(
3589 (np.expand_dims(row, 1), np.expand_dims(col, 1)), 1
3590 )
3591 return tf.SparseTensor(indices, data, shape)
3592 else:
3593 if tf.compat.v1.executing_eagerly_outside_functions():
3594 # In TF2 we do not silently densify sparse matrices.
3595 raise ValueError(
3596 "A SciPy sparse matrix was passed to a model "
3597 "that expects dense inputs. Please densify your "
3598 "inputs first, such as by calling `x.toarray()."
3599 )
3600 return value.toarray()
3601 else:
3602 return value
3605def _get_metrics_from_layers(layers):
3606 """Returns list of metrics from the given layers.
3608 This will not include the `compile` metrics of a model layer.
3610 Args:
3611 layers: List of layers.
3613 Returns:
3614 List of metrics.
3615 """
3616 metrics = []
3617 layers = layer_utils.filter_empty_layer_containers(layers)
3618 for layer in layers:
3619 if isinstance(layer, Model):
3620 # We cannot call 'metrics' on the model because we do not want to
3621 # include the metrics that were added in compile API of a nested
3622 # model.
3623 metrics.extend(layer._metrics)
3624 metrics.extend(_get_metrics_from_layers(layer.layers))
3625 else:
3626 metrics.extend(layer.metrics)
3627 return metrics
3630def _non_none_constant_value(v):
3631 constant_value = tf.get_static_value(v)
3632 return constant_value if constant_value is not None else v