Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/base_conv_lstm.py: 34%

200 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base class for N-D convolutional LSTM layers.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src import activations 

21from keras.src import backend 

22from keras.src import constraints 

23from keras.src import initializers 

24from keras.src import regularizers 

25from keras.src.engine import base_layer 

26from keras.src.layers.rnn.base_conv_rnn import ConvRNN 

27from keras.src.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin 

28from keras.src.utils import conv_utils 

29 

30 

31class ConvLSTMCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer): 

32 """Cell class for the ConvLSTM layer. 

33 

34 Args: 

35 rank: Integer, rank of the convolution, e.g. "2" for 2D convolutions. 

36 filters: Integer, the dimensionality of the output space (i.e. the number 

37 of output filters in the convolution). 

38 kernel_size: An integer or tuple/list of n integers, specifying the 

39 dimensions of the convolution window. 

40 strides: An integer or tuple/list of n integers, specifying the strides of 

41 the convolution. Specifying any stride value != 1 is incompatible with 

42 specifying any `dilation_rate` value != 1. 

43 padding: One of `"valid"` or `"same"` (case-insensitive). `"valid"` means 

44 no padding. `"same"` results in padding evenly to the left/right or 

45 up/down of the input such that output has the same height/width 

46 dimension as the input. 

47 data_format: A string, one of `channels_last` (default) or 

48 `channels_first`. When unspecified, uses 

49 `image_data_format` value found in your Keras config file at 

50 `~/.keras/keras.json` (if exists) else 'channels_last'. 

51 Defaults to 'channels_last'. 

52 dilation_rate: An integer or tuple/list of n integers, specifying the 

53 dilation rate to use for dilated convolution. Currently, specifying any 

54 `dilation_rate` value != 1 is incompatible with specifying any `strides` 

55 value != 1. 

56 activation: Activation function to use. If you don't specify anything, no 

57 activation is applied 

58 (ie. "linear" activation: `a(x) = x`). 

59 recurrent_activation: Activation function to use for the recurrent step. 

60 use_bias: Boolean, whether the layer uses a bias vector. 

61 kernel_initializer: Initializer for the `kernel` weights matrix, used for 

62 the linear transformation of the inputs. 

63 recurrent_initializer: Initializer for the `recurrent_kernel` weights 

64 matrix, used for the linear transformation of the recurrent state. 

65 bias_initializer: Initializer for the bias vector. 

66 unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate 

67 at initialization. Use in combination with `bias_initializer="zeros"`. 

68 This is recommended in [Jozefowicz et al., 2015]( 

69 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 

70 kernel_regularizer: Regularizer function applied to the `kernel` weights 

71 matrix. 

72 recurrent_regularizer: Regularizer function applied to the 

73 `recurrent_kernel` weights matrix. 

74 bias_regularizer: Regularizer function applied to the bias vector. 

75 kernel_constraint: Constraint function applied to the `kernel` weights 

76 matrix. 

77 recurrent_constraint: Constraint function applied to the 

78 `recurrent_kernel` weights matrix. 

79 bias_constraint: Constraint function applied to the bias vector. 

80 dropout: Float between 0 and 1. Fraction of the units to drop for the 

81 linear transformation of the inputs. 

82 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop 

83 for the linear transformation of the recurrent state. 

84 Call arguments: 

85 inputs: A (2+ `rank`)D tensor. 

86 states: List of state tensors corresponding to the previous timestep. 

87 training: Python boolean indicating whether the layer should behave in 

88 training mode or in inference mode. Only relevant when `dropout` or 

89 `recurrent_dropout` is used. 

90 """ 

91 

