Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/simple_rnn.py: 41%

141 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Fully connected RNN layer.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src import activations 

21from keras.src import backend 

22from keras.src import constraints 

23from keras.src import initializers 

24from keras.src import regularizers 

25from keras.src.engine import base_layer 

26from keras.src.engine.input_spec import InputSpec 

27from keras.src.layers.rnn import rnn_utils 

28from keras.src.layers.rnn.base_rnn import RNN 

29from keras.src.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin 

30from keras.src.utils import tf_utils 

31 

32# isort: off 

33from tensorflow.python.platform import tf_logging as logging 

34from tensorflow.python.util.tf_export import keras_export 

35 

36 

37@keras_export("keras.layers.SimpleRNNCell") 

38class SimpleRNNCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer): 

39 """Cell class for SimpleRNN. 

40 

41 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 

42 for details about the usage of RNN API. 

43 

44 This class processes one step within the whole time sequence input, whereas 

45 `tf.keras.layer.SimpleRNN` processes the whole sequence. 

46 

47 Args: 

48 units: Positive integer, dimensionality of the output space. 

49 activation: Activation function to use. 

50 Default: hyperbolic tangent (`tanh`). 

51 If you pass `None`, no activation is applied 

52 (ie. "linear" activation: `a(x) = x`). 

53 use_bias: Boolean, (default `True`), whether the layer uses a bias vector. 

54 kernel_initializer: Initializer for the `kernel` weights matrix, 

55 used for the linear transformation of the inputs. Default: 

56 `glorot_uniform`. 

57 recurrent_initializer: Initializer for the `recurrent_kernel` 

58 weights matrix, used for the linear transformation of the recurrent 

59 state. Default: `orthogonal`. 

60 bias_initializer: Initializer for the bias vector. Default: `zeros`. 

61 kernel_regularizer: Regularizer function applied to the `kernel` weights 

62 matrix. Default: `None`. 

63 recurrent_regularizer: Regularizer function applied to the 

64 `recurrent_kernel` weights matrix. Default: `None`. 

65 bias_regularizer: Regularizer function applied to the bias vector. 

66 Default: `None`. 

67 kernel_constraint: Constraint function applied to the `kernel` weights 

68 matrix. Default: `None`. 

69 recurrent_constraint: Constraint function applied to the 

70 `recurrent_kernel` weights matrix. Default: `None`. 

71 bias_constraint: Constraint function applied to the bias vector. Default: 

72 `None`. 

73 dropout: Float between 0 and 1. Fraction of the units to drop for the 

74 linear transformation of the inputs. Default: 0. 

75 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop 

76 for the linear transformation of the recurrent state. Default: 0. 

77 

78 Call arguments: 

79 inputs: A 2D tensor, with shape of `[batch, feature]`. 

80 states: A 2D tensor with shape of `[batch, units]`, which is the state 

81 from the previous time step. For timestep 0, the initial state provided 

82 by user will be feed to cell. 

83 training: Python boolean indicating whether the layer should behave in 

84 training mode or in inference mode. Only relevant when `dropout` or 

85 `recurrent_dropout` is used. 

86 

87 Examples: 

88 

89 ```python 

90 inputs = np.random.random([32, 10, 8]).astype(np.float32) 

91 rnn = tf.keras.layers.RNN(tf.keras.layers.SimpleRNNCell(4)) 

92 

93 output = rnn(inputs) # The output has shape `[32, 4]`. 

94 

95 rnn = tf.keras.layers.RNN( 

96 tf.keras.layers.SimpleRNNCell(4), 

97 return_sequences=True, 

98 return_state=True) 

99 

100 # whole_sequence_output has shape `[32, 10, 4]`. 

101 # final_state has shape `[32, 4]`. 

102 whole_sequence_output, final_state = rnn(inputs) 

103 ``` 

104 """ 

105 

106 def __init__( 

107 self, 

108 units, 

109 activation="tanh", 

110 use_bias=True, 

111 kernel_initializer="glorot_uniform", 

112 recurrent_initializer="orthogonal", 

113 bias_initializer="zeros", 

114 kernel_regularizer=None, 

115 recurrent_regularizer=None, 

116 bias_regularizer=None, 

117 kernel_constraint=None, 

118 recurrent_constraint=None, 

119 bias_constraint=None, 

120 dropout=0.0, 

121 recurrent_dropout=0.0, 

122 **kwargs, 

123 ): 

124 if units <= 0: 

125 raise ValueError( 

126 "Received an invalid value for argument `units`, " 

127 f"expected a positive integer, got {units}." 

128 ) 

129 # By default use cached variable under v2 mode, see b/143699808. 

130 if tf.compat.v1.executing_eagerly_outside_functions(): 

