Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/compile_utils.py: 16%
376 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Utilities for `Model.compile`."""
19import copy
21import tensorflow.compat.v2 as tf
23from keras.src import losses as losses_mod
24from keras.src import metrics as metrics_mod
25from keras.src.saving import saving_lib
26from keras.src.utils import generic_utils
27from keras.src.utils import losses_utils
28from keras.src.utils import tf_utils
31class Container:
32 """Base Container class."""
34 def __init__(self, output_names=None):
35 self._output_names = output_names
37 def build(self, y_pred):
38 if self._output_names is None:
39 # In Subclass API, output names like 'output_1' are used for
40 # `Metric` names.
41 self._output_names = create_pseudo_output_names(y_pred)
43 def _conform_to_outputs(self, outputs, struct):
44 """Convenience method to conform `struct` to `outputs` structure.
46 Mappings performed:
48 (1) Map a dict to a list of outputs, using the output names.
49 (2) Fill missing keys in a dict w/ `None`s.
50 (3) Map a single item to all outputs.
52 Args:
53 outputs: Model predictions.
54 struct: Arbitrary nested structure (e.g. of labels, sample_weights,
55 losses, or metrics).
57 Returns:
58 Mapping of `struct` to `outputs` structure.
59 """
60 struct = map_to_output_names(outputs, self._output_names, struct)
61 struct = map_missing_dict_keys(outputs, struct)
62 # Allow passing one object that applies to all outputs.
63 if not tf.nest.is_nested(struct) and tf.nest.is_nested(outputs):
64 struct = tf.nest.map_structure(lambda _: struct, outputs)
65 return struct
67 def _maybe_broadcast_to_outputs(self, outputs, objects):
68 """Determines if losses / metrics should be applied to all outputs.
70 NOTE: This method should only be called for Metrics / Losses, not for
71 y_true / sample_weight.
73 Args:
74 outputs: Model predictions.
75 objects: Arbitrary nested structure (e.g. of losses or metrics)
77 Returns:
78 Arbitrary nested structure of objects, maybe copied to each output.
80 Applies a Loss / Metric to all outputs.
81 """
82 if not self._should_broadcast(objects):
83 return objects
85 # When there is more than one Model output, this is needed to keep
86 # each Metric / Loss separate. When there is only one Model output,
87 # the user-supplied object should be used.
88 should_copy_objects = len(tf.nest.flatten(outputs)) > 1
90 def _broadcast_fn():
91 if should_copy_objects:
92 return tf.nest.map_structure(self._copy_object, objects)
93 return objects
95 return tf.nest.map_structure(lambda _: _broadcast_fn(), outputs)
97 def _should_broadcast(self, objects):
98 raise NotImplementedError
100 def _copy_object(self, obj):
101 raise NotImplementedError
104class LossesContainer(Container):
105 """A container class for losses passed to `Model.compile()`.
107 Args:
108 losses: Struct of loss function(s). See `Model.compile()` doc for more
109 information.
110 loss_weights: Weights of the losses contributions of different model
111 outputs. See `Model.compile()` doc for more information.
112 output_names: List of string. Per-output metric names.
113 total_loss_mean: A `keras.metrics.Mean` instance that is used to track the
114 mean of all losses (including compiled and regularization losses).
115 """
117 def __init__(
118 self, losses, loss_weights=None, output_names=None, total_loss_mean=None
119 ):
120 super(LossesContainer, self).__init__(output_names=output_names)
122 # Keep user-supplied values untouched for recompiling and serialization.
123 self._user_losses = losses
124 self._user_loss_weights = loss_weights
126 self._losses = losses
127 self._loss_weights = loss_weights
128 self._per_output_metrics = None # Per-output losses become metrics.
130 # Mean of the total loss.
131 self._total_loss_mean = total_loss_mean or metrics_mod.Mean(name="loss")
132 self._built = False
134 def get_config(self):
135 # In case `self._losses` is a single string where we convert it to a
136 # list.
137 self._losses = tf.nest.flatten(self._losses)
138 return {
139 "losses": [
140 saving_lib.serialize_keras_object(obj)
141 for obj in self._losses
142 if obj is not None
143 ],
144 "total_loss_mean": saving_lib.serialize_keras_object(
145 self._total_loss_mean
146 ),
147 }
149 @classmethod
150 def from_config(cls, config):
151 """Returns the `LossesContainer` instance given the `config`."""
152 deserialized_config = {}
153 for key, value in config.items():
154 if isinstance(value, list):
155 deserialized_config[key] = [
156 saving_lib.deserialize_keras_object(item) for item in value
157 ]
158 else:
159 deserialized_config[key] = saving_lib.deserialize_keras_object(
160 value
161 )
162 return cls(**deserialized_config)
164 @property
165 def metrics(self):
166 """Per-output loss metrics."""
167 if not self._built:
168 return []
169 per_output_metrics = [
170 metric_obj
171 for metric_obj in tf.nest.flatten(self._per_output_metrics)
172 if metric_obj is not None
173 ]
174 return [self._total_loss_mean] + per_output_metrics
176 def build(self, y_pred):
177 """One-time setup of loss objects."""
178 super(LossesContainer, self).build(y_pred)
180 self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses)
181 self._losses = self._conform_to_outputs(y_pred, self._losses)
182 self._losses = tf.nest.map_structure(
183 self._get_loss_object, self._losses
184 )
185 self._losses = tf.nest.flatten(self._losses)
187 self._loss_weights = self._maybe_broadcast_to_outputs(
188 y_pred, self._loss_weights
189 )
190 self._loss_weights = self._conform_to_outputs(
191 y_pred, self._loss_weights
192 )
193 self._loss_weights = tf.nest.flatten(self._loss_weights)
195 self._create_metrics()
196 self._built = True
198 @property
199 def built(self):
200 return self._built
202 def _create_metrics(self):
203 """Creates per-output loss metrics, but only for multi-output Models."""
204 if len(self._output_names) == 1:
205 self._per_output_metrics = [None]
206 else:
207 self._per_output_metrics = []
208 for loss_obj, output_name in zip(self._losses, self._output_names):
209 if loss_obj is None:
210 self._per_output_metrics.append(None)
211 else:
212 self._per_output_metrics.append(
213 metrics_mod.Mean(output_name + "_loss")
214 )
216 def __call__(
217 self, y_true, y_pred, sample_weight=None, regularization_losses=None
218 ):
219 """Computes the overall loss.
221 Args:
222 y_true: An arbitrary structure of Tensors representing the ground
223 truth.
224 y_pred: An arbitrary structure of Tensors representing a Model's
225 outputs.
226 sample_weight: An arbitrary structure of Tensors representing the
227 per-sample loss weights. If one Tensor is passed, it is used for all
228 losses. If multiple Tensors are passed, the structure should match
229 `y_pred`.
230 regularization_losses: Additional losses to be added to the total
231 loss.
233 Returns:
234 The total loss as a `tf.Tensor`, or `None` if no loss results.
235 """
236 y_true = self._conform_to_outputs(y_pred, y_true)
237 sample_weight = self._conform_to_outputs(y_pred, sample_weight)
239 if not self._built:
240 self.build(y_pred)
242 y_pred = tf.nest.flatten(y_pred)
243 y_true = tf.nest.flatten(y_true)
244 sample_weight = tf.nest.flatten(sample_weight)
246 loss_values = [] # Used for gradient calculation.
247 total_loss_mean_values = [] # Used for loss metric calculation.
248 batch_dim = None
249 zip_args = (
250 y_true,
251 y_pred,
252 sample_weight,
253 self._losses,
254 self._loss_weights,
255 self._per_output_metrics,
256 )
257 for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args):
258 if (
259 y_t is None or loss_obj is None
260 ): # Ok to have no loss for an output.
261 continue
263 y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
264 sw = losses_utils.apply_mask(y_p, sw, losses_utils.get_mask(y_p))
265 loss_value = loss_obj(y_t, y_p, sample_weight=sw)
267 total_loss_mean_value = loss_value
268 # Correct for the `Mean` loss metrics counting each replica as a
269 # batch.
270 if loss_obj.reduction == losses_utils.ReductionV2.SUM:
271 total_loss_mean_value *= (
272 tf.distribute.get_strategy().num_replicas_in_sync
273 )
275 if batch_dim is None:
276 if tf_utils.is_ragged(y_t):
277 batch_dim = y_t.nrows()
278 else:
279 batch_dim = tf.shape(y_t)[0]
281 if metric_obj is not None:
282 metric_obj.update_state(
283 total_loss_mean_value, sample_weight=batch_dim
284 )
286 if loss_weight is not None:
287 loss_value *= loss_weight
288 total_loss_mean_value *= loss_weight
290 if (
291 loss_obj.reduction
292 == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
293 or loss_obj.reduction == losses_utils.ReductionV2.AUTO
294 ):
295 loss_value = losses_utils.scale_loss_for_distribution(
296 loss_value
297 )
299 loss_values.append(loss_value)
300 total_loss_mean_values.append(total_loss_mean_value)
302 if regularization_losses:
303 regularization_losses = losses_utils.cast_losses_to_common_dtype(
304 regularization_losses
305 )
306 reg_loss = tf.add_n(regularization_losses)
307 total_loss_mean_values.append(reg_loss)
308 loss_values.append(
309 losses_utils.scale_loss_for_distribution(reg_loss)
310 )
312 if loss_values:
313 total_loss_mean_values = losses_utils.cast_losses_to_common_dtype(
314 total_loss_mean_values
315 )
316 total_total_loss_mean_value = tf.add_n(total_loss_mean_values)
317 self._total_loss_mean.update_state(
318 total_total_loss_mean_value, sample_weight=batch_dim
319 )
321 loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
322 total_loss = tf.add_n(loss_values)
323 return total_loss
324 else:
325 return None
327 def reset_state(self):
328 """Resets the state of loss metrics."""
329 if not self._built:
330 return
331 metrics = [self._total_loss_mean] + tf.nest.flatten(
332 self._per_output_metrics
333 )
334 for metric_obj in metrics:
335 if metric_obj is not None:
336 metric_obj.reset_state()
338 def _get_loss_object(self, loss):
339 """Returns a `Loss` object.
341 Converts the user-supplied loss to a `Loss` object. Also allows
342 `SUM_OVER_BATCH_SIZE` reduction to be used for this loss.
344 Args:
345 loss: A string, function, or `Loss` object.
347 Returns:
348 A `Loss` object.
349 """
350 if loss is None:
351 return None # Ok to have no loss for an output.
353 loss = losses_mod.get(loss)
354 if not isinstance(loss, losses_mod.Loss):
355 loss_name = get_custom_object_name(loss)
356 if loss_name is None:
357 raise ValueError(f"Loss should be a callable, received: {loss}")
358 loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
359 loss._allow_sum_over_batch_size = True
360 return loss
362 def _should_broadcast(self, obj):
363 return not tf.nest.is_nested(obj)
365 def _copy_object(self, obj):
366 return obj # Losses don't need to be copied.
369class MetricsContainer(Container):
370 """A container class for metrics passed to `Model.compile`."""
372 def __init__(
373 self,
374 metrics=None,
375 weighted_metrics=None,
376 output_names=None,
377 from_serialized=False,
378 ):
379 """Initializes a container for metrics.
381 Arguments:
382 metrics: see the `metrics` argument from `tf.keras.Model.compile`.
383 weighted_metrics: see the `weighted_metrics` argument from
384 `tf.keras.Model.compile`.
385 output_names: A list of strings of names of outputs for the model.
386 from_serialized: Whether the model being compiled is from a serialized
387 model. Used to avoid redundantly applying pre-processing renaming
388 steps.
389 """
390 super(MetricsContainer, self).__init__(output_names=output_names)
392 self._check_duplicated_metrics(metrics, weighted_metrics)
393 # Keep user-supplied values untouched for recompiling and serialization.
394 self._user_metrics = metrics
395 self._user_weighted_metrics = weighted_metrics
397 self._metrics = metrics
398 self._weighted_metrics = weighted_metrics
399 self._built = False
401 self._from_serialized = from_serialized
403 def _check_duplicated_metrics(self, metrics, weighted_metrics):
404 """Raise error when user provided metrics have any duplications.
406 Note that metrics are stateful container, a shared metric instance
407 between model.metric and model.weighted_metric will make the same
408 intance to be udpated twice, and report wrong value.
410 Args:
411 metrics: User provided metrics list.
412 weighted_metrics: User provided weighted metrics list.
414 Raises:
415 ValueError, when duplicated metrics instance discovered in user
416 provided metrics and weighted metrics.
417 """
418 seen = set()
419 duplicated = []
420 for x in tf.nest.flatten(metrics) + tf.nest.flatten(weighted_metrics):
421 # We only check metrics object. The string and function objects
422 # will be converted to unique Metric instance.
423 if not isinstance(x, metrics_mod.Metric):
424 continue
425 if x in seen:
426 duplicated.append(x)
427 seen.add(x)
429 if duplicated:
430 raise ValueError(
431 "Found duplicated metrics object in the user provided "
432 "metrics and weighted metrics. This will cause the same "
433 "metric object to be updated multiple times, and report "
434 "wrong results. \n"
435 f"Duplicated items: {duplicated}"
436 )
438 @property
439 def metrics(self):
440 """All metrics in this container."""
441 if not self._built:
442 return []
443 return self._metrics_in_order
445 @property
446 def unweighted_metrics(self):
447 """Metrics in the container that should not be passed sample_weight."""
448 if not self._built:
449 return None
450 return tf.nest.flatten(self._metrics)
452 @property
453 def weighted_metrics(self):
454 """Metrics in this container that should be passed `sample_weight`."""
455 if not self._built:
456 return None
457 return tf.nest.flatten(self._weighted_metrics)
459 def build(self, y_pred, y_true):
460 """One-time setup of metric objects."""
461 super(MetricsContainer, self).build(y_pred)
463 self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
464 self._metrics = self._conform_to_outputs(y_pred, self._metrics)
466 self._weighted_metrics = self._maybe_broadcast_to_outputs(
467 y_pred, self._weighted_metrics
468 )
469 self._weighted_metrics = self._conform_to_outputs(
470 y_pred, self._weighted_metrics
471 )
473 # Standardize on tuple since `tf.data` turns lists into `Tensor`s.
474 y_pred = tf.__internal__.nest.list_to_tuple(y_pred)
475 y_true = tf.__internal__.nest.list_to_tuple(y_true)
476 self._metrics = tf.__internal__.nest.list_to_tuple(self._metrics)
477 self._weighted_metrics = tf.__internal__.nest.list_to_tuple(
478 self._weighted_metrics
479 )
481 # Convert to `Metric` objects, potentially disambiguating based on
482 # output properties.
483 self._metrics = tf.__internal__.nest.map_structure_up_to(
484 y_pred, self._get_metric_objects, self._metrics, y_true, y_pred
485 )
486 self._weighted_metrics = tf.__internal__.nest.map_structure_up_to(
487 y_pred,
488 self._get_metric_objects,
489 self._weighted_metrics,
490 y_true,
491 y_pred,
492 )
494 self._metrics = tf.__internal__.nest.flatten_up_to(
495 y_pred, self._metrics, check_types=False
496 )
497 self._weighted_metrics = tf.__internal__.nest.flatten_up_to(
498 y_pred, self._weighted_metrics, check_types=False
499 )
501 # Assumes metrics, weighted_metrics have been flattened up to outputs.
502 #
503 # If we are loading a model that has been already serialized, we do not
504 # want to re-apply any pre-processing metric renaming steps.
505 if not self._from_serialized:
506 self._set_metric_names()
507 self._create_ordered_metrics()
508 self._built = True
510 @property
511 def built(self):
512 return self._built
514 def _set_metric_names(self):
515 """Sets unique metric names."""
516 # For multi-output models, prepend the output name to the metric name.
517 # For weighted metrics, prepend "weighted_" if the name would be
518 # non-unique.
520 metric_names = set()
521 is_multi_output = len(self._output_names) > 1
522 zip_args = (self._output_names, self._metrics, self._weighted_metrics)
523 for output_name, output_metrics, weighted_output_metrics in zip(
524 *zip_args
525 ):
526 for m in output_metrics:
527 if m is None:
528 continue
529 if is_multi_output:
530 m._name = output_name + "_" + m._name
531 if m._name in metric_names:
532 raise ValueError(
533 f"Found two metrics with the same name: {m._name}. "
534 "All the metrics added to the model need to have "
535 "unique names."
536 )
537 metric_names.add(m._name)
539 for wm in weighted_output_metrics:
540 if wm is None:
541 continue
542 if is_multi_output:
543 if output_name + "_" + wm._name in metric_names:
544 wm._name = output_name + "_weighted_" + wm._name
545 else:
546 wm._name = output_name + "_" + wm._name
547 elif wm._name in metric_names:
548 wm._name = "weighted_" + wm._name
550 if wm._name in metric_names:
551 raise ValueError(
552 "Found two weighted metrics with the same name: "
553 f"{wm._name}.All the metrics added to the model need "
554 "to have unique names."
555 )
556 metric_names.add(wm._name)
558 def _create_ordered_metrics(self):
559 """Cache the flat order needed when return metrics, for backcompat."""
560 self._metrics_in_order = []
561 for output_metrics, output_weighted_metrics in zip(
562 self._metrics, self._weighted_metrics
563 ):
564 for m in tf.nest.flatten(output_metrics):
565 if m is not None:
566 self._metrics_in_order.append(m)
567 for wm in tf.nest.flatten(output_weighted_metrics):
568 if wm is not None:
569 self._metrics_in_order.append(wm)
571 def update_state(self, y_true, y_pred, sample_weight=None):
572 """Updates the state of per-output metrics."""
573 y_true = self._conform_to_outputs(y_pred, y_true)
574 sample_weight = self._conform_to_outputs(y_pred, sample_weight)
576 if not self._built:
577 self.build(y_pred, y_true)
579 y_pred = tf.nest.flatten(y_pred)
580 y_true = tf.nest.flatten(y_true) if y_true is not None else []
581 sample_weight = tf.nest.flatten(sample_weight)
583 zip_args = (
584 y_true,
585 y_pred,
586 sample_weight,
587 self._metrics,
588 self._weighted_metrics,
589 )
590 for y_t, y_p, sw, metric_objs, weighted_metric_objs in zip(*zip_args):
591 # Ok to have no metrics for an output.
592 if y_t is None or (
593 all(m is None for m in metric_objs)
594 and all(wm is None for wm in weighted_metric_objs)
595 ):
596 continue
598 y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
599 mask = losses_utils.get_mask(y_p)
600 sw = losses_utils.apply_mask(y_p, sw, mask)
602 for metric_obj in metric_objs:
603 if metric_obj is None:
604 continue
605 metric_obj.update_state(y_t, y_p, sample_weight=mask)
607 for weighted_metric_obj in weighted_metric_objs:
608 if weighted_metric_obj is None:
609 continue
610 weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw)
612 def reset_state(self):
613 """Resets the state of all `Metric`s in this container."""
614 if self._built:
615 metrics = self._metrics_in_order
616 else:
617 # If the user supplied `Metric` objects directly, we should
618 # reset those. This could also contain `str`s or `function`s
619 # though.
620 metrics = tf.nest.flatten(self._user_metrics) + tf.nest.flatten(
621 self._user_weighted_metrics
622 )
624 for metric_obj in metrics:
625 if isinstance(metric_obj, metrics_mod.Metric):
626 metric_obj.reset_state()
628 def _get_metric_objects(self, metrics, y_t, y_p):
629 """Convert user-supplied metrics to `Metric` objects."""
630 metrics = tf.nest.flatten(metrics)
631 return [self._get_metric_object(m, y_t, y_p) for m in metrics]
633 def _get_metric_object(self, metric, y_t, y_p):
634 """Converts user-supplied metric to a `Metric` object.
636 Args:
637 metric: A string, function, or `Metric` object.
638 y_t: Sample of label.
639 y_p: Sample of output.
641 Returns:
642 A `Metric` object.
643 """
644 if metric is None:
645 return None # Ok to have no metric for an output.
647 # Convenience feature for selecting b/t binary, categorical,
648 # and sparse categorical.
649 if str(metric).lower() not in ["accuracy", "acc", "crossentropy", "ce"]:
650 metric_obj = metrics_mod.get(metric)
651 else:
652 y_t_rank = len(y_t.shape.as_list())
653 y_p_rank = len(y_p.shape.as_list())
654 y_t_last_dim = y_t.shape.as_list()[-1]
655 y_p_last_dim = y_p.shape.as_list()[-1]
657 is_binary = y_p_last_dim == 1
658 is_sparse_categorical = (
659 y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1
660 )
662 if str(metric).lower() in ["accuracy", "acc"]:
663 if is_binary:
664 metric_obj = metrics_mod.binary_accuracy
665 elif is_sparse_categorical:
666 metric_obj = metrics_mod.sparse_categorical_accuracy
667 else:
668 metric_obj = metrics_mod.categorical_accuracy
669 else:
670 if is_binary:
671 metric_obj = metrics_mod.binary_crossentropy
672 elif is_sparse_categorical:
673 metric_obj = metrics_mod.sparse_categorical_crossentropy
674 else:
675 metric_obj = metrics_mod.categorical_crossentropy
677 if isinstance(metric_obj, losses_mod.Loss):
678 metric_obj._allow_sum_over_batch_size = True
680 if not isinstance(metric_obj, metrics_mod.Metric):
681 if isinstance(metric, str):
682 metric_name = metric
683 else:
684 metric_name = get_custom_object_name(metric)
685 if metric_name is None:
686 raise ValueError(
687 f"Metric should be a callable, received: {metric}"
688 )
690 metric_obj = metrics_mod.MeanMetricWrapper(
691 metric_obj, name=metric_name
692 )
694 return metric_obj
696 def _should_broadcast(self, obj):
697 # e.g. 'mse'.
698 if not tf.nest.is_nested(obj):
699 return True
700 # e.g. ['mse'] or ['mse', 'mae'].
701 return isinstance(obj, (list, tuple)) and not any(
702 tf.nest.is_nested(o) for o in obj
703 )
705 def _copy_object(self, obj):
706 if isinstance(obj, metrics_mod.Metric):
707 return obj.__class__.from_config(obj.get_config())
708 return obj # Can be a function or `None`.
711def create_pseudo_output_names(outputs):
712 """Create pseudo output names for a subclassed Model."""
713 return _create_pseudo_names(outputs, prefix="output_")
716def create_pseudo_input_names(inputs):
717 """Create pseudo input names for a subclassed Model."""
718 return _create_pseudo_names(inputs, prefix="input_")
721def _create_pseudo_names(tensors, prefix):
722 """Creates pseudo {input | output} names for subclassed Models.
724 Warning: this function should only be used to define default
725 names for `Metics` and `SavedModel`. No other use cases should
726 rely on a `Model`'s input or output names.
728 Example with dict:
730 `{'a': [x1, x2], 'b': x3}` becomes:
731 `['a_1', 'a_2', 'b']`
733 Example with list:
735 `[x, y]` becomes:
736 `['output_1', 'output_2']`
738 Args:
739 tensors: `Model`'s outputs or inputs.
740 prefix: 'output_' for outputs, 'input_' for inputs.
742 Returns:
743 Flattened list of pseudo names.
744 """
746 def one_index(ele):
747 # Start with "output_1" instead of "output_0".
748 if isinstance(ele, int):
749 return ele + 1
750 return ele
752 flat_paths = list(tf.__internal__.nest.yield_flat_paths(tensors))
753 flat_paths = tf.nest.map_structure(one_index, flat_paths)
754 names = []
755 for path in flat_paths:
756 if not path:
757 name = prefix + "1" # Single output.
758 else:
759 name = "_".join(str(p) for p in path)
760 if isinstance(path[0], int):
761 name = prefix + name
762 names.append(name)
763 return names
766def map_to_output_names(y_pred, output_names, struct):
767 """Maps a dict to a list using `output_names` as keys.
769 This is a convenience feature only. When a `Model`'s outputs
770 are a list, you can specify per-output losses and metrics as
771 a dict, where the keys are the output names. If you specify
772 per-output losses and metrics via the same structure as the
773 `Model`'s outputs (recommended), no mapping is performed.
775 For the Functional API, the output names are the names of the
776 last layer of each output. For the Subclass API, the output names
777 are determined by `create_pseudo_output_names` (For example:
778 `['output_1', 'output_2']` for a list of outputs).
780 This mapping preserves backwards compatibility for `compile` and
781 `fit`.
783 Args:
784 y_pred: Sample outputs of the Model, to determine if this convenience
785 feature should be applied (`struct` is returned unmodified if `y_pred`
786 isn't a flat list).
787 output_names: List. The names of the outputs of the Model.
788 struct: The structure to map.
790 Returns:
791 `struct` mapped to a list in same order as `output_names`.
792 """
793 single_output = not tf.nest.is_nested(y_pred)
794 outputs_are_flat_list = (
795 not single_output
796 and isinstance(y_pred, (list, tuple))
797 and not any(tf.nest.is_nested(y_p) for y_p in y_pred)
798 )
800 if (single_output or outputs_are_flat_list) and isinstance(struct, dict):
801 output_names = output_names or create_pseudo_output_names(y_pred)
802 struct = copy.copy(struct)
803 new_struct = [struct.pop(name, None) for name in output_names]
804 if struct:
805 raise ValueError(
806 "Found unexpected losses or metrics that do not correspond "
807 f"to any Model output: {struct.keys()}. "
808 f"Valid mode output names: {output_names}. "
809 f"Received struct is: {struct}."
810 )
811 if len(new_struct) == 1:
812 return new_struct[0]
813 return new_struct
814 else:
815 return struct
818def map_missing_dict_keys(y_pred, struct):
819 """Replaces missing dict keys in `struct` with `None` placeholders."""
820 if not isinstance(y_pred, dict) or not isinstance(struct, dict):
821 return struct
822 struct = copy.copy(struct)
823 for k in y_pred.keys():
824 if k not in struct:
825 struct[k] = None
826 return struct
829def match_dtype_and_rank(y_t, y_p, sw):
830 """Match dtype and rank of predictions."""
831 if y_t.shape.rank == 1 and y_p.shape.rank == 2:
832 y_t = tf.expand_dims(y_t, axis=-1)
833 if sw is not None:
834 if sw.shape.rank == 1 and y_p.shape.rank == 2:
835 sw = tf.expand_dims(sw, axis=-1)
837 # Dtype.
838 # This is required mainly for custom loss functions which do not take care
839 # casting dtypes.
840 if (y_t.dtype.is_floating and y_p.dtype.is_floating) or (
841 y_t.dtype.is_integer and y_p.dtype.is_integer
842 ):
843 y_t = tf.cast(y_t, y_p.dtype)
845 if sw is not None:
846 sw = tf.cast(sw, y_p.dtype)
847 return y_t, y_p, sw
850def get_custom_object_name(obj):
851 """Returns the name to use for a custom loss or metric callable.
853 Args:
854 obj: Custom loss of metric callable
856 Returns:
857 Name to use, or `None` if the object was not recognized.
858 """
859 if hasattr(obj, "name"): # Accept `Loss` instance as `Metric`.
860 return obj.name
861 elif hasattr(obj, "__name__"): # Function.
862 return obj.__name__
863 elif hasattr(obj, "__class__"): # Class instance.
864 return generic_utils.to_snake_case(obj.__class__.__name__)
865 else: # Unrecognized object.
866 return None