Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/layers/convolutional_recurrent.py: 24%

339 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=protected-access 

16# pylint: disable=g-classes-have-attributes 

17"""Convolutional-recurrent layers.""" 

18 

19import numpy as np 

20 

21from tensorflow.python.keras import activations 

22from tensorflow.python.keras import backend 

23from tensorflow.python.keras import constraints 

24from tensorflow.python.keras import initializers 

25from tensorflow.python.keras import regularizers 

26from tensorflow.python.keras.engine.base_layer import Layer 

27from tensorflow.python.keras.engine.input_spec import InputSpec 

28from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin 

29from tensorflow.python.keras.layers.recurrent import RNN 

30from tensorflow.python.keras.utils import conv_utils 

31from tensorflow.python.keras.utils import generic_utils 

32from tensorflow.python.keras.utils import tf_utils 

33from tensorflow.python.ops import array_ops 

34from tensorflow.python.util.tf_export import keras_export 

35 

36 

37class ConvRNN2D(RNN): 

38 """Base class for convolutional-recurrent layers. 

39 

40 Args: 

41 cell: A RNN cell instance. A RNN cell is a class that has: 

42 - a `call(input_at_t, states_at_t)` method, returning 

43 `(output_at_t, states_at_t_plus_1)`. The call method of the 

44 cell can also take the optional argument `constants`, see 

45 section "Note on passing external constants" below. 

46 - a `state_size` attribute. This can be a single integer 

47 (single state) in which case it is 

48 the number of channels of the recurrent state 

49 (which should be the same as the number of channels of the cell 

50 output). This can also be a list/tuple of integers 

51 (one size per state). In this case, the first entry 

52 (`state_size[0]`) should be the same as 

53 the size of the cell output. 

54 return_sequences: Boolean. Whether to return the last output. 

55 in the output sequence, or the full sequence. 

56 return_state: Boolean. Whether to return the last state 

57 in addition to the output. 

58 go_backwards: Boolean (default False). 

59 If True, process the input sequence backwards and return the 

60 reversed sequence. 

61 stateful: Boolean (default False). If True, the last state 

62 for each sample at index i in a batch will be used as initial 

63 state for the sample of index i in the following batch. 

64 input_shape: Use this argument to specify the shape of the 

65 input when this layer is the first one in a model. 

66 

67 Call arguments: 

68 inputs: A 5D tensor. 

69 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 

70 a given timestep should be masked. 

71 training: Python boolean indicating whether the layer should behave in 

72 training mode or in inference mode. This argument is passed to the cell 

73 when calling it. This is for use with cells that use dropout. 

74 initial_state: List of initial state tensors to be passed to the first 

75 call of the cell. 

76 constants: List of constant tensors to be passed to the cell at each 

77 timestep. 

78 

79 Input shape: 

80 5D tensor with shape: 

81 `(samples, timesteps, channels, rows, cols)` 

82 if data_format='channels_first' or 5D tensor with shape: 

83 `(samples, timesteps, rows, cols, channels)` 

84 if data_format='channels_last'. 

85 

86 Output shape: 

87 - If `return_state`: a list of tensors. The first tensor is 

88 the output. The remaining tensors are the last states, 

89 each 4D tensor with shape: 

90 `(samples, filters, new_rows, new_cols)` 

91 if data_format='channels_first' 

92 or 4D tensor with shape: 

93 `(samples, new_rows, new_cols, filters)` 

94 if data_format='channels_last'. 

95 `rows` and `cols` values might have changed due to padding. 

96 - If `return_sequences`: 5D tensor with shape: 

97 `(samples, timesteps, filters, new_rows, new_cols)` 

98 if data_format='channels_first' 

99 or 5D tensor with shape: 

100 `(samples, timesteps, new_rows, new_cols, filters)` 

101 if data_format='channels_last'. 

102 - Else, 4D tensor with shape: 

103 `(samples, filters, new_rows, new_cols)` 

104 if data_format='channels_first' 

105 or 4D tensor with shape: 

106 `(samples, new_rows, new_cols, filters)` 

107 if data_format='channels_last'. 

108 

109 Masking: 

110 This layer supports masking for input data with a variable number 

111 of timesteps. 

112 

113 Note on using statefulness in RNNs: 

114 You can set RNN layers to be 'stateful', which means that the states 

115 computed for the samples in one batch will be reused as initial states 

116 for the samples in the next batch. This assumes a one-to-one mapping 

117 between samples in different successive batches. 

118 To enable statefulness: 

119 - Specify `stateful=True` in the layer constructor. 

120 - Specify a fixed batch size for your model, by passing 

121 - If sequential model: 

122 `batch_input_shape=(...)` to the first layer in your model. 

123 - If functional model with 1 or more Input layers: 

124 `batch_shape=(...)` to all the first layers in your model. 

125 This is the expected shape of your inputs 

126 *including the batch size*. 

127 It should be a tuple of integers, 

128 e.g. `(32, 10, 100, 100, 32)`. 

129 Note that the number of rows and columns should be specified 

130 too. 

131 - Specify `shuffle=False` when calling fit(). 

132 To reset the states of your model, call `.reset_states()` on either 

133 a specific layer, or on your entire model. 

134 

135 Note on specifying the initial state of RNNs: 

136 You can specify the initial state of RNN layers symbolically by 

137 calling them with the keyword argument `initial_state`. The value of 

138 `initial_state` should be a tensor or list of tensors representing 

139 the initial state of the RNN layer. 

140 You can specify the initial state of RNN layers numerically by 

141 calling `reset_states` with the keyword argument `states`. The value of 

142 `states` should be a numpy array or list of numpy arrays representing 

143 the initial state of the RNN layer. 

144 

145 Note on passing external constants to RNNs: 

146 You can pass "external" constants to the cell using the `constants` 

147 keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This 

148 requires that the `cell.call` method accepts the same keyword argument 

149 `constants`. Such constants can be used to condition the cell 

150 transformation on additional static inputs (not changing over time), 

151 a.k.a. an attention mechanism. 

152 """ 