92 def __init__( 

93 self, 

94 rank, 

95 filters, 

96 kernel_size, 

97 strides=1, 

98 padding="valid", 

99 data_format=None, 

100 dilation_rate=1, 

101 activation="tanh", 

102 recurrent_activation="hard_sigmoid", 

103 use_bias=True, 

104 kernel_initializer="glorot_uniform", 

105 recurrent_initializer="orthogonal", 

106 bias_initializer="zeros", 

107 unit_forget_bias=True, 

108 kernel_regularizer=None, 

109 recurrent_regularizer=None, 

110 bias_regularizer=None, 

111 kernel_constraint=None, 

112 recurrent_constraint=None, 

113 bias_constraint=None, 

114 dropout=0.0, 

115 recurrent_dropout=0.0, 

116 **kwargs, 

117 ): 

118 super().__init__(**kwargs) 

119 self.rank = rank 

120 if self.rank > 3: 

121 raise ValueError( 

122 f"Rank {rank} convolutions are not currently " 

123 f"implemented. Received: rank={rank}" 

124 ) 

125 self.filters = filters 

126 self.kernel_size = conv_utils.normalize_tuple( 

127 kernel_size, self.rank, "kernel_size" 

128 ) 

129 self.strides = conv_utils.normalize_tuple( 

130 strides, self.rank, "strides", allow_zero=True 

131 ) 

132 self.padding = conv_utils.normalize_padding(padding) 

133 self.data_format = conv_utils.normalize_data_format(data_format) 

134 self.dilation_rate = conv_utils.normalize_tuple( 

135 dilation_rate, self.rank, "dilation_rate" 

136 ) 

137 self.activation = activations.get(activation) 

138 self.recurrent_activation = activations.get(recurrent_activation) 

139 self.use_bias = use_bias 

140 

141 self.kernel_initializer = initializers.get(kernel_initializer) 

142 self.recurrent_initializer = initializers.get(recurrent_initializer) 

143 self.bias_initializer = initializers.get(bias_initializer) 

144 self.unit_forget_bias = unit_forget_bias 

145 

146 self.kernel_regularizer = regularizers.get(kernel_regularizer) 

147 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 

148 self.bias_regularizer = regularizers.get(bias_regularizer) 

149 

150 self.kernel_constraint = constraints.get(kernel_constraint) 

151 self.recurrent_constraint = constraints.get(recurrent_constraint) 

152 self.bias_constraint = constraints.get(bias_constraint) 

153 

154 self.dropout = min(1.0, max(0.0, dropout)) 

155 self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout)) 

156 self.state_size = (self.filters, self.filters) 

157 

158 def build(self, input_shape): 

159 super().build(input_shape) 

160 if self.data_format == "channels_first": 

161 channel_axis = 1 

162 else: 

163 channel_axis = -1 

164 if input_shape[channel_axis] is None: 

165 raise ValueError( 

166 "The channel dimension of the inputs (last axis) should be " 

167 "defined. Found None. Full input shape received: " 

168 f"input_shape={input_shape}" 

169 ) 

170 input_dim = input_shape[channel_axis] 

171 self.kernel_shape = self.kernel_size + (input_dim, self.filters * 4) 

172 recurrent_kernel_shape = self.kernel_size + ( 

173 self.filters, 

174 self.filters * 4, 

175 ) 

176 

177 self.kernel = self.add_weight( 

178 shape=self.kernel_shape, 

179 initializer=self.kernel_initializer, 

180 name="kernel", 

181 regularizer=self.kernel_regularizer, 

182 constraint=self.kernel_constraint, 

183 ) 

184 self.recurrent_kernel = self.add_weight( 

185 shape=recurrent_kernel_shape, 

186 initializer=self.recurrent_initializer, 

187 name="recurrent_kernel", 

188 regularizer=self.recurrent_regularizer, 

189 constraint=self.recurrent_constraint, 

190 ) 

191 

192 if self.use_bias: 

193 if self.unit_forget_bias: 

194 

195 def bias_initializer(_, *args, **kwargs): 

196 return backend.concatenate( 

197 [ 

198 self.bias_initializer( 

199 (self.filters,), *args, **kwargs 

200 ), 

201 initializers.get("ones")( 

202 (self.filters,), *args, **kwargs 

203 ), 

204 self.bias_initializer( 

205 (self.filters * 2,), *args, **kwargs 

206 ), 

207 ] 

208 ) 

