Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/merging/base_merge.py: 13%

127 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Private base class for layers that can merge several inputs into one.""" 

16 

17import tensorflow.compat.v2 as tf 

18 

19from keras.src import backend 

20from keras.src.engine.base_layer import Layer 

21from keras.src.utils import tf_utils 

22 

23 

24class _Merge(Layer): 

25 """Generic merge layer for elementwise merge functions. 

26 

27 Used to implement `Sum`, `Average`, etc. 

28 """ 

29 

30 def __init__(self, **kwargs): 

31 """Initializes a Merge layer. 

32 

33 Args: 

34 **kwargs: standard layer keyword arguments. 

35 """ 

36 super().__init__(**kwargs) 

37 self.supports_masking = True 

38 

39 def _merge_function(self, inputs): 

40 raise NotImplementedError 

41 

42 def _compute_elemwise_op_output_shape(self, shape1, shape2): 

43 """Computes the shape of the resultant of an elementwise operation. 

44 

45 Args: 

46 shape1: tuple or None. Shape of the first tensor 

47 shape2: tuple or None. Shape of the second tensor 

48 

49 Returns: 

50 expected output shape when an element-wise operation is 

51 carried out on 2 tensors with shapes shape1 and shape2. 

52 tuple or None. 

53 

54 Raises: 

55 ValueError: if shape1 and shape2 are not compatible for 

56 element-wise operations. 