153 

154 def __init__(self, 

155 cell, 

156 return_sequences=False, 

157 return_state=False, 

158 go_backwards=False, 

159 stateful=False, 

160 unroll=False, 

161 **kwargs): 

162 if unroll: 

163 raise TypeError('Unrolling isn\'t possible with ' 

164 'convolutional RNNs.') 

165 if isinstance(cell, (list, tuple)): 

166 # The StackedConvRNN2DCells isn't implemented yet. 

167 raise TypeError('It is not possible at the moment to' 

168 'stack convolutional cells.') 

169 super(ConvRNN2D, self).__init__(cell, 

170 return_sequences, 

171 return_state, 

172 go_backwards, 

173 stateful, 

174 unroll, 

175 **kwargs) 

176 self.input_spec = [InputSpec(ndim=5)] 

177 self.states = None 

178 self._num_constants = None 

179 

180 @tf_utils.shape_type_conversion 

181 def compute_output_shape(self, input_shape): 

182 if isinstance(input_shape, list): 

183 input_shape = input_shape[0] 

184 

185 cell = self.cell 

186 if cell.data_format == 'channels_first': 

187 rows = input_shape[3] 

188 cols = input_shape[4] 

189 elif cell.data_format == 'channels_last': 

190 rows = input_shape[2] 

191 cols = input_shape[3] 

192 rows = conv_utils.conv_output_length(rows, 

193 cell.kernel_size[0], 

194 padding=cell.padding, 

195 stride=cell.strides[0], 

196 dilation=cell.dilation_rate[0]) 

197 cols = conv_utils.conv_output_length(cols, 

198 cell.kernel_size[1], 

199 padding=cell.padding, 

200 stride=cell.strides[1], 

201 dilation=cell.dilation_rate[1]) 

202 

203 if cell.data_format == 'channels_first': 

204 output_shape = input_shape[:2] + (cell.filters, rows, cols) 

205 elif cell.data_format == 'channels_last': 

206 output_shape = input_shape[:2] + (rows, cols, cell.filters) 

207 

208 if not self.return_sequences: 

209 output_shape = output_shape[:1] + output_shape[2:] 

210 

211 if self.return_state: 

212 output_shape = [output_shape] 

213 if cell.data_format == 'channels_first': 

214 output_shape += [(input_shape[0], cell.filters, rows, cols) 

215 for _ in range(2)] 

216 elif cell.data_format == 'channels_last': 

217 output_shape += [(input_shape[0], rows, cols, cell.filters) 

218 for _ in range(2)] 

219 return output_shape 

220 

221 @tf_utils.shape_type_conversion 

222 def build(self, input_shape): 

223 # Note input_shape will be list of shapes of initial states and 

224 # constants if these are passed in __call__. 

225 if self._num_constants is not None: 

226 constants_shape = input_shape[-self._num_constants:] # pylint: disable=E1130 

227 else: 

228 constants_shape = None 

229 

230 if isinstance(input_shape, list): 

231 input_shape = input_shape[0] 

232 

233 batch_size = input_shape[0] if self.stateful else None 

234 self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:5]) 

235 

236 # allow cell (if layer) to build before we set or validate state_spec 

237 if isinstance(self.cell, Layer): 

238 step_input_shape = (input_shape[0],) + input_shape[2:] 

239 if constants_shape is not None: 

240 self.cell.build([step_input_shape] + constants_shape) 

241 else: 

242 self.cell.build(step_input_shape) 

243 

244 # set or validate state_spec 

245 if hasattr(self.cell.state_size, '__len__'): 

246 state_size = list(self.cell.state_size) 

247 else: 

248 state_size = [self.cell.state_size] 

249 

250 if self.state_spec is not None: 

251 # initial_state was passed in call, check compatibility 

252 if self.cell.data_format == 'channels_first': 

253 ch_dim = 1 

254 elif self.cell.data_format == 'channels_last': 

255 ch_dim = 3 

256 if [spec.shape[ch_dim] for spec in self.state_spec] != state_size: 

