Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/merging/dot.py: 23%
84 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Layer that computes the dot product between two inputs."""
18import tensorflow.compat.v2 as tf
20from keras.src import backend
21from keras.src.engine import base_layer_utils
22from keras.src.layers.merging.base_merge import _Merge
23from keras.src.utils import tf_utils
25# isort: off
26from tensorflow.python.util.tf_export import keras_export
29@keras_export("keras.layers.Dot")
30class Dot(_Merge):
31 """Layer that computes a dot product between samples in two tensors.
33 E.g. if applied to a list of two tensors `a` and `b` of shape
34 `(batch_size, n)`, the output will be a tensor of shape `(batch_size, 1)`
35 where each entry `i` will be the dot product between
36 `a[i]` and `b[i]`.
38 >>> x = np.arange(10).reshape(1, 5, 2)
39 >>> print(x)
40 [[[0 1]
41 [2 3]
42 [4 5]
43 [6 7]
44 [8 9]]]
45 >>> y = np.arange(10, 20).reshape(1, 2, 5)
46 >>> print(y)
47 [[[10 11 12 13 14]
48 [15 16 17 18 19]]]
49 >>> tf.keras.layers.Dot(axes=(1, 2))([x, y])
50 <tf.Tensor: shape=(1, 2, 2), dtype=int64, numpy=
51 array([[[260, 360],
52 [320, 445]]])>
54 >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
55 >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
56 >>> dotted = tf.keras.layers.Dot(axes=1)([x1, x2])
57 >>> dotted.shape
58 TensorShape([5, 1])
61 """
63 def __init__(self, axes, normalize=False, **kwargs):
64 """Initializes a layer that computes the element-wise dot product.
66 >>> x = np.arange(10).reshape(1, 5, 2)
67 >>> print(x)
68 [[[0 1]
69 [2 3]
70 [4 5]
71 [6 7]
72 [8 9]]]
73 >>> y = np.arange(10, 20).reshape(1, 2, 5)
74 >>> print(y)
75 [[[10 11 12 13 14]
76 [15 16 17 18 19]]]
77 >>> tf.keras.layers.Dot(axes=(1, 2))([x, y])
78 <tf.Tensor: shape=(1, 2, 2), dtype=int64, numpy=
79 array([[[260, 360],
80 [320, 445]]])>
82 Args:
83 axes: Integer or tuple of integers,
84 axis or axes along which to take the dot product. If a tuple, should
85 be two integers corresponding to the desired axis from the first
86 input and the desired axis from the second input, respectively. Note
87 that the size of the two selected axes must match.
88 normalize: Whether to L2-normalize samples along the
89 dot product axis before taking the dot product.
90 If set to True, then the output of the dot product
91 is the cosine proximity between the two samples.
92 **kwargs: Standard layer keyword arguments.
93 """
94 super().__init__(**kwargs)
95 if not isinstance(axes, int):
96 if not isinstance(axes, (list, tuple)):
97 raise TypeError(
98 "Invalid type for argument `axes`: it should be "
99 f"a list or an int. Received: axes={axes}"
100 )
101 if len(axes) != 2:
102 raise ValueError(
103 "Invalid format for argument `axes`: it should contain two "
104 f"elements. Received: axes={axes}"
105 )
106 if not isinstance(axes[0], int) or not isinstance(axes[1], int):
107 raise ValueError(
108 "Invalid format for argument `axes`: list elements should "
109 f"be integers. Received: axes={axes}"
110 )
111 self.axes = axes
112 self.normalize = normalize
113 self.supports_masking = True
114 self._reshape_required = False
116 @tf_utils.shape_type_conversion
117 def build(self, input_shape):
118 # Used purely for shape validation.
119 if not isinstance(input_shape[0], tuple) or len(input_shape) != 2:
120 raise ValueError(
121 "A `Dot` layer should be called on a list of 2 inputs. "
122 f"Received: input_shape={input_shape}"
123 )
124 shape1 = input_shape[0]
125 shape2 = input_shape[1]
126 if shape1 is None or shape2 is None:
127 return
128 if isinstance(self.axes, int):
129 if self.axes < 0:
130 axes = [self.axes % len(shape1), self.axes % len(shape2)]
131 else:
132 axes = [self.axes] * 2
133 else:
134 axes = self.axes
135 if shape1[axes[0]] != shape2[axes[1]]:
136 raise ValueError(
137 "Incompatible input shapes: "
138 f"axis values {shape1[axes[0]]} (at axis {axes[0]}) != "
139 f"{shape2[axes[1]]} (at axis {axes[1]}). "
140 f"Full input shapes: {shape1}, {shape2}"
141 )
143 def _merge_function(self, inputs):
144 base_layer_utils.no_ragged_support(inputs, self.name)
145 if len(inputs) != 2:
146 raise ValueError(
147 "A `Dot` layer should be called on exactly 2 inputs. "
148 f"Received: inputs={inputs}"
149 )
150 x1 = inputs[0]
151 x2 = inputs[1]
152 if isinstance(self.axes, int):
153 if self.axes < 0:
154 axes = [
155 self.axes % backend.ndim(x1),
156 self.axes % backend.ndim(x2),
157 ]
158 else:
159 axes = [self.axes] * 2
160 else:
161 axes = []
162 for i in range(len(self.axes)):
163 if self.axes[i] < 0:
164 axes.append(self.axes[i] % backend.ndim(inputs[i]))
165 else:
166 axes.append(self.axes[i])
167 if self.normalize:
168 x1 = tf.linalg.l2_normalize(x1, axis=axes[0])
169 x2 = tf.linalg.l2_normalize(x2, axis=axes[1])
170 output = backend.batch_dot(x1, x2, axes)
171 return output
173 @tf_utils.shape_type_conversion
174 def compute_output_shape(self, input_shape):
175 if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2:
176 raise ValueError(
177 "A `Dot` layer should be called on a list of 2 inputs. "
178 f"Received: input_shape={input_shape}"
179 )
180 shape1 = list(input_shape[0])
181 shape2 = list(input_shape[1])
182 if isinstance(self.axes, int):
183 if self.axes < 0:
184 axes = [self.axes % len(shape1), self.axes % len(shape2)]
185 else:
186 axes = [self.axes] * 2
187 else:
188 axes = self.axes
189 shape1.pop(axes[0])
190 shape2.pop(axes[1])
191 shape2.pop(0)
192 output_shape = shape1 + shape2
193 if len(output_shape) == 1:
194 output_shape += [1]
195 return tuple(output_shape)
197 def compute_mask(self, inputs, mask=None):
198 return None
200 def get_config(self):
201 config = {
202 "axes": self.axes,
203 "normalize": self.normalize,
204 }
205 base_config = super().get_config()
206 return dict(list(base_config.items()) + list(config.items()))
209@keras_export("keras.layers.dot")
210def dot(inputs, axes, normalize=False, **kwargs):
211 """Functional interface to the `Dot` layer.
213 Args:
214 inputs: A list of input tensors (at least 2).
215 axes: Integer or tuple of integers,
216 axis or axes along which to take the dot product.
217 normalize: Whether to L2-normalize samples along the
218 dot product axis before taking the dot product.
219 If set to True, then the output of the dot product
220 is the cosine proximity between the two samples.
221 **kwargs: Standard layer keyword arguments.
223 Returns:
224 A tensor, the dot product of the samples from the inputs.
225 """
226 return Dot(axes=axes, normalize=normalize, **kwargs)(inputs)