131 self._enable_caching_device = kwargs.pop( 

132 "enable_caching_device", True 

133 ) 

134 else: 

135 self._enable_caching_device = kwargs.pop( 

136 "enable_caching_device", False 

137 ) 

138 super().__init__(**kwargs) 

139 self.units = units 

140 self.activation = activations.get(activation) 

141 self.use_bias = use_bias 

142 

143 self.kernel_initializer = initializers.get(kernel_initializer) 

144 self.recurrent_initializer = initializers.get(recurrent_initializer) 

145 self.bias_initializer = initializers.get(bias_initializer) 

146 

147 self.kernel_regularizer = regularizers.get(kernel_regularizer) 

148 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 

149 self.bias_regularizer = regularizers.get(bias_regularizer) 

150 

151 self.kernel_constraint = constraints.get(kernel_constraint) 

152 self.recurrent_constraint = constraints.get(recurrent_constraint) 

153 self.bias_constraint = constraints.get(bias_constraint) 

154 

155 self.dropout = min(1.0, max(0.0, dropout)) 

156 self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout)) 

157 self.state_size = self.units 

158 self.output_size = self.units 

159 

160 @tf_utils.shape_type_conversion 

161 def build(self, input_shape): 

162 super().build(input_shape) 

163 default_caching_device = rnn_utils.caching_device(self) 

164 self.kernel = self.add_weight( 

165 shape=(input_shape[-1], self.units), 

166 name="kernel", 

167 initializer=self.kernel_initializer, 

168 regularizer=self.kernel_regularizer, 

169 constraint=self.kernel_constraint, 

170 caching_device=default_caching_device, 

171 ) 

172 self.recurrent_kernel = self.add_weight( 

173 shape=(self.units, self.units), 

174 name="recurrent_kernel", 

175 initializer=self.recurrent_initializer, 

176 regularizer=self.recurrent_regularizer, 

177 constraint=self.recurrent_constraint, 

178 caching_device=default_caching_device, 

179 ) 

180 if self.use_bias: 

181 self.bias = self.add_weight( 

182 shape=(self.units,), 

183 name="bias", 

184 initializer=self.bias_initializer, 

185 regularizer=self.bias_regularizer, 

186 constraint=self.bias_constraint, 

187 caching_device=default_caching_device, 

188 ) 

189 else: 

190 self.bias = None 

191 self.built = True 

192 

193 def call(self, inputs, states, training=None): 

194 prev_output = states[0] if tf.nest.is_nested(states) else states 

195 dp_mask = self.get_dropout_mask_for_cell(inputs, training) 

196 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 

197 prev_output, training 

198 ) 

199 

200 if dp_mask is not None: 

201 h = backend.dot(inputs * dp_mask, self.kernel) 

202 else: 

203 h = backend.dot(inputs, self.kernel) 

204 if self.bias is not None: 

205 h = backend.bias_add(h, self.bias) 

206 

207 if rec_dp_mask is not None: 

208 prev_output = prev_output * rec_dp_mask 

209 output = h + backend.dot(prev_output, self.recurrent_kernel) 

210 if self.activation is not None: 

211 output = self.activation(output) 

212 

213 new_state = [output] if tf.nest.is_nested(states) else output 

214 return output, new_state 

215 

216 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

217 return rnn_utils.generate_zero_filled_state_for_cell( 

218 self, inputs, batch_size, dtype 

219 ) 

220 

221 def get_config(self): 

222 config = { 

223 "units": self.units, 

224 "activation": activations.serialize(self.activation), 

225 "use_bias": self.use_bias, 

226 "kernel_initializer": initializers.serialize( 

227 self.kernel_initializer 

228 ), 

229 "recurrent_initializer": initializers.serialize( 

230 self.recurrent_initializer 

231 ), 

232 "bias_initializer": initializers.serialize(self.bias_initializer), 

233 "kernel_regularizer": regularizers.serialize( 

234 self.kernel_regularizer 

235 ), 

236 "recurrent_regularizer": regularizers.serialize( 

237 self.recurrent_regularizer 

238 ), 

239 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

240 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

241 "recurrent_constraint": constraints.serialize( 

242 self.recurrent_constraint 

243 ), 

244 "bias_constraint": constraints.serialize(self.bias_constraint), 

245 "dropout": self.dropout, 

246 "recurrent_dropout": self.recurrent_dropout, 

247 } 

248 config.update(rnn_utils.config_for_enable_caching_device(self)) 

249 base_config = super().get_config() 

250 return dict(list(base_config.items()) + list(config.items())) 

251 

252 

253@keras_export("keras.layers.SimpleRNN") 

254class SimpleRNN(RNN): 