257 raise ValueError( 

258 'An initial_state was passed that is not compatible with ' 

259 '`cell.state_size`. Received `state_spec`={}; ' 

260 'However `cell.state_size` is ' 

261 '{}'.format([spec.shape for spec in self.state_spec], 

262 self.cell.state_size)) 

263 else: 

264 if self.cell.data_format == 'channels_first': 

265 self.state_spec = [InputSpec(shape=(None, dim, None, None)) 

266 for dim in state_size] 

267 elif self.cell.data_format == 'channels_last': 

268 self.state_spec = [InputSpec(shape=(None, None, None, dim)) 

269 for dim in state_size] 

270 if self.stateful: 

271 self.reset_states() 

272 self.built = True 

273 

274 def get_initial_state(self, inputs): 

275 # (samples, timesteps, rows, cols, filters) 

276 initial_state = backend.zeros_like(inputs) 

277 # (samples, rows, cols, filters) 

278 initial_state = backend.sum(initial_state, axis=1) 

279 shape = list(self.cell.kernel_shape) 

280 shape[-1] = self.cell.filters 

281 initial_state = self.cell.input_conv(initial_state, 

282 array_ops.zeros(tuple(shape), 

283 initial_state.dtype), 

284 padding=self.cell.padding) 

285 

286 if hasattr(self.cell.state_size, '__len__'): 

287 return [initial_state for _ in self.cell.state_size] 

288 else: 

289 return [initial_state] 

290 

291 def call(self, 

292 inputs, 

293 mask=None, 

294 training=None, 

295 initial_state=None, 

296 constants=None): 

297 # note that the .build() method of subclasses MUST define 

298 # self.input_spec and self.state_spec with complete input shapes. 

299 inputs, initial_state, constants = self._process_inputs( 

300 inputs, initial_state, constants) 

301 

302 if isinstance(mask, list): 

303 mask = mask[0] 

304 timesteps = backend.int_shape(inputs)[1] 

305 

306 kwargs = {} 

307 if generic_utils.has_arg(self.cell.call, 'training'): 

308 kwargs['training'] = training 

309 

310 if constants: 

311 if not generic_utils.has_arg(self.cell.call, 'constants'): 

312 raise ValueError('RNN cell does not support constants') 

313 

314 def step(inputs, states): 

315 constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type 

316 states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type 

317 return self.cell.call(inputs, states, constants=constants, **kwargs) 

318 else: 

319 def step(inputs, states): 

320 return self.cell.call(inputs, states, **kwargs) 

321 

322 last_output, outputs, states = backend.rnn(step, 

323 inputs, 

324 initial_state, 

325 constants=constants, 

326 go_backwards=self.go_backwards, 

327 mask=mask, 

328 input_length=timesteps) 

329 if self.stateful: 

330 updates = [ 

331 backend.update(self_state, state) 

332 for self_state, state in zip(self.states, states) 

333 ] 

334 self.add_update(updates) 

335 

336 if self.return_sequences: 

337 output = outputs 

338 else: 

339 output = last_output 

340 

341 if self.return_state: 

342 if not isinstance(states, (list, tuple)): 

343 states = [states] 

344 else: 

345 states = list(states) 

346 return [output] + states 

347 else: 

348 return output 

349 

350 def reset_states(self, states=None): 

351 if not self.stateful: 

352 raise AttributeError('Layer must be stateful.') 

353 input_shape = self.input_spec[0].shape 

354 state_shape = self.compute_output_shape(input_shape) 

355 if self.return_state: 

356 state_shape = state_shape[0] 

357 if self.return_sequences: 

358 state_shape = state_shape[:1].concatenate(state_shape[2:]) 

359 if None in state_shape: 

360 raise ValueError('If a RNN is stateful, it needs to know ' 

361 'its batch size. Specify the batch size ' 

362 'of your input tensors: \n' 

363 '- If using a Sequential model, ' 

364 'specify the batch size by passing ' 

365 'a `batch_input_shape` ' 

366 'argument to your first layer.\n' 

367 '- If using the functional API, specify ' 

368 'the time dimension by passing a ' 

369 '`batch_shape` argument to your Input layer.\n' 

370 'The same thing goes for the number of rows and ' 

371 'columns.') 

372 

373 # helper function 

374 def get_tuple_shape(nb_channels): 

375 result = list(state_shape) 

376 if self.cell.data_format == 'channels_first': 

377 result[1] = nb_channels 

378 elif self.cell.data_format == 'channels_last': 

379 result[3] = nb_channels 

380 else: 

381 raise KeyError 

382 return tuple(result) 

383 

384 # initialize state if None 

385 if self.states[0] is None: 

386 if hasattr(self.cell.state_size, '__len__'): 

387 self.states = [backend.zeros(get_tuple_shape(dim)) 

388 for dim in self.cell.state_size] 

389 else: 

390 self.states = [backend.zeros(get_tuple_shape(self.cell.state_size))] 

391 elif states is None: 

392 if hasattr(self.cell.state_size, '__len__'): 

393 for state, dim in zip(self.states, self.cell.state_size): 

394 backend.set_value(state, np.zeros(get_tuple_shape(dim))) 

395 else: 

