Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/engine/compile_utils.py: 16%
364 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilites for `Model.compile`."""
17import copy
19from tensorflow.python.distribute import distribute_lib
20from tensorflow.python.keras import losses as losses_mod
21from tensorflow.python.keras import metrics as metrics_mod
22from tensorflow.python.keras.utils import generic_utils
23from tensorflow.python.keras.utils import losses_utils
24from tensorflow.python.keras.utils import tf_utils
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.util import nest
30class Container(object):
31 """Base Container class."""
33 def __init__(self, output_names=None):
34 self._output_names = output_names
36 def build(self, y_pred):
37 if self._output_names is None:
38 # In Subclass API, output names like 'output_1' are used for
39 # `Metric` names.
40 self._output_names = create_pseudo_output_names(y_pred)
42 def _conform_to_outputs(self, outputs, struct):
43 """Convenience method to conform `struct` to `outputs` structure.
45 Mappings performed:
47 (1) Map a dict to a list of outputs, using the output names.
48 (2) Fill missing keys in a dict w/ `None`s.
49 (3) Map a single item to all outputs.
51 Args:
52 outputs: Model predictions.
53 struct: Arbitrary nested structure (e.g. of labels, sample_weights,
54 losses, or metrics).
56 Returns:
57 Mapping of `struct` to `outputs` structure.
58 """
59 struct = map_to_output_names(outputs, self._output_names, struct)
60 struct = map_missing_dict_keys(outputs, struct)
61 # Allow passing one object that applies to all outputs.
62 if not nest.is_nested(struct) and nest.is_nested(outputs):
63 struct = nest.map_structure(lambda _: struct, outputs)
64 return struct
66 def _maybe_broadcast_to_outputs(self, outputs, objects):
67 """Determines if losses / metrics should be applied to all outputs.
69 NOTE: This method should only be called for Metrics / Losses, not for
70 y_true / sample_weight.
72 Args:
73 outputs: Model predictions.
74 objects: Arbitrary nested structure (e.g. of losses or metrics)
76 Returns:
77 Arbitrary nested structure of objects, maybe copied to each output.
79 Applies a Loss / Metric to all outputs.
80 """
81 if not self._should_broadcast(objects):
82 return objects
84 # When there is more than one Model output, this is needed to keep
85 # each Metric / Loss separate. When there is only one Model output,
86 # the user-supplied object should be used.
87 should_copy_objects = len(nest.flatten(outputs)) > 1
89 def _broadcast_fn():
90 if should_copy_objects:
91 return nest.map_structure(self._copy_object, objects)
92 return objects
94 return nest.map_structure(lambda _: _broadcast_fn(), outputs)
96 def _should_broadcast(self, objects):
97 raise NotImplementedError
99 def _copy_object(self, obj):
100 raise NotImplementedError
103class LossesContainer(Container):
104 """A container class for losses passed to `Model.compile`."""
106 def __init__(self, losses, loss_weights=None, output_names=None):
107 super(LossesContainer, self).__init__(output_names=output_names)
109 # Keep user-supplied values untouched for recompiling and serialization.
110 self._user_losses = losses
111 self._user_loss_weights = loss_weights
113 self._losses = losses
114 self._loss_weights = loss_weights
115 self._per_output_metrics = None # Per-output losses become metrics.
116 self._loss_metric = metrics_mod.Mean(name='loss') # Total loss.
117 self._built = False
119 @property
120 def metrics(self):
121 """Per-output loss metrics."""
122 if not self._built:
123 return []
124 per_output_metrics = [
125 metric_obj for metric_obj in nest.flatten(self._per_output_metrics)
126 if metric_obj is not None
127 ]
128 return [self._loss_metric] + per_output_metrics
130 def build(self, y_pred):
131 """One-time setup of loss objects."""
132 super(LossesContainer, self).build(y_pred)
134 self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses)
135 self._losses = self._conform_to_outputs(y_pred, self._losses)
136 self._losses = nest.map_structure(self._get_loss_object, self._losses)
137 self._losses = nest.flatten(self._losses)
139 self._loss_weights = self._maybe_broadcast_to_outputs(
140 y_pred, self._loss_weights)
141 self._loss_weights = self._conform_to_outputs(y_pred, self._loss_weights)
142 self._loss_weights = nest.flatten(self._loss_weights)
144 self._create_metrics()
145 self._built = True
147 @property
148 def built(self):
149 return self._built
151 def _create_metrics(self):
152 """Creates per-output loss metrics, but only for multi-output Models."""
153 if len(self._output_names) == 1:
154 self._per_output_metrics = [None]
155 else:
156 self._per_output_metrics = []
157 for loss_obj, output_name in zip(self._losses, self._output_names):
158 if loss_obj is None:
159 self._per_output_metrics.append(None)
160 else:
161 self._per_output_metrics.append(
162 metrics_mod.Mean(output_name + '_loss'))
164 def __call__(self,
165 y_true,
166 y_pred,
167 sample_weight=None,
168 regularization_losses=None):
169 """Computes the overall loss.
171 Args:
172 y_true: An arbitrary structure of Tensors representing the ground truth.
173 y_pred: An arbitrary structure of Tensors representing a Model's outputs.
174 sample_weight: An arbitrary structure of Tensors representing the
175 per-sample loss weights. If one Tensor is passed, it is used for all
176 losses. If multiple Tensors are passed, the structure should match
177 `y_pred`.
178 regularization_losses: Additional losses to be added to the total loss.
180 Returns:
181 Tuple of `(total_loss, per_output_loss_list)`
182 """
183 y_true = self._conform_to_outputs(y_pred, y_true)
184 sample_weight = self._conform_to_outputs(y_pred, sample_weight)
186 if not self._built:
187 self.build(y_pred)
189 y_pred = nest.flatten(y_pred)
190 y_true = nest.flatten(y_true)
191 sample_weight = nest.flatten(sample_weight)
193 loss_values = [] # Used for gradient calculation.
194 loss_metric_values = [] # Used for loss metric calculation.
195 batch_dim = None
196 zip_args = (y_true, y_pred, sample_weight, self._losses, self._loss_weights,
197 self._per_output_metrics)
198 for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args):
199 if y_t is None or loss_obj is None: # Ok to have no loss for an output.
200 continue
202 y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
203 sw = apply_mask(y_p, sw, get_mask(y_p))
204 loss_value = loss_obj(y_t, y_p, sample_weight=sw)
206 loss_metric_value = loss_value
207 # Correct for the `Mean` loss metrics counting each replica as a batch.
208 if loss_obj.reduction == losses_utils.ReductionV2.SUM:
209 loss_metric_value *= distribute_lib.get_strategy().num_replicas_in_sync
211 if batch_dim is None:
212 if tf_utils.is_ragged(y_t):
213 batch_dim = y_t.nrows()
214 else:
215 batch_dim = array_ops.shape(y_t)[0]
217 if metric_obj is not None:
218 metric_obj.update_state(loss_metric_value, sample_weight=batch_dim)
220 if loss_weight is not None:
221 loss_value *= loss_weight
222 loss_metric_value *= loss_weight
224 if (loss_obj.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE or
225 loss_obj.reduction == losses_utils.ReductionV2.AUTO):
226 loss_value = losses_utils.scale_loss_for_distribution(loss_value)
228 loss_values.append(loss_value)
229 loss_metric_values.append(loss_metric_value)
231 if regularization_losses:
232 regularization_losses = losses_utils.cast_losses_to_common_dtype(
233 regularization_losses)
234 reg_loss = math_ops.add_n(regularization_losses)
235 loss_metric_values.append(reg_loss)
236 loss_values.append(losses_utils.scale_loss_for_distribution(reg_loss))
238 if loss_values:
239 loss_metric_values = losses_utils.cast_losses_to_common_dtype(
240 loss_metric_values)
241 total_loss_metric_value = math_ops.add_n(loss_metric_values)
242 self._loss_metric.update_state(
243 total_loss_metric_value, sample_weight=batch_dim)
245 loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
246 total_loss = math_ops.add_n(loss_values)
247 return total_loss
248 else:
249 # Ok for a model to have no compiled loss.
250 return array_ops.zeros(shape=())
252 def reset_state(self):
253 """Resets the state of loss metrics."""
254 if not self._built:
255 return
256 metrics = [self._loss_metric] + nest.flatten(self._per_output_metrics)
257 for metric_obj in metrics:
258 if metric_obj is not None:
259 metric_obj.reset_state()
261 def _get_loss_object(self, loss):
262 """Returns a `Loss` object.
264 Converts the user-supplied loss to a `Loss` object. Also allows
265 `SUM_OVER_BATCH_SIZE` reduction to be used for this loss.
267 Args:
268 loss: A string, function, or `Loss` object.
270 Returns:
271 A `Loss` object.
272 """
273 if loss is None:
274 return None # Ok to have no loss for an output.
276 loss = losses_mod.get(loss)
277 if not isinstance(loss, losses_mod.Loss):
278 loss_name = get_custom_object_name(loss)
279 if loss_name is None:
280 raise ValueError('Loss should be a callable, found: {}'.format(loss))
281 loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
282 loss._allow_sum_over_batch_size = True # pylint: disable=protected-access
283 return loss
285 def _should_broadcast(self, obj):
286 return not nest.is_nested(obj)
288 def _copy_object(self, obj):
289 return obj # Losses don't need to be copied.
292class MetricsContainer(Container):
293 """A container class for metrics passed to `Model.compile`."""
295 def __init__(self, metrics=None, weighted_metrics=None, output_names=None,
296 from_serialized=False):
297 """Initializes a container for metrics.
299 Arguments:
300 metrics: see the `metrics` argument from `tf.keras.Model.compile`.
301 weighted_metrics: see the `weighted_metrics` argument from
302 `tf.keras.Model.compile`.
303 output_names: A list of strings of names of outputs for the model.
304 from_serialized: Whether the model being compiled is from a serialized
305 model. Used to avoid redundantly applying pre-processing renaming
306 steps.
307 """
308 super(MetricsContainer, self).__init__(output_names=output_names)
310 # Keep user-supplied values untouched for recompiling and serialization.
311 self._user_metrics = metrics
312 self._user_weighted_metrics = weighted_metrics
314 self._metrics = metrics
315 self._weighted_metrics = weighted_metrics
316 self._built = False
318 self._from_serialized = from_serialized
320 @property
321 def metrics(self):
322 """All metrics in this container."""
323 if not self._built:
324 return []
325 return self._metrics_in_order
327 @property
328 def unweighted_metrics(self):
329 """Metrics in this container that should not be passed `sample_weight`."""
330 if not self._built:
331 return None
332 return nest.flatten(self._metrics)
334 @property
335 def weighted_metrics(self):
336 """Metrics in this container that should be passed `sample_weight`."""
337 if not self._built:
338 return None
339 return nest.flatten(self._weighted_metrics)
341 def build(self, y_pred, y_true):
342 """One-time setup of metric objects."""
343 super(MetricsContainer, self).build(y_pred)
345 self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
346 self._metrics = self._conform_to_outputs(y_pred, self._metrics)
348 self._weighted_metrics = self._maybe_broadcast_to_outputs(
349 y_pred, self._weighted_metrics)
350 self._weighted_metrics = self._conform_to_outputs(y_pred,
351 self._weighted_metrics)
353 # Standardize on tuple since `tf.data` turns lists into `Tensor`s.
354 y_pred = nest.list_to_tuple(y_pred)
355 y_true = nest.list_to_tuple(y_true)
356 self._metrics = nest.list_to_tuple(self._metrics)
357 self._weighted_metrics = nest.list_to_tuple(self._weighted_metrics)
359 # Convert to `Metric` objects, potentially disambiguating based on output
360 # properties.
361 self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects,
362 self._metrics, y_true, y_pred)
363 self._weighted_metrics = nest.map_structure_up_to(y_pred,
364 self._get_metric_objects,
365 self._weighted_metrics,
366 y_true, y_pred)
368 self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False)
369 self._weighted_metrics = nest.flatten_up_to(
370 y_pred, self._weighted_metrics, check_types=False)
372 # Assumes metrics, weighted_metrics have been flattened up to outputs.
373 #
374 # If we are loading a model that has been already serialized, we do not
375 # want to re-apply any pre-processing metric renaming steps.
376 if not self._from_serialized:
377 self._set_metric_names()
378 self._create_ordered_metrics()
379 self._built = True
381 @property
382 def built(self):
383 return self._built
385 def _set_metric_names(self):
386 """Sets unique metric names."""
387 # For multi-output models, prepend the output name to the metric name.
388 # For weighted metrics, prepend "weighted_" if the name would be non-unique.
389 # pylint: disable=protected-access
390 metric_names = set()
391 is_multi_output = len(self._output_names) > 1
392 zip_args = (self._output_names, self._metrics, self._weighted_metrics)
393 for output_name, output_metrics, weighted_output_metrics in zip(*zip_args):
394 for m in output_metrics:
395 if m is None:
396 continue
397 if is_multi_output:
398 m._name = output_name + '_' + m._name
399 if m._name in metric_names:
400 raise ValueError('Found two metrics with the same name: {}'.format(
401 m._name))
402 metric_names.add(m._name)
404 for wm in weighted_output_metrics:
405 if wm is None:
406 continue
407 if is_multi_output:
408 if output_name + '_' + wm._name in metric_names:
409 wm._name = output_name + '_weighted_' + wm._name
410 else:
411 wm._name = output_name + '_' + wm._name
412 elif wm._name in metric_names:
413 wm._name = 'weighted_' + wm._name
415 if wm._name in metric_names:
416 raise ValueError('Found two metrics with the same name: {}'.format(
417 wm._name))
418 metric_names.add(wm._name)
419 # pylint: enable=protected-access
421 def _create_ordered_metrics(self):
422 """Cache the flat order needed when returning metrics, for backwards compat."""
423 self._metrics_in_order = []
424 for output_metrics, output_weighted_metrics in zip(self._metrics,
425 self._weighted_metrics):
426 for m in nest.flatten(output_metrics):
427 if m is not None:
428 self._metrics_in_order.append(m)
429 for wm in nest.flatten(output_weighted_metrics):
430 if wm is not None:
431 self._metrics_in_order.append(wm)
433 def update_state(self, y_true, y_pred, sample_weight=None):
434 """Updates the state of per-output metrics."""
435 y_true = self._conform_to_outputs(y_pred, y_true)
436 sample_weight = self._conform_to_outputs(y_pred, sample_weight)
438 if not self._built:
439 self.build(y_pred, y_true)
441 y_pred = nest.flatten(y_pred)
442 y_true = nest.flatten(y_true) if y_true is not None else []
443 sample_weight = nest.flatten(sample_weight)
445 zip_args = (y_true, y_pred, sample_weight, self._metrics,
446 self._weighted_metrics)
447 for y_t, y_p, sw, metric_objs, weighted_metric_objs in zip(*zip_args):
448 # Ok to have no metrics for an output.
449 if (y_t is None or (all(m is None for m in metric_objs) and
450 all(wm is None for wm in weighted_metric_objs))):
451 continue
453 y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
454 mask = get_mask(y_p)
455 sw = apply_mask(y_p, sw, mask)
457 for metric_obj in metric_objs:
458 if metric_obj is None:
459 continue
460 metric_obj.update_state(y_t, y_p, sample_weight=mask)
462 for weighted_metric_obj in weighted_metric_objs:
463 if weighted_metric_obj is None:
464 continue
465 weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw)
467 def reset_state(self):
468 """Resets the state of all `Metric`s in this container."""
469 if self._built:
470 metrics = self._metrics_in_order
471 else:
472 # If the user supplied `Metric` objects directly, we should
473 # reset those. This could also contain `str`s or `function`s
474 # though.
475 metrics = nest.flatten(self._user_metrics) + nest.flatten(
476 self._user_weighted_metrics)
478 for metric_obj in metrics:
479 if isinstance(metric_obj, metrics_mod.Metric):
480 metric_obj.reset_state()
482 def _get_metric_objects(self, metrics, y_t, y_p):
483 """Convert user-supplied metrics to `Metric` objects."""
484 metrics = nest.flatten(metrics)
485 return [self._get_metric_object(m, y_t, y_p) for m in metrics]
487 def _get_metric_object(self, metric, y_t, y_p):
488 """Converts user-supplied metric to a `Metric` object.
490 Args:
491 metric: A string, function, or `Metric` object.
492 y_t: Sample of label.
493 y_p: Sample of output.
495 Returns:
496 A `Metric` object.
497 """
498 if metric is None:
499 return None # Ok to have no metric for an output.
501 # Convenience feature for selecting b/t binary, categorical,
502 # and sparse categorical.
503 if str(metric).lower() not in ['accuracy', 'acc', 'crossentropy', 'ce']:
504 metric_obj = metrics_mod.get(metric)
505 else:
506 y_t_rank = len(y_t.shape.as_list())
507 y_p_rank = len(y_p.shape.as_list())
508 y_t_last_dim = y_t.shape.as_list()[-1]
509 y_p_last_dim = y_p.shape.as_list()[-1]
511 is_binary = y_p_last_dim == 1
512 is_sparse_categorical = (
513 y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1)
515 if str(metric).lower() in ['accuracy', 'acc']:
516 if is_binary:
517 metric_obj = metrics_mod.binary_accuracy
518 elif is_sparse_categorical:
519 metric_obj = metrics_mod.sparse_categorical_accuracy
520 else:
521 metric_obj = metrics_mod.categorical_accuracy
522 else:
523 if is_binary:
524 metric_obj = metrics_mod.binary_crossentropy
525 elif is_sparse_categorical:
526 metric_obj = metrics_mod.sparse_categorical_crossentropy
527 else:
528 metric_obj = metrics_mod.categorical_crossentropy
530 if isinstance(metric_obj, losses_mod.Loss):
531 metric_obj._allow_sum_over_batch_size = True # pylint: disable=protected-access
533 if not isinstance(metric_obj, metrics_mod.Metric):
534 if isinstance(metric, str):
535 metric_name = metric
536 else:
537 metric_name = get_custom_object_name(metric)
538 if metric_name is None:
539 raise ValueError(
540 'Metric should be a callable, found: {}'.format(metric))
542 metric_obj = metrics_mod.MeanMetricWrapper(metric_obj, name=metric_name)
544 return metric_obj
546 def _should_broadcast(self, obj):
547 # e.g. 'mse'.
548 if not nest.is_nested(obj):
549 return True
550 # e.g. ['mse'] or ['mse', 'mae'].
551 return (isinstance(obj, (list, tuple)) and
552 not any(nest.is_nested(o) for o in obj))
554 def _copy_object(self, obj):
555 if isinstance(obj, metrics_mod.Metric):
556 return obj.__class__.from_config(obj.get_config())
557 return obj # Can be a function or `None`.
560def create_pseudo_output_names(outputs):
561 """Create pseudo output names for a subclassed Model."""
562 return _create_pseudo_names(outputs, prefix='output_')
565def create_pseudo_input_names(inputs):
566 """Create pseudo input names for a subclassed Model."""
567 return _create_pseudo_names(inputs, prefix='input_')
570def _create_pseudo_names(tensors, prefix):
571 """Creates pseudo {input | output} names for subclassed Models.
573 Warning: this function should only be used to define default
574 names for `Metics` and `SavedModel`. No other use cases should
575 rely on a `Model`'s input or output names.
577 Example with dict:
579 `{'a': [x1, x2], 'b': x3}` becomes:
580 `['a_1', 'a_2', 'b']`
582 Example with list:
584 `[x, y]` becomes:
585 `['output_1', 'output_2']`
587 Args:
588 tensors: `Model`'s outputs or inputs.
589 prefix: 'output_' for outputs, 'input_' for inputs.
591 Returns:
592 Flattened list of pseudo names.
593 """
595 def one_index(ele):
596 # Start with "output_1" instead of "output_0".
597 if isinstance(ele, int):
598 return ele + 1
599 return ele
601 flat_paths = list(nest.yield_flat_paths(tensors))
602 flat_paths = nest.map_structure(one_index, flat_paths)
603 names = []
604 for path in flat_paths:
605 if not path:
606 name = prefix + '1' # Single output.
607 else:
608 name = '_'.join(str(p) for p in path)
609 if isinstance(path[0], int):
610 name = prefix + name
611 names.append(name)
612 return names
615def map_to_output_names(y_pred, output_names, struct):
616 """Maps a dict to a list using `output_names` as keys.
618 This is a convenience feature only. When a `Model`'s outputs
619 are a list, you can specify per-output losses and metrics as
620 a dict, where the keys are the output names. If you specify
621 per-output losses and metrics via the same structure as the
622 `Model`'s outputs (recommended), no mapping is performed.
624 For the Functional API, the output names are the names of the
625 last layer of each output. For the Subclass API, the output names
626 are determined by `create_pseudo_output_names` (For example:
627 `['output_1', 'output_2']` for a list of outputs).
629 This mapping preserves backwards compatibility for `compile` and
630 `fit`.
632 Args:
633 y_pred: Sample outputs of the Model, to determine if this convenience
634 feature should be applied (`struct` is returned unmodified if `y_pred`
635 isn't a flat list).
636 output_names: List. The names of the outputs of the Model.
637 struct: The structure to map.
639 Returns:
640 `struct` mapped to a list in same order as `output_names`.
641 """
642 single_output = not nest.is_nested(y_pred)
643 outputs_are_flat_list = (not single_output and
644 isinstance(y_pred, (list, tuple)) and
645 not any(nest.is_nested(y_p) for y_p in y_pred))
647 if (single_output or outputs_are_flat_list) and isinstance(struct, dict):
648 output_names = output_names or create_pseudo_output_names(y_pred)
649 struct = copy.copy(struct)
650 new_struct = [struct.pop(name, None) for name in output_names]
651 if struct:
652 raise ValueError('Found unexpected keys that do not correspond '
653 'to any Model output: {}. Expected: {}'.format(
654 struct.keys(), output_names))
655 if len(new_struct) == 1:
656 return new_struct[0]
657 return new_struct
658 else:
659 return struct
662def map_missing_dict_keys(y_pred, struct):
663 """Replaces missing dict keys in `struct` with `None` placeholders."""
664 if not isinstance(y_pred, dict) or not isinstance(struct, dict):
665 return struct
666 for k in y_pred.keys():
667 if k not in struct:
668 struct[k] = None
669 return struct
672def match_dtype_and_rank(y_t, y_p, sw):
673 """Match dtype and rank of predictions."""
674 if y_t.shape.rank == 1 and y_p.shape.rank == 2:
675 y_t = array_ops.expand_dims_v2(y_t, axis=-1)
676 if sw is not None:
677 if sw.shape.rank == 1 and y_p.shape.rank == 2:
678 sw = array_ops.expand_dims_v2(sw, axis=-1)
680 # Dtype.
681 # This is required mainly for custom loss functions which do not take care
682 # casting dtypes.
683 if ((y_t.dtype.is_floating and y_p.dtype.is_floating) or
684 (y_t.dtype.is_integer and y_p.dtype.is_integer)):
685 y_t = math_ops.cast(y_t, y_p.dtype)
687 if sw is not None:
688 sw = math_ops.cast(sw, y_p.dtype)
689 return y_t, y_p, sw
692def get_mask(y_p):
693 """Returns Keras mask from tensor."""
694 return getattr(y_p, '_keras_mask', None)
697def apply_mask(y_p, sw, mask):
698 """Applies any mask on predictions to sample weights."""
699 if mask is not None:
700 mask = math_ops.cast(mask, y_p.dtype)
701 if sw is not None:
702 mask, _, sw = (
703 losses_utils.squeeze_or_expand_dimensions(mask, sample_weight=sw))
704 sw *= mask
705 else:
706 sw = mask
707 return sw
710def get_custom_object_name(obj):
711 """Returns the name to use for a custom loss or metric callable.
713 Args:
714 obj: Custom loss of metric callable
716 Returns:
717 Name to use, or `None` if the object was not recognized.
718 """
719 if hasattr(obj, 'name'): # Accept `Loss` instance as `Metric`.
720 return obj.name
721 elif hasattr(obj, '__name__'): # Function.
722 return obj.__name__
723 elif hasattr(obj, '__class__'): # Class instance.
724 return generic_utils.to_snake_case(obj.__class__.__name__)
725 else: # Unrecognized object.
726 return None