Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/base_conv_rnn.py: 12%

154 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base class for convolutional-recurrent layers.""" 

16 

17 

18import numpy as np 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src import backend 

22from keras.src.engine import base_layer 

23from keras.src.engine.input_spec import InputSpec 

24from keras.src.layers.rnn.base_rnn import RNN 

25from keras.src.utils import conv_utils 

26from keras.src.utils import generic_utils 

27from keras.src.utils import tf_utils 

28 

29 

30class ConvRNN(RNN): 

31 """N-Dimensional Base class for convolutional-recurrent layers. 

32 

33 Args: 

34 rank: Integer, rank of the convolution, e.g. "2" for 2D convolutions. 

35 cell: A RNN cell instance. A RNN cell is a class that has: - a 

36 `call(input_at_t, states_at_t)` method, returning `(output_at_t, 

37 states_at_t_plus_1)`. The call method of the cell can also take the 

38 optional argument `constants`, see section "Note on passing external 

39 constants" below. - a `state_size` attribute. This can be a single 

40 integer (single state) in which case it is the number of channels of the 

41 recurrent state (which should be the same as the number of channels of 

42 the cell output). This can also be a list/tuple of integers (one size 

43 per state). In this case, the first entry (`state_size[0]`) should be 

44 the same as the size of the cell output. 

45 return_sequences: Boolean. Whether to return the last output. in the 

46 output sequence, or the full sequence. 

47 return_state: Boolean. Whether to return the last state in addition to the 

48 output. 

49 go_backwards: Boolean (default False). If True, process the input sequence 

50 backwards and return the reversed sequence. 

51 stateful: Boolean (default False). If True, the last state for each sample 

52 at index i in a batch will be used as initial state for the sample of 

53 index i in the following batch. 

54 input_shape: Use this argument to specify the shape of the input when this 

55 layer is the first one in a model. 

56 Call arguments: 

57 inputs: A (2 + `rank`)D tensor. 

58 mask: Binary tensor of shape `(samples, timesteps)` indicating whether a 

59 given timestep should be masked. 

60 training: Python boolean indicating whether the layer should behave in 

61 training mode or in inference mode. This argument is passed to the cell 

62 when calling it. This is for use with cells that use dropout. 

63 initial_state: List of initial state tensors to be passed to the first 

64 call of the cell. 

65 constants: List of constant tensors to be passed to the cell at each 

66 timestep. 

67 Input shape: 

68 (3 + `rank`)D tensor with shape: `(samples, timesteps, channels, 

69 img_dimensions...)` 

70 if data_format='channels_first' or shape: `(samples, timesteps, 

71 img_dimensions..., channels)` if data_format='channels_last'. 

72 Output shape: 

73 - If `return_state`: a list of tensors. The first tensor is the output. 

74 The remaining tensors are the last states, 

75 each (2 + `rank`)D tensor with shape: `(samples, filters, 

76 new_img_dimensions...)` if data_format='channels_first' 

77 or shape: `(samples, new_img_dimensions..., filters)` if 

78 data_format='channels_last'. img_dimension values might have changed 

79 due to padding. 

80 - If `return_sequences`: (3 + `rank`)D tensor with shape: `(samples, 

81 timesteps, filters, new_img_dimensions...)` if 

82 data_format='channels_first' 

83 or shape: `(samples, timesteps, new_img_dimensions..., filters)` if 

84 data_format='channels_last'. 

85 - Else, (2 + `rank`)D tensor with shape: `(samples, filters, 

86 new_img_dimensions...)` if data_format='channels_first' 

87 or shape: `(samples, new_img_dimensions..., filters)` if 

88 data_format='channels_last'. 

89 Masking: This layer supports masking for input data with a variable number 

90 of timesteps. 

91 Note on using statefulness in RNNs: You can set RNN layers to be 'stateful', 

92 which means that the states computed for the samples in one batch will be 

93 reused as initial states for the samples in the next batch. This assumes a 

94 one-to-one mapping between samples in different successive batches. 

95 To enable statefulness: - Specify `stateful=True` in the layer 

96 constructor. 

97 - Specify a fixed batch size for your model, by passing 

98 - If sequential model: `batch_input_shape=(...)` to the first layer 