396 backend.set_value(self.states[0], 

397 np.zeros(get_tuple_shape(self.cell.state_size))) 

398 else: 

399 if not isinstance(states, (list, tuple)): 

400 states = [states] 

401 if len(states) != len(self.states): 

402 raise ValueError('Layer ' + self.name + ' expects ' + 

403 str(len(self.states)) + ' states, ' + 

404 'but it received ' + str(len(states)) + 

405 ' state values. Input received: ' + str(states)) 

406 for index, (value, state) in enumerate(zip(states, self.states)): 

407 if hasattr(self.cell.state_size, '__len__'): 

408 dim = self.cell.state_size[index] 

409 else: 

410 dim = self.cell.state_size 

411 if value.shape != get_tuple_shape(dim): 

412 raise ValueError('State ' + str(index) + 

413 ' is incompatible with layer ' + 

414 self.name + ': expected shape=' + 

415 str(get_tuple_shape(dim)) + 

416 ', found shape=' + str(value.shape)) 

417 # TODO(anjalisridhar): consider batch calls to `set_value`. 

418 backend.set_value(state, value) 

419 

420 

421class ConvLSTM2DCell(DropoutRNNCellMixin, Layer): 

422 """Cell class for the ConvLSTM2D layer. 

423 

424 Args: 

425 filters: Integer, the dimensionality of the output space 

426 (i.e. the number of output filters in the convolution). 

427 kernel_size: An integer or tuple/list of n integers, specifying the 

428 dimensions of the convolution window. 

429 strides: An integer or tuple/list of n integers, 

430 specifying the strides of the convolution. 

431 Specifying any stride value != 1 is incompatible with specifying 

432 any `dilation_rate` value != 1. 

433 padding: One of `"valid"` or `"same"` (case-insensitive). 

434 `"valid"` means no padding. `"same"` results in padding evenly to  

435 the left/right or up/down of the input such that output has the same  

436 height/width dimension as the input. 

437 data_format: A string, 

438 one of `channels_last` (default) or `channels_first`. 

439 It defaults to the `image_data_format` value found in your 

440 Keras config file at `~/.keras/keras.json`. 

441 If you never set it, then it will be "channels_last". 

442 dilation_rate: An integer or tuple/list of n integers, specifying 

443 the dilation rate to use for dilated convolution. 

444 Currently, specifying any `dilation_rate` value != 1 is 

445 incompatible with specifying any `strides` value != 1. 

446 activation: Activation function to use. 

447 If you don't specify anything, no activation is applied 

448 (ie. "linear" activation: `a(x) = x`). 

449 recurrent_activation: Activation function to use 

450 for the recurrent step. 

451 use_bias: Boolean, whether the layer uses a bias vector. 

452 kernel_initializer: Initializer for the `kernel` weights matrix, 

453 used for the linear transformation of the inputs. 

454 recurrent_initializer: Initializer for the `recurrent_kernel` 

455 weights matrix, 

456 used for the linear transformation of the recurrent state. 

457 bias_initializer: Initializer for the bias vector. 

458 unit_forget_bias: Boolean. 

459 If True, add 1 to the bias of the forget gate at initialization. 

460 Use in combination with `bias_initializer="zeros"`. 

461 This is recommended in [Jozefowicz et al., 2015]( 

462 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 

463 kernel_regularizer: Regularizer function applied to 

464 the `kernel` weights matrix. 

465 recurrent_regularizer: Regularizer function applied to 

466 the `recurrent_kernel` weights matrix. 

467 bias_regularizer: Regularizer function applied to the bias vector. 

468 kernel_constraint: Constraint function applied to 

469 the `kernel` weights matrix. 

470 recurrent_constraint: Constraint function applied to 

471 the `recurrent_kernel` weights matrix. 

472 bias_constraint: Constraint function applied to the bias vector. 

473 dropout: Float between 0 and 1. 

474 Fraction of the units to drop for 

475 the linear transformation of the inputs. 

476 recurrent_dropout: Float between 0 and 1. 

477 Fraction of the units to drop for 

478 the linear transformation of the recurrent state. 

479 

480 Call arguments: 

481 inputs: A 4D tensor. 

482 states: List of state tensors corresponding to the previous timestep. 

483 training: Python boolean indicating whether the layer should behave in 

484 training mode or in inference mode. Only relevant when `dropout` or 

485 `recurrent_dropout` is used. 

486 """ 

487 

488 def __init__(self, 

489 filters, 

490 kernel_size, 

491 strides=(1, 1), 

492 padding='valid', 

493 data_format=None, 

494 dilation_rate=(1, 1), 

495 activation='tanh', 

496 recurrent_activation='hard_sigmoid', 

497 use_bias=True, 

498 kernel_initializer='glorot_uniform', 

499 recurrent_initializer='orthogonal', 

500 bias_initializer='zeros', 

501 unit_forget_bias=True, 

502 kernel_regularizer=None, 

503 recurrent_regularizer=None, 

504 bias_regularizer=None, 

505 kernel_constraint=None, 

506 recurrent_constraint=None, 

507 bias_constraint=None, 

508 dropout=0., 

509 recurrent_dropout=0., 

510 **kwargs): 

