Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/merging/dot.py: 23%

84 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Layer that computes the dot product between two inputs.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src import backend 

21from keras.src.engine import base_layer_utils 

22from keras.src.layers.merging.base_merge import _Merge 

23from keras.src.utils import tf_utils 

24 

25# isort: off 

26from tensorflow.python.util.tf_export import keras_export 

27 

28 

29@keras_export("keras.layers.Dot") 

30class Dot(_Merge): 

31 """Layer that computes a dot product between samples in two tensors. 

32 

33 E.g. if applied to a list of two tensors `a` and `b` of shape 

34 `(batch_size, n)`, the output will be a tensor of shape `(batch_size, 1)` 

35 where each entry `i` will be the dot product between 

36 `a[i]` and `b[i]`. 

37 

38 >>> x = np.arange(10).reshape(1, 5, 2) 

39 >>> print(x) 

40 [[[0 1] 

41 [2 3] 

42 [4 5] 

43 [6 7] 

44 [8 9]]] 

45 >>> y = np.arange(10, 20).reshape(1, 2, 5) 

46 >>> print(y) 

47 [[[10 11 12 13 14] 

48 [15 16 17 18 19]]] 

49 >>> tf.keras.layers.Dot(axes=(1, 2))([x, y]) 

50 <tf.Tensor: shape=(1, 2, 2), dtype=int64, numpy= 

51 array([[[260, 360], 

52 [320, 445]]])> 

53 

54 >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2)) 

55 >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2)) 

56 >>> dotted = tf.keras.layers.Dot(axes=1)([x1, x2]) 

57 >>> dotted.shape 

58 TensorShape([5, 1]) 

59 

60 

61 """ 

62 

63 def __init__(self, axes, normalize=False, **kwargs): 

64 """Initializes a layer that computes the element-wise dot product. 

65 

66 >>> x = np.arange(10).reshape(1, 5, 2) 

67 >>> print(x) 

68 [[[0 1] 

69 [2 3] 

70 [4 5] 

71 [6 7] 

72 [8 9]]] 

73 >>> y = np.arange(10, 20).reshape(1, 2, 5) 

74 >>> print(y) 

75 [[[10 11 12 13 14] 

76 [15 16 17 18 19]]] 

77 >>> tf.keras.layers.Dot(axes=(1, 2))([x, y]) 

78 <tf.Tensor: shape=(1, 2, 2), dtype=int64, numpy= 

79 array([[[260, 360], 

80 [320, 445]]])> 

81 

82 Args: 

83 axes: Integer or tuple of integers, 

84 axis or axes along which to take the dot product. If a tuple, should 

85 be two integers corresponding to the desired axis from the first 

86 input and the desired axis from the second input, respectively. Note 

87 that the size of the two selected axes must match. 

88 normalize: Whether to L2-normalize samples along the 

89 dot product axis before taking the dot product. 

90 If set to True, then the output of the dot product 

91 is the cosine proximity between the two samples. 

92 **kwargs: Standard layer keyword arguments. 

93 """ 

94 super().__init__(**kwargs) 

95 if not isinstance(axes, int): 

96 if not isinstance(axes, (list, tuple)): 

97 raise TypeError( 

98 "Invalid type for argument `axes`: it should be " 

99 f"a list or an int. Received: axes={axes}" 

100 ) 

101 if len(axes) != 2: 

102 raise ValueError( 

103 "Invalid format for argument `axes`: it should contain two " 

104 f"elements. Received: axes={axes}" 

105 ) 

106 if not isinstance(axes[0], int) or not isinstance(axes[1], int): 

107 raise ValueError( 

108 "Invalid format for argument `axes`: list elements should " 

109 f"be integers. Received: axes={axes}" 

110 ) 

111 self.axes = axes 

112 self.normalize = normalize 

113 self.supports_masking = True 

114 self._reshape_required = False 

115 

116 @tf_utils.shape_type_conversion 

117 def build(self, input_shape): 

118 # Used purely for shape validation. 

119 if not isinstance(input_shape[0], tuple) or len(input_shape) != 2: 