209 

210 else: 

211 bias_initializer = self.bias_initializer 

212 self.bias = self.add_weight( 

213 shape=(self.filters * 4,), 

214 name="bias", 

215 initializer=bias_initializer, 

216 regularizer=self.bias_regularizer, 

217 constraint=self.bias_constraint, 

218 ) 

219 else: 

220 self.bias = None 

221 self.built = True 

222 

223 def call(self, inputs, states, training=None): 

224 h_tm1 = states[0] # previous memory state 

225 c_tm1 = states[1] # previous carry state 

226 

227 # dropout matrices for input units 

228 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) 

229 # dropout matrices for recurrent units 

230 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 

231 h_tm1, training, count=4 

232 ) 

233 

234 if 0 < self.dropout < 1.0: 

235 inputs_i = inputs * dp_mask[0] 

236 inputs_f = inputs * dp_mask[1] 

237 inputs_c = inputs * dp_mask[2] 

238 inputs_o = inputs * dp_mask[3] 

239 else: 

240 inputs_i = inputs 

241 inputs_f = inputs 

242 inputs_c = inputs 

243 inputs_o = inputs 

244 

245 if 0 < self.recurrent_dropout < 1.0: 

246 h_tm1_i = h_tm1 * rec_dp_mask[0] 

247 h_tm1_f = h_tm1 * rec_dp_mask[1] 

248 h_tm1_c = h_tm1 * rec_dp_mask[2] 

249 h_tm1_o = h_tm1 * rec_dp_mask[3] 

250 else: 

251 h_tm1_i = h_tm1 

252 h_tm1_f = h_tm1 

253 h_tm1_c = h_tm1 

254 h_tm1_o = h_tm1 

255 

256 (kernel_i, kernel_f, kernel_c, kernel_o) = tf.split( 

257 self.kernel, 4, axis=self.rank + 1 

258 ) 

259 ( 

260 recurrent_kernel_i, 

261 recurrent_kernel_f, 

262 recurrent_kernel_c, 

263 recurrent_kernel_o, 

264 ) = tf.split(self.recurrent_kernel, 4, axis=self.rank + 1) 

265 

266 if self.use_bias: 

267 bias_i, bias_f, bias_c, bias_o = tf.split(self.bias, 4) 

268 else: 

269 bias_i, bias_f, bias_c, bias_o = None, None, None, None 

270 

271 x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding) 

272 x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding) 

273 x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding) 

274 x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding) 

275 h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i) 

276 h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f) 

277 h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c) 

278 h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o) 

279 

280 i = self.recurrent_activation(x_i + h_i) 

281 f = self.recurrent_activation(x_f + h_f) 

282 c = f * c_tm1 + i * self.activation(x_c + h_c) 

283 o = self.recurrent_activation(x_o + h_o) 

284 h = o * self.activation(c) 

285 return h, [h, c] 

286 

287 @property 

288 def _conv_func(self): 

289 if self.rank == 1: 

290 return backend.conv1d 

291 if self.rank == 2: 

292 return backend.conv2d 

293 if self.rank == 3: 

294 return backend.conv3d 

295 

296 def input_conv(self, x, w, b=None, padding="valid"): 

297 conv_out = self._conv_func( 

298 x, 

299 w, 

300 strides=self.strides, 

301 padding=padding, 

302 data_format=self.data_format, 

303 dilation_rate=self.dilation_rate, 

304 ) 

305 if b is not None: 

306 conv_out = backend.bias_add( 

307 conv_out, b, data_format=self.data_format 

308 ) 

309 return conv_out 

310 

311 def recurrent_conv(self, x, w): 

312 strides = conv_utils.normalize_tuple( 

313 1, self.rank, "strides", allow_zero=True 

314 ) 

