Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/merging/concatenate.py: 24%

75 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Layer that concatenates several inputs.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src import backend 

21from keras.src.layers.merging.base_merge import _Merge 

22from keras.src.utils import tf_utils 

23 

24# isort: off 

25from tensorflow.python.util.tf_export import keras_export 

26 

27 

28@keras_export("keras.layers.Concatenate") 

29class Concatenate(_Merge): 

30 """Layer that concatenates a list of inputs. 

31 

32 It takes as input a list of tensors, all of the same shape except 

33 for the concatenation axis, and returns a single tensor that is the 

34 concatenation of all inputs. 

35 

36 >>> x = np.arange(20).reshape(2, 2, 5) 

37 >>> print(x) 

38 [[[ 0 1 2 3 4] 

39 [ 5 6 7 8 9]] 

40 [[10 11 12 13 14] 

41 [15 16 17 18 19]]] 

42 >>> y = np.arange(20, 30).reshape(2, 1, 5) 

43 >>> print(y) 

44 [[[20 21 22 23 24]] 

45 [[25 26 27 28 29]]] 

46 >>> tf.keras.layers.Concatenate(axis=1)([x, y]) 

47 <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy= 

48 array([[[ 0, 1, 2, 3, 4], 

49 [ 5, 6, 7, 8, 9], 

50 [20, 21, 22, 23, 24]], 

51 [[10, 11, 12, 13, 14], 

52 [15, 16, 17, 18, 19], 

53 [25, 26, 27, 28, 29]]])> 

54 

55 >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2)) 

56 >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2)) 

57 >>> concatted = tf.keras.layers.Concatenate()([x1, x2]) 

58 >>> concatted.shape 

59 TensorShape([5, 16]) 

60 

61 """ 

62 

63 def __init__(self, axis=-1, **kwargs): 

64 """Instantiates a Concatenate layer. 

65 

66 >>> x = np.arange(20).reshape(2, 2, 5) 

67 >>> print(x) 

68 [[[ 0 1 2 3 4] 

69 [ 5 6 7 8 9]] 

70 [[10 11 12 13 14] 

71 [15 16 17 18 19]]] 

72 >>> y = np.arange(20, 30).reshape(2, 1, 5) 

73 >>> print(y) 

74 [[[20 21 22 23 24]] 

75 [[25 26 27 28 29]]] 

76 >>> tf.keras.layers.Concatenate(axis=1)([x, y]) 

77 <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy= 

78 array([[[ 0, 1, 2, 3, 4], 

79 [ 5, 6, 7, 8, 9], 

80 [20, 21, 22, 23, 24]], 

81 [[10, 11, 12, 13, 14], 

82 [15, 16, 17, 18, 19], 

83 [25, 26, 27, 28, 29]]])> 

84 

85 Args: 

86 axis: Axis along which to concatenate. 

87 **kwargs: standard layer keyword arguments. 

88 """ 

89 super().__init__(**kwargs) 

90 self.axis = axis 

91 self.supports_masking = True 

92 self._reshape_required = False 

93 

94 @tf_utils.shape_type_conversion 

95 def build(self, input_shape): 

96 # Used purely for shape validation. 

97 if len(input_shape) < 1 or not isinstance(input_shape[0], tuple): 

98 raise ValueError( 

99 "A `Concatenate` layer should be called on a list of " 

100 f"at least 1 input. Received: input_shape={input_shape}" 

101 ) 

102 if all(shape is None for shape in input_shape): 

103 return 

104 reduced_inputs_shapes = [list(shape) for shape in input_shape] 

105 shape_set = set() 

106 for i in range(len(reduced_inputs_shapes)): 

107 del reduced_inputs_shapes[i][self.axis] 

108 shape_set.add(tuple(reduced_inputs_shapes[i])) 

109 

110 if len(shape_set) != 1: 

111 err_msg = ( 

112 "A `Concatenate` layer requires inputs with matching shapes " 

113 "except for the concatenation axis. " 

114 f"Received: input_shape={input_shape}" 

115 ) 

116 # Make sure all the shapes have same ranks. 

117 ranks = set(len(shape) for shape in shape_set) 

118 if len(ranks) != 1: 

119 raise ValueError(err_msg) 

120 # Get the only rank for the set. 

121 (rank,) = ranks 

122 for axis in range(rank): 

123 # Skip the Nones in the shape since they are dynamic, also the 

124 # axis for concat has been removed above. 