57 """ 

58 if None in [shape1, shape2]: 

59 return None 

60 elif len(shape1) < len(shape2): 

61 return self._compute_elemwise_op_output_shape(shape2, shape1) 

62 elif not shape2: 

63 return shape1 

64 output_shape = list(shape1[: -len(shape2)]) 

65 for i, j in zip(shape1[-len(shape2) :], shape2): 

66 if i is None or j is None: 

67 output_shape.append(None) 

68 elif i == 1: 

69 output_shape.append(j) 

70 elif j == 1: 

71 output_shape.append(i) 

72 else: 

73 if i != j: 

74 raise ValueError( 

75 "Inputs have incompatible shapes. " 

76 f"Received shapes {shape1} and {shape2}" 

77 ) 

78 output_shape.append(i) 

79 return tuple(output_shape) 

80 

81 @tf_utils.shape_type_conversion 

82 def build(self, input_shape): 

83 # Used purely for shape validation. 

84 if not isinstance(input_shape[0], tuple): 

85 raise ValueError( 

86 "A merge layer should be called on a list of inputs. " 

87 f"Received: input_shape={input_shape} (not a list of shapes)" 

88 ) 

89 if len(input_shape) < 1: 

90 raise ValueError( 

91 "A merge layer should be called " 

92 "on a list of at least 1 input. " 

93 f"Got {len(input_shape)} inputs. " 

94 f"Full input_shape received: {input_shape}" 

95 ) 

96 batch_sizes = {s[0] for s in input_shape if s} - {None} 

97 if len(batch_sizes) > 1: 

98 raise ValueError( 

99 "Cannot merge tensors with different batch sizes. " 

100 f"Got tensors with shapes {input_shape}" 

101 ) 

102 if input_shape[0] is None: 

103 output_shape = None 

104 else: 

105 output_shape = input_shape[0][1:] 

106 for i in range(1, len(input_shape)): 

107 if input_shape[i] is None: 

108 shape = None 

109 else: 

110 shape = input_shape[i][1:] 

111 output_shape = self._compute_elemwise_op_output_shape( 

112 output_shape, shape 

113 ) 

114 # If the inputs have different ranks, we have to reshape them 

115 # to make them broadcastable. 

116 if None not in input_shape and len(set(map(len, input_shape))) == 1: 

117 self._reshape_required = False 

118 else: 

119 self._reshape_required = True 

120 

121 def call(self, inputs): 

122 if not isinstance(inputs, (list, tuple)): 

123 raise ValueError( 

124 "A merge layer should be called on a list of inputs. " 

125 f"Received: inputs={inputs} (not a list of tensors)" 

126 ) 

127 if self._reshape_required: 

128 reshaped_inputs = [] 

129 input_ndims = list(map(backend.ndim, inputs)) 

130 if None not in input_ndims: 

131 # If ranks of all inputs are available, 

132 # we simply expand each of them at axis=1 

133 # until all of them have the same rank. 

134 max_ndim = max(input_ndims) 

135 for x in inputs: 

136 x_ndim = backend.ndim(x) 

137 for _ in range(max_ndim - x_ndim): 

138 x = tf.expand_dims(x, axis=1) 

139 reshaped_inputs.append(x) 

140 return self._merge_function(reshaped_inputs) 

141 else: 

142 # Transpose all inputs so that batch size is the last dimension. 

143 # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , 

144 # batch_size) 

145 transposed = False 

146 for x in inputs: 

147 x_ndim = backend.ndim(x) 

148 if x_ndim is None: 

149 x_shape = tf.shape(x) 

150 batch_size = x_shape[0] 

151 new_shape = backend.concatenate( 

152 [x_shape[1:], tf.expand_dims(batch_size, axis=-1)] 

153 ) 

154 x_transposed = tf.reshape( 

155 x, 

156 tf.stack( 

157 [batch_size, tf.reduce_prod(x_shape[1:])], 

158 axis=0, 

159 ), 

160 ) 

161 x_transposed = tf.transpose(x_transposed, perm=(1, 0)) 

162 x_transposed = tf.reshape(x_transposed, new_shape) 

163 reshaped_inputs.append(x_transposed) 

164 transposed = True 

165 elif x_ndim > 1: 

166 dims = list(range(1, x_ndim)) + [0] 

167 reshaped_inputs.append(tf.transpose(x, perm=dims)) 

168 transposed = True 

169 else: 

170 # We don't transpose inputs if they are 1D vectors or 

171 # scalars. 

172 reshaped_inputs.append(x) 

173 y = self._merge_function(reshaped_inputs) 

174 y_ndim = backend.ndim(y) 

175 if transposed: 

176 # If inputs have been transposed, we have to transpose the 

177 # output too. 

178 if y_ndim is None: 

179 y_shape = tf.shape(y) 

180 y_ndim = tf.shape(y_shape)[0] 

181 batch_size = y_shape[y_ndim - 1] 

182 new_shape = backend.concatenate( 

183 [ 

184 tf.expand_dims(batch_size, axis=-1), 

185 y_shape[: y_ndim - 1], 

186 ] 

187 ) 

188 y = tf.reshape(y, (-1, batch_size)) 

189 y = tf.transpose(y, perm=(1, 0)) 

190 y = tf.reshape(y, new_shape) 

191 elif y_ndim > 1: 

192 dims = [y_ndim - 1] + list(range(y_ndim - 1)) 

193 y = tf.transpose(y, perm=dims) 

194 return y 

195 else: 

196 return self._merge_function(inputs) 

197 

198 @tf_utils.shape_type_conversion 

199 def compute_output_shape(self, input_shape): 

200 if input_shape[0] is None: 

201 output_shape = None 

202 else: 

203 output_shape = input_shape[0][1:] 

204 for i in range(1, len(input_shape)): 

205 if input_shape[i] is None: 

206 shape = None 

207 else: 

208 shape = input_shape[i][1:] 

209 output_shape = self._compute_elemwise_op_output_shape( 

210 output_shape, shape 

211 ) 

212 batch_sizes = {s[0] for s in input_shape if s is not None} - {None} 

213 if len(batch_sizes) == 1: 

214 output_shape = (list(batch_sizes)[0],) + output_shape 

215 else: 

216 output_shape = (None,) + output_shape 

217 return output_shape 

218 

219 def compute_mask(self, inputs, mask=None): 

220 if mask is None: 

221 return None 

222 if not isinstance(mask, (tuple, list)): 

223 raise ValueError(f"`mask` should be a list. Received: mask={mask}") 

224 if not isinstance(inputs, (tuple, list)): 

225 raise ValueError( 

226 f"`inputs` should be a list. Received: inputs={inputs}" 

227 ) 

228 if len(mask) != len(inputs): 

229 raise ValueError( 

230 "The lists `inputs` and `mask` should have the same length. " 

231 f"Received: inputs={inputs} of length {len(inputs)}, and " 

232 f"mask={mask} of length {len(mask)}" 

233 ) 

234 if all(m is None for m in mask): 

235 return None 

236 masks = [tf.expand_dims(m, axis=0) for m in mask if m is not None] 

237 return backend.all( 

238 backend.concatenate(masks, axis=0), axis=0, keepdims=False 

239 ) 

240 

241 def get_config(self): 

242 return super().get_config() 

243