Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/core/einsum_dense.py: 17%
109 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras-based einsum dense layer."""
18import re
20import tensorflow.compat.v2 as tf
22from keras.src import activations
23from keras.src import constraints
24from keras.src import initializers
25from keras.src import regularizers
26from keras.src.engine.base_layer import Layer
28# isort: off
29from tensorflow.python.util.tf_export import keras_export
32@keras_export(
33 "keras.layers.EinsumDense", "keras.layers.experimental.EinsumDense"
34)
35class EinsumDense(Layer):
36 """A layer that uses `tf.einsum` as the backing computation.
38 This layer can perform einsum calculations of arbitrary dimensionality.
40 Args:
41 equation: An equation describing the einsum to perform. This equation must
42 be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or
43 `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum
44 axis expression sequence.
45 output_shape: The expected shape of the output tensor (excluding the batch
46 dimension and any dimensions represented by ellipses). You can specify
47 None for any dimension that is unknown or can be inferred from the input
48 shape.
49 activation: Activation function to use. If you don't specify anything, no
50 activation is applied (that is, a "linear" activation: `a(x) = x`).
51 bias_axes: A string containing the output dimension(s) to apply a bias to.
52 Each character in the `bias_axes` string should correspond to a
53 character in the output portion of the `equation` string.
54 kernel_initializer: Initializer for the `kernel` weights matrix.
55 bias_initializer: Initializer for the bias vector.
56 kernel_regularizer: Regularizer function applied to the `kernel` weights
57 matrix.
58 bias_regularizer: Regularizer function applied to the bias vector.
59 activity_regularizer: Regularizer function applied to the output of the
60 layer (its "activation").
61 kernel_constraint: Constraint function applied to the `kernel` weights
62 matrix.
63 bias_constraint: Constraint function applied to the bias vector.
65 Examples:
67 **Biased dense layer with einsums**
69 This example shows how to instantiate a standard Keras dense layer using
70 einsum operations. This example is equivalent to
71 `tf.keras.layers.Dense(64, use_bias=True)`.
73 >>> layer = tf.keras.layers.EinsumDense("ab,bc->ac",
74 ... output_shape=64,
75 ... bias_axes="c")
76 >>> input_tensor = tf.keras.Input(shape=[32])
77 >>> output_tensor = layer(input_tensor)
78 >>> output_tensor
79 <... shape=(None, 64) dtype=...>
81 **Applying a dense layer to a sequence**
83 This example shows how to instantiate a layer that applies the same dense
84 operation to every element in a sequence. Here, the `output_shape` has two
85 values (since there are two non-batch dimensions in the output); the first
86 dimension in the `output_shape` is `None`, because the sequence dimension
87 `b` has an unknown shape.
89 >>> layer = tf.keras.layers.EinsumDense("abc,cd->abd",
90 ... output_shape=(None, 64),
91 ... bias_axes="d")
92 >>> input_tensor = tf.keras.Input(shape=[32, 128])
93 >>> output_tensor = layer(input_tensor)
94 >>> output_tensor
95 <... shape=(None, 32, 64) dtype=...>
97 **Applying a dense layer to a sequence using ellipses**
99 This example shows how to instantiate a layer that applies the same dense
100 operation to every element in a sequence, but uses the ellipsis notation
101 instead of specifying the batch and sequence dimensions.
103 Because we are using ellipsis notation and have specified only one axis, the
104 `output_shape` arg is a single value. When instantiated in this way, the
105 layer can handle any number of sequence dimensions - including the case
106 where no sequence dimension exists.
108 >>> layer = tf.keras.layers.EinsumDense("...x,xy->...y",
109 ... output_shape=64,
110 ... bias_axes="y")
111 >>> input_tensor = tf.keras.Input(shape=[32, 128])
112 >>> output_tensor = layer(input_tensor)
113 >>> output_tensor
114 <... shape=(None, 32, 64) dtype=...>
115 """
117 def __init__(
118 self,
119 equation,
120 output_shape,
121 activation=None,
122 bias_axes=None,
123 kernel_initializer="glorot_uniform",
124 bias_initializer="zeros",
125 kernel_regularizer=None,
126 bias_regularizer=None,
127 activity_regularizer=None,
128 kernel_constraint=None,
129 bias_constraint=None,
130 **kwargs,
131 ):
132 super().__init__(**kwargs)
133 self.equation = equation
134 if isinstance(output_shape, int):
135 self.partial_output_shape = [output_shape]
136 else:
137 self.partial_output_shape = list(output_shape)
138 self.bias_axes = bias_axes
139 self.activation = activations.get(activation)
140 self.kernel_initializer = initializers.get(kernel_initializer)
141 self.bias_initializer = initializers.get(bias_initializer)
142 self.kernel_regularizer = regularizers.get(kernel_regularizer)
143 self.bias_regularizer = regularizers.get(bias_regularizer)
144 self.kernel_constraint = constraints.get(kernel_constraint)
145 self.bias_constraint = constraints.get(bias_constraint)
147 def build(self, input_shape):
148 input_shape = tf.TensorShape(input_shape)
149 shape_data = _analyze_einsum_string(
150 self.equation,
151 self.bias_axes,
152 input_shape,
153 self.partial_output_shape,
154 )
155 kernel_shape, bias_shape, self.full_output_shape = shape_data
156 self.kernel = self.add_weight(
157 "kernel",
158 shape=kernel_shape,
159 initializer=self.kernel_initializer,
160 regularizer=self.kernel_regularizer,
161 constraint=self.kernel_constraint,
162 dtype=self.dtype,
163 trainable=True,
164 )
166 if bias_shape is not None:
167 self.bias = self.add_weight(
168 "bias",
169 shape=bias_shape,
170 initializer=self.bias_initializer,
171 regularizer=self.bias_regularizer,
172 constraint=self.bias_constraint,
173 dtype=self.dtype,
174 trainable=True,
175 )
176 else:
177 self.bias = None
178 super().build(input_shape)
180 def compute_output_shape(self, _):
181 return tf.TensorShape(self.full_output_shape)
183 def get_config(self):
184 config = {
185 "output_shape": self.partial_output_shape,
186 "equation": self.equation,
187 "activation": activations.serialize(self.activation),
188 "bias_axes": self.bias_axes,
189 "kernel_initializer": initializers.serialize(
190 self.kernel_initializer
191 ),
192 "bias_initializer": initializers.serialize(self.bias_initializer),
193 "kernel_regularizer": regularizers.serialize(
194 self.kernel_regularizer
195 ),
196 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
197 "activity_regularizer": regularizers.serialize(
198 self.activity_regularizer
199 ),
200 "kernel_constraint": constraints.serialize(self.kernel_constraint),
201 "bias_constraint": constraints.serialize(self.bias_constraint),
202 }
203 base_config = super().get_config()
204 return dict(list(base_config.items()) + list(config.items()))
206 def call(self, inputs):
207 ret = tf.einsum(self.equation, inputs, self.kernel)
208 if self.bias is not None:
209 ret += self.bias
210 if self.activation is not None:
211 ret = self.activation(ret)
212 return ret
215def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape):
216 """Analyzes an einsum string to determine the required weight shape."""
218 dot_replaced_string = re.sub(r"\.\.\.", "0", equation)
220 # This is the case where no ellipses are present in the string.
221 split_string = re.match(
222 "([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)", dot_replaced_string
223 )
224 if split_string:
225 return _analyze_split_string(
226 split_string, bias_axes, input_shape, output_shape
227 )
229 # This is the case where ellipses are present on the left.
230 split_string = re.match(
231 "0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)", dot_replaced_string
232 )
233 if split_string:
234 return _analyze_split_string(
235 split_string, bias_axes, input_shape, output_shape, left_elided=True
236 )
238 # This is the case where ellipses are present on the right.
239 split_string = re.match(
240 "([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0", dot_replaced_string
241 )
242 if split_string:
243 return _analyze_split_string(
244 split_string, bias_axes, input_shape, output_shape
245 )
247 raise ValueError(
248 f"Invalid einsum equation '{equation}'. Equations must be in the form "
249 "[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...."
250 )
253def _analyze_split_string(
254 split_string, bias_axes, input_shape, output_shape, left_elided=False
255):
256 """Analyze an pre-split einsum string to find the weight shape."""
257 input_spec = split_string.group(1)
258 weight_spec = split_string.group(2)
259 output_spec = split_string.group(3)
260 elided = len(input_shape) - len(input_spec)
262 if isinstance(output_shape, int):
263 output_shape = [output_shape]
264 else:
265 output_shape = list(output_shape)
267 output_shape.insert(0, input_shape[0])
269 if elided > 0 and left_elided:
270 for i in range(1, elided):
271 # We already inserted the 0th input dimension at dim 0, so we need
272 # to start at location 1 here.
273 output_shape.insert(1, input_shape[i])
274 elif elided > 0 and not left_elided:
275 for i in range(len(input_shape) - elided, len(input_shape)):
276 output_shape.append(input_shape[i])
278 if left_elided:
279 # If we have beginning dimensions elided, we need to use negative
280 # indexing to determine where in the input dimension our values are.
281 input_dim_map = {
282 dim: (i + elided) - len(input_shape)
283 for i, dim in enumerate(input_spec)
284 }
285 # Because we've constructed the full output shape already, we don't need
286 # to do negative indexing.
287 output_dim_map = {
288 dim: (i + elided) for i, dim in enumerate(output_spec)
289 }
290 else:
291 input_dim_map = {dim: i for i, dim in enumerate(input_spec)}
292 output_dim_map = {dim: i for i, dim in enumerate(output_spec)}
294 for dim in input_spec:
295 input_shape_at_dim = input_shape[input_dim_map[dim]]
296 if dim in output_dim_map:
297 output_shape_at_dim = output_shape[output_dim_map[dim]]
298 if (
299 output_shape_at_dim is not None
300 and output_shape_at_dim != input_shape_at_dim
301 ):
302 raise ValueError(
303 "Input shape and output shape do not match at shared "
304 f"dimension '{dim}'. Input shape is {input_shape_at_dim}, "
305 "and output shape "
306 f"is {output_shape[output_dim_map[dim]]}."
307 )
309 for dim in output_spec:
310 if dim not in input_spec and dim not in weight_spec:
311 raise ValueError(
312 f"Dimension '{dim}' was specified in the output "
313 f"'{output_spec}' but has no corresponding dim in the input "
314 f"spec '{input_spec}' or weight spec '{output_spec}'"
315 )
317 weight_shape = []
318 for dim in weight_spec:
319 if dim in input_dim_map:
320 weight_shape.append(input_shape[input_dim_map[dim]])
321 elif dim in output_dim_map:
322 weight_shape.append(output_shape[output_dim_map[dim]])
323 else:
324 raise ValueError(
325 f"Weight dimension '{dim}' did not have a match in either "
326 f"the input spec '{input_spec}' or the output "
327 f"spec '{output_spec}'. For this layer, the weight must "
328 "be fully specified."
329 )
331 if bias_axes is not None:
332 num_left_elided = elided if left_elided else 0
333 idx_map = {
334 char: output_shape[i + num_left_elided]
335 for i, char in enumerate(output_spec)
336 }
338 for char in bias_axes:
339 if char not in output_spec:
340 raise ValueError(
341 f"Bias dimension '{char}' was requested, but is not part "
342 f"of the output spec '{output_spec}'"
343 )
345 first_bias_location = min(
346 [output_spec.find(char) for char in bias_axes]
347 )
348 bias_output_spec = output_spec[first_bias_location:]
350 bias_shape = [
351 idx_map[char] if char in bias_axes else 1
352 for char in bias_output_spec
353 ]
355 if not left_elided:
356 for _ in range(elided):
357 bias_shape.append(1)
358 else:
359 bias_shape = None
361 return weight_shape, bias_shape, output_shape