Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/merging/concatenate.py: 24%
75 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Layer that concatenates several inputs."""
18import tensorflow.compat.v2 as tf
20from keras.src import backend
21from keras.src.layers.merging.base_merge import _Merge
22from keras.src.utils import tf_utils
24# isort: off
25from tensorflow.python.util.tf_export import keras_export
28@keras_export("keras.layers.Concatenate")
29class Concatenate(_Merge):
30 """Layer that concatenates a list of inputs.
32 It takes as input a list of tensors, all of the same shape except
33 for the concatenation axis, and returns a single tensor that is the
34 concatenation of all inputs.
36 >>> x = np.arange(20).reshape(2, 2, 5)
37 >>> print(x)
38 [[[ 0 1 2 3 4]
39 [ 5 6 7 8 9]]
40 [[10 11 12 13 14]
41 [15 16 17 18 19]]]
42 >>> y = np.arange(20, 30).reshape(2, 1, 5)
43 >>> print(y)
44 [[[20 21 22 23 24]]
45 [[25 26 27 28 29]]]
46 >>> tf.keras.layers.Concatenate(axis=1)([x, y])
47 <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy=
48 array([[[ 0, 1, 2, 3, 4],
49 [ 5, 6, 7, 8, 9],
50 [20, 21, 22, 23, 24]],
51 [[10, 11, 12, 13, 14],
52 [15, 16, 17, 18, 19],
53 [25, 26, 27, 28, 29]]])>
55 >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
56 >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
57 >>> concatted = tf.keras.layers.Concatenate()([x1, x2])
58 >>> concatted.shape
59 TensorShape([5, 16])
61 """
63 def __init__(self, axis=-1, **kwargs):
64 """Instantiates a Concatenate layer.
66 >>> x = np.arange(20).reshape(2, 2, 5)
67 >>> print(x)
68 [[[ 0 1 2 3 4]
69 [ 5 6 7 8 9]]
70 [[10 11 12 13 14]
71 [15 16 17 18 19]]]
72 >>> y = np.arange(20, 30).reshape(2, 1, 5)
73 >>> print(y)
74 [[[20 21 22 23 24]]
75 [[25 26 27 28 29]]]
76 >>> tf.keras.layers.Concatenate(axis=1)([x, y])
77 <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy=
78 array([[[ 0, 1, 2, 3, 4],
79 [ 5, 6, 7, 8, 9],
80 [20, 21, 22, 23, 24]],
81 [[10, 11, 12, 13, 14],
82 [15, 16, 17, 18, 19],
83 [25, 26, 27, 28, 29]]])>
85 Args:
86 axis: Axis along which to concatenate.
87 **kwargs: standard layer keyword arguments.
88 """
89 super().__init__(**kwargs)
90 self.axis = axis
91 self.supports_masking = True
92 self._reshape_required = False
94 @tf_utils.shape_type_conversion
95 def build(self, input_shape):
96 # Used purely for shape validation.
97 if len(input_shape) < 1 or not isinstance(input_shape[0], tuple):
98 raise ValueError(
99 "A `Concatenate` layer should be called on a list of "
100 f"at least 1 input. Received: input_shape={input_shape}"
101 )
102 if all(shape is None for shape in input_shape):
103 return
104 reduced_inputs_shapes = [list(shape) for shape in input_shape]
105 shape_set = set()
106 for i in range(len(reduced_inputs_shapes)):
107 del reduced_inputs_shapes[i][self.axis]
108 shape_set.add(tuple(reduced_inputs_shapes[i]))
110 if len(shape_set) != 1:
111 err_msg = (
112 "A `Concatenate` layer requires inputs with matching shapes "
113 "except for the concatenation axis. "
114 f"Received: input_shape={input_shape}"
115 )
116 # Make sure all the shapes have same ranks.
117 ranks = set(len(shape) for shape in shape_set)
118 if len(ranks) != 1:
119 raise ValueError(err_msg)
120 # Get the only rank for the set.
121 (rank,) = ranks
122 for axis in range(rank):
123 # Skip the Nones in the shape since they are dynamic, also the
124 # axis for concat has been removed above.
125 unique_dims = set(
126 shape[axis]
127 for shape in shape_set
128 if shape[axis] is not None
129 )
130 if len(unique_dims) > 1:
131 raise ValueError(err_msg)
133 def _merge_function(self, inputs):
134 return backend.concatenate(inputs, axis=self.axis)
136 @tf_utils.shape_type_conversion
137 def compute_output_shape(self, input_shape):
138 if (not isinstance(input_shape, (tuple, list))) or (
139 not isinstance(input_shape[0], (tuple, list))
140 ):
141 # The tf_utils.shape_type_conversion decorator turns tensorshapes
142 # into tuples, so we need to verify that `input_shape` is a
143 # list/tuple, *and* that the individual elements are themselves
144 # shape tuples.
145 raise ValueError(
146 "A `Concatenate` layer should be called on a list of inputs. "
147 f"Received: input_shape={input_shape}"
148 )
149 input_shapes = input_shape
150 output_shape = list(input_shapes[0])
151 for shape in input_shapes[1:]:
152 if output_shape[self.axis] is None or shape[self.axis] is None:
153 output_shape[self.axis] = None
154 break
155 output_shape[self.axis] += shape[self.axis]
156 return tuple(output_shape)
158 def compute_mask(self, inputs, mask=None):
159 if mask is None:
160 return None
161 if not isinstance(mask, (tuple, list)):
162 raise ValueError(f"`mask` should be a list. Received mask={mask}")
163 if not isinstance(inputs, (tuple, list)):
164 raise ValueError(
165 f"`inputs` should be a list. Received: inputs={inputs}"
166 )
167 if len(mask) != len(inputs):
168 raise ValueError(
169 "The lists `inputs` and `mask` should have the same length. "
170 f"Received: inputs={inputs} of length {len(inputs)}, and "
171 f"mask={mask} of length {len(mask)}"
172 )
173 if all(m is None for m in mask):
174 return None
175 # Make a list of masks while making sure
176 # the dimensionality of each mask
177 # is the same as the corresponding input.
178 masks = []
179 for input_i, mask_i in zip(inputs, mask):
180 if mask_i is None:
181 # Input is unmasked. Append all 1s to masks,
182 masks.append(tf.ones_like(input_i, dtype="bool"))
183 elif backend.ndim(mask_i) < backend.ndim(input_i):
184 # Mask is smaller than the input, expand it
185 masks.append(tf.expand_dims(mask_i, axis=-1))
186 else:
187 masks.append(mask_i)
188 concatenated = backend.concatenate(masks, axis=self.axis)
189 return backend.all(concatenated, axis=-1, keepdims=False)
191 def get_config(self):
192 config = {
193 "axis": self.axis,
194 }
195 base_config = super().get_config()
196 return dict(list(base_config.items()) + list(config.items()))
199@keras_export("keras.layers.concatenate")
200def concatenate(inputs, axis=-1, **kwargs):
201 """Functional interface to the `Concatenate` layer.
203 >>> x = np.arange(20).reshape(2, 2, 5)
204 >>> print(x)
205 [[[ 0 1 2 3 4]
206 [ 5 6 7 8 9]]
207 [[10 11 12 13 14]
208 [15 16 17 18 19]]]
209 >>> y = np.arange(20, 30).reshape(2, 1, 5)
210 >>> print(y)
211 [[[20 21 22 23 24]]
212 [[25 26 27 28 29]]]
213 >>> tf.keras.layers.concatenate([x, y],
214 ... axis=1)
215 <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy=
216 array([[[ 0, 1, 2, 3, 4],
217 [ 5, 6, 7, 8, 9],
218 [20, 21, 22, 23, 24]],
219 [[10, 11, 12, 13, 14],
220 [15, 16, 17, 18, 19],
221 [25, 26, 27, 28, 29]]])>
223 Args:
224 inputs: A list of input tensors.
225 axis: Concatenation axis.
226 **kwargs: Standard layer keyword arguments.
228 Returns:
229 A tensor, the concatenation of the inputs alongside axis `axis`.
230 """
231 return Concatenate(axis=axis, **kwargs)(inputs)