99 in your model. 

100 - If functional model with 1 or more Input layers: 

101 `batch_shape=(...)` to all the first layers in your model. This is 

102 the expected shape of your inputs *including the batch size*. It 

103 should be a tuple of integers, e.g. `(32, 10, 100, 100, 32)`. for 

104 rank 2 convolution Note that the image dimensions should be 

105 specified too. - Specify `shuffle=False` when calling fit(). To 

106 reset the states of your model, call `.reset_states()` on either a 

107 specific layer, or on your entire model. 

108 Note on specifying the initial state of RNNs: You can specify the initial 

109 state of RNN layers symbolically by calling them with the keyword argument 

110 `initial_state`. The value of `initial_state` should be a tensor or list 

111 of tensors representing the initial state of the RNN layer. You can 

112 specify the initial state of RNN layers numerically by calling 

113 `reset_states` with the keyword argument `states`. The value of `states` 

114 should be a numpy array or list of numpy arrays representing the initial 

115 state of the RNN layer. 

116 Note on passing external constants to RNNs: You can pass "external" 

117 constants to the cell using the `constants` keyword argument of 

118 `RNN.__call__` (as well as `RNN.call`) method. This requires that the 

119 `cell.call` method accepts the same keyword argument `constants`. Such 

120 constants can be used to condition the cell transformation on additional 

121 static inputs (not changing over time), a.k.a. an attention mechanism. 

