Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/core/einsum_dense.py: 17%

109 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras-based einsum dense layer.""" 

16 

17 

18import re 

19 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src import activations 

23from keras.src import constraints 

24from keras.src import initializers 

25from keras.src import regularizers 

26from keras.src.engine.base_layer import Layer 

27 

28# isort: off 

29from tensorflow.python.util.tf_export import keras_export 

30 

31 

32@keras_export( 

33 "keras.layers.EinsumDense", "keras.layers.experimental.EinsumDense" 

34) 

35class EinsumDense(Layer): 

36 """A layer that uses `tf.einsum` as the backing computation. 

37 

38 This layer can perform einsum calculations of arbitrary dimensionality. 

39 

40 Args: 

41 equation: An equation describing the einsum to perform. This equation must 

42 be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or 

43 `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum 

44 axis expression sequence. 

45 output_shape: The expected shape of the output tensor (excluding the batch 

46 dimension and any dimensions represented by ellipses). You can specify 

47 None for any dimension that is unknown or can be inferred from the input 

48 shape. 

49 activation: Activation function to use. If you don't specify anything, no 

50 activation is applied (that is, a "linear" activation: `a(x) = x`). 

51 bias_axes: A string containing the output dimension(s) to apply a bias to. 

52 Each character in the `bias_axes` string should correspond to a 

53 character in the output portion of the `equation` string. 

54 kernel_initializer: Initializer for the `kernel` weights matrix. 

55 bias_initializer: Initializer for the bias vector. 

56 kernel_regularizer: Regularizer function applied to the `kernel` weights 

57 matrix. 

58 bias_regularizer: Regularizer function applied to the bias vector. 

59 activity_regularizer: Regularizer function applied to the output of the 

60 layer (its "activation"). 

61 kernel_constraint: Constraint function applied to the `kernel` weights 

62 matrix. 

63 bias_constraint: Constraint function applied to the bias vector. 

64 

65 Examples: 

66 

67 **Biased dense layer with einsums** 

68 

69 This example shows how to instantiate a standard Keras dense layer using 

70 einsum operations. This example is equivalent to 

71 `tf.keras.layers.Dense(64, use_bias=True)`. 

72 

73 >>> layer = tf.keras.layers.EinsumDense("ab,bc->ac", 

74 ... output_shape=64, 

75 ... bias_axes="c") 

76 >>> input_tensor = tf.keras.Input(shape=[32]) 

77 >>> output_tensor = layer(input_tensor) 

78 >>> output_tensor 

79 <... shape=(None, 64) dtype=...> 

80 

81 **Applying a dense layer to a sequence** 

82 

83 This example shows how to instantiate a layer that applies the same dense 

84 operation to every element in a sequence. Here, the `output_shape` has two 

85 values (since there are two non-batch dimensions in the output); the first 

86 dimension in the `output_shape` is `None`, because the sequence dimension 

87 `b` has an unknown shape. 

88 

89 >>> layer = tf.keras.layers.EinsumDense("abc,cd->abd", 

90 ... output_shape=(None, 64), 

91 ... bias_axes="d") 

92 >>> input_tensor = tf.keras.Input(shape=[32, 128]) 

93 >>> output_tensor = layer(input_tensor) 

94 >>> output_tensor 

95 <... shape=(None, 32, 64) dtype=...> 

96 

97 **Applying a dense layer to a sequence using ellipses** 

98 

99 This example shows how to instantiate a layer that applies the same dense 

100 operation to every element in a sequence, but uses the ellipsis notation 

101 instead of specifying the batch and sequence dimensions. 

102 

103 Because we are using ellipsis notation and have specified only one axis, the 

104 `output_shape` arg is a single value. When instantiated in this way, the 

105 layer can handle any number of sequence dimensions - including the case 

106 where no sequence dimension exists. 

107 

108 >>> layer = tf.keras.layers.EinsumDense("...x,xy->...y", 

109 ... output_shape=64, 

110 ... bias_axes="y") 

111 >>> input_tensor = tf.keras.Input(shape=[32, 128]) 

112 >>> output_tensor = layer(input_tensor) 

113 >>> output_tensor 

114 <... shape=(None, 32, 64) dtype=...> 

115 """ 

116 

117 def __init__( 

118 self, 

119 equation, 

120 output_shape, 

121 activation=None, 

122 bias_axes=None, 

123 kernel_initializer="glorot_uniform", 

124 bias_initializer="zeros", 

125 kernel_regularizer=None, 

126 bias_regularizer=None, 

127 activity_regularizer=None, 

128 kernel_constraint=None, 

129 bias_constraint=None, 

130 **kwargs, 

131 ): 

132 super().__init__(**kwargs) 

133 self.equation = equation 

134 if isinstance(output_shape, int): 

135 self.partial_output_shape = [output_shape] 

136 else: 

137 self.partial_output_shape = list(output_shape) 

138 self.bias_axes = bias_axes 

139 self.activation = activations.get(activation) 

140 self.kernel_initializer = initializers.get(kernel_initializer) 

141 self.bias_initializer = initializers.get(bias_initializer) 

142 self.kernel_regularizer = regularizers.get(kernel_regularizer) 

143 self.bias_regularizer = regularizers.get(bias_regularizer) 

144 self.kernel_constraint = constraints.get(kernel_constraint) 

145 self.bias_constraint = constraints.get(bias_constraint) 

146 

147 def build(self, input_shape): 

148 input_shape = tf.TensorShape(input_shape) 

149 shape_data = _analyze_einsum_string( 

150 self.equation, 

151 self.bias_axes, 

152 input_shape, 

153 self.partial_output_shape, 

154 ) 

155 kernel_shape, bias_shape, self.full_output_shape = shape_data 

156 self.kernel = self.add_weight( 

157 "kernel", 

158 shape=kernel_shape, 

159 initializer=self.kernel_initializer, 

160 regularizer=self.kernel_regularizer, 

161 constraint=self.kernel_constraint, 

162 dtype=self.dtype, 

163 trainable=True, 

164 ) 

165 

166 if bias_shape is not None: 

167 self.bias = self.add_weight( 

168 "bias", 

169 shape=bias_shape, 

170 initializer=self.bias_initializer, 

171 regularizer=self.bias_regularizer, 

172 constraint=self.bias_constraint, 

173 dtype=self.dtype, 

174 trainable=True, 

175 ) 

176 else: 

177 self.bias = None 

178 super().build(input_shape) 

179 

180 def compute_output_shape(self, _): 

181 return tf.TensorShape(self.full_output_shape) 

182 

183 def get_config(self): 

184 config = { 

185 "output_shape": self.partial_output_shape, 

186 "equation": self.equation, 

187 "activation": activations.serialize(self.activation), 

188 "bias_axes": self.bias_axes, 

189 "kernel_initializer": initializers.serialize( 

190 self.kernel_initializer 

191 ), 

192 "bias_initializer": initializers.serialize(self.bias_initializer), 

193 "kernel_regularizer": regularizers.serialize( 

194 self.kernel_regularizer 

195 ), 

196 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

197 "activity_regularizer": regularizers.serialize( 

198 self.activity_regularizer 

199 ), 

200 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

201 "bias_constraint": constraints.serialize(self.bias_constraint), 

202 } 

203 base_config = super().get_config() 

204 return dict(list(base_config.items()) + list(config.items())) 

205 

206 def call(self, inputs): 

207 ret = tf.einsum(self.equation, inputs, self.kernel) 

208 if self.bias is not None: 

209 ret += self.bias 

210 if self.activation is not None: 

211 ret = self.activation(ret) 

212 return ret 

213 

214 

215def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape): 

216 """Analyzes an einsum string to determine the required weight shape.""" 

217 

218 dot_replaced_string = re.sub(r"\.\.\.", "0", equation) 

219 

220 # This is the case where no ellipses are present in the string. 

221 split_string = re.match( 

222 "([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)", dot_replaced_string 

223 ) 

224 if split_string: 

225 return _analyze_split_string( 

226 split_string, bias_axes, input_shape, output_shape 

227 ) 

228 

229 # This is the case where ellipses are present on the left. 

230 split_string = re.match( 

231 "0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)", dot_replaced_string 

232 ) 

233 if split_string: 

234 return _analyze_split_string( 

235 split_string, bias_axes, input_shape, output_shape, left_elided=True 

236 ) 

237 

238 # This is the case where ellipses are present on the right. 

239 split_string = re.match( 

240 "([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0", dot_replaced_string 

241 ) 

242 if split_string: 

243 return _analyze_split_string( 

244 split_string, bias_axes, input_shape, output_shape 

245 ) 

246 

247 raise ValueError( 

248 f"Invalid einsum equation '{equation}'. Equations must be in the form " 

249 "[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." 

250 ) 

251 

252 

253def _analyze_split_string( 

254 split_string, bias_axes, input_shape, output_shape, left_elided=False 

255): 

256 """Analyze an pre-split einsum string to find the weight shape.""" 

257 input_spec = split_string.group(1) 

258 weight_spec = split_string.group(2) 

259 output_spec = split_string.group(3) 

260 elided = len(input_shape) - len(input_spec) 

261 

262 if isinstance(output_shape, int): 

263 output_shape = [output_shape] 

264 else: 

265 output_shape = list(output_shape) 

266 

267 output_shape.insert(0, input_shape[0]) 

268 

269 if elided > 0 and left_elided: 

270 for i in range(1, elided): 

271 # We already inserted the 0th input dimension at dim 0, so we need 

272 # to start at location 1 here. 

273 output_shape.insert(1, input_shape[i]) 

274 elif elided > 0 and not left_elided: 

275 for i in range(len(input_shape) - elided, len(input_shape)): 

276 output_shape.append(input_shape[i]) 

277 

278 if left_elided: 

279 # If we have beginning dimensions elided, we need to use negative 

280 # indexing to determine where in the input dimension our values are. 

281 input_dim_map = { 

282 dim: (i + elided) - len(input_shape) 

283 for i, dim in enumerate(input_spec) 

284 } 

285 # Because we've constructed the full output shape already, we don't need 

286 # to do negative indexing. 

287 output_dim_map = { 

288 dim: (i + elided) for i, dim in enumerate(output_spec) 

289 } 

290 else: 

291 input_dim_map = {dim: i for i, dim in enumerate(input_spec)} 

292 output_dim_map = {dim: i for i, dim in enumerate(output_spec)} 

293 

294 for dim in input_spec: 

295 input_shape_at_dim = input_shape[input_dim_map[dim]] 

296 if dim in output_dim_map: 

297 output_shape_at_dim = output_shape[output_dim_map[dim]] 

298 if ( 

299 output_shape_at_dim is not None 

300 and output_shape_at_dim != input_shape_at_dim 

301 ): 

302 raise ValueError( 

303 "Input shape and output shape do not match at shared " 

304 f"dimension '{dim}'. Input shape is {input_shape_at_dim}, " 

305 "and output shape " 

306 f"is {output_shape[output_dim_map[dim]]}." 

307 ) 

308 

309 for dim in output_spec: 

310 if dim not in input_spec and dim not in weight_spec: 

311 raise ValueError( 

312 f"Dimension '{dim}' was specified in the output " 

313 f"'{output_spec}' but has no corresponding dim in the input " 

314 f"spec '{input_spec}' or weight spec '{output_spec}'" 

315 ) 

316 

317 weight_shape = [] 

318 for dim in weight_spec: 

319 if dim in input_dim_map: 

320 weight_shape.append(input_shape[input_dim_map[dim]]) 

321 elif dim in output_dim_map: 

322 weight_shape.append(output_shape[output_dim_map[dim]]) 

323 else: 

324 raise ValueError( 

325 f"Weight dimension '{dim}' did not have a match in either " 

326 f"the input spec '{input_spec}' or the output " 

327 f"spec '{output_spec}'. For this layer, the weight must " 

328 "be fully specified." 

329 ) 

330 

331 if bias_axes is not None: 

332 num_left_elided = elided if left_elided else 0 

333 idx_map = { 

334 char: output_shape[i + num_left_elided] 

335 for i, char in enumerate(output_spec) 

336 } 

337 

338 for char in bias_axes: 

339 if char not in output_spec: 

340 raise ValueError( 

341 f"Bias dimension '{char}' was requested, but is not part " 

342 f"of the output spec '{output_spec}'" 

343 ) 

344 

345 first_bias_location = min( 

346 [output_spec.find(char) for char in bias_axes] 

347 ) 

348 bias_output_spec = output_spec[first_bias_location:] 

349 

350 bias_shape = [ 

351 idx_map[char] if char in bias_axes else 1 

352 for char in bias_output_spec 

353 ] 

354 

355 if not left_elided: 

356 for _ in range(elided): 

357 bias_shape.append(1) 

358 else: 

359 bias_shape = None 

360 

361 return weight_shape, bias_shape, output_shape 

362