255 """Fully-connected RNN where the output is to be fed back to input. 

256 

257 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 

258 for details about the usage of RNN API. 

259 

260 Args: 

261 units: Positive integer, dimensionality of the output space. 

262 activation: Activation function to use. 

263 Default: hyperbolic tangent (`tanh`). 

264 If you pass None, no activation is applied 

265 (ie. "linear" activation: `a(x) = x`). 

266 use_bias: Boolean, (default `True`), whether the layer uses a bias vector. 

267 kernel_initializer: Initializer for the `kernel` weights matrix, 

268 used for the linear transformation of the inputs. Default: 

269 `glorot_uniform`. 

270 recurrent_initializer: Initializer for the `recurrent_kernel` 

271 weights matrix, used for the linear transformation of the recurrent 

272 state. Default: `orthogonal`. 

273 bias_initializer: Initializer for the bias vector. Default: `zeros`. 

274 kernel_regularizer: Regularizer function applied to the `kernel` weights 

275 matrix. Default: `None`. 

276 recurrent_regularizer: Regularizer function applied to the 

277 `recurrent_kernel` weights matrix. Default: `None`. 

278 bias_regularizer: Regularizer function applied to the bias vector. 

279 Default: `None`. 

280 activity_regularizer: Regularizer function applied to the output of the 

281 layer (its "activation"). Default: `None`. 

282 kernel_constraint: Constraint function applied to the `kernel` weights 

283 matrix. Default: `None`. 

284 recurrent_constraint: Constraint function applied to the 

285 `recurrent_kernel` weights matrix. Default: `None`. 

286 bias_constraint: Constraint function applied to the bias vector. Default: 

287 `None`. 

288 dropout: Float between 0 and 1. 

289 Fraction of the units to drop for the linear transformation of the 

290 inputs. Default: 0. 

291 recurrent_dropout: Float between 0 and 1. 

292 Fraction of the units to drop for the linear transformation of the 

293 recurrent state. Default: 0. 

294 return_sequences: Boolean. Whether to return the last output 

295 in the output sequence, or the full sequence. Default: `False`. 

296 return_state: Boolean. Whether to return the last state 

297 in addition to the output. Default: `False` 

298 go_backwards: Boolean (default False). 

299 If True, process the input sequence backwards and return the 

300 reversed sequence. 

301 stateful: Boolean (default False). If True, the last state 

302 for each sample at index i in a batch will be used as initial 

303 state for the sample of index i in the following batch. 

304 unroll: Boolean (default False). 

305 If True, the network will be unrolled, 

306 else a symbolic loop will be used. 

307 Unrolling can speed-up a RNN, 

308 although it tends to be more memory-intensive. 

309 Unrolling is only suitable for short sequences. 

310 

311 Call arguments: 

312 inputs: A 3D tensor, with shape `[batch, timesteps, feature]`. 

313 mask: Binary tensor of shape `[batch, timesteps]` indicating whether 

314 a given timestep should be masked. An individual `True` entry indicates 

315 that the corresponding timestep should be utilized, while a `False` 

316 entry indicates that the corresponding timestep should be ignored. 

317 training: Python boolean indicating whether the layer should behave in 

318 training mode or in inference mode. This argument is passed to the cell 

319 when calling it. This is only relevant if `dropout` or 

320 `recurrent_dropout` is used. 

321 initial_state: List of initial state tensors to be passed to the first 

322 call of the cell. 

323 

324 Examples: 

325 

326 ```python 

327 inputs = np.random.random([32, 10, 8]).astype(np.float32) 

328 simple_rnn = tf.keras.layers.SimpleRNN(4) 

329 

330 output = simple_rnn(inputs) # The output has shape `[32, 4]`. 

331 

332 simple_rnn = tf.keras.layers.SimpleRNN( 

333 4, return_sequences=True, return_state=True) 

334 

335 # whole_sequence_output has shape `[32, 10, 4]`. 

336 # final_state has shape `[32, 4]`. 

337 whole_sequence_output, final_state = simple_rnn(inputs) 

338 ``` 

339 """ 

340 

341 def __init__( 

342 self, 

343 units, 

344 activation="tanh", 

345 use_bias=True, 

346 kernel_initializer="glorot_uniform", 

347 recurrent_initializer="orthogonal", 

348 bias_initializer="zeros", 

349 kernel_regularizer=None, 

350 recurrent_regularizer=None, 

351 bias_regularizer=None, 

352 activity_regularizer=None, 

353 kernel_constraint=None, 

354 recurrent_constraint=None, 

355 bias_constraint=None, 

356 dropout=0.0, 

357 recurrent_dropout=0.0, 

358 return_sequences=False, 

359 return_state=False, 

360 go_backwards=False, 

361 stateful=False, 

362 unroll=False, 

363 **kwargs, 

364 ): 