122 """ 

123 

124 def __init__( 

125 self, 

126 rank, 

127 cell, 

128 return_sequences=False, 

129 return_state=False, 

130 go_backwards=False, 

131 stateful=False, 

132 unroll=False, 

133 **kwargs, 

134 ): 

135 if unroll: 

136 raise TypeError( 

137 "Unrolling is not possible with convolutional RNNs. " 

138 f"Received: unroll={unroll}" 

139 ) 

140 if isinstance(cell, (list, tuple)): 

141 # The StackedConvRNN3DCells isn't implemented yet. 

142 raise TypeError( 

143 "It is not possible at the moment to" 

144 "stack convolutional cells. Only pass a single cell " 

145 "instance as the `cell` argument. Received: " 

146 f"cell={cell}" 

147 ) 

148 super().__init__( 

149 cell, 

150 return_sequences, 

151 return_state, 

152 go_backwards, 

153 stateful, 

154 unroll, 

155 **kwargs, 

156 ) 

157 self.rank = rank 

158 self.input_spec = [InputSpec(ndim=rank + 3)] 

159 self.states = None 

160 self._num_constants = None 

161 

162 @tf_utils.shape_type_conversion 

163 def compute_output_shape(self, input_shape): 

164 if isinstance(input_shape, list): 

165 input_shape = input_shape[0] 

166 

167 cell = self.cell 

168 if cell.data_format == "channels_first": 

169 img_dims = input_shape[3:] 

170 elif cell.data_format == "channels_last": 

171 img_dims = input_shape[2:-1] 

172 

173 norm_img_dims = tuple( 

174 [ 

175 conv_utils.conv_output_length( 

176 img_dims[idx], 

177 cell.kernel_size[idx], 

178 padding=cell.padding, 

179 stride=cell.strides[idx], 

180 dilation=cell.dilation_rate[idx], 

181 ) 

182 for idx in range(len(img_dims)) 

183 ] 

184 ) 

185 

186 if cell.data_format == "channels_first": 

187 output_shape = input_shape[:2] + (cell.filters,) + norm_img_dims 

188 elif cell.data_format == "channels_last": 

189 output_shape = input_shape[:2] + norm_img_dims + (cell.filters,) 

190 

191 if not self.return_sequences: 

192 output_shape = output_shape[:1] + output_shape[2:] 

193 

194 if self.return_state: 

195 output_shape = [output_shape] 

196 if cell.data_format == "channels_first": 

197 output_shape += [ 

198 (input_shape[0], cell.filters) + norm_img_dims 

199 for _ in range(2) 

200 ] 

201 elif cell.data_format == "channels_last": 

202 output_shape += [ 

203 (input_shape[0],) + norm_img_dims + (cell.filters,) 

204 for _ in range(2) 

205 ] 

206 return output_shape 

207 

208 @tf_utils.shape_type_conversion 

209 def build(self, input_shape): 

210 # Note input_shape will be list of shapes of initial states and 

211 # constants if these are passed in __call__. 

212 if self._num_constants is not None: 

213 constants_shape = input_shape[-self._num_constants :] 

214 else: 

215 constants_shape = None 

216 

217 if isinstance(input_shape, list): 

218 input_shape = input_shape[0] 

219 

220 batch_size = input_shape[0] if self.stateful else None 

221 self.input_spec[0] = InputSpec( 

222 shape=(batch_size, None) + input_shape[2 : self.rank + 3] 

223 ) 

224 

225 # allow cell (if layer) to build before we set or validate state_spec 

226 if isinstance(self.cell, base_layer.Layer): 

227 step_input_shape = (input_shape[0],) + input_shape[2:] 

228 if constants_shape is not None: 

229 self.cell.build([step_input_shape] + constants_shape) 

230 else: 

231 self.cell.build(step_input_shape) 

232 

233 # set or validate state_spec 

234 if hasattr(self.cell.state_size, "__len__"): 

235 state_size = list(self.cell.state_size) 

236 else: 

237 state_size = [self.cell.state_size] 

238 

239 if self.state_spec is not None: 

240 # initial_state was passed in call, check compatibility 

241 if self.cell.data_format == "channels_first": 

242 ch_dim = 1 

243 elif self.cell.data_format == "channels_last": 

244 ch_dim = self.rank + 1 

245 if [spec.shape[ch_dim] for spec in self.state_spec] != state_size: 

246 raise ValueError( 

247 "An `initial_state` was passed that is not compatible with " 

248 "`cell.state_size`. Received state shapes " 

249 f"{[spec.shape for spec in self.state_spec]}. " 

250 f"However `cell.state_size` is {self.cell.state_size}" 

251 ) 

252 else: 

253 img_dims = tuple((None for _ in range(self.rank))) 

254 if self.cell.data_format == "channels_first": 

255 self.state_spec = [ 

256 InputSpec(shape=(None, dim) + img_dims) 

257 for dim in state_size 

258 ] 

259 elif self.cell.data_format == "channels_last": 

260 self.state_spec = [ 

261 InputSpec(shape=(None,) + img_dims + (dim,)) 

262 for dim in state_size 

263 ] 

264 if self.stateful: 

265 self.reset_states() 

266 self.built = True 

267 

268 def get_initial_state(self, inputs): 

269 # (samples, timesteps, img_dims..., filters) 

270 initial_state = backend.zeros_like(inputs) 

271 # (samples, img_dims..., filters) 

272 initial_state = backend.sum(initial_state, axis=1) 

273 shape = list(self.cell.kernel_shape) 

274 shape[-1] = self.cell.filters 

275 initial_state = self.cell.input_conv( 

276 initial_state, 

277 tf.zeros(tuple(shape), initial_state.dtype), 

278 padding=self.cell.padding, 

279 ) 

280 

281 if hasattr(self.cell.state_size, "__len__"): 

282 return [initial_state for _ in self.cell.state_size] 

283 else: 

284 return [initial_state] 

285 

286 def call( 

287 self, 

288 inputs, 

289 mask=None, 

290 training=None, 

291 initial_state=None, 

292 constants=None, 

293 ): 

294 # note that the .build() method of subclasses MUST define 

295 # self.input_spec and self.state_spec with complete input shapes. 

296 inputs, initial_state, constants = self._process_inputs( 

297 inputs, initial_state, constants 

298 ) 

299 

300 if isinstance(mask, list): 

301 mask = mask[0] 

302 timesteps = backend.int_shape(inputs)[1] 

303 

304 kwargs = {} 

305 if generic_utils.has_arg(self.cell.call, "training"): 

306 kwargs["training"] = training 

307 

308 if constants: 

309 if not generic_utils.has_arg(self.cell.call, "constants"): 

310 raise ValueError( 

311 f"RNN cell {self.cell} does not support constants. " 

312 f"Received: constants={constants}" 

313 ) 

314 

315 def step(inputs, states): 

316 constants = states[-self._num_constants :] 

317 states = states[: -self._num_constants] 

318 return self.cell.call( 

319 inputs, states, constants=constants, **kwargs 

320 ) 

321 

322 else: 

323 

324 def step(inputs, states): 

325 return self.cell.call(inputs, states, **kwargs) 

326 

327 last_output, outputs, states = backend.rnn( 

328 step, 

329 inputs, 

330 initial_state, 

331 constants=constants, 

332 go_backwards=self.go_backwards, 

333 mask=mask, 

334 input_length=timesteps, 

335 return_all_outputs=self.return_sequences, 

336 ) 

337 if self.stateful: 

338 updates = [ 

339 backend.update(self_state, state) 

340 for self_state, state in zip(self.states, states) 

341 ] 

342 self.add_update(updates) 

343 

344 if self.return_sequences: 

345 output = outputs 

346 else: 

347 output = last_output 

348 

349 if self.return_state: 

350 if not isinstance(states, (list, tuple)): 

351 states = [states] 

352 else: 

353 states = list(states) 

354 return [output] + states 

355 return output 

356 

357 def reset_states(self, states=None): 

358 if not self.stateful: 

359 raise AttributeError("Layer must be stateful.") 

360 input_shape = self.input_spec[0].shape 

361 state_shape = self.compute_output_shape(input_shape) 

362 if self.return_state: 

363 state_shape = state_shape[0] 

364 if self.return_sequences: 

365 state_shape = state_shape[:1].concatenate(state_shape[2:]) 

366 if None in state_shape: 

367 raise ValueError( 

368 "If a RNN is stateful, it needs to know " 

369 "its batch size. Specify the batch size " 

370 "of your input tensors: \n" 

371 "- If using a Sequential model, " 

372 "specify the batch size by passing " 

373 "a `batch_input_shape` " 

374 "argument to your first layer.\n" 

375 "- If using the functional API, specify " 

376 "the time dimension by passing a " 

377 "`batch_shape` argument to your Input layer.\n" 

378 "The same thing goes for the number of rows and " 

379 "columns." 

380 ) 

381 

382 # helper function 

383 def get_tuple_shape(nb_channels): 

384 result = list(state_shape) 

385 if self.cell.data_format == "channels_first": 

386 result[1] = nb_channels 

387 elif self.cell.data_format == "channels_last": 

388 result[self.rank + 1] = nb_channels 

389 else: 

390 raise KeyError( 

391 "Cell data format must be one of " 

392 '{"channels_first", "channels_last"}. Received: ' 

393 f"cell.data_format={self.cell.data_format}" 

394 ) 

395 return tuple(result) 

396 

397 # initialize state if None 

398 if self.states[0] is None: 

399 if hasattr(self.cell.state_size, "__len__"): 

400 self.states = [ 

401 backend.zeros(get_tuple_shape(dim)) 

402 for dim in self.cell.state_size 

403 ] 

404 else: 

405 self.states = [ 

406 backend.zeros(get_tuple_shape(self.cell.state_size)) 

407 ] 

408 elif states is None: 

409 if hasattr(self.cell.state_size, "__len__"): 

410 for state, dim in zip(self.states, self.cell.state_size): 

411 backend.set_value(state, np.zeros(get_tuple_shape(dim))) 

412 else: 

413 backend.set_value( 

414 self.states[0], 

415 np.zeros(get_tuple_shape(self.cell.state_size)), 

416 ) 

417 else: 

418 if not isinstance(states, (list, tuple)): 

419 states = [states] 

420 if len(states) != len(self.states): 

421 raise ValueError( 

422 f"Layer {self.name} expects {len(self.states)} states, " 

423 f"but it received {len(states)} state values. " 

424 f"States received: {states}" 

425 ) 

426 for index, (value, state) in enumerate(zip(states, self.states)): 

427 if hasattr(self.cell.state_size, "__len__"): 

428 dim = self.cell.state_size[index] 

429 else: 

430 dim = self.cell.state_size 

431 if value.shape != get_tuple_shape(dim): 

432 raise ValueError( 

433 "State {index} is incompatible with layer " 

434 f"{self.name}: expected shape={get_tuple_shape(dim)}, " 

435 f"found shape={value.shape}" 

436 ) 

437 backend.set_value(state, value) 

438