511 super(ConvLSTM2DCell, self).__init__(**kwargs) 

512 self.filters = filters 

513 self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') 

514 self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') 

515 self.padding = conv_utils.normalize_padding(padding) 

516 self.data_format = conv_utils.normalize_data_format(data_format) 

517 self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 

518 'dilation_rate') 

519 self.activation = activations.get(activation) 

520 self.recurrent_activation = activations.get(recurrent_activation) 

521 self.use_bias = use_bias 

522 

523 self.kernel_initializer = initializers.get(kernel_initializer) 

524 self.recurrent_initializer = initializers.get(recurrent_initializer) 

525 self.bias_initializer = initializers.get(bias_initializer) 

526 self.unit_forget_bias = unit_forget_bias 

527 

528 self.kernel_regularizer = regularizers.get(kernel_regularizer) 

529 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 

530 self.bias_regularizer = regularizers.get(bias_regularizer) 

531 

532 self.kernel_constraint = constraints.get(kernel_constraint) 

533 self.recurrent_constraint = constraints.get(recurrent_constraint) 

534 self.bias_constraint = constraints.get(bias_constraint) 

535 

536 self.dropout = min(1., max(0., dropout)) 

537 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 

538 self.state_size = (self.filters, self.filters) 

539 

540 def build(self, input_shape): 

541 

542 if self.data_format == 'channels_first': 

543 channel_axis = 1 

544 else: 

545 channel_axis = -1 

546 if input_shape[channel_axis] is None: 

547 raise ValueError('The channel dimension of the inputs ' 

548 'should be defined. Found `None`.') 

549 input_dim = input_shape[channel_axis] 

550 kernel_shape = self.kernel_size + (input_dim, self.filters * 4) 

551 self.kernel_shape = kernel_shape 

552 recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4) 

553 

554 self.kernel = self.add_weight(shape=kernel_shape, 

555 initializer=self.kernel_initializer, 

556 name='kernel', 

557 regularizer=self.kernel_regularizer, 

558 constraint=self.kernel_constraint) 

559 self.recurrent_kernel = self.add_weight( 

560 shape=recurrent_kernel_shape, 

561 initializer=self.recurrent_initializer, 

562 name='recurrent_kernel', 

563 regularizer=self.recurrent_regularizer, 

564 constraint=self.recurrent_constraint) 

565 

566 if self.use_bias: 

567 if self.unit_forget_bias: 

568 

569 def bias_initializer(_, *args, **kwargs): 

570 return backend.concatenate([ 

571 self.bias_initializer((self.filters,), *args, **kwargs), 

572 initializers.get('ones')((self.filters,), *args, **kwargs), 

573 self.bias_initializer((self.filters * 2,), *args, **kwargs), 

574 ]) 

575 else: 

576 bias_initializer = self.bias_initializer 

577 self.bias = self.add_weight( 

578 shape=(self.filters * 4,), 

579 name='bias', 

580 initializer=bias_initializer, 

581 regularizer=self.bias_regularizer, 

582 constraint=self.bias_constraint) 

583 else: 

584 self.bias = None 

585 self.built = True 

586 

587 def call(self, inputs, states, training=None): 

588 h_tm1 = states[0] # previous memory state 

589 c_tm1 = states[1] # previous carry state 

590 

591 # dropout matrices for input units 

592 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) 

593 # dropout matrices for recurrent units 

594 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 

595 h_tm1, training, count=4) 

596 

597 if 0 < self.dropout < 1.: 

598 inputs_i = inputs * dp_mask[0] 

599 inputs_f = inputs * dp_mask[1] 

600 inputs_c = inputs * dp_mask[2] 

601 inputs_o = inputs * dp_mask[3] 

602 else: 

603 inputs_i = inputs 

604 inputs_f = inputs 

605 inputs_c = inputs 

606 inputs_o = inputs 

607 

608 if 0 < self.recurrent_dropout < 1.: 

609 h_tm1_i = h_tm1 * rec_dp_mask[0] 

610 h_tm1_f = h_tm1 * rec_dp_mask[1] 

611 h_tm1_c = h_tm1 * rec_dp_mask[2] 

612 h_tm1_o = h_tm1 * rec_dp_mask[3] 

613 else: 

614 h_tm1_i = h_tm1 

615 h_tm1_f = h_tm1 

616 h_tm1_c = h_tm1 

617 h_tm1_o = h_tm1 

618 

619 (kernel_i, kernel_f, 

620 kernel_c, kernel_o) = array_ops.split(self.kernel, 4, axis=3) 

621 (recurrent_kernel_i, 

622 recurrent_kernel_f, 

623 recurrent_kernel_c, 

624 recurrent_kernel_o) = array_ops.split(self.recurrent_kernel, 4, axis=3) 

625 

626 if self.use_bias: 

627 bias_i, bias_f, bias_c, bias_o = array_ops.split(self.bias, 4) 

628 else: 

629 bias_i, bias_f, bias_c, bias_o = None, None, None, None 

630 

631 x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding) 

632 x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding) 

633 x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding) 