125 unique_dims = set( 

126 shape[axis] 

127 for shape in shape_set 

128 if shape[axis] is not None 

129 ) 

130 if len(unique_dims) > 1: 

131 raise ValueError(err_msg) 

132 

133 def _merge_function(self, inputs): 

134 return backend.concatenate(inputs, axis=self.axis) 

135 

136 @tf_utils.shape_type_conversion 

137 def compute_output_shape(self, input_shape): 

138 if (not isinstance(input_shape, (tuple, list))) or ( 

139 not isinstance(input_shape[0], (tuple, list)) 

140 ): 

141 # The tf_utils.shape_type_conversion decorator turns tensorshapes 

142 # into tuples, so we need to verify that `input_shape` is a 

143 # list/tuple, *and* that the individual elements are themselves 

144 # shape tuples. 

145 raise ValueError( 

146 "A `Concatenate` layer should be called on a list of inputs. " 

147 f"Received: input_shape={input_shape}" 

148 ) 

149 input_shapes = input_shape 

150 output_shape = list(input_shapes[0]) 

151 for shape in input_shapes[1:]: 

152 if output_shape[self.axis] is None or shape[self.axis] is None: 

153 output_shape[self.axis] = None 

154 break 

155 output_shape[self.axis] += shape[self.axis] 

156 return tuple(output_shape) 

157 

158 def compute_mask(self, inputs, mask=None): 

159 if mask is None: 

160 return None 

161 if not isinstance(mask, (tuple, list)): 

162 raise ValueError(f"`mask` should be a list. Received mask={mask}") 

163 if not isinstance(inputs, (tuple, list)): 

164 raise ValueError( 

165 f"`inputs` should be a list. Received: inputs={inputs}" 

166 ) 

167 if len(mask) != len(inputs): 

168 raise ValueError( 

169 "The lists `inputs` and `mask` should have the same length. " 

170 f"Received: inputs={inputs} of length {len(inputs)}, and " 

171 f"mask={mask} of length {len(mask)}" 

172 ) 

173 if all(m is None for m in mask): 

174 return None 

175 # Make a list of masks while making sure 

176 # the dimensionality of each mask 

177 # is the same as the corresponding input. 

178 masks = [] 

179 for input_i, mask_i in zip(inputs, mask): 

180 if mask_i is None: 

181 # Input is unmasked. Append all 1s to masks, 

182 masks.append(tf.ones_like(input_i, dtype="bool")) 

183 elif backend.ndim(mask_i) < backend.ndim(input_i): 

184 # Mask is smaller than the input, expand it 

185 masks.append(tf.expand_dims(mask_i, axis=-1)) 

186 else: 

187 masks.append(mask_i) 

188 concatenated = backend.concatenate(masks, axis=self.axis) 

189 return backend.all(concatenated, axis=-1, keepdims=False) 

190 

191 def get_config(self): 

192 config = { 

193 "axis": self.axis, 

194 } 

195 base_config = super().get_config() 

196 return dict(list(base_config.items()) + list(config.items())) 

197 

198 

199@keras_export("keras.layers.concatenate") 

200def concatenate(inputs, axis=-1, **kwargs): 

201 """Functional interface to the `Concatenate` layer. 

202 

203 >>> x = np.arange(20).reshape(2, 2, 5) 

204 >>> print(x) 

205 [[[ 0 1 2 3 4] 

206 [ 5 6 7 8 9]] 

207 [[10 11 12 13 14] 

208 [15 16 17 18 19]]] 

209 >>> y = np.arange(20, 30).reshape(2, 1, 5) 

210 >>> print(y) 

211 [[[20 21 22 23 24]] 

212 [[25 26 27 28 29]]] 

213 >>> tf.keras.layers.concatenate([x, y], 

214 ... axis=1) 

215 <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy= 

216 array([[[ 0, 1, 2, 3, 4], 

217 [ 5, 6, 7, 8, 9], 

218 [20, 21, 22, 23, 24]], 

219 [[10, 11, 12, 13, 14], 

220 [15, 16, 17, 18, 19], 

221 [25, 26, 27, 28, 29]]])> 

222 

223 Args: 

224 inputs: A list of input tensors. 

225 axis: Concatenation axis. 

226 **kwargs: Standard layer keyword arguments. 

227 

228 Returns: 

229 A tensor, the concatenation of the inputs alongside axis `axis`. 

230 """ 

231 return Concatenate(axis=axis, **kwargs)(inputs) 

232