Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/merging/base_merge.py: 13%
127 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Private base class for layers that can merge several inputs into one."""
17import tensorflow.compat.v2 as tf
19from keras.src import backend
20from keras.src.engine.base_layer import Layer
21from keras.src.utils import tf_utils
24class _Merge(Layer):
25 """Generic merge layer for elementwise merge functions.
27 Used to implement `Sum`, `Average`, etc.
28 """
30 def __init__(self, **kwargs):
31 """Initializes a Merge layer.
33 Args:
34 **kwargs: standard layer keyword arguments.
35 """
36 super().__init__(**kwargs)
37 self.supports_masking = True
39 def _merge_function(self, inputs):
40 raise NotImplementedError
42 def _compute_elemwise_op_output_shape(self, shape1, shape2):
43 """Computes the shape of the resultant of an elementwise operation.
45 Args:
46 shape1: tuple or None. Shape of the first tensor
47 shape2: tuple or None. Shape of the second tensor
49 Returns:
50 expected output shape when an element-wise operation is
51 carried out on 2 tensors with shapes shape1 and shape2.
52 tuple or None.
54 Raises:
55 ValueError: if shape1 and shape2 are not compatible for
56 element-wise operations.
57 """
58 if None in [shape1, shape2]:
59 return None
60 elif len(shape1) < len(shape2):
61 return self._compute_elemwise_op_output_shape(shape2, shape1)
62 elif not shape2:
63 return shape1
64 output_shape = list(shape1[: -len(shape2)])
65 for i, j in zip(shape1[-len(shape2) :], shape2):
66 if i is None or j is None:
67 output_shape.append(None)
68 elif i == 1:
69 output_shape.append(j)
70 elif j == 1:
71 output_shape.append(i)
72 else:
73 if i != j:
74 raise ValueError(
75 "Inputs have incompatible shapes. "
76 f"Received shapes {shape1} and {shape2}"
77 )
78 output_shape.append(i)
79 return tuple(output_shape)
81 @tf_utils.shape_type_conversion
82 def build(self, input_shape):
83 # Used purely for shape validation.
84 if not isinstance(input_shape[0], tuple):
85 raise ValueError(
86 "A merge layer should be called on a list of inputs. "
87 f"Received: input_shape={input_shape} (not a list of shapes)"
88 )
89 if len(input_shape) < 1:
90 raise ValueError(
91 "A merge layer should be called "
92 "on a list of at least 1 input. "
93 f"Got {len(input_shape)} inputs. "
94 f"Full input_shape received: {input_shape}"
95 )
96 batch_sizes = {s[0] for s in input_shape if s} - {None}
97 if len(batch_sizes) > 1:
98 raise ValueError(
99 "Cannot merge tensors with different batch sizes. "
100 f"Got tensors with shapes {input_shape}"
101 )
102 if input_shape[0] is None:
103 output_shape = None
104 else:
105 output_shape = input_shape[0][1:]
106 for i in range(1, len(input_shape)):
107 if input_shape[i] is None:
108 shape = None
109 else:
110 shape = input_shape[i][1:]
111 output_shape = self._compute_elemwise_op_output_shape(
112 output_shape, shape
113 )
114 # If the inputs have different ranks, we have to reshape them
115 # to make them broadcastable.
116 if None not in input_shape and len(set(map(len, input_shape))) == 1:
117 self._reshape_required = False
118 else:
119 self._reshape_required = True
121 def call(self, inputs):
122 if not isinstance(inputs, (list, tuple)):
123 raise ValueError(
124 "A merge layer should be called on a list of inputs. "
125 f"Received: inputs={inputs} (not a list of tensors)"
126 )
127 if self._reshape_required:
128 reshaped_inputs = []
129 input_ndims = list(map(backend.ndim, inputs))
130 if None not in input_ndims:
131 # If ranks of all inputs are available,
132 # we simply expand each of them at axis=1
133 # until all of them have the same rank.
134 max_ndim = max(input_ndims)
135 for x in inputs:
136 x_ndim = backend.ndim(x)
137 for _ in range(max_ndim - x_ndim):
138 x = tf.expand_dims(x, axis=1)
139 reshaped_inputs.append(x)
140 return self._merge_function(reshaped_inputs)
141 else:
142 # Transpose all inputs so that batch size is the last dimension.
143 # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... ,
144 # batch_size)
145 transposed = False
146 for x in inputs:
147 x_ndim = backend.ndim(x)
148 if x_ndim is None:
149 x_shape = tf.shape(x)
150 batch_size = x_shape[0]
151 new_shape = backend.concatenate(
152 [x_shape[1:], tf.expand_dims(batch_size, axis=-1)]
153 )
154 x_transposed = tf.reshape(
155 x,
156 tf.stack(
157 [batch_size, tf.reduce_prod(x_shape[1:])],
158 axis=0,
159 ),
160 )
161 x_transposed = tf.transpose(x_transposed, perm=(1, 0))
162 x_transposed = tf.reshape(x_transposed, new_shape)
163 reshaped_inputs.append(x_transposed)
164 transposed = True
165 elif x_ndim > 1:
166 dims = list(range(1, x_ndim)) + [0]
167 reshaped_inputs.append(tf.transpose(x, perm=dims))
168 transposed = True
169 else:
170 # We don't transpose inputs if they are 1D vectors or
171 # scalars.
172 reshaped_inputs.append(x)
173 y = self._merge_function(reshaped_inputs)
174 y_ndim = backend.ndim(y)
175 if transposed:
176 # If inputs have been transposed, we have to transpose the
177 # output too.
178 if y_ndim is None:
179 y_shape = tf.shape(y)
180 y_ndim = tf.shape(y_shape)[0]
181 batch_size = y_shape[y_ndim - 1]
182 new_shape = backend.concatenate(
183 [
184 tf.expand_dims(batch_size, axis=-1),
185 y_shape[: y_ndim - 1],
186 ]
187 )
188 y = tf.reshape(y, (-1, batch_size))
189 y = tf.transpose(y, perm=(1, 0))
190 y = tf.reshape(y, new_shape)
191 elif y_ndim > 1:
192 dims = [y_ndim - 1] + list(range(y_ndim - 1))
193 y = tf.transpose(y, perm=dims)
194 return y
195 else:
196 return self._merge_function(inputs)
198 @tf_utils.shape_type_conversion
199 def compute_output_shape(self, input_shape):
200 if input_shape[0] is None:
201 output_shape = None
202 else:
203 output_shape = input_shape[0][1:]
204 for i in range(1, len(input_shape)):
205 if input_shape[i] is None:
206 shape = None
207 else:
208 shape = input_shape[i][1:]
209 output_shape = self._compute_elemwise_op_output_shape(
210 output_shape, shape
211 )
212 batch_sizes = {s[0] for s in input_shape if s is not None} - {None}
213 if len(batch_sizes) == 1:
214 output_shape = (list(batch_sizes)[0],) + output_shape
215 else:
216 output_shape = (None,) + output_shape
217 return output_shape
219 def compute_mask(self, inputs, mask=None):
220 if mask is None:
221 return None
222 if not isinstance(mask, (tuple, list)):
223 raise ValueError(f"`mask` should be a list. Received: mask={mask}")
224 if not isinstance(inputs, (tuple, list)):
225 raise ValueError(
226 f"`inputs` should be a list. Received: inputs={inputs}"
227 )
228 if len(mask) != len(inputs):
229 raise ValueError(
230 "The lists `inputs` and `mask` should have the same length. "
231 f"Received: inputs={inputs} of length {len(inputs)}, and "
232 f"mask={mask} of length {len(mask)}"
233 )
234 if all(m is None for m in mask):
235 return None
236 masks = [tf.expand_dims(m, axis=0) for m in mask if m is not None]
237 return backend.all(
238 backend.concatenate(masks, axis=0), axis=0, keepdims=False
239 )
241 def get_config(self):
242 return super().get_config()