634 x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding) 

635 h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i) 

636 h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f) 

637 h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c) 

638 h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o) 

639 

640 i = self.recurrent_activation(x_i + h_i) 

641 f = self.recurrent_activation(x_f + h_f) 

642 c = f * c_tm1 + i * self.activation(x_c + h_c) 

643 o = self.recurrent_activation(x_o + h_o) 

644 h = o * self.activation(c) 

645 return h, [h, c] 

646 

647 def input_conv(self, x, w, b=None, padding='valid'): 

648 conv_out = backend.conv2d(x, w, strides=self.strides, 

649 padding=padding, 

650 data_format=self.data_format, 

651 dilation_rate=self.dilation_rate) 

652 if b is not None: 

653 conv_out = backend.bias_add(conv_out, b, 

654 data_format=self.data_format) 

655 return conv_out 

656 

657 def recurrent_conv(self, x, w): 

658 conv_out = backend.conv2d(x, w, strides=(1, 1), 

659 padding='same', 

660 data_format=self.data_format) 

661 return conv_out 

662 

663 def get_config(self): 

664 config = {'filters': self.filters, 

665 'kernel_size': self.kernel_size, 

666 'strides': self.strides, 

667 'padding': self.padding, 

668 'data_format': self.data_format, 

669 'dilation_rate': self.dilation_rate, 

670 'activation': activations.serialize(self.activation), 

671 'recurrent_activation': activations.serialize( 

672 self.recurrent_activation), 

673 'use_bias': self.use_bias, 

674 'kernel_initializer': initializers.serialize( 

675 self.kernel_initializer), 

676 'recurrent_initializer': initializers.serialize( 

677 self.recurrent_initializer), 

678 'bias_initializer': initializers.serialize(self.bias_initializer), 

679 'unit_forget_bias': self.unit_forget_bias, 

680 'kernel_regularizer': regularizers.serialize( 

681 self.kernel_regularizer), 

682 'recurrent_regularizer': regularizers.serialize( 

683 self.recurrent_regularizer), 

684 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 

685 'kernel_constraint': constraints.serialize( 

686 self.kernel_constraint), 

687 'recurrent_constraint': constraints.serialize( 

688 self.recurrent_constraint), 

689 'bias_constraint': constraints.serialize(self.bias_constraint), 

690 'dropout': self.dropout, 

691 'recurrent_dropout': self.recurrent_dropout} 

692 base_config = super(ConvLSTM2DCell, self).get_config() 

693 return dict(list(base_config.items()) + list(config.items())) 

694 

695 

696@keras_export('keras.layers.ConvLSTM2D') 

697class ConvLSTM2D(ConvRNN2D): 

