Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/losses/contrastive.py: 75%
16 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements contrastive loss."""
17import tensorflow as tf
18from typeguard import typechecked
20from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
21from tensorflow_addons.utils.types import TensorLike, Number
24@tf.keras.utils.register_keras_serializable(package="Addons")
25@tf.function
26def contrastive_loss(
27 y_true: TensorLike, y_pred: TensorLike, margin: Number = 1.0
28) -> tf.Tensor:
29 r"""Computes the contrastive loss between `y_true` and `y_pred`.
31 This loss encourages the embedding to be close to each other for
32 the samples of the same label and the embedding to be far apart at least
33 by the margin constant for the samples of different labels.
35 The euclidean distances `y_pred` between two embedding matrices
36 `a` and `b` with shape `[batch_size, hidden_size]` can be computed
37 as follows:
39 >>> a = tf.constant([[1, 2],
40 ... [3, 4],
41 ... [5, 6]], dtype=tf.float16)
42 >>> b = tf.constant([[5, 9],
43 ... [3, 6],
44 ... [1, 8]], dtype=tf.float16)
45 >>> y_pred = tf.linalg.norm(a - b, axis=1)
46 >>> y_pred
47 <tf.Tensor: shape=(3,), dtype=float16, numpy=array([8.06 , 2. , 4.473],
48 dtype=float16)>
50 <... Note: constants a & b have been used purely for
51 example purposes and have no significant value ...>
53 See: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
55 Args:
56 y_true: 1-D integer `Tensor` with shape `[batch_size]` of
57 binary labels indicating positive vs negative pair.
58 y_pred: 1-D float `Tensor` with shape `[batch_size]` of
59 distances between two embedding matrices.
60 margin: margin term in the loss definition.
62 Returns:
63 contrastive_loss: 1-D float `Tensor` with shape `[batch_size]`.
64 """
65 y_pred = tf.convert_to_tensor(y_pred)
66 y_true = tf.dtypes.cast(y_true, y_pred.dtype)
67 return y_true * tf.math.square(y_pred) + (1.0 - y_true) * tf.math.square(
68 tf.math.maximum(margin - y_pred, 0.0)
69 )
72@tf.keras.utils.register_keras_serializable(package="Addons")
73class ContrastiveLoss(LossFunctionWrapper):
74 r"""Computes the contrastive loss between `y_true` and `y_pred`.
76 This loss encourages the embedding to be close to each other for
77 the samples of the same label and the embedding to be far apart at least
78 by the margin constant for the samples of different labels.
80 See: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
82 We expect labels `y_true` to be provided as 1-D integer `Tensor`
83 with shape `[batch_size]` of binary integer labels. And `y_pred` must be
84 1-D float `Tensor` with shape `[batch_size]` of distances between two
85 embedding matrices.
87 The euclidean distances `y_pred` between two embedding matrices
88 `a` and `b` with shape `[batch_size, hidden_size]` can be computed
89 as follows:
91 >>> a = tf.constant([[1, 2],
92 ... [3, 4],[5, 6]], dtype=tf.float16)
93 >>> b = tf.constant([[5, 9],
94 ... [3, 6],[1, 8]], dtype=tf.float16)
95 >>> y_pred = tf.linalg.norm(a - b, axis=1)
96 >>> y_pred
97 <tf.Tensor: shape=(3,), dtype=float16, numpy=array([8.06 , 2. , 4.473],
98 dtype=float16)>
100 <... Note: constants a & b have been used purely for
101 example purposes and have no significant value ...>
103 Args:
104 margin: `Float`, margin term in the loss definition.
105 Default value is 1.0.
106 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply.
107 Default value is `SUM_OVER_BATCH_SIZE`.
108 name: (Optional) name for the loss.
109 """
111 @typechecked
112 def __init__(
113 self,
114 margin: Number = 1.0,
115 reduction: str = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
116 name: str = "contrastive_loss",
117 ):
118 super().__init__(
119 contrastive_loss, reduction=reduction, name=name, margin=margin
120 )