Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_v1.py: 20%
1081 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""V1 Training-related part of the Keras engine."""
17import collections
18import warnings
20import numpy as np
22from tensorflow.python import tf2
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.data.ops import iterator_ops
25from tensorflow.python.distribute import distribute_lib
26from tensorflow.python.distribute import parameter_server_strategy
27from tensorflow.python.distribute import parameter_server_strategy_v2
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_spec
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.framework import type_spec
37from tensorflow.python.keras import backend
38from tensorflow.python.keras import losses
39from tensorflow.python.keras import metrics as metrics_module
40from tensorflow.python.keras import optimizer_v1
41from tensorflow.python.keras import optimizers
42from tensorflow.python.keras.distribute import distributed_training_utils
43from tensorflow.python.keras.distribute import distributed_training_utils_v1
44from tensorflow.python.keras.engine import base_layer
45from tensorflow.python.keras.engine import training as training_lib
46from tensorflow.python.keras.engine import training_arrays_v1
47from tensorflow.python.keras.engine import training_distributed_v1
48from tensorflow.python.keras.engine import training_eager_v1
49from tensorflow.python.keras.engine import training_generator_v1
50from tensorflow.python.keras.engine import training_utils
51from tensorflow.python.keras.engine import training_utils_v1
52from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
53from tensorflow.python.keras.mixed_precision import policy
54from tensorflow.python.keras.optimizer_v2 import optimizer_v2
55from tensorflow.python.keras.saving import saving_utils
56from tensorflow.python.keras.saving.saved_model import model_serialization
57from tensorflow.python.keras.utils import data_utils
58from tensorflow.python.keras.utils import layer_utils
59from tensorflow.python.keras.utils import losses_utils
60from tensorflow.python.keras.utils import tf_inspect
61from tensorflow.python.keras.utils import tf_utils
62from tensorflow.python.keras.utils.mode_keys import ModeKeys
63from tensorflow.python.ops import array_ops
64from tensorflow.python.ops import math_ops
65from tensorflow.python.platform import tf_logging as logging
66from tensorflow.python.trackable import base as trackable
67from tensorflow.python.types import data as data_types
68from tensorflow.python.util import nest
70try:
71 from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
72except ImportError:
73 issparse = None
76class Model(training_lib.Model):
77 """`Model` groups layers into an object with training and inference features.
79 There are two ways to instantiate a `Model`:
81 1 - With the "functional API", where you start from `Input`,
82 you chain layer calls to specify the model's forward pass,
83 and finally you create your model from inputs and outputs:
85 ```python
86 import tensorflow as tf
88 inputs = tf.keras.Input(shape=(3,))
89 x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
90 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
91 model = tf.keras.Model(inputs=inputs, outputs=outputs)
92 ```
94 2 - By subclassing the `Model` class: in that case, you should define your
95 layers in `__init__` and you should implement the model's forward pass
96 in `call`.
98 ```python
99 import tensorflow as tf
101 class MyModel(tf.keras.Model):
103 def __init__(self):
104 super(MyModel, self).__init__()
105 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
106 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
108 def call(self, inputs):
109 x = self.dense1(inputs)
110 return self.dense2(x)
112 model = MyModel()
113 ```
115 If you subclass `Model`, you can optionally have
116 a `training` argument (boolean) in `call`, which you can use to specify
117 a different behavior in training and inference:
119 ```python
120 import tensorflow as tf
122 class MyModel(tf.keras.Model):
124 def __init__(self):
125 super(MyModel, self).__init__()
126 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
127 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
128 self.dropout = tf.keras.layers.Dropout(0.5)
130 def call(self, inputs, training=False):
131 x = self.dense1(inputs)
132 if training:
133 x = self.dropout(x, training=training)
134 return self.dense2(x)
136 model = MyModel()
137 ```
138 """
140 def __init__(self, *args, **kwargs):
141 super(Model, self).__init__(*args, **kwargs)
142 # initializing _distribution_strategy here since it is possible to call
143 # predict on a model without compiling it.
144 self._distribution_strategy = None
145 self._compile_time_distribution_strategy = None
146 if (ops.executing_eagerly_outside_functions() and
147 distribute_lib.has_strategy()):
148 self._set_strategy(
149 distribute_lib.get_strategy())
151 # This flag is used to track if the user is using the deprecated path of
152 # passing distribution strategy to compile rather than creating the model
153 # under distribution strategy scope.
154 self._compile_distribution = False
156 self._run_eagerly = None
157 self._experimental_run_tf_function = (
158 ops.executing_eagerly_outside_functions())
160 self._v1_compile_was_called = False
162 def _init_batch_counters(self):
163 pass # Batch counters should not be created in legacy graph mode.
165 @trackable.no_automatic_dependency_tracking
166 def _set_strategy(self, strategy):
167 self._compile_time_distribution_strategy = strategy
169 def get_weights(self):
170 """Retrieves the weights of the model.
172 Returns:
173 A flat list of Numpy arrays.
174 """
175 strategy = (self._distribution_strategy or
176 self._compile_time_distribution_strategy)
177 if strategy:
178 with strategy.scope():
179 return base_layer.Layer.get_weights(self)
180 return base_layer.Layer.get_weights(self)
182 def load_weights(self, filepath, by_name=False, skip_mismatch=False):
183 """Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
185 If `by_name` is False weights are loaded based on the network's
186 topology. This means the architecture should be the same as when the weights
187 were saved. Note that layers that don't have weights are not taken into
188 account in the topological ordering, so adding or removing layers is fine as
189 long as they don't have weights.
191 If `by_name` is True, weights are loaded into layers only if they share the
192 same name. This is useful for fine-tuning or transfer-learning models where
193 some of the layers have changed.
195 Only topological loading (`by_name=False`) is supported when loading weights
196 from the TensorFlow format. Note that topological loading differs slightly
197 between TensorFlow and HDF5 formats for user-defined classes inheriting from
198 `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
199 TensorFlow format loads based on the object-local names of attributes to
200 which layers are assigned in the `Model`'s constructor.
202 Args:
203 filepath: String, path to the weights file to load. For weight files in
204 TensorFlow format, this is the file prefix (the same as was passed
205 to `save_weights`).
206 by_name: Boolean, whether to load weights by name or by topological
207 order. Only topological loading is supported for weight files in
208 TensorFlow format.
209 skip_mismatch: Boolean, whether to skip loading of layers where there is
210 a mismatch in the number of weights, or a mismatch in the shape of
211 the weight (only valid when `by_name=True`).
213 Returns:
214 When loading a weight file in TensorFlow format, returns the same status
215 object as `tf.train.Checkpoint.restore`. When graph building, restore
216 ops are run automatically as soon as the network is built (on first call
217 for user-defined classes inheriting from `Model`, immediately if it is
218 already built).
220 When loading weights in HDF5 format, returns `None`.
222 Raises:
223 ImportError: If h5py is not available and the weight file is in HDF5
224 format.
225 ValueError: If `skip_mismatch` is set to `True` when `by_name` is
226 `False`.
227 """
228 if backend.is_tpu_strategy(self._distribution_strategy):
229 if (self._distribution_strategy.extended.steps_per_run > 1 and
230 (not saving_utils.is_hdf5_filepath(filepath))): # pylint: disable=protected-access
231 raise ValueError('Load weights is not yet supported with TPUStrategy '
232 'with steps_per_run greater than 1.')
233 return super(Model, self).load_weights(filepath, by_name, skip_mismatch)
235 @trackable.no_automatic_dependency_tracking
236 def compile(self,
237 optimizer='rmsprop',
238 loss=None,
239 metrics=None,
240 loss_weights=None,
241 sample_weight_mode=None,
242 weighted_metrics=None,
243 target_tensors=None,
244 distribute=None,
245 **kwargs):
246 """Configures the model for training.
248 Args:
249 optimizer: String (name of optimizer) or optimizer instance.
250 See `tf.keras.optimizers`.
251 loss: String (name of objective function), objective function or
252 `tf.keras.losses.Loss` instance. See `tf.keras.losses`. An objective
253 function is any callable with the signature
254 `scalar_loss = fn(y_true, y_pred)`. If the model has multiple
255 outputs, you can use a different loss on each output by passing a
256 dictionary or a list of losses. The loss value that will be
257 minimized by the model will then be the sum of all individual
258 losses.
259 metrics: List of metrics to be evaluated by the model during training
260 and testing. Typically you will use `metrics=['accuracy']`.
261 To specify different metrics for different outputs of a
262 multi-output model, you could also pass a dictionary, such as
263 `metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}`.
264 You can also pass a list (len = len(outputs)) of lists of metrics
265 such as `metrics=[['accuracy'], ['accuracy', 'mse']]` or
266 `metrics=['accuracy', ['accuracy', 'mse']]`.
267 loss_weights: Optional list or dictionary specifying scalar
268 coefficients (Python floats) to weight the loss contributions
269 of different model outputs.
270 The loss value that will be minimized by the model
271 will then be the *weighted sum* of all individual losses,
272 weighted by the `loss_weights` coefficients.
273 If a list, it is expected to have a 1:1 mapping
274 to the model's outputs. If a tensor, it is expected to map
275 output names (strings) to scalar coefficients.
276 sample_weight_mode: If you need to do timestep-wise
277 sample weighting (2D weights), set this to `"temporal"`.
278 `None` defaults to sample-wise weights (1D).
279 If the model has multiple outputs, you can use a different
280 `sample_weight_mode` on each output by passing a
281 dictionary or a list of modes.
282 weighted_metrics: List of metrics to be evaluated and weighted
283 by sample_weight or class_weight during training and testing.
284 target_tensors: By default, Keras will create placeholders for the
285 model's target, which will be fed with the target data during
286 training. If instead you would like to use your own
287 target tensors (in turn, Keras will not expect external
288 Numpy data for these targets at training time), you
289 can specify them via the `target_tensors` argument. It can be
290 a single tensor (for a single-output model), a list of tensors,
291 or a dict mapping output names to target tensors.
292 distribute: NOT SUPPORTED IN TF 2.0, please create and compile the
293 model under distribution strategy scope instead of passing it to
294 compile.
295 **kwargs: Any additional arguments.
297 Raises:
298 ValueError: In case of invalid arguments for
299 `optimizer`, `loss`, `metrics` or `sample_weight_mode`.
300 """
301 self._assert_built_as_v1()
302 self._run_eagerly = kwargs.pop('run_eagerly', None)
303 self._experimental_run_tf_function = kwargs.pop(
304 'experimental_run_tf_function', True)
305 self._v1_compile_was_called = True
307 # Prepare Session arguments (legacy).
308 kwargs.pop('cloning', None) # Legacy DistStrat argument, never used.
309 self._from_serialized = kwargs.pop('from_serialized', False)
310 allowed_kwargs = {'feed_dict', 'fetches', 'options', 'run_metadata'}
311 unknown_kwargs = set(kwargs.keys()) - allowed_kwargs
312 if unknown_kwargs:
313 raise TypeError(
314 'Invalid keyword argument(s) in `compile`: %s' % (unknown_kwargs,))
315 self._function_kwargs = kwargs
316 if self._function_kwargs:
317 self._experimental_run_tf_function = False
318 if self.run_eagerly:
319 raise ValueError(
320 'Session keyword arguments are not supported '
321 'when `run_eagerly=True`. You passed the following '
322 'Session arguments: %s' % (self._function_kwargs,))
324 self._set_optimizer(optimizer)
325 is_any_keras_optimizer_v1 = any(
326 (isinstance(opt, optimizer_v1.Optimizer)
327 and not isinstance(opt, optimizer_v1.TFOptimizer)
328 ) for opt in nest.flatten(self.optimizer))
330 if is_any_keras_optimizer_v1 and ops.executing_eagerly_outside_functions():
331 raise ValueError('`tf.compat.v1.keras` Optimizer (', optimizer, ') is '
332 'not supported when eager execution is enabled. Use a '
333 '`tf.keras` Optimizer instead, or disable eager '
334 'execution.')
336 if ((target_tensors is not None)
337 or not ops.executing_eagerly_outside_functions()):
338 # Fallback out of things that aren't supported with v2 loops
339 self._experimental_run_tf_function = False
341 if distribute is not None:
342 if tf2.enabled() or self._experimental_run_tf_function:
343 raise ValueError(
344 'Distribute argument in compile is not available in TF 2.0 please '
345 'create the model under the distribution strategy scope.')
346 logging.warning('Distribute argument in compile is deprecated please '
347 'create the model under the distribution strategy scope.')
348 self._distribution_strategy = distribute
349 self._compile_distribution = True
350 else:
351 if distribute_lib.has_strategy():
352 # When the user builds the model in the DS scope and cross replica
353 # context we want distribution strategy to be set but when building the
354 # replica copies of the models internally we should not be compiling
355 # with distribution strategy and use the default compilation path.
356 if distribute_lib.in_cross_replica_context():
357 self._distribution_strategy = (
358 distribute_lib.get_strategy())
360 if isinstance(self._distribution_strategy,
361 parameter_server_strategy.ParameterServerStrategyV1):
362 raise NotImplementedError(
363 '`tf.compat.v1.distribute.experimental.ParameterServerStrategy` '
364 'currently only works with the tf.Estimator API')
366 if isinstance(self._distribution_strategy,
367 parameter_server_strategy_v2.ParameterServerStrategyV2):
368 raise NotImplementedError(
369 '`tf.distribute.experimental.ParameterServerStrategy` is only '
370 'supported in TF2.')
372 if not self._experimental_run_tf_function:
373 self._validate_compile_param_for_distribution_strategy(self.run_eagerly,
374 sample_weight_mode,
375 target_tensors,
376 weighted_metrics)
377 # We've disabled automatic dependency tracking for this method, but do want
378 # to add a checkpoint dependency on the optimizer if it's trackable.
379 if isinstance(self.optimizer, trackable.Trackable):
380 self._track_trackable(
381 self.optimizer, name='optimizer', overwrite=True)
382 self.loss = loss or {}
383 self.loss_weights = loss_weights
384 self.sample_weight_mode = sample_weight_mode
385 self._compile_metrics = metrics or []
386 self._compile_weighted_metrics = weighted_metrics
387 if self.run_eagerly and target_tensors is not None:
388 raise ValueError(
389 'target_tensors argument is not supported when '
390 'running a model eagerly.')
392 # _training_endpoints contains a list of _TrainingEndpoint object, which has
393 # all the model output/target/loss and related metadata.
394 self._training_endpoints = []
396 # Used to freeze the behavior of the Model once `compile` has been called.
397 self._compiled_trainable_state = self._get_trainable_state()
399 # Set tf.distribute.Strategy specific parameters.
400 self._distributed_model_cache = {}
401 self._distributed_function_cache = {}
403 # Clear any `_eager_losses` that was added.
404 self._clear_losses()
406 if (not context.executing_eagerly() and
407 self._distribution_strategy is not None):
408 # Ensures a Session is created and configured correctly for Distribution
409 # Strategy.
410 backend.configure_and_create_distributed_session(
411 self._distribution_strategy)
412 # Initialize model metric attributes.
413 self._init_metric_attributes()
414 if not self.built or not self.inputs or not self.outputs:
415 # Model is not compilable because it does not know its number of inputs
416 # and outputs, nor their shapes and names. We will compile after the first
417 # time the model gets called on training data.
418 return
419 self._is_compiled = True
421 # Prepare list of loss functions, same size of model outputs.
422 self.loss_functions = training_utils_v1.prepare_loss_functions(
423 self.loss, self.output_names)
425 target_tensors = self._process_target_tensor_for_compile(target_tensors)
427 for o, n, l, t in zip(self.outputs, self.output_names,
428 self.loss_functions, target_tensors):
429 endpoint = _TrainingEndpoint(o, n, l)
430 endpoint.create_training_target(t, run_eagerly=self.run_eagerly)
431 self._training_endpoints.append(endpoint)
433 # Prepare list loss weights, same size of model outputs.
434 training_utils_v1.prepare_loss_weights(self._training_endpoints,
435 loss_weights)
437 # Initialization for Eager mode execution.
438 if self.run_eagerly:
439 self._compile_eagerly(metrics, weighted_metrics, sample_weight_mode)
440 return
442 with backend.get_graph().as_default():
443 # Save all metric attributes per output of the model.
444 self._cache_output_metric_attributes(metrics, weighted_metrics)
446 # Set metric attributes on model.
447 self._set_metric_attributes()
449 # Invoke metric functions (unweighted) for all the outputs.
450 self._handle_metrics(
451 self.outputs,
452 targets=self._targets,
453 skip_target_masks=self._prepare_skip_target_masks(),
454 masks=self._prepare_output_masks())
456 # Prepare sample weight modes. List with the same length as model outputs.
457 training_utils_v1.prepare_sample_weight_modes(
458 self._training_endpoints, sample_weight_mode)
460 # Creates the model loss and weighted metrics sub-graphs.
461 self._compile_weights_loss_and_weighted_metrics()
463 # Functions for train, test and predict will
464 # be compiled lazily when required.
465 # This saves time when the user is not using all functions.
466 self.train_function = None
467 self.test_function = None
468 self.predict_function = None
470 # Collected trainable weights, sorted in topological order.
471 self._collected_trainable_weights = self.trainable_weights
473 # Validate all variables were correctly created in distribution scope.
474 if self._distribution_strategy and not self._compile_distribution:
475 for v in self.variables:
476 strategy = self._distribution_strategy
477 if not strategy.extended.variable_created_in_scope(v):
478 raise ValueError(
479 'Variable (%s) was not created in the distribution strategy '
480 'scope of (%s). It is most likely due to not all layers or '
481 'the model or optimizer being created outside the distribution '
482 'strategy scope. Try to make sure your code looks similar '
483 'to the following.\n'
484 'with strategy.scope():\n'
485 ' model=_create_model()\n'
486 ' model.compile(...)'% (v, strategy))
488 @trackable.no_automatic_dependency_tracking
489 def _init_distributed_function_cache_if_not_compiled(self):
490 if not hasattr(self, '_distributed_function_cache'):
491 self._distributed_function_cache = {}
493 @property
494 def metrics(self):
495 """Returns the model's metrics added using `compile`, `add_metric` APIs."""
496 metrics = []
497 if self._is_compiled:
498 if not hasattr(self, '_v1_compile_was_called'):
499 # See b/155687393 for more details, the model is created as a v2
500 # instance but converted to v1. Fallback to use base Model to retrieve
501 # the metrics.
502 return super(Model, self).metrics
503 metrics += self._compile_metric_functions
504 metrics.extend(self._metrics)
505 metrics.extend(
506 _get_metrics_from_layers(
507 list(self._flatten_layers(include_self=False, recursive=False))))
508 return metrics
510 @property
511 def metrics_names(self):
512 """Returns the model's display labels for all outputs."""
514 # This property includes all output names including `loss` and per-output
515 # losses for backward compatibility.
516 metrics_names = ['loss']
517 if self._is_compiled:
518 if not hasattr(self, '_v1_compile_was_called'):
519 # See b/155687393 for more details, the model is created as a v2
520 # instance but converted to v1. Fallback to use base Model to retrieve
521 # the metrics name
522 return super(Model, self).metrics_names
524 # Add output loss metric names to the metric names list.
525 if len(self._training_endpoints) > 1:
526 metrics_names.extend([
527 e.loss_name()
528 for e in self._training_endpoints
529 if not e.should_skip_target()
530 ])
532 # Add all metric names.
533 metrics_names += [m.name for m in self.metrics]
534 return metrics_names
536 @property
537 def run_eagerly(self):
538 """Settable attribute indicating whether the model should run eagerly.
540 Running eagerly means that your model will be run step by step,
541 like Python code. Your model might run slower, but it should become easier
542 for you to debug it by stepping into individual layer calls.
544 By default, we will attempt to compile your model to a static graph to
545 deliver the best execution performance.
547 Returns:
548 Boolean, whether the model should run eagerly.
549 """
550 if self._run_eagerly is True and not context.executing_eagerly():
551 raise ValueError('You can only set `run_eagerly=True` if eager execution '
552 'is enabled.')
553 if not self.dynamic:
554 if self._run_eagerly is None:
555 # Respect `tf.config.run_functions_eagerly` unless
556 # `run_eagerly` was explicitly passed to `compile`.
557 return def_function.functions_run_eagerly()
558 else:
559 return self._run_eagerly
560 else:
561 if not context.executing_eagerly():
562 raise ValueError('Your model contains layers that can only be '
563 'successfully run in eager execution (layers '
564 'constructed with `dynamic=True`). '
565 'You must enable eager execution with '
566 '`tf.enable_eager_execution()`.')
567 if self._run_eagerly is False:
568 # TODO(fchollet): consider using py_func to enable this.
569 raise ValueError('Your model contains layers that can only be '
570 'successfully run in eager execution (layers '
571 'constructed with `dynamic=True`). '
572 'You cannot set `run_eagerly=False`.')
573 return context.executing_eagerly()
575 @run_eagerly.setter
576 def run_eagerly(self, value):
577 self._run_eagerly = value
579 def _select_training_loop(self, inputs):
580 """Select training loop for fit/eval/predict based on the inputs."""
581 # TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely
582 # integrated into the data adapters in the v2 loop. We can't do this yet
583 # because we currently have to fall back for unhandled data types.
584 if isinstance(inputs, (iterator_ops.Iterator,
585 iterator_ops.IteratorBase)):
586 raise ValueError('For performance reasons Keras `fit`, `evaluate` and'
587 '`predict` accept tf.data `Datasets` as input but not '
588 'iterators that have been manually generated from '
589 'Datasets by users. Please directly pass in the '
590 'original `Dataset` object instead of passing in '
591 '`iter(dataset)`.')
593 # Case 1: distribution strategy.
594 if self._distribution_strategy:
595 if self._in_multi_worker_mode():
596 return training_distributed_v1.DistributionMultiWorkerTrainingLoop(
597 training_distributed_v1.DistributionSingleWorkerTrainingLoop())
598 else:
599 return training_distributed_v1.DistributionSingleWorkerTrainingLoop()
601 # Case 2: generator-like. Input is Python generator, or Sequence object,
602 # or a non-distributed Dataset or iterator in eager execution.
603 if data_utils.is_generator_or_sequence(inputs):
604 return training_generator_v1.GeneratorOrSequenceTrainingLoop()
605 if training_utils_v1.is_eager_dataset_or_iterator(inputs):
606 return training_generator_v1.EagerDatasetOrIteratorTrainingLoop()
608 # Case 3: Symbolic tensors or Numpy array-like.
609 # This includes Datasets and iterators in graph mode (since they
610 # generate symbolic tensors).
611 if self.run_eagerly:
612 return training_generator_v1.GeneratorLikeTrainingLoop()
613 else:
614 return training_arrays_v1.ArrayLikeTrainingLoop()
616 def fit(self,
617 x=None,
618 y=None,
619 batch_size=None,
620 epochs=1,
621 verbose=1,
622 callbacks=None,
623 validation_split=0.,
624 validation_data=None,
625 shuffle=True,
626 class_weight=None,
627 sample_weight=None,
628 initial_epoch=0,
629 steps_per_epoch=None,
630 validation_steps=None,
631 validation_freq=1,
632 max_queue_size=10,
633 workers=1,
634 use_multiprocessing=False,
635 **kwargs):
636 """Trains the model for a fixed number of epochs (iterations on a dataset).
638 Args:
639 x: Input data. It could be:
640 - A Numpy array (or array-like), or a list of arrays
641 (in case the model has multiple inputs).
642 - A TensorFlow tensor, or a list of tensors
643 (in case the model has multiple inputs).
644 - A dict mapping input names to the corresponding array/tensors,
645 if the model has named inputs.
646 - A `tf.data` dataset. Should return a tuple
647 of either `(inputs, targets)` or
648 `(inputs, targets, sample_weights)`.
649 - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
650 or `(inputs, targets, sample weights)`.
651 y: Target data. Like the input data `x`,
652 it could be either Numpy array(s) or TensorFlow tensor(s).
653 It should be consistent with `x` (you cannot have Numpy inputs and
654 tensor targets, or inversely). If `x` is a dataset, generator,
655 or `keras.utils.Sequence` instance, `y` should
656 not be specified (since targets will be obtained from `x`).
657 batch_size: Integer or `None`.
658 Number of samples per gradient update.
659 If unspecified, `batch_size` will default to 32.
660 Do not specify the `batch_size` if your data is in the
661 form of symbolic tensors, datasets,
662 generators, or `keras.utils.Sequence` instances (since they generate
663 batches).
664 epochs: Integer. Number of epochs to train the model.
665 An epoch is an iteration over the entire `x` and `y`
666 data provided.
667 Note that in conjunction with `initial_epoch`,
668 `epochs` is to be understood as "final epoch".
669 The model is not trained for a number of iterations
670 given by `epochs`, but merely until the epoch
671 of index `epochs` is reached.
672 verbose: 0, 1, or 2. Verbosity mode.
673 0 = silent, 1 = progress bar, 2 = one line per epoch.
674 Note that the progress bar is not particularly useful when
675 logged to a file, so verbose=2 is recommended when not running
676 interactively (eg, in a production environment).
677 callbacks: List of `keras.callbacks.Callback` instances.
678 List of callbacks to apply during training.
679 See `tf.keras.callbacks`.
680 validation_split: Float between 0 and 1.
681 Fraction of the training data to be used as validation data.
682 The model will set apart this fraction of the training data,
683 will not train on it, and will evaluate
684 the loss and any model metrics
685 on this data at the end of each epoch.
686 The validation data is selected from the last samples
687 in the `x` and `y` data provided, before shuffling. This argument is
688 not supported when `x` is a dataset, generator or
689 `keras.utils.Sequence` instance.
690 validation_data: Data on which to evaluate
691 the loss and any model metrics at the end of each epoch.
692 The model will not be trained on this data.
693 `validation_data` will override `validation_split`.
694 `validation_data` could be:
695 - tuple `(x_val, y_val)` of Numpy arrays or tensors
696 - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
697 - dataset
698 For the first two cases, `batch_size` must be provided.
699 For the last case, `validation_steps` could be provided.
700 shuffle: Boolean (whether to shuffle the training data
701 before each epoch) or str (for 'batch').
702 'batch' is a special option for dealing with the
703 limitations of HDF5 data; it shuffles in batch-sized chunks.
704 Has no effect when `steps_per_epoch` is not `None`.
705 class_weight: Optional dictionary mapping class indices (integers)
706 to a weight (float) value, used for weighting the loss function
707 (during training only).
708 This can be useful to tell the model to
709 "pay more attention" to samples from
710 an under-represented class.
711 sample_weight: Optional Numpy array of weights for
712 the training samples, used for weighting the loss function
713 (during training only). You can either pass a flat (1D)
714 Numpy array with the same length as the input samples
715 (1:1 mapping between weights and samples),
716 or in the case of temporal data,
717 you can pass a 2D array with shape
718 `(samples, sequence_length)`,
719 to apply a different weight to every timestep of every sample.
720 In this case you should make sure to specify
721 `sample_weight_mode="temporal"` in `compile()`. This argument is not
722 supported when `x` is a dataset, generator, or
723 `keras.utils.Sequence` instance, instead provide the sample_weights
724 as the third element of `x`.
725 initial_epoch: Integer.
726 Epoch at which to start training
727 (useful for resuming a previous training run).
728 steps_per_epoch: Integer or `None`.
729 Total number of steps (batches of samples)
730 before declaring one epoch finished and starting the
731 next epoch. When training with input tensors such as
732 TensorFlow data tensors, the default `None` is equal to
733 the number of samples in your dataset divided by
734 the batch size, or 1 if that cannot be determined. If x is a
735 `tf.data` dataset, and 'steps_per_epoch'
736 is None, the epoch will run until the input dataset is exhausted.
737 This argument is not supported with array inputs.
738 validation_steps: Only relevant if `validation_data` is provided and
739 is a `tf.data` dataset. Total number of steps (batches of
740 samples) to draw before stopping when performing validation
741 at the end of every epoch. If 'validation_steps' is None, validation
742 will run until the `validation_data` dataset is exhausted. In the
743 case of a infinite dataset, it will run into a infinite loop.
744 If 'validation_steps' is specified and only part of the dataset
745 will be consumed, the evaluation will start from the beginning of
746 the dataset at each epoch. This ensures that the same validation
747 samples are used every time.
748 validation_freq: Only relevant if validation data is provided. Integer
749 or `collections.abc.Container` instance (e.g. list, tuple, etc.).
750 If an integer, specifies how many training epochs to run before a
751 new validation run is performed, e.g. `validation_freq=2` runs
752 validation every 2 epochs. If a Container, specifies the epochs on
753 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
754 validation at the end of the 1st, 2nd, and 10th epochs.
755 max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
756 input only. Maximum size for the generator queue.
757 If unspecified, `max_queue_size` will default to 10.
758 workers: Integer. Used for generator or `keras.utils.Sequence` input
759 only. Maximum number of processes to spin up
760 when using process-based threading. If unspecified, `workers`
761 will default to 1. If 0, will execute the generator on the main
762 thread.
763 use_multiprocessing: Boolean. Used for generator or
764 `keras.utils.Sequence` input only. If `True`, use process-based
765 threading. If unspecified, `use_multiprocessing` will default to
766 `False`. Note that because this implementation relies on
767 multiprocessing, you should not pass non-picklable arguments to
768 the generator as they can't be passed easily to children processes.
769 **kwargs: Used for backwards compatibility.
771 Returns:
772 A `History` object. Its `History.history` attribute is
773 a record of training loss values and metrics values
774 at successive epochs, as well as validation loss values
775 and validation metrics values (if applicable).
777 Raises:
778 RuntimeError: If the model was never compiled.
779 ValueError: In case of mismatch between the provided input data
780 and what the model expects.
781 """
782 self._assert_built_as_v1()
783 # Legacy support
784 if 'nb_epoch' in kwargs:
785 logging.warning(
786 'The `nb_epoch` argument in `fit` has been renamed `epochs`.')
787 epochs = kwargs.pop('nb_epoch')
788 if kwargs:
789 raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
790 self._assert_compile_was_called()
791 self._check_call_args('fit')
793 func = self._select_training_loop(x)
794 return func.fit(
795 self,
796 x=x,
797 y=y,
798 batch_size=batch_size,
799 epochs=epochs,
800 verbose=verbose,
801 callbacks=callbacks,
802 validation_split=validation_split,
803 validation_data=validation_data,
804 shuffle=shuffle,
805 class_weight=class_weight,
806 sample_weight=sample_weight,
807 initial_epoch=initial_epoch,
808 steps_per_epoch=steps_per_epoch,
809 validation_steps=validation_steps,
810 validation_freq=validation_freq,
811 max_queue_size=max_queue_size,
812 workers=workers,
813 use_multiprocessing=use_multiprocessing)
815 def evaluate(self,
816 x=None,
817 y=None,
818 batch_size=None,
819 verbose=1,
820 sample_weight=None,
821 steps=None,
822 callbacks=None,
823 max_queue_size=10,
824 workers=1,
825 use_multiprocessing=False):
826 """Returns the loss value & metrics values for the model in test mode.
828 Computation is done in batches (see the `batch_size` arg.)
830 Args:
831 x: Input data. It could be:
832 - A Numpy array (or array-like), or a list of arrays
833 (in case the model has multiple inputs).
834 - A TensorFlow tensor, or a list of tensors
835 (in case the model has multiple inputs).
836 - A dict mapping input names to the corresponding array/tensors,
837 if the model has named inputs.
838 - A `tf.data` dataset.
839 - A generator or `keras.utils.Sequence` instance.
840 y: Target data. Like the input data `x`,
841 it could be either Numpy array(s) or TensorFlow tensor(s).
842 It should be consistent with `x` (you cannot have Numpy inputs and
843 tensor targets, or inversely).
844 If `x` is a dataset, generator or
845 `keras.utils.Sequence` instance, `y` should not be specified (since
846 targets will be obtained from the iterator/dataset).
847 batch_size: Integer or `None`.
848 Number of samples per batch of computation.
849 If unspecified, `batch_size` will default to 32.
850 Do not specify the `batch_size` if your data is in the
851 form of symbolic tensors, dataset,
852 generators, or `keras.utils.Sequence` instances (since they generate
853 batches).
854 verbose: 0 or 1. Verbosity mode.
855 0 = silent, 1 = progress bar.
856 sample_weight: Optional Numpy array of weights for
857 the test samples, used for weighting the loss function.
858 You can either pass a flat (1D)
859 Numpy array with the same length as the input samples
860 (1:1 mapping between weights and samples),
861 or in the case of temporal data,
862 you can pass a 2D array with shape
863 `(samples, sequence_length)`,
864 to apply a different weight to every timestep of every sample.
865 In this case you should make sure to specify
866 `sample_weight_mode="temporal"` in `compile()`. This argument is not
867 supported when `x` is a dataset, instead pass
868 sample weights as the third element of `x`.
869 steps: Integer or `None`.
870 Total number of steps (batches of samples)
871 before declaring the evaluation round finished.
872 Ignored with the default value of `None`.
873 If x is a `tf.data` dataset and `steps` is
874 None, 'evaluate' will run until the dataset is exhausted.
875 This argument is not supported with array inputs.
876 callbacks: List of `keras.callbacks.Callback` instances.
877 List of callbacks to apply during evaluation.
878 See [callbacks](/api_docs/python/tf/keras/callbacks).
879 max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
880 input only. Maximum size for the generator queue.
881 If unspecified, `max_queue_size` will default to 10.
882 workers: Integer. Used for generator or `keras.utils.Sequence` input
883 only. Maximum number of processes to spin up when using
884 process-based threading. If unspecified, `workers` will default
885 to 1. If 0, will execute the generator on the main thread.
886 use_multiprocessing: Boolean. Used for generator or
887 `keras.utils.Sequence` input only. If `True`, use process-based
888 threading. If unspecified, `use_multiprocessing` will default to
889 `False`. Note that because this implementation relies on
890 multiprocessing, you should not pass non-picklable arguments to
891 the generator as they can't be passed easily to children processes.
893 Returns:
894 Scalar test loss (if the model has a single output and no metrics)
895 or list of scalars (if the model has multiple outputs
896 and/or metrics). The attribute `model.metrics_names` will give you
897 the display labels for the scalar outputs.
899 Raises:
900 ValueError: in case of invalid arguments.
901 """
902 self._assert_built_as_v1()
903 self._assert_compile_was_called()
904 self._check_call_args('evaluate')
906 func = self._select_training_loop(x)
907 return func.evaluate(
908 self,
909 x=x,
910 y=y,
911 batch_size=batch_size,
912 verbose=verbose,
913 sample_weight=sample_weight,
914 steps=steps,
915 callbacks=callbacks,
916 max_queue_size=max_queue_size,
917 workers=workers,
918 use_multiprocessing=use_multiprocessing)
920 def predict(self,
921 x,
922 batch_size=None,
923 verbose=0,
924 steps=None,
925 callbacks=None,
926 max_queue_size=10,
927 workers=1,
928 use_multiprocessing=False):
929 """Generates output predictions for the input samples.
931 Computation is done in batches (see the `batch_size` arg.)
933 Args:
934 x: Input samples. It could be:
935 - A Numpy array (or array-like), or a list of arrays
936 (in case the model has multiple inputs).
937 - A TensorFlow tensor, or a list of tensors
938 (in case the model has multiple inputs).
939 - A `tf.data` dataset.
940 - A generator or `keras.utils.Sequence` instance.
941 batch_size: Integer or `None`.
942 Number of samples per batch of computation.
943 If unspecified, `batch_size` will default to 32.
944 Do not specify the `batch_size` if your data is in the
945 form of symbolic tensors, dataset,
946 generators, or `keras.utils.Sequence` instances (since they generate
947 batches).
948 verbose: Verbosity mode, 0 or 1.
949 steps: Total number of steps (batches of samples)
950 before declaring the prediction round finished.
951 Ignored with the default value of `None`. If x is a `tf.data`
952 dataset and `steps` is None, `predict` will
953 run until the input dataset is exhausted.
954 callbacks: List of `keras.callbacks.Callback` instances.
955 List of callbacks to apply during prediction.
956 See [callbacks](/api_docs/python/tf/keras/callbacks).
957 max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
958 input only. Maximum size for the generator queue.
959 If unspecified, `max_queue_size` will default to 10.
960 workers: Integer. Used for generator or `keras.utils.Sequence` input
961 only. Maximum number of processes to spin up when using
962 process-based threading. If unspecified, `workers` will default
963 to 1. If 0, will execute the generator on the main thread.
964 use_multiprocessing: Boolean. Used for generator or
965 `keras.utils.Sequence` input only. If `True`, use process-based
966 threading. If unspecified, `use_multiprocessing` will default to
967 `False`. Note that because this implementation relies on
968 multiprocessing, you should not pass non-picklable arguments to
969 the generator as they can't be passed easily to children processes.
972 Returns:
973 Numpy array(s) of predictions.
975 Raises:
976 ValueError: In case of mismatch between the provided
977 input data and the model's expectations,
978 or in case a stateful model receives a number of samples
979 that is not a multiple of the batch size.
980 """
981 self._assert_built_as_v1()
982 self._check_call_args('predict')
984 func = self._select_training_loop(x)
985 return func.predict(
986 self,
987 x=x,
988 batch_size=batch_size,
989 verbose=verbose,
990 steps=steps,
991 callbacks=callbacks,
992 max_queue_size=max_queue_size,
993 workers=workers,
994 use_multiprocessing=use_multiprocessing)
996 def reset_metrics(self):
997 """Resets the state of metrics."""
998 metrics = self._get_training_eval_metrics()
999 for m in metrics:
1000 m.reset_state()
1002 # Reset metrics on all the distributed (cloned) models.
1003 if self._distribution_strategy:
1004 distributed_training_utils_v1._reset_metrics(self) # pylint: disable=protected-access
1006 def train_on_batch(self,
1007 x,
1008 y=None,
1009 sample_weight=None,
1010 class_weight=None,
1011 reset_metrics=True):
1012 """Runs a single gradient update on a single batch of data.
1014 Args:
1015 x: Input data. It could be:
1016 - A Numpy array (or array-like), or a list of arrays
1017 (in case the model has multiple inputs).
1018 - A TensorFlow tensor, or a list of tensors
1019 (in case the model has multiple inputs).
1020 - A dict mapping input names to the corresponding array/tensors,
1021 if the model has named inputs.
1022 - A `tf.data` dataset.
1023 y: Target data. Like the input data `x`, it could be either Numpy
1024 array(s) or TensorFlow tensor(s). It should be consistent with `x`
1025 (you cannot have Numpy inputs and tensor targets, or inversely). If
1026 `x` is a dataset, `y` should not be specified
1027 (since targets will be obtained from the iterator).
1028 sample_weight: Optional array of the same length as x, containing
1029 weights to apply to the model's loss for each sample. In the case of
1030 temporal data, you can pass a 2D array with shape (samples,
1031 sequence_length), to apply a different weight to every timestep of
1032 every sample. In this case you should make sure to specify
1033 sample_weight_mode="temporal" in compile(). This argument is not
1034 supported when `x` is a dataset.
1035 class_weight: Optional dictionary mapping class indices (integers) to a
1036 weight (float) to apply to the model's loss for the samples from this
1037 class during training. This can be useful to tell the model to "pay
1038 more attention" to samples from an under-represented class.
1039 reset_metrics: If `True`, the metrics returned will be only for this
1040 batch. If `False`, the metrics will be statefully accumulated across
1041 batches.
1043 Returns:
1044 Scalar training loss
1045 (if the model has a single output and no metrics)
1046 or list of scalars (if the model has multiple outputs
1047 and/or metrics). The attribute `model.metrics_names` will give you
1048 the display labels for the scalar outputs.
1050 Raises:
1051 ValueError: In case of invalid user-provided arguments.
1052 """
1053 self._assert_compile_was_called()
1054 self._check_call_args('train_on_batch')
1056 # If at this point we are in the replica context, then it is okay to execute
1057 # the Eager code path. The expected way to get here is to call `fit` that
1058 # calls `train_on_batch` on each replica.
1059 if (self._distribution_strategy and
1060 distribute_lib.in_cross_replica_context()):
1061 raise NotImplementedError('`train_on_batch` is not supported for models '
1062 'distributed with tf.distribute.Strategy.')
1063 # Validate and standardize user data.
1064 x, y, sample_weights = self._standardize_user_data(
1065 x, y, sample_weight=sample_weight, class_weight=class_weight,
1066 extract_tensors_from_dataset=True)
1068 # If `self._distribution_strategy` is True, then we are in a replica context
1069 # at this point because of the check above. `train_on_batch` is being run
1070 # for each replica by `self._distribution_strategy` and the same code path
1071 # as Eager is expected to be taken.
1072 if self.run_eagerly or self._distribution_strategy:
1073 output_dict = training_eager_v1.train_on_batch(
1074 self,
1075 x,
1076 y,
1077 sample_weights=sample_weights,
1078 output_loss_metrics=self._output_loss_metrics)
1079 outputs = (output_dict['total_loss'] + output_dict['output_losses']
1080 + output_dict['metrics'])
1081 outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
1082 else:
1083 x = training_utils_v1.ModelInputs(x).as_list()
1084 ins = x + list(y or []) + list(sample_weights or [])
1086 if not isinstance(backend.symbolic_learning_phase(), int):
1087 ins += [True] # Add learning phase value.
1089 self._update_sample_weight_modes(sample_weights=sample_weights)
1090 self._make_train_function()
1091 outputs = self.train_function(ins) # pylint: disable=not-callable
1093 if reset_metrics:
1094 self.reset_metrics()
1096 if len(outputs) == 1:
1097 return outputs[0]
1098 return outputs
1100 def test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True):
1101 """Test the model on a single batch of samples.
1103 Args:
1104 x: Input data. It could be:
1105 - A Numpy array (or array-like), or a list of arrays
1106 (in case the model has multiple inputs).
1107 - A TensorFlow tensor, or a list of tensors
1108 (in case the model has multiple inputs).
1109 - A dict mapping input names to the corresponding array/tensors,
1110 if the model has named inputs.
1111 - A `tf.data` dataset.
1112 y: Target data. Like the input data `x`,
1113 it could be either Numpy array(s) or TensorFlow tensor(s).
1114 It should be consistent with `x` (you cannot have Numpy inputs and
1115 tensor targets, or inversely). If `x` is a dataset `y` should
1116 not be specified (since targets will be obtained from the iterator).
1117 sample_weight: Optional array of the same length as x, containing
1118 weights to apply to the model's loss for each sample.
1119 In the case of temporal data, you can pass a 2D array
1120 with shape (samples, sequence_length),
1121 to apply a different weight to every timestep of every sample.
1122 In this case you should make sure to specify
1123 sample_weight_mode="temporal" in compile(). This argument is not
1124 supported when `x` is a dataset.
1125 reset_metrics: If `True`, the metrics returned will be only for this
1126 batch. If `False`, the metrics will be statefully accumulated across
1127 batches.
1129 Returns:
1130 Scalar test loss (if the model has a single output and no metrics)
1131 or list of scalars (if the model has multiple outputs
1132 and/or metrics). The attribute `model.metrics_names` will give you
1133 the display labels for the scalar outputs.
1135 Raises:
1136 ValueError: In case of invalid user-provided arguments.
1137 """
1138 self._assert_compile_was_called()
1139 self._check_call_args('test_on_batch')
1141 if (self._distribution_strategy and
1142 distribute_lib.in_cross_replica_context()):
1143 raise NotImplementedError('`test_on_batch` is not supported for models '
1144 'distributed with tf.distribute.Strategy.')
1145 # Validate and standardize user data.
1146 x, y, sample_weights = self._standardize_user_data(
1147 x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True)
1149 # If `self._distribution_strategy` is True, then we are in a replica context
1150 # at this point.
1151 if self.run_eagerly or self._distribution_strategy:
1152 output_dict = training_eager_v1.test_on_batch(
1153 self,
1154 x,
1155 y,
1156 sample_weights=sample_weights,
1157 output_loss_metrics=self._output_loss_metrics)
1158 outputs = (output_dict['total_loss'] + output_dict['output_losses']
1159 + output_dict['metrics'])
1160 outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
1161 else:
1162 x = training_utils_v1.ModelInputs(x).as_list()
1163 inputs = x + list(y or []) + list(sample_weights or [])
1165 self._update_sample_weight_modes(sample_weights=sample_weights)
1166 self._make_test_function()
1167 outputs = self.test_function(inputs) # pylint: disable=not-callable
1169 if reset_metrics:
1170 self.reset_metrics()
1172 if len(outputs) == 1:
1173 return outputs[0]
1174 return outputs
1176 def predict_on_batch(self, x):
1177 """Returns predictions for a single batch of samples.
1179 Args:
1180 x: Input data. It could be:
1181 - A Numpy array (or array-like), or a list of arrays
1182 (in case the model has multiple inputs).
1183 - A TensorFlow tensor, or a list of tensors
1184 (in case the model has multiple inputs).
1185 - A `tf.data` dataset.
1187 Returns:
1188 Numpy array(s) of predictions.
1190 Raises:
1191 ValueError: In case of mismatch between given number of inputs and
1192 expectations of the model.
1193 """
1194 self._check_call_args('predict_on_batch')
1196 if (self._distribution_strategy and
1197 distribute_lib.in_cross_replica_context()):
1198 raise NotImplementedError(
1199 '`predict_on_batch` is not supported for models distributed with'
1200 ' tf.distribute.Strategy.')
1201 # Validate and standardize user data.
1202 inputs, _, _ = self._standardize_user_data(
1203 x, extract_tensors_from_dataset=True)
1204 # If `self._distribution_strategy` is True, then we are in a replica context
1205 # at this point.
1206 if self.run_eagerly or self._distribution_strategy:
1207 inputs = training_utils_v1.cast_if_floating_dtype(inputs)
1208 if isinstance(inputs, collections.abc.Sequence):
1209 # Unwrap lists with only one input, as we do when training on batch
1210 if len(inputs) == 1:
1211 inputs = inputs[0]
1213 return self(inputs) # pylint: disable=not-callable
1215 self._make_predict_function()
1216 outputs = self.predict_function(inputs)
1218 if len(outputs) == 1:
1219 return outputs[0]
1220 return outputs
1222 def fit_generator(self,
1223 generator,
1224 steps_per_epoch=None,
1225 epochs=1,
1226 verbose=1,
1227 callbacks=None,
1228 validation_data=None,
1229 validation_steps=None,
1230 validation_freq=1,
1231 class_weight=None,
1232 max_queue_size=10,
1233 workers=1,
1234 use_multiprocessing=False,
1235 shuffle=True,
1236 initial_epoch=0):
1237 """Fits the model on data yielded batch-by-batch by a Python generator.
1239 DEPRECATED:
1240 `Model.fit` now supports generators, so there is no longer any need to use
1241 this endpoint.
1242 """
1243 warnings.warn('`model.fit_generator` is deprecated and '
1244 'will be removed in a future version. '
1245 'Please use `Model.fit`, which supports generators.')
1246 return self.fit(
1247 generator,
1248 steps_per_epoch=steps_per_epoch,
1249 epochs=epochs,
1250 verbose=verbose,
1251 callbacks=callbacks,
1252 validation_data=validation_data,
1253 validation_steps=validation_steps,
1254 validation_freq=validation_freq,
1255 class_weight=class_weight,
1256 max_queue_size=max_queue_size,
1257 workers=workers,
1258 use_multiprocessing=use_multiprocessing,
1259 shuffle=shuffle,
1260 initial_epoch=initial_epoch)
1262 def evaluate_generator(self,
1263 generator,
1264 steps=None,
1265 callbacks=None,
1266 max_queue_size=10,
1267 workers=1,
1268 use_multiprocessing=False,
1269 verbose=0):
1270 """Evaluates the model on a data generator.
1272 DEPRECATED:
1273 `Model.evaluate` now supports generators, so there is no longer any need
1274 to use this endpoint.
1275 """
1276 warnings.warn('`Model.evaluate_generator` is deprecated and '
1277 'will be removed in a future version. '
1278 'Please use `Model.evaluate`, which supports generators.')
1279 self._check_call_args('evaluate_generator')
1281 return self.evaluate(
1282 generator,
1283 steps=steps,
1284 max_queue_size=max_queue_size,
1285 workers=workers,
1286 use_multiprocessing=use_multiprocessing,
1287 verbose=verbose,
1288 callbacks=callbacks)
1290 def predict_generator(self,
1291 generator,
1292 steps=None,
1293 callbacks=None,
1294 max_queue_size=10,
1295 workers=1,
1296 use_multiprocessing=False,
1297 verbose=0):
1298 """Generates predictions for the input samples from a data generator.
1300 DEPRECATED:
1301 `Model.predict` now supports generators, so there is no longer any need
1302 to use this endpoint.
1303 """
1304 warnings.warn('`Model.predict_generator` is deprecated and '
1305 'will be removed in a future version. '
1306 'Please use `Model.predict`, which supports generators.')
1307 return self.predict(
1308 generator,
1309 steps=steps,
1310 max_queue_size=max_queue_size,
1311 workers=workers,
1312 use_multiprocessing=use_multiprocessing,
1313 verbose=verbose,
1314 callbacks=callbacks)
1316 def _check_call_args(self, method_name):
1317 """Check that `call` has only one positional arg."""
1318 # Always allow first arg, regardless of arg name.
1319 fullargspec = self._call_full_argspec
1320 if fullargspec.defaults:
1321 positional_args = fullargspec.args[:-len(fullargspec.defaults)]
1322 else:
1323 positional_args = fullargspec.args
1324 if 'training' in positional_args:
1325 positional_args.remove('training')
1327 # self and first arg can be positional.
1328 if len(positional_args) > 2:
1329 extra_args = positional_args[2:]
1330 raise ValueError(
1331 'Models passed to `' + method_name + '` can only have `training` '
1332 'and the first argument in `call` as positional arguments, '
1333 'found: ' + str(extra_args) + '.')
1335 def _set_optimizer(self, optimizer):
1336 """Sets self.optimizer.
1338 Sets self.optimizer to `optimizer`, potentially wrapping it with a
1339 LossScaleOptimizer.
1341 Args:
1342 optimizer: The optimizer(s) to assign to self.optimizer.
1343 """
1344 if isinstance(optimizer, (list, tuple)):
1345 self.optimizer = [optimizers.get(opt) for opt in optimizer]
1346 else:
1347 self.optimizer = optimizers.get(optimizer)
1349 if isinstance(self._dtype_policy, policy.PolicyV1):
1350 loss_scale = self._dtype_policy.loss_scale
1351 elif self._dtype_policy.name == 'mixed_float16':
1352 loss_scale = 'dynamic'
1353 else:
1354 loss_scale = None
1356 if (loss_scale is not None and
1357 not isinstance(self.optimizer,
1358 loss_scale_optimizer.LossScaleOptimizer)):
1359 if isinstance(self.optimizer, list):
1360 raise ValueError('When a dtype policy with a loss scale is used, you '
1361 'can only pass a single optimizer. Using policy %s '
1362 'and got optimizers: %s' %
1363 self._dtype_policy, self.optimizer)
1364 if not isinstance(self.optimizer, optimizer_v2.OptimizerV2):
1365 raise ValueError('"optimizer" must be an instance of '
1366 'tf.keras.optimizers.Optimizer when a dype policy '
1367 'with a loss scale used, but got: %s. Using policy: '
1368 '%s' %
1369 (self.optimizer, self._dtype_policy))
1370 if loss_scale == 'dynamic':
1371 self.optimizer = loss_scale_optimizer.LossScaleOptimizer(self.optimizer)
1372 else:
1373 self.optimizer = loss_scale_optimizer.LossScaleOptimizerV1(
1374 self.optimizer, loss_scale)
1376 def _prepare_validation_data(self, validation_data, batch_size,
1377 validation_steps):
1378 """Unpack and check the validation data."""
1379 val_x, val_y, val_sample_weights = training_utils_v1.unpack_validation_data(
1380 validation_data)
1381 return self._standardize_user_data(
1382 val_x,
1383 val_y,
1384 sample_weight=val_sample_weights,
1385 batch_size=batch_size,
1386 steps=validation_steps,
1387 steps_name='validation_steps')
1389 def _validate_compile_param_for_distribution_strategy(
1390 self, run_eagerly, sample_weight_mode, target_tensors, weighted_metrics):
1391 # Validate that arguments passed by the user to `compile` are supported by
1392 # tf.distribute.Strategy.
1393 if self._distribution_strategy:
1394 if sample_weight_mode:
1395 raise NotImplementedError('sample_weight_mode is not supported with '
1396 'tf.distribute.Strategy.')
1397 if weighted_metrics:
1398 raise NotImplementedError('weighted_metrics is not supported with '
1399 'tf.distribute.Strategy.')
1400 if target_tensors:
1401 raise ValueError('target_tensors is not supported with '
1402 'tf.distribute.Strategy.')
1404 if run_eagerly:
1405 raise ValueError(
1406 'We currently do not support enabling `run_eagerly` with '
1407 'distribution strategy.')
1409 if (distributed_training_utils_v1.is_distributing_by_cloning(self) and
1410 (not self.built or not self.inputs or not self.outputs)):
1411 raise ValueError(
1412 'We currently do not support distribution strategy with a '
1413 '`Sequential` model that is created without `input_shape`/'
1414 '`input_dim` set in its first layer or a subclassed model.')
1416 def _process_target_tensor_for_compile(self, target_tensors):
1417 if self.run_eagerly:
1418 # target tensor is not supported with run_eagerly. Create a list with None
1419 # as placeholder for each output.
1420 return [None for _ in self.output_names]
1422 if target_tensors is not None and not (isinstance(target_tensors, list) and
1423 target_tensors == []): # pylint: disable=g-explicit-bool-comparison
1424 if isinstance(target_tensors, list):
1425 if len(target_tensors) != len(self.outputs):
1426 raise ValueError(
1427 'When passing a list as `target_tensors`, '
1428 'it should have one entry per model output. '
1429 'The model has %s outputs, but you passed target_tensors=%s' %
1430 (len(self.outputs), target_tensors))
1431 elif isinstance(target_tensors, dict):
1432 unexpected_target_tensor_names = set(target_tensors.keys()).difference(
1433 self.output_names)
1434 if unexpected_target_tensor_names:
1435 raise ValueError(
1436 'Unknown entry in `target_tensors` dictionary: "{name}". '
1437 'Only expected the following keys: {keys}'.format(
1438 name=unexpected_target_tensor_names,
1439 keys=str(self.output_names)))
1440 tmp_target_tensors = []
1441 for name in self.output_names:
1442 tmp_target_tensors.append(target_tensors.get(name, None))
1443 target_tensors = tmp_target_tensors
1444 elif tensor_util.is_tf_type(target_tensors):
1445 target_tensors = [target_tensors]
1446 else:
1447 raise TypeError('Expected `target_tensors` to be a list or tuple or '
1448 'dict or a single tensor, but got:', target_tensors)
1449 else:
1450 # In case target tensor is empty or None, create a list with Nones
1451 # that has same length as self.output_names. With that, the None check of
1452 # target tensor can be skipped downstream.
1453 target_tensors = [None for _ in self.output_names]
1454 return target_tensors
1456 def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode):
1457 # Prepare sample weight modes. List with the same length as model outputs.
1458 training_utils_v1.prepare_sample_weight_modes(
1459 self._training_endpoints, sample_weight_mode)
1460 # Prepare sample weights.
1461 self._prepare_sample_weights()
1462 # Save all metric attributes per output of the model.
1463 self._cache_output_metric_attributes(metrics, weighted_metrics)
1464 self.total_loss = None
1465 # Set metric attributes on model.
1466 self._set_metric_attributes()
1468 self._collected_trainable_weights = self.trainable_weights
1470 def _update_sample_weight_modes(self, sample_weights=None):
1471 """Updates sample weight modes based on training/eval inputs.
1473 Sample weight placeholders will be created for all or no outputs
1474 based on whether sample_weight is provided for any output.
1476 If model contains `_sample_weight_modes` we check if the input
1477 `sample_weights` corresponds to the sample weight modes.
1478 1. Set sample weight mode to be 'temporal' for output i, if `compile`
1479 sample_weight_mode was set to `temporal` and sample weight inputs
1480 are given for one or more outputs.
1481 2. Set sample weight mode to be 'samplewise' for output i, if `compile`
1482 sample_weight_mode was not set and sample weight inputs are given for
1483 one or more outputs.
1484 3. Reset sample weight mode to None for output i if sample weight mode
1485 was set but there is no sample weight input.
1487 Args:
1488 sample_weights: List of sample weights of the same length as model outputs
1489 or None.
1490 """
1491 if not self._is_compiled:
1492 return
1493 if sample_weights and any(s is not None for s in sample_weights):
1494 for endpoint in self._training_endpoints:
1495 endpoint.sample_weight_mode = (
1496 endpoint.sample_weight_mode or 'samplewise')
1497 else:
1498 for endpoint in self._training_endpoints:
1499 endpoint.sample_weight_mode = None
1501 def _recompile_weights_loss_and_weighted_metrics(self):
1502 if not self._is_compiled:
1503 return False
1504 recompile = any(
1505 e.sample_weights_mismatch() for e in self._training_endpoints)
1507 if recompile:
1508 self._compile_weights_loss_and_weighted_metrics()
1509 return recompile
1511 @trackable.no_automatic_dependency_tracking
1512 def _compile_weights_loss_and_weighted_metrics(self, sample_weights=None):
1513 """Compiles the model loss and weighted metric sub-graphs.
1515 This may be used to set graph tensors as sample weights (instead of creating
1516 placeholders). This functionality is necessary for
1517 `tf.keras.estimator.model_to_estimator`, which calls Keras models in a v1
1518 graph, and creates iterator tensors for inputs, targets, and sample weights.
1520 Args:
1521 sample_weights: List of tensors to use as the sample weights. Must be the
1522 same length as the number of outputs. If left as `None`, placeholders
1523 are used instead.
1524 """
1525 with backend.get_graph().as_default():
1526 if sample_weights is not None:
1527 self._update_sample_weight_modes(sample_weights)
1528 self._prepare_sample_weights(sample_weights)
1530 masks = self._prepare_output_masks()
1532 # Compute weighted metrics.
1533 self._handle_metrics(
1534 self.outputs,
1535 targets=self._targets,
1536 skip_target_masks=self._prepare_skip_target_masks(),
1537 sample_weights=self.sample_weights,
1538 masks=masks,
1539 return_weighted_metrics=True)
1541 # Compute total loss.
1542 # Used to keep track of the total loss value (stateless).
1543 # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
1544 # loss_weight_2 * output_2_loss_fn(...) +
1545 # layer losses.
1546 self.total_loss = self._prepare_total_loss(masks)
1548 def _prepare_skip_target_masks(self):
1549 """Boolean mask for whether the target in the output list should be skipped.
1551 If the loss function corresponding to a model output is None, then this
1552 output will be skipped during total loss calculation and feed targets
1553 preparation.
1555 Returns:
1556 A boolean list for whether the corresponding target in the output list
1557 should be skipped during loss calculation.
1558 """
1559 return [l is None for l in self.loss_functions]
1561 def _prepare_output_masks(self):
1562 """Returns masks corresponding to model outputs."""
1563 return [getattr(x, '_keras_mask', None) for x in self.outputs]
1565 def _prepare_total_loss(self, masks):
1566 """Computes total loss from loss functions.
1568 Args:
1569 masks: List of mask values corresponding to each model output.
1571 Returns:
1572 A list of loss weights of python floats.
1574 Raises:
1575 TypeError: If model run_eagerly is True.
1576 """
1577 if self.run_eagerly:
1578 raise TypeError('total loss can not be computed when compiled with '
1579 'run_eagerly = True.')
1580 loss_list = []
1581 with backend.name_scope('loss'):
1582 for endpoint, mask in zip(self._training_endpoints, masks):
1583 if endpoint.should_skip_target():
1584 continue
1585 y_true = endpoint.training_target.target
1586 y_pred = endpoint.output
1587 loss_fn = endpoint.loss_fn
1588 loss_weight = endpoint.loss_weight
1589 loss_name = endpoint.loss_name()
1590 sample_weight = endpoint.sample_weight
1592 with backend.name_scope(loss_name):
1593 if mask is not None:
1594 mask = math_ops.cast(mask, y_pred.dtype)
1595 # Update weights with mask.
1596 if sample_weight is None:
1597 sample_weight = mask
1598 else:
1599 # Update dimensions of weights to match with mask if possible.
1600 mask, _, sample_weight = (
1601 losses_utils.squeeze_or_expand_dimensions(
1602 mask, sample_weight=sample_weight))
1603 sample_weight *= mask
1605 if hasattr(loss_fn, 'reduction'):
1606 per_sample_losses = loss_fn.call(y_true, y_pred)
1607 weighted_losses = losses_utils.compute_weighted_loss(
1608 per_sample_losses,
1609 sample_weight=sample_weight,
1610 reduction=losses_utils.ReductionV2.NONE)
1611 loss_reduction = loss_fn.reduction
1613 # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all
1614 # compile use cases.
1615 if loss_reduction == losses_utils.ReductionV2.AUTO:
1616 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
1618 # Compute the stateless loss value.
1619 output_loss = losses_utils.reduce_weighted_loss(
1620 weighted_losses, reduction=loss_reduction)
1621 else:
1622 # Compute the stateless loss value for a custom loss class.
1623 # Here we assume that the class takes care of loss reduction
1624 # because if this class returns a vector value we cannot
1625 # differentiate between use case where a custom optimizer
1626 # expects a vector loss value vs unreduced per-sample loss value.
1627 output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
1628 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
1630 if len(self.outputs) > 1:
1631 # Keep track of stateful result tensor for the loss.
1632 endpoint.output_loss_metric(output_loss)
1634 # Scale output loss for distribution. For custom losses we assume
1635 # reduction was mean.
1636 if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE:
1637 output_loss = losses_utils.scale_loss_for_distribution(output_loss)
1639 loss_list.append(loss_weight * output_loss)
1640 if not loss_list and not self.losses:
1641 raise ValueError('The model cannot be compiled '
1642 'because it has no loss to optimize.')
1644 # Add regularization penalties and other layer-specific losses.
1645 custom_losses = self.get_losses_for(None) + self.get_losses_for(
1646 self.inputs)
1647 if custom_losses:
1648 total_custom_loss = math_ops.add_n(
1649 losses_utils.cast_losses_to_common_dtype(custom_losses))
1650 loss_list.append(
1651 losses_utils.scale_loss_for_distribution(total_custom_loss))
1653 loss_list = losses_utils.cast_losses_to_common_dtype(loss_list)
1654 if loss_list:
1655 total_loss = math_ops.add_n(loss_list)
1656 else:
1657 total_loss = 0.
1658 return total_loss
1660 def _get_callback_model(self):
1661 """Returns the Callback Model for this Model."""
1663 if hasattr(self, '_replicated_model') and self._replicated_model:
1664 # When using training_distributed, we set the callback model
1665 # to an instance of the `DistributedModel` that we create in
1666 # the `compile` call. The `DistributedModel` is initialized
1667 # with the first replicated model. We need to set the callback
1668 # model to a DistributedModel to allow us to override saving
1669 # and loading weights when we checkpoint the model during training.
1670 return self._replicated_model
1671 if hasattr(self, 'callback_model') and self.callback_model:
1672 return self.callback_model
1673 return self
1675 @trackable.no_automatic_dependency_tracking
1676 def _make_callback_model(self, grouped_model):
1677 first_replicated_model = self._distribution_strategy.unwrap(
1678 grouped_model)[0]
1679 # We initialize the callback model with the first replicated model.
1680 self._replicated_model = DistributedCallbackModel(first_replicated_model)
1681 self._replicated_model.set_original_model(self)
1683 def _validate_or_infer_batch_size(self, batch_size, steps, x):
1684 """Validates that the `batch_size` provided is consistent with InputLayer.
1686 It's possible that the user specified a static batch size in their
1687 InputLayer. If so, this method checks the provided `batch_size` and `x`
1688 arguments are consistent with this static batch size. Also, if
1689 `batch_size` is `None`, this method will attempt to infer the batch size
1690 from the static batch size of the InputLayer. Lastly, ValueError will be
1691 raised if `x` is a tf.data.Dataset and `batch_size` is specified as we
1692 expect users to provide batched datasets.
1694 Args:
1695 batch_size: The batch_size provided as an argument to
1696 fit/evaluate/predict.
1697 steps: The steps provided as an argument to fit/evaluate/predict.
1698 x: The data passed as `x` to fit/evaluate/predict.
1700 Returns:
1701 The validated batch_size, auto-inferred from the first layer if not
1702 provided.
1703 """
1704 if (isinstance(x, (data_types.DatasetV1,
1705 data_types.DatasetV2,
1706 data_utils.Sequence)) or
1707 tf_inspect.isgenerator(x)):
1708 if batch_size is not None:
1709 raise ValueError(
1710 'The `batch_size` argument must not be specified for the given '
1711 'input type. Received input: {}, batch_size: {}'.format(
1712 x, batch_size))
1713 return
1715 # Avoids the override in Sequential.layers which filters Input layers.
1716 # (Which are often the very layers that we're after.)
1717 layers = self._flatten_layers(include_self=False, recursive=False)
1718 first_layer = next(layers, None)
1719 if first_layer:
1720 # The per-replica static batch size.
1721 static_batch_size = training_utils.get_static_batch_size(first_layer)
1722 if static_batch_size is not None:
1724 # Determine number of times the user-supplied batch size will be split.
1725 if (self._distribution_strategy and
1726 distributed_training_utils.global_batch_size_supported(
1727 self._distribution_strategy)):
1728 num_splits_for_ds = self._distribution_strategy.num_replicas_in_sync
1729 else:
1730 num_splits_for_ds = 1
1732 # Check `batch_size` argument is consistent with InputLayer.
1733 if batch_size is not None:
1734 if batch_size % num_splits_for_ds != 0:
1735 raise ValueError('The `batch_size` argument ({}) must be divisible '
1736 'the by number of replicas ({})'.format(
1737 batch_size, num_splits_for_ds))
1738 per_replica_batch_size = batch_size // num_splits_for_ds
1740 if per_replica_batch_size != static_batch_size:
1741 raise ValueError('The `batch_size` argument value {} is '
1742 'incompatible with the specified batch size of '
1743 'your Input Layer: {}'.format(
1744 per_replica_batch_size, static_batch_size))
1746 # Check Dataset/Iterator batch size is consistent with InputLayer.
1747 if isinstance(x, (data_types.DatasetV2, iterator_ops.Iterator,
1748 iterator_ops.IteratorBase)):
1749 ds_batch_size = tensor_shape.Dimension(
1750 nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
1751 if ds_batch_size is not None:
1752 if ds_batch_size % num_splits_for_ds != 0:
1753 raise ValueError(
1754 'The batch output shape of your `Dataset` {} '
1755 'cannot be divisible by number of replicas {}'.format(
1756 ds_batch_size, num_splits_for_ds))
1758 ds_per_replica_batch_size = ds_batch_size // num_splits_for_ds
1759 if ds_per_replica_batch_size != static_batch_size:
1760 raise ValueError('The batch output shape of your `Dataset` is '
1761 '{}, which is incompatible with the specified '
1762 'batch size of your Input Layer: {}'.format(
1763 ds_per_replica_batch_size,
1764 static_batch_size))
1766 # Set inferred batch size from the InputLayer.
1767 if steps is None:
1768 batch_size = static_batch_size * num_splits_for_ds
1770 if batch_size is None and steps is None:
1771 # Backwards compatibility
1772 batch_size = 32
1773 return batch_size
1775 def _prepare_sample_weights(self, sample_weights=None):
1776 """Sets sample weight attribute on the model."""
1777 # List with the same length as model outputs.
1778 if sample_weights is not None:
1779 if len(sample_weights) != len(self._training_endpoints):
1780 raise ValueError('Provided sample weights must have same length as the '
1781 'number of outputs. Expected: {}, got: {}.'.format(
1782 len(self._training_endpoints),
1783 len(sample_weights)))
1784 else:
1785 sample_weights = [None] * len(self._training_endpoints)
1786 for endpoint, weight in zip(self._training_endpoints, sample_weights):
1787 endpoint.populate_sample_weight(weight, endpoint.sample_weight_mode)
1789 def _cache_output_metric_attributes(self, metrics, weighted_metrics):
1790 """Caches metric name and function attributes for every model output."""
1791 output_shapes = []
1792 for output in self.outputs:
1793 if output is None or output.shape.rank is None:
1794 output_shapes.append(None)
1795 else:
1796 output_shapes.append(output.shape.as_list())
1797 self._per_output_metrics = training_utils_v1.collect_per_output_metric_info(
1798 metrics, self.output_names, output_shapes, self.loss_functions,
1799 from_serialized=self._from_serialized)
1800 self._per_output_weighted_metrics = (
1801 training_utils_v1.collect_per_output_metric_info(
1802 weighted_metrics,
1803 self.output_names,
1804 output_shapes,
1805 self.loss_functions,
1806 from_serialized=self._from_serialized,
1807 is_weighted=True))
1809 def _add_unique_metric_name(self, metric_name, metric_fn, output_index):
1810 """Makes the metric name unique.
1812 If there are multiple outputs for which the metrics are calculated, the
1813 metric names have to be made unique by appending an integer.
1815 Args:
1816 metric_name: Metric name that corresponds to the metric specified by the
1817 user. For example: 'acc'.
1818 metric_fn: The Metric object.
1819 output_index: The index of the model output for which the metric name is
1820 being added.
1822 Returns:
1823 string, name of the model's unique metric name
1824 """
1825 # For multi-output models, prepend the output names to the metric name.
1826 if len(self.output_names) > 1:
1827 # If we're loading from an already-serialized model, we've already
1828 # prepended the output name, and we don't want to do it again.
1829 #
1830 # Alternatively, we may be receiving a stateless metric (e.g. the string
1831 # "accuracy") rather than a `Metric` object, in which case we want to
1832 # prepend the output name even if we are loading a serialized model.
1833 if not getattr(metric_fn, '_from_serialized', False):
1834 metric_name = '%s_%s' % (self.output_names[output_index], metric_name)
1836 j = 1
1837 base_metric_name = metric_name
1838 while metric_name in self.metrics_names:
1839 metric_name = '%s_%d' % (base_metric_name, j)
1840 j += 1
1842 return metric_name
1844 def _init_metric_attributes(self):
1845 """Initialized model metric attributes."""
1846 # List of stateful metric functions. Used for resetting metric state during
1847 # training/eval.
1848 self._compile_metric_functions = []
1850 def _set_per_output_metric_attributes(self, metrics_dict, output_index):
1851 """Sets the metric attributes on the model for the given output.
1853 Args:
1854 metrics_dict: A dict with metric names as keys and metric fns as values.
1855 output_index: The index of the model output for which the metric
1856 attributes are added.
1858 Returns:
1859 Metrics dict updated with unique metric names as keys.
1860 """
1861 updated_metrics_dict = collections.OrderedDict()
1862 for metric_name, metric_fn in metrics_dict.items():
1863 metric_name = self._add_unique_metric_name(
1864 metric_name, metric_fn, output_index)
1866 # Update the name on the metric class to be the unique generated name.
1867 metric_fn._name = metric_name # pylint: disable=protected-access
1868 updated_metrics_dict[metric_name] = metric_fn
1869 # Keep track of metric name and function.
1870 self._compile_metric_functions.append(metric_fn)
1871 return updated_metrics_dict
1873 def _set_metric_attributes(self):
1874 """Sets the metric attributes on the model for all the model outputs."""
1875 updated_per_output_metrics = []
1876 updated_per_output_weighted_metrics = []
1877 for i, endpoint in enumerate(self._training_endpoints):
1878 if endpoint.should_skip_target():
1879 updated_per_output_metrics.append(self._per_output_metrics[i])
1880 updated_per_output_weighted_metrics.append(
1881 self._per_output_weighted_metrics[i])
1882 continue
1883 updated_per_output_metrics.append(
1884 self._set_per_output_metric_attributes(self._per_output_metrics[i],
1885 i))
1886 updated_per_output_weighted_metrics.append(
1887 self._set_per_output_metric_attributes(
1888 self._per_output_weighted_metrics[i], i))
1890 # Create a metric wrapper for each output loss. This computes mean of an
1891 # output loss across mini-batches (irrespective of how we reduce within a
1892 # batch).
1893 if len(self._training_endpoints) > 1:
1894 for endpoint in self._training_endpoints:
1895 if not endpoint.should_skip_target():
1896 endpoint.output_loss_metric = metrics_module.Mean(
1897 name=endpoint.loss_name())
1899 self._per_output_metrics = updated_per_output_metrics
1900 self._per_output_weighted_metrics = updated_per_output_weighted_metrics
1902 def _handle_per_output_metrics(self,
1903 metrics_dict,
1904 y_true,
1905 y_pred,
1906 mask,
1907 weights=None):
1908 """Calls metric functions for a single output.
1910 Args:
1911 metrics_dict: A dict with metric names as keys and metric fns as values.
1912 y_true: Target output.
1913 y_pred: Predicted output.
1914 mask: Computed mask value for the current output.
1915 weights: Weights to be applied on the current output.
1917 Returns:
1918 A list of metric result tensors.
1919 """
1920 metric_results = []
1921 for metric_name, metric_fn in metrics_dict.items():
1922 with backend.name_scope(metric_name):
1923 metric_result = training_utils_v1.call_metric_function(
1924 metric_fn, y_true, y_pred, weights=weights, mask=mask)
1925 metric_results.append(metric_result)
1926 return metric_results
1928 def _handle_metrics(self,
1929 outputs,
1930 targets=None,
1931 skip_target_masks=None,
1932 sample_weights=None,
1933 masks=None,
1934 return_weighted_metrics=False,
1935 return_weighted_and_unweighted_metrics=False):
1936 """Handles calling metric functions.
1938 Args:
1939 outputs: List of outputs (predictions).
1940 targets: List of targets.
1941 skip_target_masks: Optional. List of boolean for whether the corresponding
1942 target should be ignored or not.
1943 sample_weights: Optional list of sample weight arrays.
1944 masks: List of computed output mask values.
1945 return_weighted_metrics: Flag that indicates whether weighted metrics
1946 should be computed instead of unweighted metrics. This flag is ignored
1947 when `return_weighted_and_unweighted_metrics` is enabled.
1948 return_weighted_and_unweighted_metrics: Flag that is used to indicate
1949 whether both weighted and unweighted metrics should be computed. When
1950 this is not enabled, we use `return_weighted_metrics` param to indicate
1951 whether weighted or unweighted metrics should be returned.
1953 Returns:
1954 A list of metric result tensors.
1955 """
1956 # TODO(scottzhu): Update this to use the new training_endpoints. Currently
1957 # the eager and graph logic is bit different.
1958 skip_target_masks = skip_target_masks or [False] * len(outputs)
1959 metric_results = []
1960 with backend.name_scope('metrics'):
1961 # Invoke all metrics added using `compile`.
1962 for i in range(len(outputs)):
1963 if skip_target_masks[i]:
1964 continue
1965 output = outputs[i] if outputs else None
1966 target = targets[i] if targets else None
1967 output_mask = masks[i] if masks else None
1969 if (return_weighted_and_unweighted_metrics or
1970 not return_weighted_metrics):
1971 metric_results.extend(
1972 self._handle_per_output_metrics(self._per_output_metrics[i],
1973 target, output, output_mask))
1974 if return_weighted_and_unweighted_metrics or return_weighted_metrics:
1975 metric_results.extend(
1976 self._handle_per_output_metrics(
1977 self._per_output_weighted_metrics[i],
1978 target,
1979 output,
1980 output_mask,
1981 weights=sample_weights[i] if sample_weights else None))
1982 return metric_results
1984 def _check_trainable_weights_consistency(self):
1985 """Check trainable weights count consistency.
1987 This will raise a warning if `trainable_weights` and
1988 `_collected_trainable_weights` are inconsistent (i.e. have different
1989 number of parameters).
1990 Inconsistency will typically arise when one modifies `model.trainable`
1991 without calling `model.compile` again.
1992 """
1993 if not hasattr(self, '_collected_trainable_weights'):
1994 return
1996 if len(self.trainable_weights) != len(self._collected_trainable_weights):
1997 logging.log_first_n(
1998 logging.WARN, 'Discrepancy between trainable weights and collected'
1999 ' trainable weights, did you set `model.trainable`'
2000 ' without calling `model.compile` after ?', 1)
2002 def _make_train_function(self):
2003 has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
2004 self._check_trainable_weights_consistency()
2005 if isinstance(self.optimizer, list):
2006 raise ValueError('The `optimizer` in `compile` should be a single '
2007 'optimizer.')
2008 # If we have re-compiled the loss/weighted metric sub-graphs then create
2009 # train function even if one exists already. This is because
2010 # `_feed_sample_weights` list has been updated on re-compile.
2011 if getattr(self, 'train_function', None) is None or has_recompiled:
2012 # Restore the compiled trainable state.
2013 current_trainable_state = self._get_trainable_state()
2014 self._set_trainable_state(self._compiled_trainable_state)
2016 inputs = (self._feed_inputs +
2017 self._feed_targets +
2018 self._feed_sample_weights)
2019 if not isinstance(backend.symbolic_learning_phase(), int):
2020 inputs += [backend.symbolic_learning_phase()]
2022 with backend.get_graph().as_default():
2023 with backend.name_scope('training'):
2024 # Training updates
2025 updates = self.optimizer.get_updates(
2026 params=self._collected_trainable_weights, loss=self.total_loss)
2027 # Unconditional updates
2028 updates += self.get_updates_for(None)
2029 # Conditional updates relevant to this model
2030 updates += self.get_updates_for(self.inputs)
2032 metrics = self._get_training_eval_metrics()
2033 metrics_tensors = [
2034 m._call_result for m in metrics if hasattr(m, '_call_result') # pylint: disable=protected-access
2035 ]
2037 with backend.name_scope('training'):
2038 # Gets loss and metrics. Updates weights at each call.
2039 fn = backend.function(
2040 inputs, [self.total_loss] + metrics_tensors,
2041 updates=updates,
2042 name='train_function',
2043 **self._function_kwargs)
2044 setattr(self, 'train_function', fn)
2046 # Restore the current trainable state
2047 self._set_trainable_state(current_trainable_state)
2049 def _make_test_function(self):
2050 has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
2051 # If we have re-compiled the loss/weighted metric sub-graphs then create
2052 # test function even if one exists already. This is because
2053 # `_feed_sample_weights` list has been updated on re-compile.
2054 if getattr(self, 'test_function', None) is None or has_recompiled:
2055 inputs = (self._feed_inputs +
2056 self._feed_targets +
2057 self._feed_sample_weights)
2059 with backend.get_graph().as_default():
2060 metrics = self._get_training_eval_metrics()
2061 metrics_tensors = [
2062 m._call_result for m in metrics if hasattr(m, '_call_result') # pylint: disable=protected-access
2063 ]
2065 with backend.name_scope('evaluation'):
2066 updates = self.state_updates
2067 # Return loss and metrics, no gradient updates.
2068 # Does update the network states.
2069 fn = backend.function(
2070 inputs, [self.total_loss] + metrics_tensors,
2071 updates=updates,
2072 name='test_function',
2073 **self._function_kwargs)
2074 setattr(self, 'test_function', fn)
2076 def _make_predict_function(self):
2077 if not hasattr(self, 'predict_function'):
2078 self.predict_function = None
2079 if self.predict_function is None:
2080 inputs = self._feed_inputs
2081 # Gets network outputs. Does not update weights.
2082 # Does update the network states.
2083 kwargs = getattr(self, '_function_kwargs', {})
2084 with backend.name_scope(ModeKeys.PREDICT):
2085 self.predict_function = backend.function(
2086 inputs,
2087 self.outputs,
2088 updates=self.state_updates,
2089 name='predict_function',
2090 **kwargs)
2092 def _make_execution_function(self, mode):
2093 if mode == ModeKeys.TRAIN:
2094 self._make_train_function()
2095 return self.train_function
2096 if mode == ModeKeys.TEST:
2097 self._make_test_function()
2098 return self.test_function
2099 if mode == ModeKeys.PREDICT:
2100 self._make_predict_function()
2101 return self.predict_function
2103 def _distribution_standardize_user_data(self,
2104 x,
2105 y=None,
2106 sample_weight=None,
2107 class_weight=None,
2108 batch_size=None,
2109 validation_split=0,
2110 shuffle=False,
2111 epochs=1,
2112 allow_partial_batch=False):
2113 """Runs validation checks on input and target data passed by the user.
2115 This is called when using tf.distribute.Strategy to train, evaluate or serve
2116 the model.
2118 Args:
2119 x: Input data. A numpy array or `tf.data` dataset.
2120 y: Target data. A numpy array or None if x is a `tf.data` dataset.
2121 sample_weight: An optional sample-weight array passed by the user to
2122 weight the importance of each sample in `x`.
2123 class_weight: An optional class-weight array by the user to
2124 weight the importance of samples in `x` based on the class they belong
2125 to, as conveyed by `y`.
2126 batch_size: Integer batch size. If provided, it is used to run additional
2127 validation checks on stateful models.
2128 validation_split: Float between 0 and 1.
2129 Fraction of the training data to be used as validation data.
2130 shuffle: Boolean whether to shuffle the training data before each epoch.
2131 epochs: Integer epochs. If > 1, repeat the numpy training data epochs
2132 times when converting to training dataset.
2133 allow_partial_batch: Boolean whether to enforce that all batches have the
2134 same size.
2136 Returns:
2137 Dataset instance.
2139 Raises:
2140 ValueError: In case of invalid user-provided data.
2141 RuntimeError: If the model was never compiled.
2142 """
2143 if class_weight:
2144 raise NotImplementedError('`class_weight` is currently not supported '
2145 'when using tf.distribute.Strategy.')
2147 if (sample_weight is not None and sample_weight.all() and
2148 backend.is_tpu_strategy(self._distribution_strategy)):
2149 raise NotImplementedError('`sample_weight` is currently not supported '
2150 'when using TPUStrategy.')
2152 # Validates `steps` and `shuffle` arguments right at the beginning
2153 # since we use it to construct the dataset object.
2154 # TODO(anjalisridhar): Remove this check once we refactor the
2155 # _standardize_user_data code path. This check is already present elsewhere
2156 # in the codebase.
2157 if isinstance(x, data_types.DatasetV2):
2158 if shuffle:
2159 training_utils_v1.verify_dataset_shuffled(x)
2161 strategy = self._distribution_strategy
2162 with strategy.scope():
2163 # We should be sure to call get_session() inside the strategy.scope()
2164 # so the strategy can affect the session options.
2165 if ops.executing_eagerly_outside_functions():
2166 session = None
2167 else:
2168 session = backend.get_session()
2170 first_x_value = nest.flatten(x)[0]
2171 if isinstance(first_x_value, np.ndarray):
2172 x = training_utils.list_to_tuple(x)
2173 if y is not None:
2174 y = training_utils.list_to_tuple(y)
2175 if sample_weight is not None:
2176 sample_weight = training_utils.list_to_tuple(sample_weight)
2177 in_tuple = (x, y, sample_weight)
2178 else:
2179 in_tuple = (x, y)
2180 else:
2181 in_tuple = x
2183 ds = strategy.extended.experimental_make_numpy_dataset(in_tuple,
2184 session=session)
2185 if shuffle:
2186 # We want a buffer size that is larger than the batch size provided by
2187 # the user and provides sufficient randomness. Note that larger
2188 # numbers introduce more memory usage based on the size of each
2189 # sample.
2190 ds = ds.shuffle(max(1024, batch_size * 8))
2191 if epochs > 1:
2192 ds = ds.repeat(epochs)
2194 # We need to use the drop_remainder argument to get a known static
2195 # input shape which is required for TPUs.
2196 drop_remainder = (not allow_partial_batch and
2197 strategy.extended.experimental_require_static_shapes)
2199 # TODO(b/131720208): We still drop remainder here if number of examples
2200 # is divisible by batch size, as sometimes dynamic padder will time out
2201 # with keras.metrics.CategoricalAccuracy() metric.
2202 if backend.is_tpu_strategy(strategy) and not drop_remainder:
2203 dataset_size = first_x_value.shape[0]
2204 if dataset_size % batch_size == 0:
2205 drop_remainder = True
2207 x = ds.batch(batch_size, drop_remainder=drop_remainder)
2208 else:
2209 assert isinstance(x, data_types.DatasetV2)
2210 training_utils_v1.validate_dataset_input(x, y, sample_weight,
2211 validation_split)
2212 return x
2214 def _standardize_user_data(self,
2215 x,
2216 y=None,
2217 sample_weight=None,
2218 class_weight=None,
2219 batch_size=None,
2220 check_steps=False,
2221 steps_name='steps',
2222 steps=None,
2223 validation_split=0,
2224 shuffle=False,
2225 extract_tensors_from_dataset=False):
2226 """Runs validation checks on input and target data passed by the user.
2228 Also standardizes the data to lists of arrays, in order.
2230 Also builds and compiles the model on the fly if it is a subclassed model
2231 that has never been called before (and thus has no inputs/outputs).
2233 This is a purely internal method, subject to refactoring at any time.
2235 Args:
2236 x: Input data. It could be:
2237 - A Numpy array (or array-like), or a list of arrays
2238 (in case the model has multiple inputs).
2239 - A TensorFlow tensor, or a list of tensors
2240 (in case the model has multiple inputs).
2241 - A dict mapping input names to the corresponding array/tensors,
2242 if the model has named inputs.
2243 - A `tf.data` dataset.
2244 y: Target data. Like the input data `x`,
2245 it could be either Numpy array(s) or TensorFlow tensor(s).
2246 It should be consistent with `x` (you cannot have Numpy inputs and
2247 tensor targets, or inversely). If `x` is a dataset, `y` should not be
2248 specified (since targets will be obtained from the iterator).
2249 sample_weight: An optional sample-weight array passed by the user to
2250 weight the importance of each sample in `x`.
2251 class_weight: An optional class-weight array by the user to
2252 weight the importance of samples in `x` based on the class they belong
2253 to, as conveyed by `y`. If both `sample_weight` and `class_weight` are
2254 provided, the weights are multiplied.
2255 batch_size: Integer batch size. If provided, it is used to run additional
2256 validation checks on stateful models.
2257 check_steps: boolean, True if we want to check for validity of `steps` and
2258 False, otherwise. For example, when we are standardizing one batch of
2259 data for train_on_batch/predict_on_batch/test_on_batch APIs, `steps`
2260 value is not required and we should not check for its validity in these
2261 cases.
2262 steps_name: The public API's parameter name for `steps`.
2263 steps: Integer or `None`. Total number of steps (batches of samples) to
2264 execute.
2265 validation_split: Float between 0 and 1.
2266 Fraction of the training data to be used as validation data.
2267 shuffle: Boolean whether to shuffle the training data before each epoch.
2268 extract_tensors_from_dataset: Boolean. When `x` is a dataset instance,
2269 this indicates whether to extract actual tensors from the dataset or
2270 instead output the dataset instance itself.
2271 Set to True when calling from `train_on_batch`/etc.
2273 Returns:
2274 A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict
2275 or not), target arrays, sample-weight arrays.
2276 If the model's input and targets are symbolic, these lists are empty
2277 (since the model takes no user-provided data, instead the data comes
2278 from the symbolic inputs/targets).
2280 Raises:
2281 ValueError: In case of invalid user-provided data.
2282 RuntimeError: If the model was never compiled.
2283 """
2284 if isinstance(x, (data_types.DatasetV1, data_types.DatasetV2)):
2285 # Graph mode dataset. We'll pass the dataset as-is (unless
2286 # `extract_tensors_from_dataset` is True, in which case we extract
2287 # the tensors from the dataset and we output them.
2288 training_utils_v1.validate_dataset_input(x, y, sample_weight,
2289 validation_split)
2290 if shuffle:
2291 training_utils_v1.verify_dataset_shuffled(x)
2293 is_dataset = True
2294 if extract_tensors_from_dataset:
2295 # We do this for `train_on_batch`/etc.
2296 x, y, sample_weight = training_utils_v1.extract_tensors_from_dataset(x)
2297 elif isinstance(x, iterator_ops.Iterator):
2298 # Graph mode iterator. We extract the symbolic tensors.
2299 training_utils_v1.validate_dataset_input(x, y, sample_weight,
2300 validation_split)
2301 iterator = x
2302 x, y, sample_weight = training_utils_v1.unpack_iterator_input(iterator)
2303 is_dataset = True
2304 else:
2305 is_dataset = False
2307 # Validates `steps` argument based on x's type.
2308 if check_steps:
2309 training_utils_v1.check_steps_argument(x, steps, steps_name)
2311 # First, we build the model on the fly if necessary.
2312 if not self.inputs:
2313 all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
2314 is_build_called = True
2315 else:
2316 all_inputs = []
2317 # Whether this is a subclassed model that expects dictionary inputs
2318 # rather than list inputs (e.g. FeatureColumn-based models).
2319 dict_inputs = isinstance(self.inputs, dict)
2320 is_build_called = False
2321 y_input = y
2323 # Second, we compile the model on the fly if necessary, mostly for subclass
2324 # models.
2325 is_compile_called = False
2326 if not self._is_compiled and self.optimizer:
2327 self._compile_from_inputs(all_inputs, y_input, x, y)
2328 is_compile_called = True
2330 # In graph mode, if we had just set inputs and targets as symbolic tensors
2331 # by invoking build and compile on the model respectively, we do not have to
2332 # feed anything to the model. Model already has input and target data as
2333 # part of the graph.
2334 # Note: in this case, `any` and `all` are equivalent since we disallow
2335 # mixed symbolic/value inputs.
2337 # self.run_eagerly is not free to compute, so we want to reuse the value.
2338 run_eagerly = self.run_eagerly
2340 if (not run_eagerly and is_build_called and is_compile_called and
2341 not is_dataset and any(_is_symbolic_tensor(v) for v in all_inputs)):
2342 return [], [], None
2344 return self._standardize_tensors(
2345 x, y, sample_weight,
2346 run_eagerly=run_eagerly,
2347 dict_inputs=dict_inputs,
2348 is_dataset=is_dataset,
2349 class_weight=class_weight,
2350 batch_size=batch_size)
2352 def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs,
2353 is_dataset, class_weight=None, batch_size=None):
2354 if run_eagerly:
2355 # In eager mode, do not do shape validation
2356 # since the network has no input nodes (placeholders) to be fed.
2357 feed_input_names = self.input_names
2358 feed_input_shapes = None
2359 elif not self._is_graph_network:
2360 # Case: symbolic-mode subclassed network. Do not do shape validation.
2361 feed_input_names = self._feed_input_names
2362 feed_input_shapes = None
2363 else:
2364 # Case: symbolic-mode graph network.
2365 # In this case, we run extensive shape validation checks.
2366 feed_input_names = self._feed_input_names
2367 feed_input_shapes = self._feed_input_shapes
2369 # Standardize the inputs.
2370 if not isinstance(x, (data_types.DatasetV1, data_types.DatasetV2)):
2371 # TODO(fchollet): run static checks with dataset output shape(s).
2372 x = training_utils_v1.standardize_input_data(
2373 x,
2374 feed_input_names,
2375 feed_input_shapes,
2376 check_batch_axis=False, # Don't enforce the batch size.
2377 exception_prefix='input')
2379 # Get typespecs for the input data and sanitize it if necessary.
2380 # TODO(momernick): This should be capable of doing full input validation
2381 # at all times - validate that this is so and refactor the standardization
2382 # code.
2383 if isinstance(x, data_types.DatasetV2):
2384 x_shapes = dataset_ops.get_structure(x)
2385 if isinstance(x_shapes, tuple):
2386 # If the output of a Dataset is a tuple, we assume it's either of the
2387 # form (x_data, y_data) or (x_data, y_data, sample_weights). In either
2388 # case, we only care about x_data here.
2389 x_shapes = x_shapes[0]
2390 else:
2391 flat_inputs = nest.flatten(x, expand_composites=False)
2392 flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
2393 converted_x = []
2394 for (a, b) in zip(flat_inputs, flat_expected_inputs):
2395 converted_x.append(_convert_scipy_sparse_tensor(a, b))
2396 x = nest.pack_sequence_as(x, converted_x, expand_composites=False)
2398 def _type_spec_from_value(value):
2399 """Grab type_spec without converting array-likes to tensors."""
2400 if tf_utils.is_extension_type(value):
2401 return value._type_spec # pylint: disable=protected-access
2402 # Get a TensorSpec for array-like data without
2403 # converting the data to a Tensor
2404 if hasattr(value, 'shape') and hasattr(value, 'dtype'):
2405 return tensor_spec.TensorSpec(value.shape, value.dtype)
2406 else:
2407 return type_spec.type_spec_from_value(value)
2409 x_shapes = nest.map_structure(_type_spec_from_value, x)
2411 flat_inputs = nest.flatten(x_shapes, expand_composites=False)
2412 flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
2413 for (a, b) in zip(flat_inputs, flat_expected_inputs):
2414 nest.assert_same_structure(a, b, expand_composites=True)
2416 if y is not None:
2417 # Prepare self._sample_weight_modes. List with the same length as
2418 # model outputs.
2419 training_utils_v1.prepare_sample_weight_modes(self._training_endpoints,
2420 self.sample_weight_mode)
2421 feed_output_names = self._feed_output_names
2422 feed_sample_weight_modes = self._sample_weight_modes
2423 if not self._is_graph_network:
2424 feed_output_shapes = None
2425 else:
2426 feed_output_shapes = self._feed_output_shapes
2428 # Standardize the outputs.
2429 y = training_utils_v1.standardize_input_data(
2430 y,
2431 feed_output_names,
2432 # Don't enforce target shapes to match output shapes.
2433 # Precise checks will be run in `check_loss_and_target_compatibility`.
2434 shapes=None,
2435 check_batch_axis=False, # Don't enforce the batch size.
2436 exception_prefix='target')
2438 # Generate sample-wise weight values given the `sample_weight` and
2439 # `class_weight` arguments.
2440 sample_weights = training_utils_v1.standardize_sample_weights(
2441 sample_weight, feed_output_names)
2442 class_weights = training_utils_v1.standardize_class_weights(
2443 class_weight, feed_output_names)
2445 sample_weights = [
2446 training_utils_v1.standardize_weights(ref, sw, cw, mode)
2447 for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights,
2448 feed_sample_weight_modes)
2449 ]
2450 # Check that all arrays have the same length.
2451 if not self._distribution_strategy:
2452 training_utils_v1.check_array_lengths(x, y, sample_weights)
2453 if self._is_graph_network and not run_eagerly:
2454 # Additional checks to avoid users mistakenly using improper loss fns.
2455 training_utils_v1.check_loss_and_target_compatibility(
2456 y, self._feed_loss_fns, feed_output_shapes)
2458 sample_weights, _, _ = training_utils.handle_partial_sample_weights(
2459 y, sample_weights, feed_sample_weight_modes, check_all_flat=True)
2460 else:
2461 y = []
2462 sample_weights = None
2464 if self.stateful and batch_size and not is_dataset:
2465 # Check that for stateful networks, number of samples is a multiple
2466 # of the static batch size.
2467 if x[0].shape[0] % batch_size != 0:
2468 raise ValueError('In a stateful network, '
2469 'you should only pass inputs with '
2470 'a number of samples that can be '
2471 'divided by the batch size. Found: ' +
2472 str(x[0].shape[0]) + ' samples')
2474 # If dictionary inputs were provided, we return a dictionary as well.
2475 if dict_inputs and not isinstance(x, (data_types.DatasetV1,
2476 data_types.DatasetV2)):
2477 x = dict(zip(feed_input_names, x))
2478 return x, y, sample_weights
2480 def _build_model_with_inputs(self, inputs, targets):
2481 """Build the model (set model inputs/outputs), mainly for subclass model."""
2482 processed_inputs = []
2483 is_dict_inputs = False
2484 orig_inputs = inputs
2485 # We need to use `inputs` to set the model inputs.
2486 # If input data is a dataset iterator in graph mode or if it is an eager
2487 # iterator and only one batch of samples is required, we fetch the data
2488 # tensors from the iterator and then standardize them.
2489 if isinstance(inputs, (data_types.DatasetV1, data_types.DatasetV2)):
2490 inputs, targets, _ = training_utils_v1.extract_tensors_from_dataset(
2491 inputs)
2492 # We type-check that `inputs` and `targets` are either single arrays
2493 # or lists of arrays, and extract a flat list of inputs from the passed
2494 # structure.
2495 training_utils_v1.validate_input_types(inputs, orig_inputs)
2497 if isinstance(inputs, (list, tuple)):
2498 processed_inputs += list(inputs)
2499 elif isinstance(inputs, dict):
2500 is_dict_inputs = True
2501 keys = sorted(inputs.keys())
2502 processed_inputs = [inputs[k] for k in keys]
2503 else:
2504 processed_inputs.append(inputs)
2505 # Now that we have a flat set of inputs, we make sure that none of them
2506 # are CompositeTensors or CompositeTensorValues of any type (or scipy
2507 # sparse arrays, which we treat as SparseTensor values). We cannot safely
2508 # infer input data from an arbitrary composite tensor, so we don't try -
2509 # users should explicitly add composite tensor inputs to their subclassed
2510 # models.
2511 for input_tensor in processed_inputs:
2512 if training_utils_v1.is_composite_or_composite_value(input_tensor):
2513 # TODO(b/132691975): Document subclass-model CT input handling.
2514 raise ValueError(
2515 'All SparseTensor and RaggedTensor inputs must be explicitly '
2516 'declared using a keras.Input() with sparse=True or ragged=True. '
2517 'We found an undeclared input %s. For Sequential models, please '
2518 'add a keras.Input() as your first Layer. For subclassed models, '
2519 'please call self._set_inputs() on your input set, which you can '
2520 'create using keras.Input() for each input to your model.' %
2521 (input_tensor,))
2522 # Build the model using the retrieved inputs (value or symbolic).
2523 # If values are generated from a dataset, then in symbolic-mode
2524 # placeholders will be created to match the value shapes.
2525 if isinstance(orig_inputs, (data_types.DatasetV1, data_types.DatasetV2,
2526 iterator_ops.Iterator)):
2527 if not self.inputs:
2528 # For subclassed models, a robust input spec is not available so we
2529 # must cast to the model dtype.
2530 inputs = training_utils_v1.cast_if_floating_dtype(inputs, self.dtype)
2532 def create_tensor_spec(t):
2533 return tensor_spec.TensorSpec(t.shape, t.dtype)
2535 cast_inputs = nest.map_structure(create_tensor_spec, inputs)
2536 elif training_utils_v1.has_tensors(inputs):
2537 cast_inputs = training_utils_v1.cast_if_floating_dtype(inputs)
2538 else:
2539 cast_inputs = inputs
2540 self._set_inputs(cast_inputs)
2541 return processed_inputs, targets, is_dict_inputs
2543 def _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target):
2544 if target is not None:
2545 # We need to use `y` to set the model targets.
2546 if training_utils_v1.has_tensors(target):
2547 target = training_utils_v1.cast_if_floating_dtype_and_mismatch(
2548 target, self.outputs)
2549 training_utils_v1.validate_input_types(
2550 target, orig_target, allow_dict=False, field_name='target')
2551 if isinstance(target, (list, tuple)):
2552 all_inputs += list(target)
2553 else:
2554 all_inputs.append(target)
2555 # Type check that all inputs are *either* value *or* symbolic.
2556 # TODO(fchollet): this check could be removed in Eager mode?
2557 if any(tensor_util.is_tf_type(v) for v in all_inputs):
2558 if not all(tensor_util.is_tf_type(v) for v in all_inputs):
2559 raise ValueError('Do not pass inputs that mix Numpy arrays and '
2560 'TensorFlow tensors. '
2561 'You passed: x=' + str(orig_inputs) +
2562 '; y=' + str(orig_target))
2563 is_dataset = isinstance(orig_inputs, (data_types.DatasetV1,
2564 data_types.DatasetV2,
2565 iterator_ops.Iterator))
2566 if is_dataset or context.executing_eagerly():
2567 target_tensors = None
2568 else:
2569 # Handle target tensors if any passed.
2570 if target is not None:
2571 if not isinstance(target, (list, tuple)):
2572 target = [target]
2573 target_tensors = [v for v in target if _is_symbolic_tensor(v)]
2574 else:
2575 target_tensors = None
2577 self.compile(
2578 optimizer=self.optimizer,
2579 loss=self.loss,
2580 metrics=self._compile_metrics,
2581 weighted_metrics=self._compile_weighted_metrics,
2582 loss_weights=self.loss_weights,
2583 target_tensors=target_tensors,
2584 sample_weight_mode=self.sample_weight_mode,
2585 run_eagerly=self.run_eagerly,
2586 experimental_run_tf_function=self._experimental_run_tf_function)
2588 # TODO(omalleyt): Consider changing to a more descriptive function name.
2589 def _set_inputs(self, inputs, outputs=None, training=None):
2590 """Set model's input and output specs based on the input data received.
2592 This is to be used for Model subclasses, which do not know at instantiation
2593 time what their inputs look like.
2595 Args:
2596 inputs: Single array, or list of arrays. The arrays could be placeholders,
2597 Numpy arrays, data tensors, or TensorSpecs.
2598 - if placeholders: the model is built on top of these placeholders,
2599 and we expect Numpy data to be fed for them when calling `fit`/etc.
2600 - if Numpy data or TensorShapes: we create placeholders matching the
2601 TensorShapes or shapes of the Numpy arrays. We expect Numpy data to be
2602 fed for these placeholders when calling `fit`/etc.
2603 - if data tensors: the model is built on top of these tensors.
2604 We do not expect any Numpy data to be provided when calling `fit`/etc.
2605 outputs: None, a data tensor, or a list of tensors. If None, the
2606 outputs will be determined by invoking `self.call()`, otherwise the
2607 provided value will be used.
2608 training: Boolean or None. Only relevant in symbolic mode. Specifies
2609 whether to build the model's graph in inference mode (False), training
2610 mode (True), or using the Keras learning phase (None).
2611 Raises:
2612 ValueError: If dict inputs are passed to a Sequential Model where the
2613 first layer isn't FeatureLayer.
2614 """
2615 self._set_save_spec(inputs)
2616 inputs = self._set_input_attrs(inputs)
2618 if outputs is None:
2619 kwargs = {}
2620 if self._expects_training_arg:
2621 # In V2 mode, feeding `training=None` is not allowed because any value
2622 # explicitly passed by the user is respected, even `None`.`
2623 if training is None and not ops.executing_eagerly_outside_functions():
2624 training = backend.learning_phase()
2625 if training is not None:
2626 kwargs['training'] = training
2627 try:
2628 outputs = self(inputs, **kwargs)
2629 except NotImplementedError:
2630 # This Model or a submodel is dynamic and hasn't overridden
2631 # `compute_output_shape`.
2632 outputs = None
2634 self._set_output_attrs(outputs)
2636 @trackable.no_automatic_dependency_tracking
2637 def _set_input_attrs(self, inputs):
2638 """Sets attributes related to the inputs of the Model."""
2639 if self.inputs:
2640 raise ValueError('Model inputs are already set.')
2642 if self.__class__.__name__ == 'Sequential' and not self.built:
2643 if tensor_util.is_tf_type(inputs):
2644 input_shape = (None,) + tuple(inputs.shape.as_list()[1:])
2645 elif isinstance(inputs, tensor_shape.TensorShape):
2646 input_shape = (None,) + tuple(inputs.as_list()[1:])
2647 elif isinstance(inputs, dict):
2648 # We assert that the first layer is a FeatureLayer.
2649 if not training_utils_v1.is_feature_layer(self.layers[0]):
2650 raise ValueError('Passing a dictionary input to a Sequential Model '
2651 'which doesn\'t have FeatureLayer as the first layer'
2652 ' is an error.')
2653 input_shape = (None,)
2654 else:
2655 input_shape = (None,) + tuple(inputs.shape[1:])
2656 self._build_input_shape = input_shape
2658 # Cast inputs to the compute dtype. This is primarily used
2659 # when saving to determine the correct dtype in the input signature.
2660 inputs = self._maybe_cast_inputs(inputs)
2662 # On-the-fly setting of symbolic model inputs (either by using the tensor
2663 # provided, or by creating a placeholder if Numpy data was provided).
2664 model_inputs = training_utils_v1.ModelInputs(inputs)
2665 inputs = model_inputs.get_symbolic_inputs()
2666 self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
2667 self.input_names = model_inputs.get_input_names()
2669 self._feed_inputs = []
2670 self._feed_input_names = []
2671 self._feed_input_shapes = []
2673 for k, v in model_inputs.as_dict():
2674 if backend.is_placeholder(v):
2675 self._feed_input_names.append(k)
2676 self._feed_inputs.append(v)
2677 self._feed_input_shapes.append(backend.int_shape(v))
2679 return inputs
2681 @trackable.no_automatic_dependency_tracking
2682 def _set_output_attrs(self, outputs):
2683 """Sets attributes related to the outputs of the Model."""
2684 # NOTE(taylorrobie): This convention cannot be changed without updating the
2685 # data adapter since it assumes nest.flatten ordering.
2686 outputs = nest.flatten(outputs)
2687 self.outputs = outputs
2688 self.output_names = training_utils_v1.generic_output_names(outputs)
2689 # TODO(scottzhu): Should we cleanup the self._training_endpoints here?
2690 self.built = True
2692 @property
2693 def _targets(self):
2694 """The output target tensors for the model."""
2695 return [
2696 e.training_target.target
2697 for e in self._training_endpoints
2698 if e.has_training_target()
2699 ]
2701 @property
2702 def _feed_targets(self):
2703 return [
2704 e.training_target.target
2705 for e in self._training_endpoints
2706 if e.has_feedable_training_target()
2707 ]
2709 @property
2710 def _feed_output_names(self):
2711 return [
2712 e.output_name
2713 for e in self._training_endpoints
2714 if e.has_feedable_training_target()
2715 ]
2717 @property
2718 def _feed_output_shapes(self):
2719 return [
2720 e.feed_output_shape
2721 for e in self._training_endpoints
2722 if e.has_feedable_training_target()
2723 ]
2725 @property
2726 def _feed_loss_fns(self):
2727 return [
2728 e.loss_fn
2729 for e in self._training_endpoints
2730 if e.has_feedable_training_target()
2731 ]
2733 @property
2734 def _loss_weights_list(self):
2735 return [e.loss_weight for e in self._training_endpoints]
2737 @property
2738 def _output_loss_metrics(self):
2739 if hasattr(self, '_training_endpoints'):
2740 return [
2741 e.output_loss_metric
2742 for e in self._training_endpoints
2743 if e.output_loss_metric is not None
2744 ]
2745 return None
2747 @property
2748 def sample_weights(self):
2749 return [e.sample_weight for e in self._training_endpoints]
2751 @property
2752 def _sample_weight_modes(self):
2753 return [e.sample_weight_mode for e in self._training_endpoints]
2755 @property
2756 def _feed_sample_weights(self):
2757 return [e.sample_weight for e in self._training_endpoints
2758 if e.sample_weight is not None]
2760 def _maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
2761 """Maybe load initial epoch from ckpt considering possible worker recovery.
2763 Refer to tensorflow/python/keras/distribute/worker_training_state.py
2764 for more information.
2766 Args:
2767 initial_epoch: The original initial_epoch user passes in in `fit()`.
2768 mode: The mode for running `model.fit()`.
2770 Returns:
2771 If the training is recovering from previous failure under multi-worker
2772 training setting, return the epoch the training is supposed to continue
2773 at. Otherwise, return the `initial_epoch` the user passes in.
2774 """
2775 if self._training_state is not None:
2776 return self._training_state.maybe_load_initial_epoch_from_ckpt(
2777 initial_epoch, mode)
2778 return initial_epoch
2780 def _get_training_eval_metrics(self):
2781 """Returns all the metrics that are to be reported.
2783 This includes the output loss metrics, compile metrics/weighted metrics,
2784 add_metric metrics.
2785 """
2786 metrics = []
2787 metrics.extend(getattr(self, '_output_loss_metrics', None) or [])
2788 metrics.extend(getattr(self, 'metrics', None) or [])
2789 return metrics
2791 def _assert_compile_was_called(self):
2792 # Checks whether `compile` has been called. If it has been called,
2793 # then the optimizer is set. This is different from whether the
2794 # model is compiled
2795 # (i.e. whether the model is built and its inputs/outputs are set).
2796 if not self._compile_was_called:
2797 raise RuntimeError('You must compile your model before '
2798 'training/testing. '
2799 'Use `model.compile(optimizer, loss)`.')
2801 def _in_multi_worker_mode(self):
2802 """Method to infer if this `Model` is working in multi-worker settings.
2804 Multi-worker training refers to the setup where the training is
2805 distributed across multiple workers, as opposed to the case where
2806 only a local process performs the training. This function is
2807 used to infer for example whether or not a distribute coordinator
2808 should be run, and thus TensorFlow servers should be started for
2809 communication with other servers in the cluster, or whether or not
2810 saving/restoring checkpoints is relevant for preemption fault tolerance.
2812 Experimental. Signature and implementation are subject to change.
2814 Returns:
2815 Whether this model indicates it's working in multi-worker settings.
2816 """
2817 strategy = self._distribution_strategy
2819 # Otherwise, use the strategy whose scope this is in.
2820 if not strategy and distribute_lib.has_strategy():
2821 strategy = distribute_lib.get_strategy()
2822 return strategy and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
2824 @property
2825 def _trackable_saved_model_saver(self):
2826 return model_serialization.ModelSavedModelSaver(self)
2828 def _get_compile_args(self, user_metrics=True):
2829 del user_metrics
2830 self._assert_compile_was_called()
2831 kwargs = {
2832 'loss': self.loss,
2833 'metrics': self._compile_metrics,
2834 'loss_weights': self.loss_weights,
2835 'sample_weight_mode': self.sample_weight_mode,
2836 'weighted_metrics': self._compile_weighted_metrics,
2837 }
2838 return kwargs
2840 @property
2841 def _compile_was_called(self):
2842 return self._v1_compile_was_called
2845class DistributedCallbackModel(Model):
2846 """Model that is used for callbacks with tf.distribute.Strategy."""
2848 def __init__(self, model):
2849 super(DistributedCallbackModel, self).__init__()
2850 self.optimizer = model.optimizer
2852 def set_original_model(self, orig_model):
2853 self._original_model = orig_model
2855 def save_weights(self, filepath, overwrite=True, save_format=None):
2856 self._replicated_model.save_weights(filepath, overwrite=overwrite,
2857 save_format=save_format)
2859 def save(self, filepath, overwrite=True, include_optimizer=True):
2860 # save weights from the distributed model to the original model
2861 distributed_model_weights = self.get_weights()
2862 self._original_model.set_weights(distributed_model_weights)
2863 # TODO(anjalisridhar): Do we need to save the original model here?
2864 # Saving the first replicated model works as well.
2865 self._original_model.save(filepath, overwrite=True, include_optimizer=False)
2867 def load_weights(self, filepath, by_name=False):
2868 self._original_model.load_weights(filepath, by_name=False)
2869 # Copy the weights from the original model to each of the replicated models.
2870 orig_model_weights = self._original_model.get_weights()
2871 distributed_training_utils_v1.set_weights(
2872 self._original_model._distribution_strategy, self, # pylint: disable=protected-access
2873 orig_model_weights)
2875 def __getattr__(self, item):
2876 # Allowed attributes of the model that can be accessed by the user
2877 # during a callback.
2878 if item not in ('_setattr_tracking', '_layers'):
2879 logging.warning('You are accessing attribute ' + item + ' of the '
2880 'DistributedCallbackModel that may not have been set '
2881 'correctly.')
2882 return super(DistributedCallbackModel, self).__getattr__(item)
2885class _TrainingEndpoint(object):
2886 """A container for the training output/target and related entities.
2888 In the case of model with multiple outputs, there is a one-to-one mapping
2889 between model output (y_pred), model target (y_true), loss, metrics etc.
2890 By unifying these entities into one class, different entity can access
2891 information between each other, rather than currently access different list of
2892 attributes of the model.
2893 """
2895 def __init__(self,
2896 output,
2897 output_name,
2898 loss_fn,
2899 loss_weight=None,
2900 training_target=None,
2901 output_loss_metric=None,
2902 sample_weight=None,
2903 sample_weight_mode=None):
2904 """Initialize the _TrainingEndpoint.
2906 Note that the output and output_name should be stable as long as the model
2907 structure doesn't change. The training_target suppose to be mutable since
2908 the information is provided via `compile()`
2910 Args:
2911 output: the output tensor of the model.
2912 output_name: the unique name of the output tensor.
2913 loss_fn: the loss function for the output tensor.
2914 loss_weight: float, the weights for the loss.
2915 training_target: the _TrainingTarget for the model.
2916 output_loss_metric: the metric object for the loss function.
2917 sample_weight: the weights for how a sample is weighted during metric and
2918 loss calculation. Could be None.
2919 sample_weight_mode: string, 'temporal', 'samplewise' or None. The mode for
2920 how the sample_weight is populated.
2921 """
2922 self._output = output
2923 self._output_name = output_name
2924 self._loss_fn = loss_fn
2925 self._loss_weight = loss_weight
2926 self._training_target = training_target
2927 self._output_loss_metric = output_loss_metric
2928 self._sample_weight = sample_weight
2929 self._sample_weight_mode = sample_weight_mode
2931 @property
2932 def output(self):
2933 return self._output
2935 @property
2936 def output_name(self):
2937 return self._output_name
2939 @property
2940 def shape(self):
2941 return backend.int_shape(self.output)
2943 @property
2944 def loss_fn(self):
2945 return self._loss_fn
2947 @property
2948 def loss_weight(self):
2949 return self._loss_weight
2951 @loss_weight.setter
2952 def loss_weight(self, value):
2953 self._loss_weight = value
2955 @property
2956 def training_target(self):
2957 return self._training_target
2959 @training_target.setter
2960 def training_target(self, value):
2961 self._training_target = value
2963 def create_training_target(self, target, run_eagerly=False):
2964 """Create training_target instance and update the self.training_target.
2966 Note that the input target should just be a tensor or None, and
2967 corresponding training target will be created based on the output and
2968 loss_fn.
2970 Args:
2971 target: the target tensor for the current output. Could be None.
2972 run_eagerly: boolean, whether the model is in run_eagerly mode.
2974 Raises:
2975 ValueError if the training_target field for the current instance has
2976 already been populated.
2977 """
2978 if self.has_training_target():
2979 raise ValueError('The training_target field for the _TrainingEndpoint '
2980 'instance has already been populated')
2981 if run_eagerly:
2982 # When run_eagerly, the target tensor is ignored, and the None placeholder
2983 # is created instead.
2984 self.training_target = _TrainingTarget(
2985 None, feedable=True, skip_target_weights=False)
2986 return
2988 if self.should_skip_target():
2989 self.training_target = _TrainingTarget(None)
2990 else:
2991 if target is not None and not backend.is_placeholder(target):
2992 feedable = False
2993 skip_target_weights = True
2994 else:
2995 feedable = True
2996 skip_target_weights = False
2998 if target is None:
2999 target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get(
3000 self.loss_fn, backend.dtype(self.output))
3002 target = backend.placeholder(
3003 ndim=len(self.shape),
3004 name=self.output_name + '_target',
3005 sparse=backend.is_sparse(self.output),
3006 dtype=target_dtype)
3008 self.training_target = _TrainingTarget(
3009 target,
3010 feedable=feedable,
3011 skip_target_weights=skip_target_weights)
3013 @property
3014 def output_loss_metric(self):
3015 return self._output_loss_metric
3017 @output_loss_metric.setter
3018 def output_loss_metric(self, value):
3019 self._output_loss_metric = value
3021 @property
3022 def sample_weight(self):
3023 return self._sample_weight
3025 @sample_weight.setter
3026 def sample_weight(self, value):
3027 self._sample_weight = value
3029 @property
3030 def sample_weight_mode(self):
3031 return self._sample_weight_mode
3033 @sample_weight_mode.setter
3034 def sample_weight_mode(self, value):
3035 self._sample_weight_mode = value
3037 def should_skip_target(self):
3038 return self._loss_fn is None
3040 def should_skip_target_weights(self):
3041 return (self.should_skip_target() or self.training_target is None or
3042 self.training_target.skip_target_weights)
3044 def has_training_target(self):
3045 return self.training_target is not None
3047 def has_feedable_training_target(self):
3048 return (not self.should_skip_target() and
3049 self.training_target is not None and self.training_target.feedable)
3051 def loss_name(self):
3052 if self._loss_fn is not None:
3053 return self._output_name + '_loss'
3054 return None
3056 @property
3057 def feed_output_shape(self):
3058 """The output shape for the feedable target."""
3059 if not self.has_feedable_training_target():
3060 return None
3062 if ((isinstance(self.loss_fn, losses.LossFunctionWrapper) and
3063 self.loss_fn.fn == losses.sparse_categorical_crossentropy)) or (
3064 isinstance(self.loss_fn, losses.SparseCategoricalCrossentropy)):
3065 if backend.image_data_format() == 'channels_first':
3066 return (self.shape[0], 1) + self.shape[2:]
3067 else:
3068 return self.shape[:-1] + (1,)
3069 elif (not isinstance(self.loss_fn, losses.Loss) or
3070 (isinstance(self.loss_fn, losses.LossFunctionWrapper) and
3071 (getattr(losses, self.loss_fn.fn.__name__, None) is None))):
3072 # If the given loss is not an instance of the `Loss` class (custom
3073 # class) or if the loss function that is wrapped is not in the
3074 # `losses` module, then it is a user-defined loss and we make no
3075 # assumptions about it.
3076 return None
3077 else:
3078 return self.shape
3080 def sample_weights_mismatch(self):
3081 """Check if the sample weight and the mode match or not."""
3082 # If there is a mismatch between sample weight mode and the placeholders
3083 # created, then recompile the sub-graphs that depend on sample weights.
3084 return (
3085 (self.sample_weight_mode is not None and self.sample_weight is None) or
3086 (self.sample_weight_mode is None and self.sample_weight is not None))
3088 def populate_sample_weight(self, sample_weight, sample_weight_mode):
3089 """Populate the sample weight and based on the sample weight mode."""
3090 if (sample_weight is None and
3091 (self.should_skip_target_weights() or sample_weight_mode is None or
3092 context.executing_eagerly())):
3093 self._sample_weight = None
3094 return
3096 assert sample_weight_mode in ['temporal', 'samplewise']
3097 if sample_weight_mode == 'temporal':
3098 default_value = [[1.]]
3099 shape = [None, None]
3100 else:
3101 # sample_weight_mode == 'samplewise'
3102 default_value = [1.]
3103 shape = [None]
3105 if sample_weight is not None:
3106 if not sample_weight.shape.is_compatible_with(shape):
3107 raise ValueError('Received sample weight with shape {}. Expected shape '
3108 '{}.'.format(sample_weight.shape, shape))
3109 self._sample_weight = sample_weight
3110 else:
3111 self._sample_weight = array_ops.placeholder_with_default(
3112 constant_op.constant(default_value, dtype=backend.floatx()),
3113 shape=shape,
3114 name=self.output_name + '_sample_weights')
3117class _TrainingTarget(object):
3118 """Container for a target tensor (y_true) and its metadata (shape, loss...).
3120 Args:
3121 target: A target tensor for the model. It may be `None` if the
3122 output is excluded from loss computation. It is still kept as None
3123 since each output of the model should have a corresponding target. If
3124 the target is None, the rest of the attributes will be None as well.
3125 feedable: Boolean, whether the target is feedable (requires data to be
3126 passed in `fit` or `train_on_batch`), or not (model compiled with
3127 `target_tensors` argument).
3128 skip_target_weights: Boolean, whether the target should be skipped during
3129 weights calculation.
3130 """
3132 def __init__(self, target, feedable=False, skip_target_weights=True):
3133 self._target = target
3134 self._feedable = feedable
3135 self._skip_target_weights = skip_target_weights
3137 @property
3138 def target(self):
3139 return self._target
3141 @property
3142 def feedable(self):
3143 return self._feedable
3145 @property
3146 def skip_target_weights(self):
3147 return self._skip_target_weights
3150def _is_symbolic_tensor(x):
3151 return tensor_util.is_tf_type(x)
3154def _convert_scipy_sparse_tensor(value, expected_input):
3155 """Handle scipy sparse tensor conversions.
3157 This method takes a value 'value' and returns the proper conversion. If
3158 value is a scipy sparse tensor and the expected input is a dense tensor,
3159 we densify 'value'. If value is a scipy sparse tensor and the expected input
3160 is a TF SparseTensor, we convert 'value' to a SparseTensor. If 'value' is
3161 not a scipy sparse tensor, or scipy is not imported, we pass it through
3162 unchanged.
3164 Args:
3165 value: An object that may be a scipy sparse tensor
3166 expected_input: The expected input placeholder.
3168 Returns:
3169 The possibly-converted 'value'.
3170 """
3171 if issparse is not None and issparse(value):
3172 if backend.is_sparse(expected_input):
3173 sparse_coo = value.tocoo()
3174 row, col = sparse_coo.row, sparse_coo.col
3175 data, shape = sparse_coo.data, sparse_coo.shape
3176 indices = np.concatenate((np.expand_dims(row, 1), np.expand_dims(col, 1)),
3177 1)
3178 return sparse_tensor.SparseTensor(indices, data, shape)
3179 else:
3180 if ops.executing_eagerly_outside_functions():
3181 # In TF2 we do not silently densify sparse matrices.
3182 raise ValueError('A SciPy sparse matrix was passed to a model '
3183 'that expects dense inputs. Please densify your '
3184 'inputs first, such as by calling `x.toarray().')
3185 return value.toarray()
3186 else:
3187 return value
3190def _get_metrics_from_layers(layers):
3191 """Returns list of metrics from the given layers.
3193 This will not include the `compile` metrics of a model layer.
3195 Args:
3196 layers: List of layers.
3198 Returns:
3199 List of metrics.
3200 """
3201 metrics = []
3202 layers = layer_utils.filter_empty_layer_containers(layers)
3203 for layer in layers:
3204 if isinstance(layer, Model):
3205 # We cannot call 'metrics' on the model because we do not want to
3206 # include the metrics that were added in compile API of a nested model.
3207 metrics.extend(layer._metrics) # pylint: disable=protected-access
3208 metrics.extend(_get_metrics_from_layers(layer.layers))
3209 else:
3210 metrics.extend(layer.metrics)
3211 return metrics
3214def _non_none_constant_value(v):
3215 constant_value = tensor_util.constant_value(v)
3216 return constant_value if constant_value is not None else v