698 """2D Convolutional LSTM layer. 

699 

700 A convolutional LSTM is similar to an LSTM, but the input transformations 

701 and recurrent transformations are both convolutional. This layer is typically 

702 used to process timeseries of images (i.e. video-like data). 

703 

704 It is known to perform well for weather data forecasting, 

705 using inputs that are timeseries of 2D grids of sensor values. 

706 It isn't usually applied to regular video data, due to its high computational 

707 cost. 

708 

709 Args: 

710 filters: Integer, the dimensionality of the output space 

711 (i.e. the number of output filters in the convolution). 

712 kernel_size: An integer or tuple/list of n integers, specifying the 

713 dimensions of the convolution window. 

714 strides: An integer or tuple/list of n integers, 

715 specifying the strides of the convolution. 

716 Specifying any stride value != 1 is incompatible with specifying 

717 any `dilation_rate` value != 1. 

718 padding: One of `"valid"` or `"same"` (case-insensitive). 

719 `"valid"` means no padding. `"same"` results in padding evenly to 

720 the left/right or up/down of the input such that output has the same 

721 height/width dimension as the input. 

722 data_format: A string, 

723 one of `channels_last` (default) or `channels_first`. 

724 The ordering of the dimensions in the inputs. 

725 `channels_last` corresponds to inputs with shape 

726 `(batch, time, ..., channels)` 

727 while `channels_first` corresponds to 

728 inputs with shape `(batch, time, channels, ...)`. 

729 It defaults to the `image_data_format` value found in your 

730 Keras config file at `~/.keras/keras.json`. 

731 If you never set it, then it will be "channels_last". 

732 dilation_rate: An integer or tuple/list of n integers, specifying 

733 the dilation rate to use for dilated convolution. 

734 Currently, specifying any `dilation_rate` value != 1 is 

735 incompatible with specifying any `strides` value != 1. 

736 activation: Activation function to use. 

737 By default hyperbolic tangent activation function is applied 

738 (`tanh(x)`). 

739 recurrent_activation: Activation function to use 

740 for the recurrent step. 

741 use_bias: Boolean, whether the layer uses a bias vector. 

742 kernel_initializer: Initializer for the `kernel` weights matrix, 

743 used for the linear transformation of the inputs. 

744 recurrent_initializer: Initializer for the `recurrent_kernel` 

745 weights matrix, 

746 used for the linear transformation of the recurrent state. 

747 bias_initializer: Initializer for the bias vector. 

748 unit_forget_bias: Boolean. 

749 If True, add 1 to the bias of the forget gate at initialization. 

750 Use in combination with `bias_initializer="zeros"`. 

751 This is recommended in [Jozefowicz et al., 2015]( 

752 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 

753 kernel_regularizer: Regularizer function applied to 

754 the `kernel` weights matrix. 

755 recurrent_regularizer: Regularizer function applied to 

756 the `recurrent_kernel` weights matrix. 

757 bias_regularizer: Regularizer function applied to the bias vector. 

758 activity_regularizer: Regularizer function applied to. 

759 kernel_constraint: Constraint function applied to 

760 the `kernel` weights matrix. 

761 recurrent_constraint: Constraint function applied to 

762 the `recurrent_kernel` weights matrix. 

763 bias_constraint: Constraint function applied to the bias vector. 

764 return_sequences: Boolean. Whether to return the last output 

765 in the output sequence, or the full sequence. (default False) 

766 return_state: Boolean Whether to return the last state 

767 in addition to the output. (default False) 

768 go_backwards: Boolean (default False). 

769 If True, process the input sequence backwards. 

770 stateful: Boolean (default False). If True, the last state 

771 for each sample at index i in a batch will be used as initial 

772 state for the sample of index i in the following batch. 

773 dropout: Float between 0 and 1. 

774 Fraction of the units to drop for 

775 the linear transformation of the inputs. 

776 recurrent_dropout: Float between 0 and 1. 

777 Fraction of the units to drop for 

778 the linear transformation of the recurrent state. 

779 

780 Call arguments: 

781 inputs: A 5D float tensor (see input shape description below). 

782 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 

783 a given timestep should be masked. 

784 training: Python boolean indicating whether the layer should behave in 

785 training mode or in inference mode. This argument is passed to the cell 

786 when calling it. This is only relevant if `dropout` or `recurrent_dropout` 

787 are set. 

788 initial_state: List of initial state tensors to be passed to the first 

789 call of the cell. 

790 

791 Input shape: 

792 - If data_format='channels_first' 

793 5D tensor with shape: 

794 `(samples, time, channels, rows, cols)` 

795 - If data_format='channels_last' 

796 5D tensor with shape: 

797 `(samples, time, rows, cols, channels)` 

798 

799 Output shape: 

800 - If `return_state`: a list of tensors. The first tensor is 

801 the output. The remaining tensors are the last states, 

802 each 4D tensor with shape: 

803 `(samples, filters, new_rows, new_cols)` 

804 if data_format='channels_first' 

805 or 4D tensor with shape: 

806 `(samples, new_rows, new_cols, filters)` 

807 if data_format='channels_last'. 

808 `rows` and `cols` values might have changed due to padding. 

809 - If `return_sequences`: 5D tensor with shape: 

810 `(samples, timesteps, filters, new_rows, new_cols)` 

811 if data_format='channels_first' 

812 or 5D tensor with shape: 

813 `(samples, timesteps, new_rows, new_cols, filters)` 

814 if data_format='channels_last'. 

815 - Else, 4D tensor with shape: 

816 `(samples, filters, new_rows, new_cols)` 

817 if data_format='channels_first' 

818 or 4D tensor with shape: 

819 `(samples, new_rows, new_cols, filters)` 

820 if data_format='channels_last'. 

821 

822 Raises: 

823 ValueError: in case of invalid constructor arguments. 

824 

825 References: 

826 - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1) 

827 (the current implementation does not include the feedback loop on the 

828 cells output). 

829 

830 Example: 

831 

832 ```python 

833 steps = 10 

834 height = 32 

835 width = 32 

836 input_channels = 3 

837 output_channels = 6 

838 

839 inputs = tf.keras.Input(shape=(steps, height, width, input_channels)) 

840 layer = tf.keras.layers.ConvLSTM2D(filters=output_channels, kernel_size=3) 

841 outputs = layer(inputs) 

842 ``` 

843 """ 

844 

845 def __init__(self, 

846 filters, 

847 kernel_size, 

848 strides=(1, 1), 

849 padding='valid', 

850 data_format=None, 

851 dilation_rate=(1, 1), 

852 activation='tanh', 

853 recurrent_activation='hard_sigmoid', 

854 use_bias=True, 

855 kernel_initializer='glorot_uniform', 

856 recurrent_initializer='orthogonal', 

857 bias_initializer='zeros', 

858 unit_forget_bias=True, 

859 kernel_regularizer=None, 

860 recurrent_regularizer=None, 

861 bias_regularizer=None, 

862 activity_regularizer=None, 

863 kernel_constraint=None, 

864 recurrent_constraint=None, 

865 bias_constraint=None, 

866 return_sequences=False, 

867 return_state=False, 

868 go_backwards=False, 

869 stateful=False, 

870 dropout=0., 

871 recurrent_dropout=0., 

872 **kwargs): 