315 conv_out = self._conv_func( 

316 x, w, strides=strides, padding="same", data_format=self.data_format 

317 ) 

318 return conv_out 

319 

320 def get_config(self): 

321 config = { 

322 "filters": self.filters, 

323 "kernel_size": self.kernel_size, 

324 "strides": self.strides, 

325 "padding": self.padding, 

326 "data_format": self.data_format, 

327 "dilation_rate": self.dilation_rate, 

328 "activation": activations.serialize(self.activation), 

329 "recurrent_activation": activations.serialize( 

330 self.recurrent_activation 

331 ), 

332 "use_bias": self.use_bias, 

333 "kernel_initializer": initializers.serialize( 

334 self.kernel_initializer 

335 ), 

336 "recurrent_initializer": initializers.serialize( 

337 self.recurrent_initializer 

338 ), 

339 "bias_initializer": initializers.serialize(self.bias_initializer), 

340 "unit_forget_bias": self.unit_forget_bias, 

341 "kernel_regularizer": regularizers.serialize( 

342 self.kernel_regularizer 

343 ), 

344 "recurrent_regularizer": regularizers.serialize( 

345 self.recurrent_regularizer 

346 ), 

347 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

348 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

349 "recurrent_constraint": constraints.serialize( 

350 self.recurrent_constraint 

351 ), 

352 "bias_constraint": constraints.serialize(self.bias_constraint), 

353 "dropout": self.dropout, 

354 "recurrent_dropout": self.recurrent_dropout, 

355 } 

356 base_config = super().get_config() 

357 return dict(list(base_config.items()) + list(config.items())) 

358 

359 

360class ConvLSTM(ConvRNN): 

361 """Abstract N-D Convolutional LSTM layer (used as implementation base). 

362 

363 Similar to an LSTM layer, but the input transformations 

364 and recurrent transformations are both convolutional. 

365 

366 Args: 

367 rank: Integer, rank of the convolution, e.g. "2" for 2D convolutions. 

368 filters: Integer, the dimensionality of the output space 

369 (i.e. the number of output filters in the convolution). 

370 kernel_size: An integer or tuple/list of n integers, specifying the 

371 dimensions of the convolution window. 

372 strides: An integer or tuple/list of n integers, 

373 specifying the strides of the convolution. 

374 Specifying any stride value != 1 is incompatible with specifying 

375 any `dilation_rate` value != 1. 

376 padding: One of `"valid"` or `"same"` (case-insensitive). 

377 `"valid"` means no padding. `"same"` results in padding evenly to 

378 the left/right or up/down of the input such that output has the same 

379 height/width dimension as the input. 

380 data_format: A string, 

381 one of `channels_last` (default) or `channels_first`. 

382 The ordering of the dimensions in the inputs. 

383 `channels_last` corresponds to inputs with shape 

384 `(batch, time, ..., channels)` 

385 while `channels_first` corresponds to 

386 inputs with shape `(batch, time, channels, ...)`. 

387 When unspecified, uses 

388 `image_data_format` value found in your Keras config file at 

389 `~/.keras/keras.json` (if exists) else 'channels_last'. 

390 Defaults to 'channels_last'. 

391 dilation_rate: An integer or tuple/list of n integers, specifying 

392 the dilation rate to use for dilated convolution. 

393 Currently, specifying any `dilation_rate` value != 1 is 

394 incompatible with specifying any `strides` value != 1. 

395 activation: Activation function to use. 

396 By default hyperbolic tangent activation function is applied 

397 (`tanh(x)`). 

398 recurrent_activation: Activation function to use 

399 for the recurrent step. 

400 use_bias: Boolean, whether the layer uses a bias vector. 

401 kernel_initializer: Initializer for the `kernel` weights matrix, 

402 used for the linear transformation of the inputs. 

403 recurrent_initializer: Initializer for the `recurrent_kernel` 

404 weights matrix, 

405 used for the linear transformation of the recurrent state. 

406 bias_initializer: Initializer for the bias vector. 

407 unit_forget_bias: Boolean. 

408 If True, add 1 to the bias of the forget gate at initialization. 

409 Use in combination with `bias_initializer="zeros"`. 

410 This is recommended in [Jozefowicz et al., 2015]( 

411 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 

412 kernel_regularizer: Regularizer function applied to 

413 the `kernel` weights matrix. 

414 recurrent_regularizer: Regularizer function applied to 

415 the `recurrent_kernel` weights matrix. 

416 bias_regularizer: Regularizer function applied to the bias vector. 

417 activity_regularizer: Regularizer function applied to. 

418 kernel_constraint: Constraint function applied to 

419 the `kernel` weights matrix. 

420 recurrent_constraint: Constraint function applied to 

421 the `recurrent_kernel` weights matrix. 

422 bias_constraint: Constraint function applied to the bias vector. 

423 return_sequences: Boolean. Whether to return the last output 

424 in the output sequence, or the full sequence. (default False) 

425 return_state: Boolean Whether to return the last state 

426 in addition to the output. (default False) 

427 go_backwards: Boolean (default False). 

428 If True, process the input sequence backwards. 

429 stateful: Boolean (default False). If True, the last state 

430 for each sample at index i in a batch will be used as initial 

431 state for the sample of index i in the following batch. 

432 dropout: Float between 0 and 1. 

433 Fraction of the units to drop for 

434 the linear transformation of the inputs. 

435 recurrent_dropout: Float between 0 and 1. 

436 Fraction of the units to drop for 

437 the linear transformation of the recurrent state. 

438 """ 

