Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/losses/metric_learning.py: 32%

22 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Functions of metric learning.""" 

16 

17import tensorflow as tf 

18from tensorflow_addons.utils.types import TensorLike 

19 

20 

21@tf.function 

22def pairwise_distance(feature: TensorLike, squared: bool = False): 

23 """Computes the pairwise distance matrix with numerical stability. 

24 

25 output[i, j] = || feature[i, :] - feature[j, :] ||_2 

26 

27 Args: 

28 feature: 2-D Tensor of size `[number of data, feature dimension]`. 

29 squared: Boolean, whether or not to square the pairwise distances. 

30 

31 Returns: 

32 pairwise_distances: 2-D Tensor of size `[number of data, number of data]`. 

33 """ 

34 pairwise_distances_squared = tf.math.add( 

35 tf.math.reduce_sum(tf.math.square(feature), axis=[1], keepdims=True), 

36 tf.math.reduce_sum( 

37 tf.math.square(tf.transpose(feature)), axis=[0], keepdims=True 

38 ), 

39 ) - 2.0 * tf.matmul(feature, tf.transpose(feature)) 

40 

41 # Deal with numerical inaccuracies. Set small negatives to zero. 

42 pairwise_distances_squared = tf.math.maximum(pairwise_distances_squared, 0.0) 

43 # Get the mask where the zero distances are at. 

44 error_mask = tf.math.less_equal(pairwise_distances_squared, 0.0) 

45 

46 # Optionally take the sqrt. 

47 if squared: 

48 pairwise_distances = pairwise_distances_squared 

49 else: 

50 pairwise_distances = tf.math.sqrt( 

51 pairwise_distances_squared 

52 + tf.cast(error_mask, dtype=tf.dtypes.float32) * 1e-16 

53 ) 

54 

55 # Undo conditionally adding 1e-16. 

56 pairwise_distances = tf.math.multiply( 

57 pairwise_distances, 

58 tf.cast(tf.math.logical_not(error_mask), dtype=tf.dtypes.float32), 

59 ) 

60 

61 num_data = tf.shape(feature)[0] 

62 # Explicitly set diagonals to zero. 

63 mask_offdiagonals = tf.ones_like(pairwise_distances) - tf.linalg.diag( 

64 tf.ones([num_data]) 

65 ) 

66 pairwise_distances = tf.math.multiply(pairwise_distances, mask_offdiagonals) 

67 return pairwise_distances 

68 

69 

70@tf.function 

71def angular_distance(feature: TensorLike): 

72 """Computes the angular distance matrix. 

73 

74 output[i, j] = 1 - cosine_similarity(feature[i, :], feature[j, :]) 

75 

76 Args: 

77 feature: 2-D Tensor of size `[number of data, feature dimension]`. 

78 

79 Returns: 

80 angular_distances: 2-D Tensor of size `[number of data, number of data]`. 

81 """ 

82 # normalize input 

83 feature = tf.math.l2_normalize(feature, axis=1) 

84 

85 # create adjaceny matrix of cosine similarity 

86 angular_distances = 1 - tf.matmul(feature, feature, transpose_b=True) 

87 

88 # ensure all distances > 1e-16 

89 angular_distances = tf.maximum(angular_distances, 0.0) 

90 

91 return angular_distances