Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/losses/metric_learning.py: 32%
22 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions of metric learning."""
17import tensorflow as tf
18from tensorflow_addons.utils.types import TensorLike
21@tf.function
22def pairwise_distance(feature: TensorLike, squared: bool = False):
23 """Computes the pairwise distance matrix with numerical stability.
25 output[i, j] = || feature[i, :] - feature[j, :] ||_2
27 Args:
28 feature: 2-D Tensor of size `[number of data, feature dimension]`.
29 squared: Boolean, whether or not to square the pairwise distances.
31 Returns:
32 pairwise_distances: 2-D Tensor of size `[number of data, number of data]`.
33 """
34 pairwise_distances_squared = tf.math.add(
35 tf.math.reduce_sum(tf.math.square(feature), axis=[1], keepdims=True),
36 tf.math.reduce_sum(
37 tf.math.square(tf.transpose(feature)), axis=[0], keepdims=True
38 ),
39 ) - 2.0 * tf.matmul(feature, tf.transpose(feature))
41 # Deal with numerical inaccuracies. Set small negatives to zero.
42 pairwise_distances_squared = tf.math.maximum(pairwise_distances_squared, 0.0)
43 # Get the mask where the zero distances are at.
44 error_mask = tf.math.less_equal(pairwise_distances_squared, 0.0)
46 # Optionally take the sqrt.
47 if squared:
48 pairwise_distances = pairwise_distances_squared
49 else:
50 pairwise_distances = tf.math.sqrt(
51 pairwise_distances_squared
52 + tf.cast(error_mask, dtype=tf.dtypes.float32) * 1e-16
53 )
55 # Undo conditionally adding 1e-16.
56 pairwise_distances = tf.math.multiply(
57 pairwise_distances,
58 tf.cast(tf.math.logical_not(error_mask), dtype=tf.dtypes.float32),
59 )
61 num_data = tf.shape(feature)[0]
62 # Explicitly set diagonals to zero.
63 mask_offdiagonals = tf.ones_like(pairwise_distances) - tf.linalg.diag(
64 tf.ones([num_data])
65 )
66 pairwise_distances = tf.math.multiply(pairwise_distances, mask_offdiagonals)
67 return pairwise_distances
70@tf.function
71def angular_distance(feature: TensorLike):
72 """Computes the angular distance matrix.
74 output[i, j] = 1 - cosine_similarity(feature[i, :], feature[j, :])
76 Args:
77 feature: 2-D Tensor of size `[number of data, feature dimension]`.
79 Returns:
80 angular_distances: 2-D Tensor of size `[number of data, number of data]`.
81 """
82 # normalize input
83 feature = tf.math.l2_normalize(feature, axis=1)
85 # create adjaceny matrix of cosine similarity
86 angular_distances = 1 - tf.matmul(feature, feature, transpose_b=True)
88 # ensure all distances > 1e-16
89 angular_distances = tf.maximum(angular_distances, 0.0)
91 return angular_distances