439 

440 def __init__( 

441 self, 

442 rank, 

443 filters, 

444 kernel_size, 

445 strides=1, 

446 padding="valid", 

447 data_format=None, 

448 dilation_rate=1, 

449 activation="tanh", 

450 recurrent_activation="hard_sigmoid", 

451 use_bias=True, 

452 kernel_initializer="glorot_uniform", 

453 recurrent_initializer="orthogonal", 

454 bias_initializer="zeros", 

455 unit_forget_bias=True, 

456 kernel_regularizer=None, 

457 recurrent_regularizer=None, 

458 bias_regularizer=None, 

459 activity_regularizer=None, 

460 kernel_constraint=None, 

461 recurrent_constraint=None, 

462 bias_constraint=None, 

463 return_sequences=False, 

464 return_state=False, 

465 go_backwards=False, 

466 stateful=False, 

467 dropout=0.0, 

468 recurrent_dropout=0.0, 

469 **kwargs, 

470 ): 

471 cell = ConvLSTMCell( 

472 rank=rank, 

473 filters=filters, 

474 kernel_size=kernel_size, 

475 strides=strides, 

476 padding=padding, 

477 data_format=data_format, 

478 dilation_rate=dilation_rate, 

479 activation=activation, 

480 recurrent_activation=recurrent_activation, 

481 use_bias=use_bias, 

482 kernel_initializer=kernel_initializer, 

483 recurrent_initializer=recurrent_initializer, 

484 bias_initializer=bias_initializer, 

485 unit_forget_bias=unit_forget_bias, 

486 kernel_regularizer=kernel_regularizer, 

487 recurrent_regularizer=recurrent_regularizer, 

488 bias_regularizer=bias_regularizer, 

489 kernel_constraint=kernel_constraint, 

490 recurrent_constraint=recurrent_constraint, 

491 bias_constraint=bias_constraint, 

492 dropout=dropout, 

493 recurrent_dropout=recurrent_dropout, 

494 name="conv_lstm_cell", 

495 dtype=kwargs.get("dtype"), 

496 ) 

497 super().__init__( 

498 rank, 

499 cell, 

500 return_sequences=return_sequences, 

501 return_state=return_state, 

502 go_backwards=go_backwards, 

503 stateful=stateful, 

504 **kwargs, 

505 ) 

506 self.activity_regularizer = regularizers.get(activity_regularizer) 