365 if "implementation" in kwargs: 

366 kwargs.pop("implementation") 

367 logging.warning( 

368 "The `implementation` argument " 

369 "in `SimpleRNN` has been deprecated. " 

370 "Please remove it from your layer call." 

371 ) 

372 if "enable_caching_device" in kwargs: 

373 cell_kwargs = { 

374 "enable_caching_device": kwargs.pop("enable_caching_device") 

375 } 

376 else: 

377 cell_kwargs = {} 

378 cell = SimpleRNNCell( 

379 units, 

380 activation=activation, 

381 use_bias=use_bias, 

382 kernel_initializer=kernel_initializer, 

383 recurrent_initializer=recurrent_initializer, 

384 bias_initializer=bias_initializer, 

385 kernel_regularizer=kernel_regularizer, 

386 recurrent_regularizer=recurrent_regularizer, 

387 bias_regularizer=bias_regularizer, 

388 kernel_constraint=kernel_constraint, 

389 recurrent_constraint=recurrent_constraint, 

390 bias_constraint=bias_constraint, 

391 dropout=dropout, 

392 recurrent_dropout=recurrent_dropout, 

393 dtype=kwargs.get("dtype"), 

394 trainable=kwargs.get("trainable", True), 

395 name="simple_rnn_cell", 

396 **cell_kwargs, 

397 ) 

398 super().__init__( 

399 cell, 

400 return_sequences=return_sequences, 

401 return_state=return_state, 

402 go_backwards=go_backwards, 

403 stateful=stateful, 

404 unroll=unroll, 

405 **kwargs, 

406 ) 

407 self.activity_regularizer = regularizers.get(activity_regularizer) 

408 self.input_spec = [InputSpec(ndim=3)] 

409 

410 def call(self, inputs, mask=None, training=None, initial_state=None): 

411 return super().call( 

412 inputs, mask=mask, training=training, initial_state=initial_state 

413 ) 

414 

415 @property 

416 def units(self): 

417 return self.cell.units 

418 

419 @property 

420 def activation(self): 

421 return self.cell.activation 

422 

423 @property 

424 def use_bias(self): 

425 return self.cell.use_bias 

426 

427 @property 

428 def kernel_initializer(self): 

429 return self.cell.kernel_initializer 

430 

431 @property 

432 def recurrent_initializer(self): 

433 return self.cell.recurrent_initializer 

434 

435 @property 

436 def bias_initializer(self): 

437 return self.cell.bias_initializer 

438 

439 @property 

440 def kernel_regularizer(self): 

441 return self.cell.kernel_regularizer 

442 

443 @property 

444 def recurrent_regularizer(self): 

445 return self.cell.recurrent_regularizer 

446 

447 @property 

448 def bias_regularizer(self): 

449 return self.cell.bias_regularizer 

450 

451 @property 

452 def kernel_constraint(self): 

453 return self.cell.kernel_constraint 

454 

455 @property 

456 def recurrent_constraint(self): 

457 return self.cell.recurrent_constraint 

458 

459 @property 

460 def bias_constraint(self): 

461 return self.cell.bias_constraint 

462 

463 @property 

464 def dropout(self): 

465 return self.cell.dropout 

466 

467 @property 

468 def recurrent_dropout(self): 

469 return self.cell.recurrent_dropout 

470 

471 def get_config(self): 

472 config = { 

473 "units": self.units, 

474 "activation": activations.serialize(self.activation), 

475 "use_bias": self.use_bias, 

476 "kernel_initializer": initializers.serialize( 

477 self.kernel_initializer 

478 ), 

479 "recurrent_initializer": initializers.serialize( 

480 self.recurrent_initializer 

481 ), 

482 "bias_initializer": initializers.serialize(self.bias_initializer), 

483 "kernel_regularizer": regularizers.serialize( 

484 self.kernel_regularizer 

485 ), 

486 "recurrent_regularizer": regularizers.serialize( 

487 self.recurrent_regularizer 

488 ), 

489 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

490 "activity_regularizer": regularizers.serialize( 

491 self.activity_regularizer 

492 ), 

493 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

494 "recurrent_constraint": constraints.serialize( 

495 self.recurrent_constraint 

496 ), 

497 "bias_constraint": constraints.serialize(self.bias_constraint), 

498 "dropout": self.dropout, 

499 "recurrent_dropout": self.recurrent_dropout, 

500 } 

501 base_config = super().get_config() 

502 config.update(rnn_utils.config_for_enable_caching_device(self.cell)) 

503 del base_config["cell"] 

504 return dict(list(base_config.items()) + list(config.items())) 

505 

506 @classmethod 

507 def from_config(cls, config): 

508 if "implementation" in config: 

509 config.pop("implementation") 

510 return cls(**config) 

511