Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/losses/lifted.py: 35%
40 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements lifted_struct_loss."""
17import tensorflow as tf
18from tensorflow_addons.losses import metric_learning
20from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
21from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
22from typeguard import typechecked
23from typing import Optional
26@tf.keras.utils.register_keras_serializable(package="Addons")
27@tf.function
28def lifted_struct_loss(
29 labels: TensorLike, embeddings: TensorLike, margin: FloatTensorLike = 1.0
30) -> tf.Tensor:
31 """Computes the lifted structured loss.
33 Args:
34 labels: 1-D tf.int32 `Tensor` with shape `[batch_size]` of
35 multiclass integer labels.
36 embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should
37 not be l2 normalized.
38 margin: Float, margin term in the loss definition.
40 Returns:
41 lifted_loss: float scalar with dtype of embeddings.
42 """
43 convert_to_float32 = (
44 embeddings.dtype == tf.dtypes.float16 or embeddings.dtype == tf.dtypes.bfloat16
45 )
46 precise_embeddings = (
47 tf.cast(embeddings, tf.dtypes.float32) if convert_to_float32 else embeddings
48 )
50 # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor.
51 lshape = tf.shape(labels)
52 labels = tf.reshape(labels, [lshape[0], 1])
54 # Build pairwise squared distance matrix.
55 pairwise_distances = metric_learning.pairwise_distance(precise_embeddings)
57 # Build pairwise binary adjacency matrix.
58 adjacency = tf.math.equal(labels, tf.transpose(labels))
59 # Invert so we can select negatives only.
60 adjacency_not = tf.math.logical_not(adjacency)
62 batch_size = tf.size(labels)
64 diff = margin - pairwise_distances
65 mask = tf.cast(adjacency_not, dtype=tf.dtypes.float32)
66 # Safe maximum: Temporarily shift negative distances
67 # above zero before taking max.
68 # this is to take the max only among negatives.
69 row_minimums = tf.math.reduce_min(diff, 1, keepdims=True)
70 row_negative_maximums = (
71 tf.math.reduce_max(
72 tf.math.multiply(diff - row_minimums, mask), 1, keepdims=True
73 )
74 + row_minimums
75 )
77 # Compute the loss.
78 # Keep track of matrix of maximums where M_ij = max(m_i, m_j)
79 # where m_i is the max of alpha - negative D_i's.
80 # This matches the Caffe loss layer implementation at:
81 # https://github.com/rksltnl/Caffe-Deep-Metric-Learning-CVPR16/blob/0efd7544a9846f58df923c8b992198ba5c355454/src/caffe/layers/lifted_struct_similarity_softmax_layer.cpp
83 max_elements = tf.math.maximum(
84 row_negative_maximums, tf.transpose(row_negative_maximums)
85 )
86 diff_tiled = tf.tile(diff, [batch_size, 1])
87 mask_tiled = tf.tile(mask, [batch_size, 1])
88 max_elements_vect = tf.reshape(tf.transpose(max_elements), [-1, 1])
90 loss_exp_left = tf.reshape(
91 tf.math.reduce_sum(
92 tf.math.multiply(tf.math.exp(diff_tiled - max_elements_vect), mask_tiled),
93 1,
94 keepdims=True,
95 ),
96 [batch_size, batch_size],
97 )
99 loss_mat = max_elements + tf.math.log(loss_exp_left + tf.transpose(loss_exp_left))
100 # Add the positive distance.
101 loss_mat += pairwise_distances
103 mask_positives = tf.cast(adjacency, dtype=tf.dtypes.float32) - tf.linalg.diag(
104 tf.ones([batch_size])
105 )
107 # *0.5 for upper triangular, and another *0.5 for 1/2 factor for loss^2.
108 num_positives = tf.math.reduce_sum(mask_positives) / 2.0
110 lifted_loss = tf.math.truediv(
111 0.25
112 * tf.math.reduce_sum(
113 tf.math.square(
114 tf.math.maximum(tf.math.multiply(loss_mat, mask_positives), 0.0)
115 )
116 ),
117 num_positives,
118 )
120 if convert_to_float32:
121 return tf.cast(lifted_loss, embeddings.dtype)
122 else:
123 return lifted_loss
126@tf.keras.utils.register_keras_serializable(package="Addons")
127class LiftedStructLoss(LossFunctionWrapper):
128 """Computes the lifted structured loss.
130 The loss encourages the positive distances (between a pair of embeddings
131 with the same labels) to be smaller than any negative distances (between
132 a pair of embeddings with different labels) in the mini-batch in a way
133 that is differentiable with respect to the embedding vectors.
134 See: https://arxiv.org/abs/1511.06452.
136 Args:
137 margin: Float, margin term in the loss definition.
138 name: Optional name for the op.
139 """
141 @typechecked
142 def __init__(
143 self, margin: FloatTensorLike = 1.0, name: Optional[str] = None, **kwargs
144 ):
145 super().__init__(
146 lifted_struct_loss,
147 name=name,
148 reduction=tf.keras.losses.Reduction.NONE,
149 margin=margin,
150 )