873 cell = ConvLSTM2DCell(filters=filters, 

874 kernel_size=kernel_size, 

875 strides=strides, 

876 padding=padding, 

877 data_format=data_format, 

878 dilation_rate=dilation_rate, 

879 activation=activation, 

880 recurrent_activation=recurrent_activation, 

881 use_bias=use_bias, 

882 kernel_initializer=kernel_initializer, 

883 recurrent_initializer=recurrent_initializer, 

884 bias_initializer=bias_initializer, 

885 unit_forget_bias=unit_forget_bias, 

886 kernel_regularizer=kernel_regularizer, 

887 recurrent_regularizer=recurrent_regularizer, 

888 bias_regularizer=bias_regularizer, 

889 kernel_constraint=kernel_constraint, 

890 recurrent_constraint=recurrent_constraint, 

891 bias_constraint=bias_constraint, 

892 dropout=dropout, 

893 recurrent_dropout=recurrent_dropout, 

894 dtype=kwargs.get('dtype')) 

895 super(ConvLSTM2D, self).__init__(cell, 

896 return_sequences=return_sequences, 

897 return_state=return_state, 

898 go_backwards=go_backwards, 

899 stateful=stateful, 

900 **kwargs) 

901 self.activity_regularizer = regularizers.get(activity_regularizer) 

902 

903 def call(self, inputs, mask=None, training=None, initial_state=None): 

904 return super(ConvLSTM2D, self).call(inputs, 

905 mask=mask, 

906 training=training, 

907 initial_state=initial_state) 

908 

909 @property 

910 def filters(self): 

911 return self.cell.filters 

912 

913 @property 

914 def kernel_size(self): 

915 return self.cell.kernel_size 

916 

917 @property 

918 def strides(self): 

919 return self.cell.strides 

920 

921 @property 

922 def padding(self): 

923 return self.cell.padding 

924 

925 @property 

926 def data_format(self): 

927 return self.cell.data_format 

928 

929 @property 

930 def dilation_rate(self): 

931 return self.cell.dilation_rate 

932 

933 @property 

934 def activation(self): 

935 return self.cell.activation 

936 

937 @property 

938 def recurrent_activation(self): 

939 return self.cell.recurrent_activation 

940 

941 @property 

942 def use_bias(self): 

943 return self.cell.use_bias 

944 

945 @property 

946 def kernel_initializer(self): 

947 return self.cell.kernel_initializer 

948 

949 @property 

950 def recurrent_initializer(self): 

951 return self.cell.recurrent_initializer 

952 

953 @property 

954 def bias_initializer(self): 

955 return self.cell.bias_initializer 

956 

957 @property 

958 def unit_forget_bias(self): 

959 return self.cell.unit_forget_bias 

960 

961 @property 

962 def kernel_regularizer(self): 

963 return self.cell.kernel_regularizer 

964 

965 @property 

966 def recurrent_regularizer(self): 

967 return self.cell.recurrent_regularizer 

968 

969 @property 

970 def bias_regularizer(self): 

971 return self.cell.bias_regularizer 

972 

973 @property 

974 def kernel_constraint(self): 

975 return self.cell.kernel_constraint 

976 

977 @property 

978 def recurrent_constraint(self): 

979 return self.cell.recurrent_constraint 

980 

981 @property 

982 def bias_constraint(self): 

983 return self.cell.bias_constraint 

984 

985 @property 

986 def dropout(self): 

987 return self.cell.dropout 

988 

989 @property 

990 def recurrent_dropout(self): 

991 return self.cell.recurrent_dropout 

992 

993 def get_config(self): 

994 config = {'filters': self.filters, 

995 'kernel_size': self.kernel_size, 

996 'strides': self.strides, 

997 'padding': self.padding, 

998 'data_format': self.data_format, 

999 'dilation_rate': self.dilation_rate, 

1000 'activation': activations.serialize(self.activation), 

1001 'recurrent_activation': activations.serialize( 

1002 self.recurrent_activation), 

1003 'use_bias': self.use_bias, 

1004 'kernel_initializer': initializers.serialize( 

1005 self.kernel_initializer), 

1006 'recurrent_initializer': initializers.serialize( 

1007 self.recurrent_initializer), 

1008 'bias_initializer': initializers.serialize(self.bias_initializer), 

1009 'unit_forget_bias': self.unit_forget_bias, 

1010 'kernel_regularizer': regularizers.serialize( 

1011 self.kernel_regularizer), 

1012 'recurrent_regularizer': regularizers.serialize( 

1013 self.recurrent_regularizer), 

1014 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 

1015 'activity_regularizer': regularizers.serialize( 

1016 self.activity_regularizer), 

1017 'kernel_constraint': constraints.serialize( 

1018 self.kernel_constraint), 

1019 'recurrent_constraint': constraints.serialize( 

1020 self.recurrent_constraint), 

1021 'bias_constraint': constraints.serialize(self.bias_constraint), 

1022 'dropout': self.dropout, 

1023 'recurrent_dropout': self.recurrent_dropout} 

1024 base_config = super(ConvLSTM2D, self).get_config() 

1025 del base_config['cell'] 

1026 return dict(list(base_config.items()) + list(config.items())) 

1027 

1028 @classmethod 

1029 def from_config(cls, config): 

1030 return cls(**config)