Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_eager_v1.py: 15%
120 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras training and evaluation routines for eager execution."""
16# pylint: disable=protected-access
18import numpy as np
20from tensorflow.python.eager.backprop import GradientTape
21from tensorflow.python.framework import tensor_conversion
22from tensorflow.python.keras import backend
23from tensorflow.python.keras.engine import training_utils
24from tensorflow.python.keras.engine import training_utils_v1
25from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
26from tensorflow.python.keras.utils import losses_utils
27from tensorflow.python.ops import math_ops
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.util import nest
32def _eager_loss_fn(outputs, targets, loss_fn, output_name):
33 with backend.name_scope(output_name + '_loss'):
34 loss = loss_fn(targets, outputs)
35 return loss
38def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None):
39 """Calculates the metrics for each output of the given model.
41 Args:
42 model: The model on which metrics are being calculated.
43 outputs: The outputs of the given model.
44 targets: The predictions or targets of the given model.
45 sample_weights: Optional list of sample weights for each output.
46 masks: Optional list of masks for each output.
48 Returns:
49 Returns the metric results for each output of the model.
50 """
51 outputs = nest.flatten(outputs)
52 targets = nest.flatten(targets)
53 # Invoke all(weighted and unweighted) metrics.
54 metric_results = []
55 if targets:
56 # Insert None values corresponding to the targets that need to be skipped
57 # on the model.
58 if len(model._targets) != len(targets):
59 new_targets = [
60 None if t is None else targets.pop(0) for t in model._targets
61 ]
62 targets = new_targets
64 metric_results = model._handle_metrics(
65 outputs,
66 targets=targets,
67 sample_weights=sample_weights,
68 masks=masks,
69 return_weighted_and_unweighted_metrics=True,
70 skip_target_masks=model._prepare_skip_target_masks())
72 # Add metric results from the `add_metric` metrics.
73 metric_results.extend([
74 m.result()
75 for m in model.metrics
76 if m not in model._compile_metric_functions
77 ])
78 return metric_results
81def _model_loss(model,
82 inputs,
83 targets,
84 output_loss_metrics=None,
85 sample_weights=None,
86 training=False):
87 """Calculates the loss for a given model.
89 Args:
90 model: The model on which metrics are being calculated.
91 inputs: Either a dictionary of inputs to the model or a list of input
92 arrays.
93 targets: List of target arrays.
94 output_loss_metrics: List of metrics that are used to aggregated output
95 loss values.
96 sample_weights: Optional list of sample weight arrays.
97 training: Whether the model should be run in inference or training mode.
99 Returns:
100 Returns the model output, total loss, loss value calculated using the
101 specified loss function and masks for each output. The total loss includes
102 regularization losses and applies masking and sample weighting
103 to the loss value.
104 """
105 # TODO(psv): Dedup code here with graph mode prepare_total_loss() fn.
106 # Used to keep track of the total loss value (stateless).
107 # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
108 # loss_weight_2 * output_2_loss_fn(...) +
109 # layer losses.
110 total_loss = 0
111 kwargs = {}
112 if model._expects_training_arg:
113 kwargs['training'] = training
114 if len(inputs) == 1 and not isinstance(inputs, dict):
115 inputs = inputs[0]
117 # Allow mixed `NumPy` and `EagerTensor` input here.
118 if any(
119 isinstance(input_t, (np.ndarray, float, int))
120 for input_t in nest.flatten(inputs)):
121 inputs = nest.map_structure(
122 tensor_conversion.convert_to_tensor_v2_with_dispatch, inputs
123 )
125 outs = model(inputs, **kwargs)
126 outs = nest.flatten(outs)
128 if targets:
129 targets = training_utils_v1.cast_if_floating_dtype_and_mismatch(
130 targets, outs)
131 # TODO(sallymatson/psv): check if we should do same mismatch fix for weights
132 if sample_weights:
133 new_sample_weights = []
134 for val in sample_weights:
135 if val is not None:
136 new_sample_weights.append(training_utils_v1.cast_if_floating_dtype(
137 tensor_conversion.convert_to_tensor_v2_with_dispatch(val)))
138 else:
139 new_sample_weights.append(None)
140 sample_weights = new_sample_weights
142 masks = [getattr(t, '_keras_mask', None) for t in outs]
143 targets = nest.flatten(targets)
145 # Used to keep track of individual output losses.
146 output_losses = []
148 with backend.name_scope('loss'):
149 loss_fns = [
150 loss_fn for loss_fn in model.loss_functions if loss_fn is not None
151 ]
152 custom_losses = model.losses # Regularization losses
154 if not loss_fns and not custom_losses:
155 if training:
156 raise ValueError('The model cannot be trained '
157 'because it has no loss to optimize.')
158 else:
159 raise ValueError('The model cannot be evaluated '
160 'because it has no loss to compute.')
162 for i, loss_fn in enumerate(loss_fns):
163 weights = sample_weights[i] if sample_weights else None
164 mask = masks[i]
165 with backend.name_scope(model.output_names[i] + '_loss'):
166 if mask is not None:
167 mask = math_ops.cast(mask, outs[i].dtype)
168 # Update weights with mask.
169 if weights is None:
170 weights = mask
171 else:
172 # Update dimensions of weights to match with mask if possible.
173 weights = math_ops.cast(weights, outs[i].dtype)
174 mask, _, weights = (
175 losses_utils.squeeze_or_expand_dimensions(
176 mask, sample_weight=weights))
177 weights *= mask
179 if hasattr(loss_fn, 'reduction'):
180 per_sample_losses = loss_fn.call(targets[i], outs[i])
181 weighted_losses = losses_utils.compute_weighted_loss(
182 per_sample_losses,
183 sample_weight=weights,
184 reduction=losses_utils.ReductionV2.NONE)
185 loss_reduction = loss_fn.reduction
187 # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all
188 # compile use cases.
189 if loss_reduction == losses_utils.ReductionV2.AUTO:
190 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
192 # Compute the stateless loss value.
193 output_loss = losses_utils.reduce_weighted_loss(
194 weighted_losses, reduction=loss_reduction)
195 else:
196 # Compute the stateless loss value for a custom loss class.
197 # Here we assume that the class takes care of loss reduction
198 # because if this class returns a vector value we cannot
199 # differentiate between use case where a custom optimizer
200 # expects a vector loss value vs unreduced per-sample loss value.
201 output_loss = loss_fn(targets[i], outs[i], sample_weight=weights)
202 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
204 # If the number of outputs is 1 then we don't append the loss metric
205 # associated with each model output. When there are multiple outputs
206 # associated with a model, each output's loss is calculated and returned
207 # as part of the loss_metrics.
208 if len(model.outputs) > 1:
209 # Keep track of the stateful output loss result.
210 output_losses.append(output_loss_metrics[i](output_loss))
212 # Scale output loss for distribution. For custom losses we assume
213 # reduction was mean.
214 if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE:
215 output_loss = losses_utils.scale_loss_for_distribution(output_loss)
216 total_loss += model._loss_weights_list[i] * output_loss
218 # Add regularization losses
219 if custom_losses:
220 total_loss += losses_utils.scale_loss_for_distribution(
221 math_ops.add_n(custom_losses))
222 return outs, total_loss, output_losses, masks
225def _process_single_batch(model,
226 inputs,
227 targets,
228 output_loss_metrics=None,
229 sample_weights=None,
230 training=False):
231 """Calculate the loss and gradient for one input batch.
233 The model weights are updated if training is set to True.
235 Args:
236 model: Model whose loss has to be calculated.
237 inputs: List of input arrays.
238 targets: List of target arrays.
239 output_loss_metrics: List of metrics that are used to aggregated output
240 loss values.
241 sample_weights: Optional list of sample weight arrays.
242 training: The boolean represents if the weights of the model are updated.
243 'fit' methods will set this to True while 'evaluate' methods will
244 set this to False.
246 Returns:
247 output of the model, total loss, the loss and the mask
248 associated with each output.
250 Raises:
251 ValueError: If the model has no loss to optimize.
252 """
253 with backend.eager_learning_phase_scope(1 if training else 0), \
254 training_utils.RespectCompiledTrainableState(model):
255 with GradientTape() as tape:
256 outs, total_loss, output_losses, masks = (
257 _model_loss(
258 model,
259 inputs,
260 targets,
261 output_loss_metrics=output_loss_metrics,
262 sample_weights=sample_weights,
263 training=training))
264 if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer):
265 scaled_total_loss = model.optimizer.get_scaled_loss(total_loss)
266 else:
267 scaled_total_loss = total_loss
268 if training:
269 trainable_weights = model.trainable_weights
270 if trainable_weights:
271 # TODO(tanzheny) b/132690565: Provide mechanism for user to override
272 # model.train_on_batch.
273 if hasattr(model, '_backwards'):
274 model._backwards(tape, scaled_total_loss)
275 else:
276 grads = tape.gradient(scaled_total_loss, trainable_weights)
277 if isinstance(model.optimizer,
278 loss_scale_optimizer.LossScaleOptimizer):
279 grads = model.optimizer.get_unscaled_gradients(grads)
280 model.optimizer.apply_gradients(zip(grads, trainable_weights))
281 else:
282 logging.warning('The list of trainable weights is empty. Make sure that'
283 ' you are not setting model.trainable to False before '
284 'compiling the model.')
285 return outs, total_loss, output_losses, masks
288def train_on_batch(model,
289 inputs,
290 targets,
291 sample_weights=None,
292 output_loss_metrics=None):
293 """Calculates the loss and gradient updates for one input batch.
295 Args:
296 model: Model whose loss has to be calculated.
297 inputs: Input batch data.
298 targets: Target batch data.
299 sample_weights: Sample weight batch data.
300 output_loss_metrics: List of metrics that are used to aggregated output
301 loss values.
303 Returns:
304 Dict with three items:
305 'total_loss': list with a single tensor for overall loss,
306 'output_losses': list of tensors for loss corresponding to each of the
307 model output. Could be a empty list when model has only one output.
308 'metrics': list of tensors for metric specified.
309 """
310 inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model)
311 outs, total_loss, output_losses, masks = (
312 _process_single_batch(
313 model,
314 inputs,
315 targets,
316 sample_weights=sample_weights,
317 training=True,
318 output_loss_metrics=output_loss_metrics))
319 if not isinstance(outs, list):
320 outs = [outs]
321 metrics_results = _eager_metrics_fn(
322 model, outs, targets, sample_weights=sample_weights, masks=masks)
323 total_loss = nest.flatten(total_loss)
324 return {'total_loss': total_loss,
325 'output_losses': output_losses,
326 'metrics': metrics_results}
329def test_on_batch(model,
330 inputs,
331 targets,
332 sample_weights=None,
333 output_loss_metrics=None):
334 """Calculates the loss for one input batch.
336 Args:
337 model: Model whose loss has to be calculated.
338 inputs: Input batch data.
339 targets: Target batch data.
340 sample_weights: Sample weight batch data.
341 output_loss_metrics: List of metrics that are used to aggregated output
342 loss values.
344 Returns:
345 Dict with three items:
346 'total_loss': single tensor for overall loss,
347 'output_losses': list of tensors for loss corresponding to each of the
348 model output. Could be a empty list when model has only one output.
349 'metrics': list of tensors for metric specified.
350 """
351 inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model)
353 with backend.eager_learning_phase_scope(0):
354 outs, total_loss, output_losses, masks = (
355 _model_loss(
356 model,
357 inputs,
358 targets,
359 sample_weights=sample_weights,
360 training=False,
361 output_loss_metrics=output_loss_metrics))
362 if not isinstance(outs, list):
363 outs = [outs]
364 metrics_results = _eager_metrics_fn(
365 model, outs, targets, sample_weights=sample_weights, masks=masks)
366 total_loss = nest.flatten(total_loss)
368 return {'total_loss': total_loss,
369 'output_losses': output_losses,
370 'metrics': metrics_results}