120 raise ValueError( 

121 "A `Dot` layer should be called on a list of 2 inputs. " 

122 f"Received: input_shape={input_shape}" 

123 ) 

124 shape1 = input_shape[0] 

125 shape2 = input_shape[1] 

126 if shape1 is None or shape2 is None: 

127 return 

128 if isinstance(self.axes, int): 

129 if self.axes < 0: 

130 axes = [self.axes % len(shape1), self.axes % len(shape2)] 

131 else: 

132 axes = [self.axes] * 2 

133 else: 

134 axes = self.axes 

135 if shape1[axes[0]] != shape2[axes[1]]: 

136 raise ValueError( 

137 "Incompatible input shapes: " 

138 f"axis values {shape1[axes[0]]} (at axis {axes[0]}) != " 

139 f"{shape2[axes[1]]} (at axis {axes[1]}). " 

140 f"Full input shapes: {shape1}, {shape2}" 

141 ) 

142 

143 def _merge_function(self, inputs): 

144 base_layer_utils.no_ragged_support(inputs, self.name) 

145 if len(inputs) != 2: 

146 raise ValueError( 

147 "A `Dot` layer should be called on exactly 2 inputs. " 

148 f"Received: inputs={inputs}" 

149 ) 

150 x1 = inputs[0] 

151 x2 = inputs[1] 

152 if isinstance(self.axes, int): 

153 if self.axes < 0: 

154 axes = [ 

155 self.axes % backend.ndim(x1), 

156 self.axes % backend.ndim(x2), 

157 ] 

158 else: 

159 axes = [self.axes] * 2 

160 else: 

161 axes = [] 

162 for i in range(len(self.axes)): 

163 if self.axes[i] < 0: 

164 axes.append(self.axes[i] % backend.ndim(inputs[i])) 

165 else: 

166 axes.append(self.axes[i]) 

167 if self.normalize: 

168 x1 = tf.linalg.l2_normalize(x1, axis=axes[0]) 

169 x2 = tf.linalg.l2_normalize(x2, axis=axes[1]) 

170 output = backend.batch_dot(x1, x2, axes) 

171 return output 

172 

173 @tf_utils.shape_type_conversion 

174 def compute_output_shape(self, input_shape): 

175 if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2: 

176 raise ValueError( 

177 "A `Dot` layer should be called on a list of 2 inputs. " 

178 f"Received: input_shape={input_shape}" 

179 ) 

180 shape1 = list(input_shape[0]) 

181 shape2 = list(input_shape[1]) 

182 if isinstance(self.axes, int): 

183 if self.axes < 0: 

184 axes = [self.axes % len(shape1), self.axes % len(shape2)] 

185 else: 

186 axes = [self.axes] * 2 

187 else: 

188 axes = self.axes 

189 shape1.pop(axes[0]) 

190 shape2.pop(axes[1]) 

191 shape2.pop(0) 

192 output_shape = shape1 + shape2 

193 if len(output_shape) == 1: 

194 output_shape += [1] 

195 return tuple(output_shape) 

196 

197 def compute_mask(self, inputs, mask=None): 

198 return None 

199 

200 def get_config(self): 

201 config = { 

202 "axes": self.axes, 

203 "normalize": self.normalize, 

204 } 

205 base_config = super().get_config() 

206 return dict(list(base_config.items()) + list(config.items())) 

207 

208 

209@keras_export("keras.layers.dot") 

210def dot(inputs, axes, normalize=False, **kwargs): 

211 """Functional interface to the `Dot` layer. 

212 

213 Args: 

214 inputs: A list of input tensors (at least 2). 

215 axes: Integer or tuple of integers, 

216 axis or axes along which to take the dot product. 

217 normalize: Whether to L2-normalize samples along the 

218 dot product axis before taking the dot product. 

219 If set to True, then the output of the dot product 

220 is the cosine proximity between the two samples. 

221 **kwargs: Standard layer keyword arguments. 

222 

223 Returns: 

224 A tensor, the dot product of the samples from the inputs. 

225 """ 

226 return Dot(axes=axes, normalize=normalize, **kwargs)(inputs) 

227