507 

508 def call(self, inputs, mask=None, training=None, initial_state=None): 

509 return super().call( 

510 inputs, mask=mask, training=training, initial_state=initial_state 

511 ) 

512 

513 @property 

514 def filters(self): 

515 return self.cell.filters 

516 

517 @property 

518 def kernel_size(self): 

519 return self.cell.kernel_size 

520 

521 @property 

522 def strides(self): 

523 return self.cell.strides 

524 

525 @property 

526 def padding(self): 

527 return self.cell.padding 

528 

529 @property 

530 def data_format(self): 

531 return self.cell.data_format 

532 

533 @property 

534 def dilation_rate(self): 

535 return self.cell.dilation_rate 

536 

537 @property 

538 def activation(self): 

539 return self.cell.activation 

540 

541 @property 

542 def recurrent_activation(self): 

543 return self.cell.recurrent_activation 

544 

545 @property 

546 def use_bias(self): 

547 return self.cell.use_bias 

548 

549 @property 

550 def kernel_initializer(self): 

551 return self.cell.kernel_initializer 

552 

553 @property 

554 def recurrent_initializer(self): 

555 return self.cell.recurrent_initializer 

556 

557 @property 

558 def bias_initializer(self): 

559 return self.cell.bias_initializer 

560 

561 @property 

562 def unit_forget_bias(self): 

563 return self.cell.unit_forget_bias 

564 

565 @property 

566 def kernel_regularizer(self): 

567 return self.cell.kernel_regularizer 

568 

569 @property 

570 def recurrent_regularizer(self): 

571 return self.cell.recurrent_regularizer 

572 

573 @property 

574 def bias_regularizer(self): 

575 return self.cell.bias_regularizer 

576 

577 @property 

578 def kernel_constraint(self): 

579 return self.cell.kernel_constraint 

580 

581 @property 

582 def recurrent_constraint(self): 

583 return self.cell.recurrent_constraint 

584 

585 @property 

586 def bias_constraint(self): 

587 return self.cell.bias_constraint 

588 

589 @property 

590 def dropout(self): 

591 return self.cell.dropout 

592 

593 @property 

594 def recurrent_dropout(self): 

595 return self.cell.recurrent_dropout 

596 

597 def get_config(self): 

598 config = { 

599 "filters": self.filters, 

600 "kernel_size": self.kernel_size, 

601 "strides": self.strides, 

602 "padding": self.padding, 

603 "data_format": self.data_format, 

604 "dilation_rate": self.dilation_rate, 

605 "activation": activations.serialize(self.activation), 

606 "recurrent_activation": activations.serialize( 

607 self.recurrent_activation 

608 ), 

609 "use_bias": self.use_bias, 

610 "kernel_initializer": initializers.serialize( 

611 self.kernel_initializer 

612 ), 

613 "recurrent_initializer": initializers.serialize( 

614 self.recurrent_initializer 

615 ), 

616 "bias_initializer": initializers.serialize(self.bias_initializer), 

617 "unit_forget_bias": self.unit_forget_bias, 

618 "kernel_regularizer": regularizers.serialize( 

619 self.kernel_regularizer 

620 ), 

621 "recurrent_regularizer": regularizers.serialize( 

622 self.recurrent_regularizer 

623 ), 

624 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

625 "activity_regularizer": regularizers.serialize( 

626 self.activity_regularizer 

627 ), 

628 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

629 "recurrent_constraint": constraints.serialize( 

630 self.recurrent_constraint 

631 ), 

632 "bias_constraint": constraints.serialize(self.bias_constraint), 

633 "dropout": self.dropout, 

634 "recurrent_dropout": self.recurrent_dropout, 

635 } 

636 base_config = super().get_config() 

637 del base_config["cell"] 

638 return dict(list(base_config.items()) + list(config.items())) 

639 

640 @classmethod 

641 def from_config(cls, config): 

642 return cls(**config) 

643