Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/layers/recurrent.py: 23%

1073 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=protected-access 

16# pylint: disable=g-classes-have-attributes 

17"""Recurrent layers and their base classes.""" 

18 

19import collections 

20import warnings 

21 

22import numpy as np 

23 

24from tensorflow.python.distribute import distribute_lib 

25from tensorflow.python.eager import context 

26from tensorflow.python.framework import ops 

27from tensorflow.python.framework import tensor_shape 

28from tensorflow.python.keras import activations 

29from tensorflow.python.keras import backend 

30from tensorflow.python.keras import constraints 

31from tensorflow.python.keras import initializers 

32from tensorflow.python.keras import regularizers 

33from tensorflow.python.keras.engine.base_layer import Layer 

34from tensorflow.python.keras.engine.input_spec import InputSpec 

35from tensorflow.python.keras.saving.saved_model import layer_serialization 

36from tensorflow.python.keras.utils import control_flow_util 

37from tensorflow.python.keras.utils import generic_utils 

38from tensorflow.python.keras.utils import tf_utils 

39from tensorflow.python.ops import array_ops 

40from tensorflow.python.ops import array_ops_stack 

41from tensorflow.python.ops import cond 

42from tensorflow.python.ops import math_ops 

43from tensorflow.python.ops import state_ops 

44from tensorflow.python.platform import tf_logging as logging 

45from tensorflow.python.trackable import base as trackable 

46from tensorflow.python.util import nest 

47from tensorflow.python.util.tf_export import keras_export 

48from tensorflow.tools.docs import doc_controls 

49 

50 

51RECURRENT_DROPOUT_WARNING_MSG = ( 

52 'RNN `implementation=2` is not supported when `recurrent_dropout` is set. ' 

53 'Using `implementation=1`.') 

54 

55 

56@keras_export('keras.layers.StackedRNNCells') 

57class StackedRNNCells(Layer): 

58 """Wrapper allowing a stack of RNN cells to behave as a single cell. 

59 

60 Used to implement efficient stacked RNNs. 

61 

62 Args: 

63 cells: List of RNN cell instances. 

64 

65 Examples: 

66 

67 ```python 

68 batch_size = 3 

69 sentence_max_length = 5 

70 n_features = 2 

71 new_shape = (batch_size, sentence_max_length, n_features) 

72 x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32) 

73 

74 rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)] 

75 stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells) 

76 lstm_layer = tf.keras.layers.RNN(stacked_lstm) 

77 

78 result = lstm_layer(x) 

79 ``` 

80 """ 

81 

82 def __init__(self, cells, **kwargs): 

83 for cell in cells: 

84 if not 'call' in dir(cell): 

85 raise ValueError('All cells must have a `call` method. ' 

86 'received cells:', cells) 

87 if not 'state_size' in dir(cell): 

88 raise ValueError('All cells must have a ' 

89 '`state_size` attribute. ' 

90 'received cells:', cells) 

91 self.cells = cells 

92 # reverse_state_order determines whether the state size will be in a reverse 

93 # order of the cells' state. User might want to set this to True to keep the 

94 # existing behavior. This is only useful when use RNN(return_state=True) 

95 # since the state will be returned as the same order of state_size. 

96 self.reverse_state_order = kwargs.pop('reverse_state_order', False) 

97 if self.reverse_state_order: 

98 logging.warning('reverse_state_order=True in StackedRNNCells will soon ' 

99 'be deprecated. Please update the code to work with the ' 

100 'natural order of states if you rely on the RNN states, ' 

101 'eg RNN(return_state=True).') 

102 super(StackedRNNCells, self).__init__(**kwargs) 

103 

104 @property 

105 def state_size(self): 

106 return tuple(c.state_size for c in 

107 (self.cells[::-1] if self.reverse_state_order else self.cells)) 

108 

109 @property 

110 def output_size(self): 

111 if getattr(self.cells[-1], 'output_size', None) is not None: 

112 return self.cells[-1].output_size 

113 elif _is_multiple_state(self.cells[-1].state_size): 

114 return self.cells[-1].state_size[0] 

115 else: 

116 return self.cells[-1].state_size 

117 

118 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

119 initial_states = [] 

120 for cell in self.cells[::-1] if self.reverse_state_order else self.cells: 

121 get_initial_state_fn = getattr(cell, 'get_initial_state', None) 

122 if get_initial_state_fn: 

123 initial_states.append(get_initial_state_fn( 

124 inputs=inputs, batch_size=batch_size, dtype=dtype)) 

125 else: 

126 initial_states.append(_generate_zero_filled_state_for_cell( 

127 cell, inputs, batch_size, dtype)) 

128 

129 return tuple(initial_states) 

130 

131 def call(self, inputs, states, constants=None, training=None, **kwargs): 

132 # Recover per-cell states. 

133 state_size = (self.state_size[::-1] 

134 if self.reverse_state_order else self.state_size) 

135 nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) 

136 

137 # Call the cells in order and store the returned states. 

138 new_nested_states = [] 

139 for cell, states in zip(self.cells, nested_states): 

140 states = states if nest.is_nested(states) else [states] 

141 # TF cell does not wrap the state into list when there is only one state. 

142 is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None 

143 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 

144 if generic_utils.has_arg(cell.call, 'training'): 

145 kwargs['training'] = training 

146 else: 

147 kwargs.pop('training', None) 

148 # Use the __call__ function for callable objects, eg layers, so that it 

149 # will have the proper name scopes for the ops, etc. 

150 cell_call_fn = cell.__call__ if callable(cell) else cell.call 

151 if generic_utils.has_arg(cell.call, 'constants'): 

152 inputs, states = cell_call_fn(inputs, states, 

153 constants=constants, **kwargs) 

154 else: 

155 inputs, states = cell_call_fn(inputs, states, **kwargs) 

156 new_nested_states.append(states) 

157 

158 return inputs, nest.pack_sequence_as(state_size, 

159 nest.flatten(new_nested_states)) 

160 

161 @tf_utils.shape_type_conversion 

162 def build(self, input_shape): 

163 if isinstance(input_shape, list): 

164 input_shape = input_shape[0] 

165 for cell in self.cells: 

166 if isinstance(cell, Layer) and not cell.built: 

167 with backend.name_scope(cell.name): 

168 cell.build(input_shape) 

169 cell.built = True 

170 if getattr(cell, 'output_size', None) is not None: 

171 output_dim = cell.output_size 

172 elif _is_multiple_state(cell.state_size): 

173 output_dim = cell.state_size[0] 

174 else: 

175 output_dim = cell.state_size 

176 input_shape = tuple([input_shape[0]] + 

177 tensor_shape.TensorShape(output_dim).as_list()) 

178 self.built = True 

179 

180 def get_config(self): 

181 cells = [] 

182 for cell in self.cells: 

183 cells.append(generic_utils.serialize_keras_object(cell)) 

184 config = {'cells': cells} 

185 base_config = super(StackedRNNCells, self).get_config() 

186 return dict(list(base_config.items()) + list(config.items())) 

187 

188 @classmethod 

189 def from_config(cls, config, custom_objects=None): 

190 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 

191 cells = [] 

192 for cell_config in config.pop('cells'): 

193 cells.append( 

194 deserialize_layer(cell_config, custom_objects=custom_objects)) 

195 return cls(cells, **config) 

196 

197 

198@keras_export('keras.layers.RNN') 

199class RNN(Layer): 

200 """Base class for recurrent layers. 

201 

202 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 

203 for details about the usage of RNN API. 

204 

205 Args: 

206 cell: A RNN cell instance or a list of RNN cell instances. 

207 A RNN cell is a class that has: 

208 - A `call(input_at_t, states_at_t)` method, returning 

209 `(output_at_t, states_at_t_plus_1)`. The call method of the 

210 cell can also take the optional argument `constants`, see 

211 section "Note on passing external constants" below. 

212 - A `state_size` attribute. This can be a single integer 

213 (single state) in which case it is the size of the recurrent 

214 state. This can also be a list/tuple of integers (one size per state). 

215 The `state_size` can also be TensorShape or tuple/list of 

216 TensorShape, to represent high dimension state. 

217 - A `output_size` attribute. This can be a single integer or a 

218 TensorShape, which represent the shape of the output. For backward 

219 compatible reason, if this attribute is not available for the 

220 cell, the value will be inferred by the first element of the 

221 `state_size`. 

222 - A `get_initial_state(inputs=None, batch_size=None, dtype=None)` 

223 method that creates a tensor meant to be fed to `call()` as the 

224 initial state, if the user didn't specify any initial state via other 

225 means. The returned initial state should have a shape of 

226 [batch_size, cell.state_size]. The cell might choose to create a 

227 tensor full of zeros, or full of other values based on the cell's 

228 implementation. 

229 `inputs` is the input tensor to the RNN layer, which should 

230 contain the batch size as its shape[0], and also dtype. Note that 

231 the shape[0] might be `None` during the graph construction. Either 

232 the `inputs` or the pair of `batch_size` and `dtype` are provided. 

233 `batch_size` is a scalar tensor that represents the batch size 

234 of the inputs. `dtype` is `tf.DType` that represents the dtype of 

235 the inputs. 

236 For backward compatibility, if this method is not implemented 

237 by the cell, the RNN layer will create a zero filled tensor with the 

238 size of [batch_size, cell.state_size]. 

239 In the case that `cell` is a list of RNN cell instances, the cells 

240 will be stacked on top of each other in the RNN, resulting in an 

241 efficient stacked RNN. 

242 return_sequences: Boolean (default `False`). Whether to return the last 

243 output in the output sequence, or the full sequence. 

244 return_state: Boolean (default `False`). Whether to return the last state 

245 in addition to the output. 

246 go_backwards: Boolean (default `False`). 

247 If True, process the input sequence backwards and return the 

248 reversed sequence. 

249 stateful: Boolean (default `False`). If True, the last state 

250 for each sample at index i in a batch will be used as initial 

251 state for the sample of index i in the following batch. 

252 unroll: Boolean (default `False`). 

253 If True, the network will be unrolled, else a symbolic loop will be used. 

254 Unrolling can speed-up a RNN, although it tends to be more 

255 memory-intensive. Unrolling is only suitable for short sequences. 

256 time_major: The shape format of the `inputs` and `outputs` tensors. 

257 If True, the inputs and outputs will be in shape 

258 `(timesteps, batch, ...)`, whereas in the False case, it will be 

259 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 

260 efficient because it avoids transposes at the beginning and end of the 

261 RNN calculation. However, most TensorFlow data is batch-major, so by 

262 default this function accepts input and emits output in batch-major 

263 form. 

264 zero_output_for_mask: Boolean (default `False`). 

265 Whether the output should use zeros for the masked timesteps. Note that 

266 this field is only used when `return_sequences` is True and mask is 

267 provided. It can useful if you want to reuse the raw output sequence of 

268 the RNN without interference from the masked timesteps, eg, merging 

269 bidirectional RNNs. 

270 

271 Call arguments: 

272 inputs: Input tensor. 

273 mask: Binary tensor of shape `[batch_size, timesteps]` indicating whether 

274 a given timestep should be masked. An individual `True` entry indicates 

275 that the corresponding timestep should be utilized, while a `False` 

276 entry indicates that the corresponding timestep should be ignored. 

277 training: Python boolean indicating whether the layer should behave in 

278 training mode or in inference mode. This argument is passed to the cell 

279 when calling it. This is for use with cells that use dropout. 

280 initial_state: List of initial state tensors to be passed to the first 

281 call of the cell. 

282 constants: List of constant tensors to be passed to the cell at each 

283 timestep. 

284 

285 Input shape: 

286 N-D tensor with shape `[batch_size, timesteps, ...]` or 

287 `[timesteps, batch_size, ...]` when time_major is True. 

288 

289 Output shape: 

290 - If `return_state`: a list of tensors. The first tensor is 

291 the output. The remaining tensors are the last states, 

292 each with shape `[batch_size, state_size]`, where `state_size` could 

293 be a high dimension tensor shape. 

294 - If `return_sequences`: N-D tensor with shape 

295 `[batch_size, timesteps, output_size]`, where `output_size` could 

296 be a high dimension tensor shape, or 

297 `[timesteps, batch_size, output_size]` when `time_major` is True. 

298 - Else, N-D tensor with shape `[batch_size, output_size]`, where 

299 `output_size` could be a high dimension tensor shape. 

300 

301 Masking: 

302 This layer supports masking for input data with a variable number 

303 of timesteps. To introduce masks to your data, 

304 use an [tf.keras.layers.Embedding] layer with the `mask_zero` parameter 

305 set to `True`. 

306 

307 Note on using statefulness in RNNs: 

308 You can set RNN layers to be 'stateful', which means that the states 

309 computed for the samples in one batch will be reused as initial states 

310 for the samples in the next batch. This assumes a one-to-one mapping 

311 between samples in different successive batches. 

312 

313 To enable statefulness: 

314 - Specify `stateful=True` in the layer constructor. 

315 - Specify a fixed batch size for your model, by passing 

316 If sequential model: 

317 `batch_input_shape=(...)` to the first layer in your model. 

318 Else for functional model with 1 or more Input layers: 

319 `batch_shape=(...)` to all the first layers in your model. 

320 This is the expected shape of your inputs 

321 *including the batch size*. 

322 It should be a tuple of integers, e.g. `(32, 10, 100)`. 

323 - Specify `shuffle=False` when calling `fit()`. 

324 

325 To reset the states of your model, call `.reset_states()` on either 

326 a specific layer, or on your entire model. 

327 

328 Note on specifying the initial state of RNNs: 

329 You can specify the initial state of RNN layers symbolically by 

330 calling them with the keyword argument `initial_state`. The value of 

331 `initial_state` should be a tensor or list of tensors representing 

332 the initial state of the RNN layer. 

333 

334 You can specify the initial state of RNN layers numerically by 

335 calling `reset_states` with the keyword argument `states`. The value of 

336 `states` should be a numpy array or list of numpy arrays representing 

337 the initial state of the RNN layer. 

338 

339 Note on passing external constants to RNNs: 

340 You can pass "external" constants to the cell using the `constants` 

341 keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This 

342 requires that the `cell.call` method accepts the same keyword argument 

343 `constants`. Such constants can be used to condition the cell 

344 transformation on additional static inputs (not changing over time), 

345 a.k.a. an attention mechanism. 

346 

347 Examples: 

348 

349 ```python 

350 # First, let's define a RNN Cell, as a layer subclass. 

351 

352 class MinimalRNNCell(keras.layers.Layer): 

353 

354 def __init__(self, units, **kwargs): 

355 self.units = units 

356 self.state_size = units 

357 super(MinimalRNNCell, self).__init__(**kwargs) 

358 

359 def build(self, input_shape): 

360 self.kernel = self.add_weight(shape=(input_shape[-1], self.units), 

361 initializer='uniform', 

362 name='kernel') 

363 self.recurrent_kernel = self.add_weight( 

364 shape=(self.units, self.units), 

365 initializer='uniform', 

366 name='recurrent_kernel') 

367 self.built = True 

368 

369 def call(self, inputs, states): 

370 prev_output = states[0] 

371 h = backend.dot(inputs, self.kernel) 

372 output = h + backend.dot(prev_output, self.recurrent_kernel) 

373 return output, [output] 

374 

375 # Let's use this cell in a RNN layer: 

376 

377 cell = MinimalRNNCell(32) 

378 x = keras.Input((None, 5)) 

379 layer = RNN(cell) 

380 y = layer(x) 

381 

382 # Here's how to use the cell to build a stacked RNN: 

383 

384 cells = [MinimalRNNCell(32), MinimalRNNCell(64)] 

385 x = keras.Input((None, 5)) 

386 layer = RNN(cells) 

387 y = layer(x) 

388 ``` 

389 """ 

390 

391 def __init__(self, 

392 cell, 

393 return_sequences=False, 

394 return_state=False, 

395 go_backwards=False, 

396 stateful=False, 

397 unroll=False, 

398 time_major=False, 

399 **kwargs): 

400 if isinstance(cell, (list, tuple)): 

401 cell = StackedRNNCells(cell) 

402 if not 'call' in dir(cell): 

403 raise ValueError('`cell` should have a `call` method. ' 

404 'The RNN was passed:', cell) 

405 if not 'state_size' in dir(cell): 

406 raise ValueError('The RNN cell should have ' 

407 'an attribute `state_size` ' 

408 '(tuple of integers, ' 

409 'one integer per RNN state).') 

410 # If True, the output for masked timestep will be zeros, whereas in the 

411 # False case, output from previous timestep is returned for masked timestep. 

412 self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False) 

413 

414 if 'input_shape' not in kwargs and ( 

415 'input_dim' in kwargs or 'input_length' in kwargs): 

416 input_shape = (kwargs.pop('input_length', None), 

417 kwargs.pop('input_dim', None)) 

418 kwargs['input_shape'] = input_shape 

419 

420 super(RNN, self).__init__(**kwargs) 

421 self.cell = cell 

422 self.return_sequences = return_sequences 

423 self.return_state = return_state 

424 self.go_backwards = go_backwards 

425 self.stateful = stateful 

426 self.unroll = unroll 

427 self.time_major = time_major 

428 

429 self.supports_masking = True 

430 # The input shape is unknown yet, it could have nested tensor inputs, and 

431 # the input spec will be the list of specs for nested inputs, the structure 

432 # of the input_spec will be the same as the input. 

433 self.input_spec = None 

434 self.state_spec = None 

435 self._states = None 

436 self.constants_spec = None 

437 self._num_constants = 0 

438 

439 if stateful: 

440 if distribute_lib.has_strategy(): 

441 raise ValueError('RNNs with stateful=True not yet supported with ' 

442 'tf.distribute.Strategy.') 

443 

444 @property 

445 def _use_input_spec_as_call_signature(self): 

446 if self.unroll: 

447 # When the RNN layer is unrolled, the time step shape cannot be unknown. 

448 # The input spec does not define the time step (because this layer can be 

449 # called with any time step value, as long as it is not None), so it 

450 # cannot be used as the call function signature when saving to SavedModel. 

451 return False 

452 return super(RNN, self)._use_input_spec_as_call_signature 

453 

454 @property 

455 def states(self): 

456 if self._states is None: 

457 state = nest.map_structure(lambda _: None, self.cell.state_size) 

458 return state if nest.is_nested(self.cell.state_size) else [state] 

459 return self._states 

460 

461 @states.setter 

462 # Automatic tracking catches "self._states" which adds an extra weight and 

463 # breaks HDF5 checkpoints. 

464 @trackable.no_automatic_dependency_tracking 

465 def states(self, states): 

466 self._states = states 

467 

468 def compute_output_shape(self, input_shape): 

469 if isinstance(input_shape, list): 

470 input_shape = input_shape[0] 

471 # Check whether the input shape contains any nested shapes. It could be 

472 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy 

473 # inputs. 

474 try: 

475 input_shape = tensor_shape.TensorShape(input_shape) 

476 except (ValueError, TypeError): 

477 # A nested tensor input 

478 input_shape = nest.flatten(input_shape)[0] 

479 

480 batch = input_shape[0] 

481 time_step = input_shape[1] 

482 if self.time_major: 

483 batch, time_step = time_step, batch 

484 

485 if _is_multiple_state(self.cell.state_size): 

486 state_size = self.cell.state_size 

487 else: 

488 state_size = [self.cell.state_size] 

489 

490 def _get_output_shape(flat_output_size): 

491 output_dim = tensor_shape.TensorShape(flat_output_size).as_list() 

492 if self.return_sequences: 

493 if self.time_major: 

494 output_shape = tensor_shape.TensorShape( 

495 [time_step, batch] + output_dim) 

496 else: 

497 output_shape = tensor_shape.TensorShape( 

498 [batch, time_step] + output_dim) 

499 else: 

500 output_shape = tensor_shape.TensorShape([batch] + output_dim) 

501 return output_shape 

502 

503 if getattr(self.cell, 'output_size', None) is not None: 

504 # cell.output_size could be nested structure. 

505 output_shape = nest.flatten(nest.map_structure( 

506 _get_output_shape, self.cell.output_size)) 

507 output_shape = output_shape[0] if len(output_shape) == 1 else output_shape 

508 else: 

509 # Note that state_size[0] could be a tensor_shape or int. 

510 output_shape = _get_output_shape(state_size[0]) 

511 

512 if self.return_state: 

513 def _get_state_shape(flat_state): 

514 state_shape = [batch] + tensor_shape.TensorShape(flat_state).as_list() 

515 return tensor_shape.TensorShape(state_shape) 

516 state_shape = nest.map_structure(_get_state_shape, state_size) 

517 return generic_utils.to_list(output_shape) + nest.flatten(state_shape) 

518 else: 

519 return output_shape 

520 

521 def compute_mask(self, inputs, mask): 

522 # Time step masks must be the same for each input. 

523 # This is because the mask for an RNN is of size [batch, time_steps, 1], 

524 # and specifies which time steps should be skipped, and a time step 

525 # must be skipped for all inputs. 

526 # TODO(scottzhu): Should we accept multiple different masks? 

527 mask = nest.flatten(mask)[0] 

528 output_mask = mask if self.return_sequences else None 

529 if self.return_state: 

530 state_mask = [None for _ in self.states] 

531 return [output_mask] + state_mask 

532 else: 

533 return output_mask 

534 

535 def build(self, input_shape): 

536 if isinstance(input_shape, list): 

537 input_shape = input_shape[0] 

538 # The input_shape here could be a nest structure. 

539 

540 # do the tensor_shape to shapes here. The input could be single tensor, or a 

541 # nested structure of tensors. 

542 def get_input_spec(shape): 

543 """Convert input shape to InputSpec.""" 

544 if isinstance(shape, tensor_shape.TensorShape): 

545 input_spec_shape = shape.as_list() 

546 else: 

547 input_spec_shape = list(shape) 

548 batch_index, time_step_index = (1, 0) if self.time_major else (0, 1) 

549 if not self.stateful: 

550 input_spec_shape[batch_index] = None 

551 input_spec_shape[time_step_index] = None 

552 return InputSpec(shape=tuple(input_spec_shape)) 

553 

554 def get_step_input_shape(shape): 

555 if isinstance(shape, tensor_shape.TensorShape): 

556 shape = tuple(shape.as_list()) 

557 # remove the timestep from the input_shape 

558 return shape[1:] if self.time_major else (shape[0],) + shape[2:] 

559 

560 # Check whether the input shape contains any nested shapes. It could be 

561 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy 

562 # inputs. 

563 try: 

564 input_shape = tensor_shape.TensorShape(input_shape) 

565 except (ValueError, TypeError): 

566 # A nested tensor input 

567 pass 

568 

569 if not nest.is_nested(input_shape): 

570 # This indicates the there is only one input. 

571 if self.input_spec is not None: 

572 self.input_spec[0] = get_input_spec(input_shape) 

573 else: 

574 self.input_spec = [get_input_spec(input_shape)] 

575 step_input_shape = get_step_input_shape(input_shape) 

576 else: 

577 if self.input_spec is not None: 

578 self.input_spec[0] = nest.map_structure(get_input_spec, input_shape) 

579 else: 

580 self.input_spec = generic_utils.to_list( 

581 nest.map_structure(get_input_spec, input_shape)) 

582 step_input_shape = nest.map_structure(get_step_input_shape, input_shape) 

583 

584 # allow cell (if layer) to build before we set or validate state_spec. 

585 if isinstance(self.cell, Layer) and not self.cell.built: 

586 with backend.name_scope(self.cell.name): 

587 self.cell.build(step_input_shape) 

588 self.cell.built = True 

589 

590 # set or validate state_spec 

591 if _is_multiple_state(self.cell.state_size): 

592 state_size = list(self.cell.state_size) 

593 else: 

594 state_size = [self.cell.state_size] 

595 

596 if self.state_spec is not None: 

597 # initial_state was passed in call, check compatibility 

598 self._validate_state_spec(state_size, self.state_spec) 

599 else: 

600 self.state_spec = [ 

601 InputSpec(shape=[None] + tensor_shape.TensorShape(dim).as_list()) 

602 for dim in state_size 

603 ] 

604 if self.stateful: 

605 self.reset_states() 

606 self.built = True 

607 

608 @staticmethod 

609 def _validate_state_spec(cell_state_sizes, init_state_specs): 

610 """Validate the state spec between the initial_state and the state_size. 

611 

612 Args: 

613 cell_state_sizes: list, the `state_size` attribute from the cell. 

614 init_state_specs: list, the `state_spec` from the initial_state that is 

615 passed in `call()`. 

616 

617 Raises: 

618 ValueError: When initial state spec is not compatible with the state size. 

619 """ 

620 validation_error = ValueError( 

621 'An `initial_state` was passed that is not compatible with ' 

622 '`cell.state_size`. Received `state_spec`={}; ' 

623 'however `cell.state_size` is ' 

624 '{}'.format(init_state_specs, cell_state_sizes)) 

625 flat_cell_state_sizes = nest.flatten(cell_state_sizes) 

626 flat_state_specs = nest.flatten(init_state_specs) 

627 

628 if len(flat_cell_state_sizes) != len(flat_state_specs): 

629 raise validation_error 

630 for cell_state_spec, cell_state_size in zip(flat_state_specs, 

631 flat_cell_state_sizes): 

632 if not tensor_shape.TensorShape( 

633 # Ignore the first axis for init_state which is for batch 

634 cell_state_spec.shape[1:]).is_compatible_with( 

635 tensor_shape.TensorShape(cell_state_size)): 

636 raise validation_error 

637 

638 @doc_controls.do_not_doc_inheritable 

639 def get_initial_state(self, inputs): 

640 get_initial_state_fn = getattr(self.cell, 'get_initial_state', None) 

641 

642 if nest.is_nested(inputs): 

643 # The input are nested sequences. Use the first element in the seq to get 

644 # batch size and dtype. 

645 inputs = nest.flatten(inputs)[0] 

646 

647 input_shape = array_ops.shape(inputs) 

648 batch_size = input_shape[1] if self.time_major else input_shape[0] 

649 dtype = inputs.dtype 

650 if get_initial_state_fn: 

651 init_state = get_initial_state_fn( 

652 inputs=None, batch_size=batch_size, dtype=dtype) 

653 else: 

654 init_state = _generate_zero_filled_state(batch_size, self.cell.state_size, 

655 dtype) 

656 # Keras RNN expect the states in a list, even if it's a single state tensor. 

657 if not nest.is_nested(init_state): 

658 init_state = [init_state] 

659 # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple. 

660 return list(init_state) 

661 

662 def __call__(self, inputs, initial_state=None, constants=None, **kwargs): 

663 inputs, initial_state, constants = _standardize_args(inputs, 

664 initial_state, 

665 constants, 

666 self._num_constants) 

667 

668 if initial_state is None and constants is None: 

669 return super(RNN, self).__call__(inputs, **kwargs) 

670 

671 # If any of `initial_state` or `constants` are specified and are Keras 

672 # tensors, then add them to the inputs and temporarily modify the 

673 # input_spec to include them. 

674 

675 additional_inputs = [] 

676 additional_specs = [] 

677 if initial_state is not None: 

678 additional_inputs += initial_state 

679 self.state_spec = nest.map_structure( 

680 lambda s: InputSpec(shape=backend.int_shape(s)), initial_state) 

681 additional_specs += self.state_spec 

682 if constants is not None: 

683 additional_inputs += constants 

684 self.constants_spec = [ 

685 InputSpec(shape=backend.int_shape(constant)) for constant in constants 

686 ] 

687 self._num_constants = len(constants) 

688 additional_specs += self.constants_spec 

689 # additional_inputs can be empty if initial_state or constants are provided 

690 # but empty (e.g. the cell is stateless). 

691 flat_additional_inputs = nest.flatten(additional_inputs) 

692 is_keras_tensor = backend.is_keras_tensor( 

693 flat_additional_inputs[0]) if flat_additional_inputs else True 

694 for tensor in flat_additional_inputs: 

695 if backend.is_keras_tensor(tensor) != is_keras_tensor: 

696 raise ValueError('The initial state or constants of an RNN' 

697 ' layer cannot be specified with a mix of' 

698 ' Keras tensors and non-Keras tensors' 

699 ' (a "Keras tensor" is a tensor that was' 

700 ' returned by a Keras layer, or by `Input`)') 

701 

702 if is_keras_tensor: 

703 # Compute the full input spec, including state and constants 

704 full_input = [inputs] + additional_inputs 

705 if self.built: 

706 # Keep the input_spec since it has been populated in build() method. 

707 full_input_spec = self.input_spec + additional_specs 

708 else: 

709 # The original input_spec is None since there could be a nested tensor 

710 # input. Update the input_spec to match the inputs. 

711 full_input_spec = generic_utils.to_list( 

712 nest.map_structure(lambda _: None, inputs)) + additional_specs 

713 # Perform the call with temporarily replaced input_spec 

714 self.input_spec = full_input_spec 

715 output = super(RNN, self).__call__(full_input, **kwargs) 

716 # Remove the additional_specs from input spec and keep the rest. It is 

717 # important to keep since the input spec was populated by build(), and 

718 # will be reused in the stateful=True. 

719 self.input_spec = self.input_spec[:-len(additional_specs)] 

720 return output 

721 else: 

722 if initial_state is not None: 

723 kwargs['initial_state'] = initial_state 

724 if constants is not None: 

725 kwargs['constants'] = constants 

726 return super(RNN, self).__call__(inputs, **kwargs) 

727 

728 def call(self, 

729 inputs, 

730 mask=None, 

731 training=None, 

732 initial_state=None, 

733 constants=None): 

734 # The input should be dense, padded with zeros. If a ragged input is fed 

735 # into the layer, it is padded and the row lengths are used for masking. 

736 inputs, row_lengths = backend.convert_inputs_if_ragged(inputs) 

737 is_ragged_input = (row_lengths is not None) 

738 self._validate_args_if_ragged(is_ragged_input, mask) 

739 

740 inputs, initial_state, constants = self._process_inputs( 

741 inputs, initial_state, constants) 

742 

743 self._maybe_reset_cell_dropout_mask(self.cell) 

744 if isinstance(self.cell, StackedRNNCells): 

745 for cell in self.cell.cells: 

746 self._maybe_reset_cell_dropout_mask(cell) 

747 

748 if mask is not None: 

749 # Time step masks must be the same for each input. 

750 # TODO(scottzhu): Should we accept multiple different masks? 

751 mask = nest.flatten(mask)[0] 

752 

753 if nest.is_nested(inputs): 

754 # In the case of nested input, use the first element for shape check. 

755 input_shape = backend.int_shape(nest.flatten(inputs)[0]) 

756 else: 

757 input_shape = backend.int_shape(inputs) 

758 timesteps = input_shape[0] if self.time_major else input_shape[1] 

759 if self.unroll and timesteps is None: 

760 raise ValueError('Cannot unroll a RNN if the ' 

761 'time dimension is undefined. \n' 

762 '- If using a Sequential model, ' 

763 'specify the time dimension by passing ' 

764 'an `input_shape` or `batch_input_shape` ' 

765 'argument to your first layer. If your ' 

766 'first layer is an Embedding, you can ' 

767 'also use the `input_length` argument.\n' 

768 '- If using the functional API, specify ' 

769 'the time dimension by passing a `shape` ' 

770 'or `batch_shape` argument to your Input layer.') 

771 

772 kwargs = {} 

773 if generic_utils.has_arg(self.cell.call, 'training'): 

774 kwargs['training'] = training 

775 

776 # TF RNN cells expect single tensor as state instead of list wrapped tensor. 

777 is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None 

778 # Use the __call__ function for callable objects, eg layers, so that it 

779 # will have the proper name scopes for the ops, etc. 

780 cell_call_fn = self.cell.__call__ if callable(self.cell) else self.cell.call 

781 if constants: 

782 if not generic_utils.has_arg(self.cell.call, 'constants'): 

783 raise ValueError('RNN cell does not support constants') 

784 

785 def step(inputs, states): 

786 constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type 

787 states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type 

788 

789 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 

790 output, new_states = cell_call_fn( 

791 inputs, states, constants=constants, **kwargs) 

792 if not nest.is_nested(new_states): 

793 new_states = [new_states] 

794 return output, new_states 

795 else: 

796 

797 def step(inputs, states): 

798 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 

799 output, new_states = cell_call_fn(inputs, states, **kwargs) 

800 if not nest.is_nested(new_states): 

801 new_states = [new_states] 

802 return output, new_states 

803 last_output, outputs, states = backend.rnn( 

804 step, 

805 inputs, 

806 initial_state, 

807 constants=constants, 

808 go_backwards=self.go_backwards, 

809 mask=mask, 

810 unroll=self.unroll, 

811 input_length=row_lengths if row_lengths is not None else timesteps, 

812 time_major=self.time_major, 

813 zero_output_for_mask=self.zero_output_for_mask) 

814 

815 if self.stateful: 

816 updates = [ 

817 state_ops.assign(self_state, state) for self_state, state in zip( 

818 nest.flatten(self.states), nest.flatten(states)) 

819 ] 

820 self.add_update(updates) 

821 

822 if self.return_sequences: 

823 output = backend.maybe_convert_to_ragged( 

824 is_ragged_input, outputs, row_lengths, go_backwards=self.go_backwards) 

825 else: 

826 output = last_output 

827 

828 if self.return_state: 

829 if not isinstance(states, (list, tuple)): 

830 states = [states] 

831 else: 

832 states = list(states) 

833 return generic_utils.to_list(output) + states 

834 else: 

835 return output 

836 

837 def _process_inputs(self, inputs, initial_state, constants): 

838 # input shape: `(samples, time (padded with zeros), input_dim)` 

839 # note that the .build() method of subclasses MUST define 

840 # self.input_spec and self.state_spec with complete input shapes. 

841 if (isinstance(inputs, collections.abc.Sequence) 

842 and not isinstance(inputs, tuple)): 

843 # get initial_state from full input spec 

844 # as they could be copied to multiple GPU. 

845 if not self._num_constants: 

846 initial_state = inputs[1:] 

847 else: 

848 initial_state = inputs[1:-self._num_constants] 

849 constants = inputs[-self._num_constants:] 

850 if len(initial_state) == 0: 

851 initial_state = None 

852 inputs = inputs[0] 

853 

854 if self.stateful: 

855 if initial_state is not None: 

856 # When layer is stateful and initial_state is provided, check if the 

857 # recorded state is same as the default value (zeros). Use the recorded 

858 # state if it is not same as the default. 

859 non_zero_count = math_ops.add_n([math_ops.count_nonzero_v2(s) 

860 for s in nest.flatten(self.states)]) 

861 # Set strict = True to keep the original structure of the state. 

862 initial_state = cond.cond(non_zero_count > 0, 

863 true_fn=lambda: self.states, 

864 false_fn=lambda: initial_state, 

865 strict=True) 

866 else: 

867 initial_state = self.states 

868 elif initial_state is None: 

869 initial_state = self.get_initial_state(inputs) 

870 

871 if len(initial_state) != len(self.states): 

872 raise ValueError('Layer has ' + str(len(self.states)) + 

873 ' states but was passed ' + str(len(initial_state)) + 

874 ' initial states.') 

875 return inputs, initial_state, constants 

876 

877 def _validate_args_if_ragged(self, is_ragged_input, mask): 

878 if not is_ragged_input: 

879 return 

880 

881 if mask is not None: 

882 raise ValueError('The mask that was passed in was ' + str(mask) + 

883 ' and cannot be applied to RaggedTensor inputs. Please ' 

884 'make sure that there is no mask passed in by upstream ' 

885 'layers.') 

886 if self.unroll: 

887 raise ValueError('The input received contains RaggedTensors and does ' 

888 'not support unrolling. Disable unrolling by passing ' 

889 '`unroll=False` in the RNN Layer constructor.') 

890 

891 def _maybe_reset_cell_dropout_mask(self, cell): 

892 if isinstance(cell, DropoutRNNCellMixin): 

893 cell.reset_dropout_mask() 

894 cell.reset_recurrent_dropout_mask() 

895 

896 def reset_states(self, states=None): 

897 """Reset the recorded states for the stateful RNN layer. 

898 

899 Can only be used when RNN layer is constructed with `stateful` = `True`. 

900 Args: 

901 states: Numpy arrays that contains the value for the initial state, which 

902 will be feed to cell at the first time step. When the value is None, 

903 zero filled numpy array will be created based on the cell state size. 

904 

905 Raises: 

906 AttributeError: When the RNN layer is not stateful. 

907 ValueError: When the batch size of the RNN layer is unknown. 

908 ValueError: When the input numpy array is not compatible with the RNN 

909 layer state, either size wise or dtype wise. 

910 """ 

911 if not self.stateful: 

912 raise AttributeError('Layer must be stateful.') 

913 spec_shape = None 

914 if self.input_spec is not None: 

915 spec_shape = nest.flatten(self.input_spec[0])[0].shape 

916 if spec_shape is None: 

917 # It is possible to have spec shape to be None, eg when construct a RNN 

918 # with a custom cell, or standard RNN layers (LSTM/GRU) which we only know 

919 # it has 3 dim input, but not its full shape spec before build(). 

920 batch_size = None 

921 else: 

922 batch_size = spec_shape[1] if self.time_major else spec_shape[0] 

923 if not batch_size: 

924 raise ValueError('If a RNN is stateful, it needs to know ' 

925 'its batch size. Specify the batch size ' 

926 'of your input tensors: \n' 

927 '- If using a Sequential model, ' 

928 'specify the batch size by passing ' 

929 'a `batch_input_shape` ' 

930 'argument to your first layer.\n' 

931 '- If using the functional API, specify ' 

932 'the batch size by passing a ' 

933 '`batch_shape` argument to your Input layer.') 

934 # initialize state if None 

935 if nest.flatten(self.states)[0] is None: 

936 if getattr(self.cell, 'get_initial_state', None): 

937 flat_init_state_values = nest.flatten(self.cell.get_initial_state( 

938 inputs=None, batch_size=batch_size, 

939 dtype=self.dtype or backend.floatx())) 

940 else: 

941 flat_init_state_values = nest.flatten(_generate_zero_filled_state( 

942 batch_size, self.cell.state_size, self.dtype or backend.floatx())) 

943 flat_states_variables = nest.map_structure( 

944 backend.variable, flat_init_state_values) 

945 self.states = nest.pack_sequence_as(self.cell.state_size, 

946 flat_states_variables) 

947 if not nest.is_nested(self.states): 

948 self.states = [self.states] 

949 elif states is None: 

950 for state, size in zip(nest.flatten(self.states), 

951 nest.flatten(self.cell.state_size)): 

952 backend.set_value( 

953 state, 

954 np.zeros([batch_size] + tensor_shape.TensorShape(size).as_list())) 

955 else: 

956 flat_states = nest.flatten(self.states) 

957 flat_input_states = nest.flatten(states) 

958 if len(flat_input_states) != len(flat_states): 

959 raise ValueError('Layer ' + self.name + ' expects ' + 

960 str(len(flat_states)) + ' states, ' 

961 'but it received ' + str(len(flat_input_states)) + 

962 ' state values. Input received: ' + str(states)) 

963 set_value_tuples = [] 

964 for i, (value, state) in enumerate(zip(flat_input_states, 

965 flat_states)): 

966 if value.shape != state.shape: 

967 raise ValueError( 

968 'State ' + str(i) + ' is incompatible with layer ' + 

969 self.name + ': expected shape=' + str( 

970 (batch_size, state)) + ', found shape=' + str(value.shape)) 

971 set_value_tuples.append((state, value)) 

972 backend.batch_set_value(set_value_tuples) 

973 

974 def get_config(self): 

975 config = { 

976 'return_sequences': self.return_sequences, 

977 'return_state': self.return_state, 

978 'go_backwards': self.go_backwards, 

979 'stateful': self.stateful, 

980 'unroll': self.unroll, 

981 'time_major': self.time_major 

982 } 

983 if self._num_constants: 

984 config['num_constants'] = self._num_constants 

985 if self.zero_output_for_mask: 

986 config['zero_output_for_mask'] = self.zero_output_for_mask 

987 

988 config['cell'] = generic_utils.serialize_keras_object(self.cell) 

989 base_config = super(RNN, self).get_config() 

990 return dict(list(base_config.items()) + list(config.items())) 

991 

992 @classmethod 

993 def from_config(cls, config, custom_objects=None): 

994 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 

995 cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects) 

996 num_constants = config.pop('num_constants', 0) 

997 layer = cls(cell, **config) 

998 layer._num_constants = num_constants 

999 return layer 

1000 

1001 @property 

1002 def _trackable_saved_model_saver(self): 

1003 return layer_serialization.RNNSavedModelSaver(self) 

1004 

1005 

1006@keras_export('keras.layers.AbstractRNNCell') 

1007class AbstractRNNCell(Layer): 

1008 """Abstract object representing an RNN cell. 

1009 

1010 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 

1011 for details about the usage of RNN API. 

1012 

1013 This is the base class for implementing RNN cells with custom behavior. 

1014 

1015 Every `RNNCell` must have the properties below and implement `call` with 

1016 the signature `(output, next_state) = call(input, state)`. 

1017 

1018 Examples: 

1019 

1020 ```python 

1021 class MinimalRNNCell(AbstractRNNCell): 

1022 

1023 def __init__(self, units, **kwargs): 

1024 self.units = units 

1025 super(MinimalRNNCell, self).__init__(**kwargs) 

1026 

1027 @property 

1028 def state_size(self): 

1029 return self.units 

1030 

1031 def build(self, input_shape): 

1032 self.kernel = self.add_weight(shape=(input_shape[-1], self.units), 

1033 initializer='uniform', 

1034 name='kernel') 

1035 self.recurrent_kernel = self.add_weight( 

1036 shape=(self.units, self.units), 

1037 initializer='uniform', 

1038 name='recurrent_kernel') 

1039 self.built = True 

1040 

1041 def call(self, inputs, states): 

1042 prev_output = states[0] 

1043 h = backend.dot(inputs, self.kernel) 

1044 output = h + backend.dot(prev_output, self.recurrent_kernel) 

1045 return output, output 

1046 ``` 

1047 

1048 This definition of cell differs from the definition used in the literature. 

1049 In the literature, 'cell' refers to an object with a single scalar output. 

1050 This definition refers to a horizontal array of such units. 

1051 

1052 An RNN cell, in the most abstract setting, is anything that has 

1053 a state and performs some operation that takes a matrix of inputs. 

1054 This operation results in an output matrix with `self.output_size` columns. 

1055 If `self.state_size` is an integer, this operation also results in a new 

1056 state matrix with `self.state_size` columns. If `self.state_size` is a 

1057 (possibly nested tuple of) TensorShape object(s), then it should return a 

1058 matching structure of Tensors having shape `[batch_size].concatenate(s)` 

1059 for each `s` in `self.batch_size`. 

1060 """ 

1061 

1062 def call(self, inputs, states): 

1063 """The function that contains the logic for one RNN step calculation. 

1064 

1065 Args: 

1066 inputs: the input tensor, which is a slide from the overall RNN input by 

1067 the time dimension (usually the second dimension). 

1068 states: the state tensor from previous step, which has the same shape 

1069 as `(batch, state_size)`. In the case of timestep 0, it will be the 

1070 initial state user specified, or zero filled tensor otherwise. 

1071 

1072 Returns: 

1073 A tuple of two tensors: 

1074 1. output tensor for the current timestep, with size `output_size`. 

1075 2. state tensor for next step, which has the shape of `state_size`. 

1076 """ 

1077 raise NotImplementedError('Abstract method') 

1078 

1079 @property 

1080 def state_size(self): 

1081 """size(s) of state(s) used by this cell. 

1082 

1083 It can be represented by an Integer, a TensorShape or a tuple of Integers 

1084 or TensorShapes. 

1085 """ 

1086 raise NotImplementedError('Abstract method') 

1087 

1088 @property 

1089 def output_size(self): 

1090 """Integer or TensorShape: size of outputs produced by this cell.""" 

1091 raise NotImplementedError('Abstract method') 

1092 

1093 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

1094 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) 

1095 

1096 

1097@doc_controls.do_not_generate_docs 

1098class DropoutRNNCellMixin(object): 

1099 """Object that hold dropout related fields for RNN Cell. 

1100 

1101 This class is not a standalone RNN cell. It suppose to be used with a RNN cell 

1102 by multiple inheritance. Any cell that mix with class should have following 

1103 fields: 

1104 dropout: a float number within range [0, 1). The ratio that the input 

1105 tensor need to dropout. 

1106 recurrent_dropout: a float number within range [0, 1). The ratio that the 

1107 recurrent state weights need to dropout. 

1108 This object will create and cache created dropout masks, and reuse them for 

1109 the incoming data, so that the same mask is used for every batch input. 

1110 """ 

1111 

1112 def __init__(self, *args, **kwargs): 

1113 self._create_non_trackable_mask_cache() 

1114 super(DropoutRNNCellMixin, self).__init__(*args, **kwargs) 

1115 

1116 @trackable.no_automatic_dependency_tracking 

1117 def _create_non_trackable_mask_cache(self): 

1118 """Create the cache for dropout and recurrent dropout mask. 

1119 

1120 Note that the following two masks will be used in "graph function" mode, 

1121 e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask` 

1122 tensors will be generated differently than in the "graph function" case, 

1123 and they will be cached. 

1124 

1125 Also note that in graph mode, we still cache those masks only because the 

1126 RNN could be created with `unroll=True`. In that case, the `cell.call()` 

1127 function will be invoked multiple times, and we want to ensure same mask 

1128 is used every time. 

1129 

1130 Also the caches are created without tracking. Since they are not picklable 

1131 by python when deepcopy, we don't want `layer._obj_reference_counts_dict` 

1132 to track it by default. 

1133 """ 

1134 self._dropout_mask_cache = backend.ContextValueCache( 

1135 self._create_dropout_mask) 

1136 self._recurrent_dropout_mask_cache = backend.ContextValueCache( 

1137 self._create_recurrent_dropout_mask) 

1138 

1139 def reset_dropout_mask(self): 

1140 """Reset the cached dropout masks if any. 

1141 

1142 This is important for the RNN layer to invoke this in it `call()` method so 

1143 that the cached mask is cleared before calling the `cell.call()`. The mask 

1144 should be cached across the timestep within the same batch, but shouldn't 

1145 be cached between batches. Otherwise it will introduce unreasonable bias 

1146 against certain index of data within the batch. 

1147 """ 

1148 self._dropout_mask_cache.clear() 

1149 

1150 def reset_recurrent_dropout_mask(self): 

1151 """Reset the cached recurrent dropout masks if any. 

1152 

1153 This is important for the RNN layer to invoke this in it call() method so 

1154 that the cached mask is cleared before calling the cell.call(). The mask 

1155 should be cached across the timestep within the same batch, but shouldn't 

1156 be cached between batches. Otherwise it will introduce unreasonable bias 

1157 against certain index of data within the batch. 

1158 """ 

1159 self._recurrent_dropout_mask_cache.clear() 

1160 

1161 def _create_dropout_mask(self, inputs, training, count=1): 

1162 return _generate_dropout_mask( 

1163 array_ops.ones_like(inputs), 

1164 self.dropout, 

1165 training=training, 

1166 count=count) 

1167 

1168 def _create_recurrent_dropout_mask(self, inputs, training, count=1): 

1169 return _generate_dropout_mask( 

1170 array_ops.ones_like(inputs), 

1171 self.recurrent_dropout, 

1172 training=training, 

1173 count=count) 

1174 

1175 def get_dropout_mask_for_cell(self, inputs, training, count=1): 

1176 """Get the dropout mask for RNN cell's input. 

1177 

1178 It will create mask based on context if there isn't any existing cached 

1179 mask. If a new mask is generated, it will update the cache in the cell. 

1180 

1181 Args: 

1182 inputs: The input tensor whose shape will be used to generate dropout 

1183 mask. 

1184 training: Boolean tensor, whether its in training mode, dropout will be 

1185 ignored in non-training mode. 

1186 count: Int, how many dropout mask will be generated. It is useful for cell 

1187 that has internal weights fused together. 

1188 Returns: 

1189 List of mask tensor, generated or cached mask based on context. 

1190 """ 

1191 if self.dropout == 0: 

1192 return None 

1193 init_kwargs = dict(inputs=inputs, training=training, count=count) 

1194 return self._dropout_mask_cache.setdefault(kwargs=init_kwargs) 

1195 

1196 def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1): 

1197 """Get the recurrent dropout mask for RNN cell. 

1198 

1199 It will create mask based on context if there isn't any existing cached 

1200 mask. If a new mask is generated, it will update the cache in the cell. 

1201 

1202 Args: 

1203 inputs: The input tensor whose shape will be used to generate dropout 

1204 mask. 

1205 training: Boolean tensor, whether its in training mode, dropout will be 

1206 ignored in non-training mode. 

1207 count: Int, how many dropout mask will be generated. It is useful for cell 

1208 that has internal weights fused together. 

1209 Returns: 

1210 List of mask tensor, generated or cached mask based on context. 

1211 """ 

1212 if self.recurrent_dropout == 0: 

1213 return None 

1214 init_kwargs = dict(inputs=inputs, training=training, count=count) 

1215 return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs) 

1216 

1217 def __getstate__(self): 

1218 # Used for deepcopy. The caching can't be pickled by python, since it will 

1219 # contain tensor and graph. 

1220 state = super(DropoutRNNCellMixin, self).__getstate__() 

1221 state.pop('_dropout_mask_cache', None) 

1222 state.pop('_recurrent_dropout_mask_cache', None) 

1223 return state 

1224 

1225 def __setstate__(self, state): 

1226 state['_dropout_mask_cache'] = backend.ContextValueCache( 

1227 self._create_dropout_mask) 

1228 state['_recurrent_dropout_mask_cache'] = backend.ContextValueCache( 

1229 self._create_recurrent_dropout_mask) 

1230 super(DropoutRNNCellMixin, self).__setstate__(state) 

1231 

1232 

1233@keras_export('keras.layers.SimpleRNNCell') 

1234class SimpleRNNCell(DropoutRNNCellMixin, Layer): 

1235 """Cell class for SimpleRNN. 

1236 

1237 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 

1238 for details about the usage of RNN API. 

1239 

1240 This class processes one step within the whole time sequence input, whereas 

1241 `tf.keras.layer.SimpleRNN` processes the whole sequence. 

1242 

1243 Args: 

1244 units: Positive integer, dimensionality of the output space. 

1245 activation: Activation function to use. 

1246 Default: hyperbolic tangent (`tanh`). 

1247 If you pass `None`, no activation is applied 

1248 (ie. "linear" activation: `a(x) = x`). 

1249 use_bias: Boolean, (default `True`), whether the layer uses a bias vector. 

1250 kernel_initializer: Initializer for the `kernel` weights matrix, 

1251 used for the linear transformation of the inputs. Default: 

1252 `glorot_uniform`. 

1253 recurrent_initializer: Initializer for the `recurrent_kernel` 

1254 weights matrix, used for the linear transformation of the recurrent state. 

1255 Default: `orthogonal`. 

1256 bias_initializer: Initializer for the bias vector. Default: `zeros`. 

1257 kernel_regularizer: Regularizer function applied to the `kernel` weights 

1258 matrix. Default: `None`. 

1259 recurrent_regularizer: Regularizer function applied to the 

1260 `recurrent_kernel` weights matrix. Default: `None`. 

1261 bias_regularizer: Regularizer function applied to the bias vector. Default: 

1262 `None`. 

1263 kernel_constraint: Constraint function applied to the `kernel` weights 

1264 matrix. Default: `None`. 

1265 recurrent_constraint: Constraint function applied to the `recurrent_kernel` 

1266 weights matrix. Default: `None`. 

1267 bias_constraint: Constraint function applied to the bias vector. Default: 

1268 `None`. 

1269 dropout: Float between 0 and 1. Fraction of the units to drop for the linear 

1270 transformation of the inputs. Default: 0. 

1271 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for 

1272 the linear transformation of the recurrent state. Default: 0. 

1273 

1274 Call arguments: 

1275 inputs: A 2D tensor, with shape of `[batch, feature]`. 

1276 states: A 2D tensor with shape of `[batch, units]`, which is the state from 

1277 the previous time step. For timestep 0, the initial state provided by user 

1278 will be feed to cell. 

1279 training: Python boolean indicating whether the layer should behave in 

1280 training mode or in inference mode. Only relevant when `dropout` or 

1281 `recurrent_dropout` is used. 

1282 

1283 Examples: 

1284 

1285 ```python 

1286 inputs = np.random.random([32, 10, 8]).astype(np.float32) 

1287 rnn = tf.keras.layers.RNN(tf.keras.layers.SimpleRNNCell(4)) 

1288 

1289 output = rnn(inputs) # The output has shape `[32, 4]`. 

1290 

1291 rnn = tf.keras.layers.RNN( 

1292 tf.keras.layers.SimpleRNNCell(4), 

1293 return_sequences=True, 

1294 return_state=True) 

1295 

1296 # whole_sequence_output has shape `[32, 10, 4]`. 

1297 # final_state has shape `[32, 4]`. 

1298 whole_sequence_output, final_state = rnn(inputs) 

1299 ``` 

1300 """ 

1301 

1302 def __init__(self, 

1303 units, 

1304 activation='tanh', 

1305 use_bias=True, 

1306 kernel_initializer='glorot_uniform', 

1307 recurrent_initializer='orthogonal', 

1308 bias_initializer='zeros', 

1309 kernel_regularizer=None, 

1310 recurrent_regularizer=None, 

1311 bias_regularizer=None, 

1312 kernel_constraint=None, 

1313 recurrent_constraint=None, 

1314 bias_constraint=None, 

1315 dropout=0., 

1316 recurrent_dropout=0., 

1317 **kwargs): 

1318 if units < 0: 

1319 raise ValueError(f'Received an invalid value for units, expected ' 

1320 f'a positive integer, got {units}.') 

1321 # By default use cached variable under v2 mode, see b/143699808. 

1322 if ops.executing_eagerly_outside_functions(): 

1323 self._enable_caching_device = kwargs.pop('enable_caching_device', True) 

1324 else: 

1325 self._enable_caching_device = kwargs.pop('enable_caching_device', False) 

1326 super(SimpleRNNCell, self).__init__(**kwargs) 

1327 self.units = units 

1328 self.activation = activations.get(activation) 

1329 self.use_bias = use_bias 

1330 

1331 self.kernel_initializer = initializers.get(kernel_initializer) 

1332 self.recurrent_initializer = initializers.get(recurrent_initializer) 

1333 self.bias_initializer = initializers.get(bias_initializer) 

1334 

1335 self.kernel_regularizer = regularizers.get(kernel_regularizer) 

1336 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 

1337 self.bias_regularizer = regularizers.get(bias_regularizer) 

1338 

1339 self.kernel_constraint = constraints.get(kernel_constraint) 

1340 self.recurrent_constraint = constraints.get(recurrent_constraint) 

1341 self.bias_constraint = constraints.get(bias_constraint) 

1342 

1343 self.dropout = min(1., max(0., dropout)) 

1344 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 

1345 self.state_size = self.units 

1346 self.output_size = self.units 

1347 

1348 @tf_utils.shape_type_conversion 

1349 def build(self, input_shape): 

1350 default_caching_device = _caching_device(self) 

1351 self.kernel = self.add_weight( 

1352 shape=(input_shape[-1], self.units), 

1353 name='kernel', 

1354 initializer=self.kernel_initializer, 

1355 regularizer=self.kernel_regularizer, 

1356 constraint=self.kernel_constraint, 

1357 caching_device=default_caching_device) 

1358 self.recurrent_kernel = self.add_weight( 

1359 shape=(self.units, self.units), 

1360 name='recurrent_kernel', 

1361 initializer=self.recurrent_initializer, 

1362 regularizer=self.recurrent_regularizer, 

1363 constraint=self.recurrent_constraint, 

1364 caching_device=default_caching_device) 

1365 if self.use_bias: 

1366 self.bias = self.add_weight( 

1367 shape=(self.units,), 

1368 name='bias', 

1369 initializer=self.bias_initializer, 

1370 regularizer=self.bias_regularizer, 

1371 constraint=self.bias_constraint, 

1372 caching_device=default_caching_device) 

1373 else: 

1374 self.bias = None 

1375 self.built = True 

1376 

1377 def call(self, inputs, states, training=None): 

1378 prev_output = states[0] if nest.is_nested(states) else states 

1379 dp_mask = self.get_dropout_mask_for_cell(inputs, training) 

1380 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 

1381 prev_output, training) 

1382 

1383 if dp_mask is not None: 

1384 h = backend.dot(inputs * dp_mask, self.kernel) 

1385 else: 

1386 h = backend.dot(inputs, self.kernel) 

1387 if self.bias is not None: 

1388 h = backend.bias_add(h, self.bias) 

1389 

1390 if rec_dp_mask is not None: 

1391 prev_output = prev_output * rec_dp_mask 

1392 output = h + backend.dot(prev_output, self.recurrent_kernel) 

1393 if self.activation is not None: 

1394 output = self.activation(output) 

1395 

1396 new_state = [output] if nest.is_nested(states) else output 

1397 return output, new_state 

1398 

1399 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

1400 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) 

1401 

1402 def get_config(self): 

1403 config = { 

1404 'units': 

1405 self.units, 

1406 'activation': 

1407 activations.serialize(self.activation), 

1408 'use_bias': 

1409 self.use_bias, 

1410 'kernel_initializer': 

1411 initializers.serialize(self.kernel_initializer), 

1412 'recurrent_initializer': 

1413 initializers.serialize(self.recurrent_initializer), 

1414 'bias_initializer': 

1415 initializers.serialize(self.bias_initializer), 

1416 'kernel_regularizer': 

1417 regularizers.serialize(self.kernel_regularizer), 

1418 'recurrent_regularizer': 

1419 regularizers.serialize(self.recurrent_regularizer), 

1420 'bias_regularizer': 

1421 regularizers.serialize(self.bias_regularizer), 

1422 'kernel_constraint': 

1423 constraints.serialize(self.kernel_constraint), 

1424 'recurrent_constraint': 

1425 constraints.serialize(self.recurrent_constraint), 

1426 'bias_constraint': 

1427 constraints.serialize(self.bias_constraint), 

1428 'dropout': 

1429 self.dropout, 

1430 'recurrent_dropout': 

1431 self.recurrent_dropout 

1432 } 

1433 config.update(_config_for_enable_caching_device(self)) 

1434 base_config = super(SimpleRNNCell, self).get_config() 

1435 return dict(list(base_config.items()) + list(config.items())) 

1436 

1437 

1438@keras_export('keras.layers.SimpleRNN') 

1439class SimpleRNN(RNN): 

1440 """Fully-connected RNN where the output is to be fed back to input. 

1441 

1442 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 

1443 for details about the usage of RNN API. 

1444 

1445 Args: 

1446 units: Positive integer, dimensionality of the output space. 

1447 activation: Activation function to use. 

1448 Default: hyperbolic tangent (`tanh`). 

1449 If you pass None, no activation is applied 

1450 (ie. "linear" activation: `a(x) = x`). 

1451 use_bias: Boolean, (default `True`), whether the layer uses a bias vector. 

1452 kernel_initializer: Initializer for the `kernel` weights matrix, 

1453 used for the linear transformation of the inputs. Default: 

1454 `glorot_uniform`. 

1455 recurrent_initializer: Initializer for the `recurrent_kernel` 

1456 weights matrix, used for the linear transformation of the recurrent state. 

1457 Default: `orthogonal`. 

1458 bias_initializer: Initializer for the bias vector. Default: `zeros`. 

1459 kernel_regularizer: Regularizer function applied to the `kernel` weights 

1460 matrix. Default: `None`. 

1461 recurrent_regularizer: Regularizer function applied to the 

1462 `recurrent_kernel` weights matrix. Default: `None`. 

1463 bias_regularizer: Regularizer function applied to the bias vector. Default: 

1464 `None`. 

1465 activity_regularizer: Regularizer function applied to the output of the 

1466 layer (its "activation"). Default: `None`. 

1467 kernel_constraint: Constraint function applied to the `kernel` weights 

1468 matrix. Default: `None`. 

1469 recurrent_constraint: Constraint function applied to the `recurrent_kernel` 

1470 weights matrix. Default: `None`. 

1471 bias_constraint: Constraint function applied to the bias vector. Default: 

1472 `None`. 

1473 dropout: Float between 0 and 1. 

1474 Fraction of the units to drop for the linear transformation of the inputs. 

1475 Default: 0. 

1476 recurrent_dropout: Float between 0 and 1. 

1477 Fraction of the units to drop for the linear transformation of the 

1478 recurrent state. Default: 0. 

1479 return_sequences: Boolean. Whether to return the last output 

1480 in the output sequence, or the full sequence. Default: `False`. 

1481 return_state: Boolean. Whether to return the last state 

1482 in addition to the output. Default: `False` 

1483 go_backwards: Boolean (default False). 

1484 If True, process the input sequence backwards and return the 

1485 reversed sequence. 

1486 stateful: Boolean (default False). If True, the last state 

1487 for each sample at index i in a batch will be used as initial 

1488 state for the sample of index i in the following batch. 

1489 unroll: Boolean (default False). 

1490 If True, the network will be unrolled, 

1491 else a symbolic loop will be used. 

1492 Unrolling can speed-up a RNN, 

1493 although it tends to be more memory-intensive. 

1494 Unrolling is only suitable for short sequences. 

1495 

1496 Call arguments: 

1497 inputs: A 3D tensor, with shape `[batch, timesteps, feature]`. 

1498 mask: Binary tensor of shape `[batch, timesteps]` indicating whether 

1499 a given timestep should be masked. An individual `True` entry indicates 

1500 that the corresponding timestep should be utilized, while a `False` entry 

1501 indicates that the corresponding timestep should be ignored. 

1502 training: Python boolean indicating whether the layer should behave in 

1503 training mode or in inference mode. This argument is passed to the cell 

1504 when calling it. This is only relevant if `dropout` or 

1505 `recurrent_dropout` is used. 

1506 initial_state: List of initial state tensors to be passed to the first 

1507 call of the cell. 

1508 

1509 Examples: 

1510 

1511 ```python 

1512 inputs = np.random.random([32, 10, 8]).astype(np.float32) 

1513 simple_rnn = tf.keras.layers.SimpleRNN(4) 

1514 

1515 output = simple_rnn(inputs) # The output has shape `[32, 4]`. 

1516 

1517 simple_rnn = tf.keras.layers.SimpleRNN( 

1518 4, return_sequences=True, return_state=True) 

1519 

1520 # whole_sequence_output has shape `[32, 10, 4]`. 

1521 # final_state has shape `[32, 4]`. 

1522 whole_sequence_output, final_state = simple_rnn(inputs) 

1523 ``` 

1524 """ 

1525 

1526 def __init__(self, 

1527 units, 

1528 activation='tanh', 

1529 use_bias=True, 

1530 kernel_initializer='glorot_uniform', 

1531 recurrent_initializer='orthogonal', 

1532 bias_initializer='zeros', 

1533 kernel_regularizer=None, 

1534 recurrent_regularizer=None, 

1535 bias_regularizer=None, 

1536 activity_regularizer=None, 

1537 kernel_constraint=None, 

1538 recurrent_constraint=None, 

1539 bias_constraint=None, 

1540 dropout=0., 

1541 recurrent_dropout=0., 

1542 return_sequences=False, 

1543 return_state=False, 

1544 go_backwards=False, 

1545 stateful=False, 

1546 unroll=False, 

1547 **kwargs): 

1548 if 'implementation' in kwargs: 

1549 kwargs.pop('implementation') 

1550 logging.warning('The `implementation` argument ' 

1551 'in `SimpleRNN` has been deprecated. ' 

1552 'Please remove it from your layer call.') 

1553 if 'enable_caching_device' in kwargs: 

1554 cell_kwargs = {'enable_caching_device': 

1555 kwargs.pop('enable_caching_device')} 

1556 else: 

1557 cell_kwargs = {} 

1558 cell = SimpleRNNCell( 

1559 units, 

1560 activation=activation, 

1561 use_bias=use_bias, 

1562 kernel_initializer=kernel_initializer, 

1563 recurrent_initializer=recurrent_initializer, 

1564 bias_initializer=bias_initializer, 

1565 kernel_regularizer=kernel_regularizer, 

1566 recurrent_regularizer=recurrent_regularizer, 

1567 bias_regularizer=bias_regularizer, 

1568 kernel_constraint=kernel_constraint, 

1569 recurrent_constraint=recurrent_constraint, 

1570 bias_constraint=bias_constraint, 

1571 dropout=dropout, 

1572 recurrent_dropout=recurrent_dropout, 

1573 dtype=kwargs.get('dtype'), 

1574 trainable=kwargs.get('trainable', True), 

1575 **cell_kwargs) 

1576 super(SimpleRNN, self).__init__( 

1577 cell, 

1578 return_sequences=return_sequences, 

1579 return_state=return_state, 

1580 go_backwards=go_backwards, 

1581 stateful=stateful, 

1582 unroll=unroll, 

1583 **kwargs) 

1584 self.activity_regularizer = regularizers.get(activity_regularizer) 

1585 self.input_spec = [InputSpec(ndim=3)] 

1586 

1587 def call(self, inputs, mask=None, training=None, initial_state=None): 

1588 return super(SimpleRNN, self).call( 

1589 inputs, mask=mask, training=training, initial_state=initial_state) 

1590 

1591 @property 

1592 def units(self): 

1593 return self.cell.units 

1594 

1595 @property 

1596 def activation(self): 

1597 return self.cell.activation 

1598 

1599 @property 

1600 def use_bias(self): 

1601 return self.cell.use_bias 

1602 

1603 @property 

1604 def kernel_initializer(self): 

1605 return self.cell.kernel_initializer 

1606 

1607 @property 

1608 def recurrent_initializer(self): 

1609 return self.cell.recurrent_initializer 

1610 

1611 @property 

1612 def bias_initializer(self): 

1613 return self.cell.bias_initializer 

1614 

1615 @property 

1616 def kernel_regularizer(self): 

1617 return self.cell.kernel_regularizer 

1618 

1619 @property 

1620 def recurrent_regularizer(self): 

1621 return self.cell.recurrent_regularizer 

1622 

1623 @property 

1624 def bias_regularizer(self): 

1625 return self.cell.bias_regularizer 

1626 

1627 @property 

1628 def kernel_constraint(self): 

1629 return self.cell.kernel_constraint 

1630 

1631 @property 

1632 def recurrent_constraint(self): 

1633 return self.cell.recurrent_constraint 

1634 

1635 @property 

1636 def bias_constraint(self): 

1637 return self.cell.bias_constraint 

1638 

1639 @property 

1640 def dropout(self): 

1641 return self.cell.dropout 

1642 

1643 @property 

1644 def recurrent_dropout(self): 

1645 return self.cell.recurrent_dropout 

1646 

1647 def get_config(self): 

1648 config = { 

1649 'units': 

1650 self.units, 

1651 'activation': 

1652 activations.serialize(self.activation), 

1653 'use_bias': 

1654 self.use_bias, 

1655 'kernel_initializer': 

1656 initializers.serialize(self.kernel_initializer), 

1657 'recurrent_initializer': 

1658 initializers.serialize(self.recurrent_initializer), 

1659 'bias_initializer': 

1660 initializers.serialize(self.bias_initializer), 

1661 'kernel_regularizer': 

1662 regularizers.serialize(self.kernel_regularizer), 

1663 'recurrent_regularizer': 

1664 regularizers.serialize(self.recurrent_regularizer), 

1665 'bias_regularizer': 

1666 regularizers.serialize(self.bias_regularizer), 

1667 'activity_regularizer': 

1668 regularizers.serialize(self.activity_regularizer), 

1669 'kernel_constraint': 

1670 constraints.serialize(self.kernel_constraint), 

1671 'recurrent_constraint': 

1672 constraints.serialize(self.recurrent_constraint), 

1673 'bias_constraint': 

1674 constraints.serialize(self.bias_constraint), 

1675 'dropout': 

1676 self.dropout, 

1677 'recurrent_dropout': 

1678 self.recurrent_dropout 

1679 } 

1680 base_config = super(SimpleRNN, self).get_config() 

1681 config.update(_config_for_enable_caching_device(self.cell)) 

1682 del base_config['cell'] 

1683 return dict(list(base_config.items()) + list(config.items())) 

1684 

1685 @classmethod 

1686 def from_config(cls, config): 

1687 if 'implementation' in config: 

1688 config.pop('implementation') 

1689 return cls(**config) 

1690 

1691 

1692@keras_export(v1=['keras.layers.GRUCell']) 

1693class GRUCell(DropoutRNNCellMixin, Layer): 

1694 """Cell class for the GRU layer. 

1695 

1696 Args: 

1697 units: Positive integer, dimensionality of the output space. 

1698 activation: Activation function to use. 

1699 Default: hyperbolic tangent (`tanh`). 

1700 If you pass None, no activation is applied 

1701 (ie. "linear" activation: `a(x) = x`). 

1702 recurrent_activation: Activation function to use 

1703 for the recurrent step. 

1704 Default: hard sigmoid (`hard_sigmoid`). 

1705 If you pass `None`, no activation is applied 

1706 (ie. "linear" activation: `a(x) = x`). 

1707 use_bias: Boolean, whether the layer uses a bias vector. 

1708 kernel_initializer: Initializer for the `kernel` weights matrix, 

1709 used for the linear transformation of the inputs. 

1710 recurrent_initializer: Initializer for the `recurrent_kernel` 

1711 weights matrix, 

1712 used for the linear transformation of the recurrent state. 

1713 bias_initializer: Initializer for the bias vector. 

1714 kernel_regularizer: Regularizer function applied to 

1715 the `kernel` weights matrix. 

1716 recurrent_regularizer: Regularizer function applied to 

1717 the `recurrent_kernel` weights matrix. 

1718 bias_regularizer: Regularizer function applied to the bias vector. 

1719 kernel_constraint: Constraint function applied to 

1720 the `kernel` weights matrix. 

1721 recurrent_constraint: Constraint function applied to 

1722 the `recurrent_kernel` weights matrix. 

1723 bias_constraint: Constraint function applied to the bias vector. 

1724 dropout: Float between 0 and 1. 

1725 Fraction of the units to drop for the linear transformation of the inputs. 

1726 recurrent_dropout: Float between 0 and 1. 

1727 Fraction of the units to drop for 

1728 the linear transformation of the recurrent state. 

1729 reset_after: GRU convention (whether to apply reset gate after or 

1730 before matrix multiplication). False = "before" (default), 

1731 True = "after" (CuDNN compatible). 

1732 

1733 Call arguments: 

1734 inputs: A 2D tensor. 

1735 states: List of state tensors corresponding to the previous timestep. 

1736 training: Python boolean indicating whether the layer should behave in 

1737 training mode or in inference mode. Only relevant when `dropout` or 

1738 `recurrent_dropout` is used. 

1739 """ 

1740 

1741 def __init__(self, 

1742 units, 

1743 activation='tanh', 

1744 recurrent_activation='hard_sigmoid', 

1745 use_bias=True, 

1746 kernel_initializer='glorot_uniform', 

1747 recurrent_initializer='orthogonal', 

1748 bias_initializer='zeros', 

1749 kernel_regularizer=None, 

1750 recurrent_regularizer=None, 

1751 bias_regularizer=None, 

1752 kernel_constraint=None, 

1753 recurrent_constraint=None, 

1754 bias_constraint=None, 

1755 dropout=0., 

1756 recurrent_dropout=0., 

1757 reset_after=False, 

1758 **kwargs): 

1759 if units < 0: 

1760 raise ValueError(f'Received an invalid value for units, expected ' 

1761 f'a positive integer, got {units}.') 

1762 # By default use cached variable under v2 mode, see b/143699808. 

1763 if ops.executing_eagerly_outside_functions(): 

1764 self._enable_caching_device = kwargs.pop('enable_caching_device', True) 

1765 else: 

1766 self._enable_caching_device = kwargs.pop('enable_caching_device', False) 

1767 super(GRUCell, self).__init__(**kwargs) 

1768 self.units = units 

1769 self.activation = activations.get(activation) 

1770 self.recurrent_activation = activations.get(recurrent_activation) 

1771 self.use_bias = use_bias 

1772 

1773 self.kernel_initializer = initializers.get(kernel_initializer) 

1774 self.recurrent_initializer = initializers.get(recurrent_initializer) 

1775 self.bias_initializer = initializers.get(bias_initializer) 

1776 

1777 self.kernel_regularizer = regularizers.get(kernel_regularizer) 

1778 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 

1779 self.bias_regularizer = regularizers.get(bias_regularizer) 

1780 

1781 self.kernel_constraint = constraints.get(kernel_constraint) 

1782 self.recurrent_constraint = constraints.get(recurrent_constraint) 

1783 self.bias_constraint = constraints.get(bias_constraint) 

1784 

1785 self.dropout = min(1., max(0., dropout)) 

1786 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 

1787 

1788 implementation = kwargs.pop('implementation', 1) 

1789 if self.recurrent_dropout != 0 and implementation != 1: 

1790 logging.debug(RECURRENT_DROPOUT_WARNING_MSG) 

1791 self.implementation = 1 

1792 else: 

1793 self.implementation = implementation 

1794 self.reset_after = reset_after 

1795 self.state_size = self.units 

1796 self.output_size = self.units 

1797 

1798 @tf_utils.shape_type_conversion 

1799 def build(self, input_shape): 

1800 input_dim = input_shape[-1] 

1801 default_caching_device = _caching_device(self) 

1802 self.kernel = self.add_weight( 

1803 shape=(input_dim, self.units * 3), 

1804 name='kernel', 

1805 initializer=self.kernel_initializer, 

1806 regularizer=self.kernel_regularizer, 

1807 constraint=self.kernel_constraint, 

1808 caching_device=default_caching_device) 

1809 self.recurrent_kernel = self.add_weight( 

1810 shape=(self.units, self.units * 3), 

1811 name='recurrent_kernel', 

1812 initializer=self.recurrent_initializer, 

1813 regularizer=self.recurrent_regularizer, 

1814 constraint=self.recurrent_constraint, 

1815 caching_device=default_caching_device) 

1816 

1817 if self.use_bias: 

1818 if not self.reset_after: 

1819 bias_shape = (3 * self.units,) 

1820 else: 

1821 # separate biases for input and recurrent kernels 

1822 # Note: the shape is intentionally different from CuDNNGRU biases 

1823 # `(2 * 3 * self.units,)`, so that we can distinguish the classes 

1824 # when loading and converting saved weights. 

1825 bias_shape = (2, 3 * self.units) 

1826 self.bias = self.add_weight(shape=bias_shape, 

1827 name='bias', 

1828 initializer=self.bias_initializer, 

1829 regularizer=self.bias_regularizer, 

1830 constraint=self.bias_constraint, 

1831 caching_device=default_caching_device) 

1832 else: 

1833 self.bias = None 

1834 self.built = True 

1835 

1836 def call(self, inputs, states, training=None): 

1837 h_tm1 = states[0] if nest.is_nested(states) else states # previous memory 

1838 

1839 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3) 

1840 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 

1841 h_tm1, training, count=3) 

1842 

1843 if self.use_bias: 

1844 if not self.reset_after: 

1845 input_bias, recurrent_bias = self.bias, None 

1846 else: 

1847 input_bias, recurrent_bias = array_ops_stack.unstack(self.bias) 

1848 

1849 if self.implementation == 1: 

1850 if 0. < self.dropout < 1.: 

1851 inputs_z = inputs * dp_mask[0] 

1852 inputs_r = inputs * dp_mask[1] 

1853 inputs_h = inputs * dp_mask[2] 

1854 else: 

1855 inputs_z = inputs 

1856 inputs_r = inputs 

1857 inputs_h = inputs 

1858 

1859 x_z = backend.dot(inputs_z, self.kernel[:, :self.units]) 

1860 x_r = backend.dot(inputs_r, self.kernel[:, self.units:self.units * 2]) 

1861 x_h = backend.dot(inputs_h, self.kernel[:, self.units * 2:]) 

1862 

1863 if self.use_bias: 

1864 x_z = backend.bias_add(x_z, input_bias[:self.units]) 

1865 x_r = backend.bias_add(x_r, input_bias[self.units: self.units * 2]) 

1866 x_h = backend.bias_add(x_h, input_bias[self.units * 2:]) 

1867 

1868 if 0. < self.recurrent_dropout < 1.: 

1869 h_tm1_z = h_tm1 * rec_dp_mask[0] 

1870 h_tm1_r = h_tm1 * rec_dp_mask[1] 

1871 h_tm1_h = h_tm1 * rec_dp_mask[2] 

1872 else: 

1873 h_tm1_z = h_tm1 

1874 h_tm1_r = h_tm1 

1875 h_tm1_h = h_tm1 

1876 

1877 recurrent_z = backend.dot(h_tm1_z, self.recurrent_kernel[:, :self.units]) 

1878 recurrent_r = backend.dot( 

1879 h_tm1_r, self.recurrent_kernel[:, self.units:self.units * 2]) 

1880 if self.reset_after and self.use_bias: 

1881 recurrent_z = backend.bias_add(recurrent_z, recurrent_bias[:self.units]) 

1882 recurrent_r = backend.bias_add( 

1883 recurrent_r, recurrent_bias[self.units:self.units * 2]) 

1884 

1885 z = self.recurrent_activation(x_z + recurrent_z) 

1886 r = self.recurrent_activation(x_r + recurrent_r) 

1887 

1888 # reset gate applied after/before matrix multiplication 

1889 if self.reset_after: 

1890 recurrent_h = backend.dot( 

1891 h_tm1_h, self.recurrent_kernel[:, self.units * 2:]) 

1892 if self.use_bias: 

1893 recurrent_h = backend.bias_add( 

1894 recurrent_h, recurrent_bias[self.units * 2:]) 

1895 recurrent_h = r * recurrent_h 

1896 else: 

1897 recurrent_h = backend.dot( 

1898 r * h_tm1_h, self.recurrent_kernel[:, self.units * 2:]) 

1899 

1900 hh = self.activation(x_h + recurrent_h) 

1901 else: 

1902 if 0. < self.dropout < 1.: 

1903 inputs = inputs * dp_mask[0] 

1904 

1905 # inputs projected by all gate matrices at once 

1906 matrix_x = backend.dot(inputs, self.kernel) 

1907 if self.use_bias: 

1908 # biases: bias_z_i, bias_r_i, bias_h_i 

1909 matrix_x = backend.bias_add(matrix_x, input_bias) 

1910 

1911 x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=-1) 

1912 

1913 if self.reset_after: 

1914 # hidden state projected by all gate matrices at once 

1915 matrix_inner = backend.dot(h_tm1, self.recurrent_kernel) 

1916 if self.use_bias: 

1917 matrix_inner = backend.bias_add(matrix_inner, recurrent_bias) 

1918 else: 

1919 # hidden state projected separately for update/reset and new 

1920 matrix_inner = backend.dot( 

1921 h_tm1, self.recurrent_kernel[:, :2 * self.units]) 

1922 

1923 recurrent_z, recurrent_r, recurrent_h = array_ops.split( 

1924 matrix_inner, [self.units, self.units, -1], axis=-1) 

1925 

1926 z = self.recurrent_activation(x_z + recurrent_z) 

1927 r = self.recurrent_activation(x_r + recurrent_r) 

1928 

1929 if self.reset_after: 

1930 recurrent_h = r * recurrent_h 

1931 else: 

1932 recurrent_h = backend.dot( 

1933 r * h_tm1, self.recurrent_kernel[:, 2 * self.units:]) 

1934 

1935 hh = self.activation(x_h + recurrent_h) 

1936 # previous and candidate state mixed by update gate 

1937 h = z * h_tm1 + (1 - z) * hh 

1938 new_state = [h] if nest.is_nested(states) else h 

1939 return h, new_state 

1940 

1941 def get_config(self): 

1942 config = { 

1943 'units': self.units, 

1944 'activation': activations.serialize(self.activation), 

1945 'recurrent_activation': 

1946 activations.serialize(self.recurrent_activation), 

1947 'use_bias': self.use_bias, 

1948 'kernel_initializer': initializers.serialize(self.kernel_initializer), 

1949 'recurrent_initializer': 

1950 initializers.serialize(self.recurrent_initializer), 

1951 'bias_initializer': initializers.serialize(self.bias_initializer), 

1952 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 

1953 'recurrent_regularizer': 

1954 regularizers.serialize(self.recurrent_regularizer), 

1955 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 

1956 'kernel_constraint': constraints.serialize(self.kernel_constraint), 

1957 'recurrent_constraint': 

1958 constraints.serialize(self.recurrent_constraint), 

1959 'bias_constraint': constraints.serialize(self.bias_constraint), 

1960 'dropout': self.dropout, 

1961 'recurrent_dropout': self.recurrent_dropout, 

1962 'implementation': self.implementation, 

1963 'reset_after': self.reset_after 

1964 } 

1965 config.update(_config_for_enable_caching_device(self)) 

1966 base_config = super(GRUCell, self).get_config() 

1967 return dict(list(base_config.items()) + list(config.items())) 

1968 

1969 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

1970 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) 

1971 

1972 

1973@keras_export(v1=['keras.layers.GRU']) 

1974class GRU(RNN): 

1975 """Gated Recurrent Unit - Cho et al. 2014. 

1976 

1977 There are two variants. The default one is based on 1406.1078v3 and 

1978 has reset gate applied to hidden state before matrix multiplication. The 

1979 other one is based on original 1406.1078v1 and has the order reversed. 

1980 

1981 The second variant is compatible with CuDNNGRU (GPU-only) and allows 

1982 inference on CPU. Thus it has separate biases for `kernel` and 

1983 `recurrent_kernel`. Use `'reset_after'=True` and 

1984 `recurrent_activation='sigmoid'`. 

1985 

1986 Args: 

1987 units: Positive integer, dimensionality of the output space. 

1988 activation: Activation function to use. 

1989 Default: hyperbolic tangent (`tanh`). 

1990 If you pass `None`, no activation is applied 

1991 (ie. "linear" activation: `a(x) = x`). 

1992 recurrent_activation: Activation function to use 

1993 for the recurrent step. 

1994 Default: hard sigmoid (`hard_sigmoid`). 

1995 If you pass `None`, no activation is applied 

1996 (ie. "linear" activation: `a(x) = x`). 

1997 use_bias: Boolean, whether the layer uses a bias vector. 

1998 kernel_initializer: Initializer for the `kernel` weights matrix, 

1999 used for the linear transformation of the inputs. 

2000 recurrent_initializer: Initializer for the `recurrent_kernel` 

2001 weights matrix, used for the linear transformation of the recurrent state. 

2002 bias_initializer: Initializer for the bias vector. 

2003 kernel_regularizer: Regularizer function applied to 

2004 the `kernel` weights matrix. 

2005 recurrent_regularizer: Regularizer function applied to 

2006 the `recurrent_kernel` weights matrix. 

2007 bias_regularizer: Regularizer function applied to the bias vector. 

2008 activity_regularizer: Regularizer function applied to 

2009 the output of the layer (its "activation").. 

2010 kernel_constraint: Constraint function applied to 

2011 the `kernel` weights matrix. 

2012 recurrent_constraint: Constraint function applied to 

2013 the `recurrent_kernel` weights matrix. 

2014 bias_constraint: Constraint function applied to the bias vector. 

2015 dropout: Float between 0 and 1. 

2016 Fraction of the units to drop for 

2017 the linear transformation of the inputs. 

2018 recurrent_dropout: Float between 0 and 1. 

2019 Fraction of the units to drop for 

2020 the linear transformation of the recurrent state. 

2021 return_sequences: Boolean. Whether to return the last output 

2022 in the output sequence, or the full sequence. 

2023 return_state: Boolean. Whether to return the last state 

2024 in addition to the output. 

2025 go_backwards: Boolean (default False). 

2026 If True, process the input sequence backwards and return the 

2027 reversed sequence. 

2028 stateful: Boolean (default False). If True, the last state 

2029 for each sample at index i in a batch will be used as initial 

2030 state for the sample of index i in the following batch. 

2031 unroll: Boolean (default False). 

2032 If True, the network will be unrolled, 

2033 else a symbolic loop will be used. 

2034 Unrolling can speed-up a RNN, 

2035 although it tends to be more memory-intensive. 

2036 Unrolling is only suitable for short sequences. 

2037 time_major: The shape format of the `inputs` and `outputs` tensors. 

2038 If True, the inputs and outputs will be in shape 

2039 `(timesteps, batch, ...)`, whereas in the False case, it will be 

2040 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 

2041 efficient because it avoids transposes at the beginning and end of the 

2042 RNN calculation. However, most TensorFlow data is batch-major, so by 

2043 default this function accepts input and emits output in batch-major 

2044 form. 

2045 reset_after: GRU convention (whether to apply reset gate after or 

2046 before matrix multiplication). False = "before" (default), 

2047 True = "after" (CuDNN compatible). 

2048 

2049 Call arguments: 

2050 inputs: A 3D tensor. 

2051 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 

2052 a given timestep should be masked. An individual `True` entry indicates 

2053 that the corresponding timestep should be utilized, while a `False` 

2054 entry indicates that the corresponding timestep should be ignored. 

2055 training: Python boolean indicating whether the layer should behave in 

2056 training mode or in inference mode. This argument is passed to the cell 

2057 when calling it. This is only relevant if `dropout` or 

2058 `recurrent_dropout` is used. 

2059 initial_state: List of initial state tensors to be passed to the first 

2060 call of the cell. 

2061 """ 

2062 

2063 def __init__(self, 

2064 units, 

2065 activation='tanh', 

2066 recurrent_activation='hard_sigmoid', 

2067 use_bias=True, 

2068 kernel_initializer='glorot_uniform', 

2069 recurrent_initializer='orthogonal', 

2070 bias_initializer='zeros', 

2071 kernel_regularizer=None, 

2072 recurrent_regularizer=None, 

2073 bias_regularizer=None, 

2074 activity_regularizer=None, 

2075 kernel_constraint=None, 

2076 recurrent_constraint=None, 

2077 bias_constraint=None, 

2078 dropout=0., 

2079 recurrent_dropout=0., 

2080 return_sequences=False, 

2081 return_state=False, 

2082 go_backwards=False, 

2083 stateful=False, 

2084 unroll=False, 

2085 reset_after=False, 

2086 **kwargs): 

2087 implementation = kwargs.pop('implementation', 1) 

2088 if implementation == 0: 

2089 logging.warning('`implementation=0` has been deprecated, ' 

2090 'and now defaults to `implementation=1`.' 

2091 'Please update your layer call.') 

2092 if 'enable_caching_device' in kwargs: 

2093 cell_kwargs = {'enable_caching_device': 

2094 kwargs.pop('enable_caching_device')} 

2095 else: 

2096 cell_kwargs = {} 

2097 cell = GRUCell( 

2098 units, 

2099 activation=activation, 

2100 recurrent_activation=recurrent_activation, 

2101 use_bias=use_bias, 

2102 kernel_initializer=kernel_initializer, 

2103 recurrent_initializer=recurrent_initializer, 

2104 bias_initializer=bias_initializer, 

2105 kernel_regularizer=kernel_regularizer, 

2106 recurrent_regularizer=recurrent_regularizer, 

2107 bias_regularizer=bias_regularizer, 

2108 kernel_constraint=kernel_constraint, 

2109 recurrent_constraint=recurrent_constraint, 

2110 bias_constraint=bias_constraint, 

2111 dropout=dropout, 

2112 recurrent_dropout=recurrent_dropout, 

2113 implementation=implementation, 

2114 reset_after=reset_after, 

2115 dtype=kwargs.get('dtype'), 

2116 trainable=kwargs.get('trainable', True), 

2117 **cell_kwargs) 

2118 super(GRU, self).__init__( 

2119 cell, 

2120 return_sequences=return_sequences, 

2121 return_state=return_state, 

2122 go_backwards=go_backwards, 

2123 stateful=stateful, 

2124 unroll=unroll, 

2125 **kwargs) 

2126 self.activity_regularizer = regularizers.get(activity_regularizer) 

2127 self.input_spec = [InputSpec(ndim=3)] 

2128 

2129 def call(self, inputs, mask=None, training=None, initial_state=None): 

2130 return super(GRU, self).call( 

2131 inputs, mask=mask, training=training, initial_state=initial_state) 

2132 

2133 @property 

2134 def units(self): 

2135 return self.cell.units 

2136 

2137 @property 

2138 def activation(self): 

2139 return self.cell.activation 

2140 

2141 @property 

2142 def recurrent_activation(self): 

2143 return self.cell.recurrent_activation 

2144 

2145 @property 

2146 def use_bias(self): 

2147 return self.cell.use_bias 

2148 

2149 @property 

2150 def kernel_initializer(self): 

2151 return self.cell.kernel_initializer 

2152 

2153 @property 

2154 def recurrent_initializer(self): 

2155 return self.cell.recurrent_initializer 

2156 

2157 @property 

2158 def bias_initializer(self): 

2159 return self.cell.bias_initializer 

2160 

2161 @property 

2162 def kernel_regularizer(self): 

2163 return self.cell.kernel_regularizer 

2164 

2165 @property 

2166 def recurrent_regularizer(self): 

2167 return self.cell.recurrent_regularizer 

2168 

2169 @property 

2170 def bias_regularizer(self): 

2171 return self.cell.bias_regularizer 

2172 

2173 @property 

2174 def kernel_constraint(self): 

2175 return self.cell.kernel_constraint 

2176 

2177 @property 

2178 def recurrent_constraint(self): 

2179 return self.cell.recurrent_constraint 

2180 

2181 @property 

2182 def bias_constraint(self): 

2183 return self.cell.bias_constraint 

2184 

2185 @property 

2186 def dropout(self): 

2187 return self.cell.dropout 

2188 

2189 @property 

2190 def recurrent_dropout(self): 

2191 return self.cell.recurrent_dropout 

2192 

2193 @property 

2194 def implementation(self): 

2195 return self.cell.implementation 

2196 

2197 @property 

2198 def reset_after(self): 

2199 return self.cell.reset_after 

2200 

2201 def get_config(self): 

2202 config = { 

2203 'units': 

2204 self.units, 

2205 'activation': 

2206 activations.serialize(self.activation), 

2207 'recurrent_activation': 

2208 activations.serialize(self.recurrent_activation), 

2209 'use_bias': 

2210 self.use_bias, 

2211 'kernel_initializer': 

2212 initializers.serialize(self.kernel_initializer), 

2213 'recurrent_initializer': 

2214 initializers.serialize(self.recurrent_initializer), 

2215 'bias_initializer': 

2216 initializers.serialize(self.bias_initializer), 

2217 'kernel_regularizer': 

2218 regularizers.serialize(self.kernel_regularizer), 

2219 'recurrent_regularizer': 

2220 regularizers.serialize(self.recurrent_regularizer), 

2221 'bias_regularizer': 

2222 regularizers.serialize(self.bias_regularizer), 

2223 'activity_regularizer': 

2224 regularizers.serialize(self.activity_regularizer), 

2225 'kernel_constraint': 

2226 constraints.serialize(self.kernel_constraint), 

2227 'recurrent_constraint': 

2228 constraints.serialize(self.recurrent_constraint), 

2229 'bias_constraint': 

2230 constraints.serialize(self.bias_constraint), 

2231 'dropout': 

2232 self.dropout, 

2233 'recurrent_dropout': 

2234 self.recurrent_dropout, 

2235 'implementation': 

2236 self.implementation, 

2237 'reset_after': 

2238 self.reset_after 

2239 } 

2240 config.update(_config_for_enable_caching_device(self.cell)) 

2241 base_config = super(GRU, self).get_config() 

2242 del base_config['cell'] 

2243 return dict(list(base_config.items()) + list(config.items())) 

2244 

2245 @classmethod 

2246 def from_config(cls, config): 

2247 if 'implementation' in config and config['implementation'] == 0: 

2248 config['implementation'] = 1 

2249 return cls(**config) 

2250 

2251 

2252@keras_export(v1=['keras.layers.LSTMCell']) 

2253class LSTMCell(DropoutRNNCellMixin, Layer): 

2254 """Cell class for the LSTM layer. 

2255 

2256 Args: 

2257 units: Positive integer, dimensionality of the output space. 

2258 activation: Activation function to use. 

2259 Default: hyperbolic tangent (`tanh`). 

2260 If you pass `None`, no activation is applied 

2261 (ie. "linear" activation: `a(x) = x`). 

2262 recurrent_activation: Activation function to use 

2263 for the recurrent step. 

2264 Default: hard sigmoid (`hard_sigmoid`). 

2265 If you pass `None`, no activation is applied 

2266 (ie. "linear" activation: `a(x) = x`). 

2267 use_bias: Boolean, whether the layer uses a bias vector. 

2268 kernel_initializer: Initializer for the `kernel` weights matrix, 

2269 used for the linear transformation of the inputs. 

2270 recurrent_initializer: Initializer for the `recurrent_kernel` 

2271 weights matrix, 

2272 used for the linear transformation of the recurrent state. 

2273 bias_initializer: Initializer for the bias vector. 

2274 unit_forget_bias: Boolean. 

2275 If True, add 1 to the bias of the forget gate at initialization. 

2276 Setting it to true will also force `bias_initializer="zeros"`. 

2277 This is recommended in [Jozefowicz et al., 2015]( 

2278 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 

2279 kernel_regularizer: Regularizer function applied to 

2280 the `kernel` weights matrix. 

2281 recurrent_regularizer: Regularizer function applied to 

2282 the `recurrent_kernel` weights matrix. 

2283 bias_regularizer: Regularizer function applied to the bias vector. 

2284 kernel_constraint: Constraint function applied to 

2285 the `kernel` weights matrix. 

2286 recurrent_constraint: Constraint function applied to 

2287 the `recurrent_kernel` weights matrix. 

2288 bias_constraint: Constraint function applied to the bias vector. 

2289 dropout: Float between 0 and 1. 

2290 Fraction of the units to drop for 

2291 the linear transformation of the inputs. 

2292 recurrent_dropout: Float between 0 and 1. 

2293 Fraction of the units to drop for 

2294 the linear transformation of the recurrent state. 

2295 

2296 Call arguments: 

2297 inputs: A 2D tensor. 

2298 states: List of state tensors corresponding to the previous timestep. 

2299 training: Python boolean indicating whether the layer should behave in 

2300 training mode or in inference mode. Only relevant when `dropout` or 

2301 `recurrent_dropout` is used. 

2302 """ 

2303 

2304 def __init__(self, 

2305 units, 

2306 activation='tanh', 

2307 recurrent_activation='hard_sigmoid', 

2308 use_bias=True, 

2309 kernel_initializer='glorot_uniform', 

2310 recurrent_initializer='orthogonal', 

2311 bias_initializer='zeros', 

2312 unit_forget_bias=True, 

2313 kernel_regularizer=None, 

2314 recurrent_regularizer=None, 

2315 bias_regularizer=None, 

2316 kernel_constraint=None, 

2317 recurrent_constraint=None, 

2318 bias_constraint=None, 

2319 dropout=0., 

2320 recurrent_dropout=0., 

2321 **kwargs): 

2322 if units < 0: 

2323 raise ValueError(f'Received an invalid value for units, expected ' 

2324 f'a positive integer, got {units}.') 

2325 # By default use cached variable under v2 mode, see b/143699808. 

2326 if ops.executing_eagerly_outside_functions(): 

2327 self._enable_caching_device = kwargs.pop('enable_caching_device', True) 

2328 else: 

2329 self._enable_caching_device = kwargs.pop('enable_caching_device', False) 

2330 super(LSTMCell, self).__init__(**kwargs) 

2331 self.units = units 

2332 self.activation = activations.get(activation) 

2333 self.recurrent_activation = activations.get(recurrent_activation) 

2334 self.use_bias = use_bias 

2335 

2336 self.kernel_initializer = initializers.get(kernel_initializer) 

2337 self.recurrent_initializer = initializers.get(recurrent_initializer) 

2338 self.bias_initializer = initializers.get(bias_initializer) 

2339 self.unit_forget_bias = unit_forget_bias 

2340 

2341 self.kernel_regularizer = regularizers.get(kernel_regularizer) 

2342 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 

2343 self.bias_regularizer = regularizers.get(bias_regularizer) 

2344 

2345 self.kernel_constraint = constraints.get(kernel_constraint) 

2346 self.recurrent_constraint = constraints.get(recurrent_constraint) 

2347 self.bias_constraint = constraints.get(bias_constraint) 

2348 

2349 self.dropout = min(1., max(0., dropout)) 

2350 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 

2351 implementation = kwargs.pop('implementation', 1) 

2352 if self.recurrent_dropout != 0 and implementation != 1: 

2353 logging.debug(RECURRENT_DROPOUT_WARNING_MSG) 

2354 self.implementation = 1 

2355 else: 

2356 self.implementation = implementation 

2357 self.state_size = [self.units, self.units] 

2358 self.output_size = self.units 

2359 

2360 @tf_utils.shape_type_conversion 

2361 def build(self, input_shape): 

2362 default_caching_device = _caching_device(self) 

2363 input_dim = input_shape[-1] 

2364 self.kernel = self.add_weight( 

2365 shape=(input_dim, self.units * 4), 

2366 name='kernel', 

2367 initializer=self.kernel_initializer, 

2368 regularizer=self.kernel_regularizer, 

2369 constraint=self.kernel_constraint, 

2370 caching_device=default_caching_device) 

2371 self.recurrent_kernel = self.add_weight( 

2372 shape=(self.units, self.units * 4), 

2373 name='recurrent_kernel', 

2374 initializer=self.recurrent_initializer, 

2375 regularizer=self.recurrent_regularizer, 

2376 constraint=self.recurrent_constraint, 

2377 caching_device=default_caching_device) 

2378 

2379 if self.use_bias: 

2380 if self.unit_forget_bias: 

2381 

2382 def bias_initializer(_, *args, **kwargs): 

2383 return backend.concatenate([ 

2384 self.bias_initializer((self.units,), *args, **kwargs), 

2385 initializers.get('ones')((self.units,), *args, **kwargs), 

2386 self.bias_initializer((self.units * 2,), *args, **kwargs), 

2387 ]) 

2388 else: 

2389 bias_initializer = self.bias_initializer 

2390 self.bias = self.add_weight( 

2391 shape=(self.units * 4,), 

2392 name='bias', 

2393 initializer=bias_initializer, 

2394 regularizer=self.bias_regularizer, 

2395 constraint=self.bias_constraint, 

2396 caching_device=default_caching_device) 

2397 else: 

2398 self.bias = None 

2399 self.built = True 

2400 

2401 def _compute_carry_and_output(self, x, h_tm1, c_tm1): 

2402 """Computes carry and output using split kernels.""" 

2403 x_i, x_f, x_c, x_o = x 

2404 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 

2405 i = self.recurrent_activation( 

2406 x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])) 

2407 f = self.recurrent_activation(x_f + backend.dot( 

2408 h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])) 

2409 c = f * c_tm1 + i * self.activation(x_c + backend.dot( 

2410 h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) 

2411 o = self.recurrent_activation( 

2412 x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])) 

2413 return c, o 

2414 

2415 def _compute_carry_and_output_fused(self, z, c_tm1): 

2416 """Computes carry and output using fused kernels.""" 

2417 z0, z1, z2, z3 = z 

2418 i = self.recurrent_activation(z0) 

2419 f = self.recurrent_activation(z1) 

2420 c = f * c_tm1 + i * self.activation(z2) 

2421 o = self.recurrent_activation(z3) 

2422 return c, o 

2423 

2424 def call(self, inputs, states, training=None): 

2425 h_tm1 = states[0] # previous memory state 

2426 c_tm1 = states[1] # previous carry state 

2427 

2428 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) 

2429 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 

2430 h_tm1, training, count=4) 

2431 

2432 if self.implementation == 1: 

2433 if 0 < self.dropout < 1.: 

2434 inputs_i = inputs * dp_mask[0] 

2435 inputs_f = inputs * dp_mask[1] 

2436 inputs_c = inputs * dp_mask[2] 

2437 inputs_o = inputs * dp_mask[3] 

2438 else: 

2439 inputs_i = inputs 

2440 inputs_f = inputs 

2441 inputs_c = inputs 

2442 inputs_o = inputs 

2443 k_i, k_f, k_c, k_o = array_ops.split( 

2444 self.kernel, num_or_size_splits=4, axis=1) 

2445 x_i = backend.dot(inputs_i, k_i) 

2446 x_f = backend.dot(inputs_f, k_f) 

2447 x_c = backend.dot(inputs_c, k_c) 

2448 x_o = backend.dot(inputs_o, k_o) 

2449 if self.use_bias: 

2450 b_i, b_f, b_c, b_o = array_ops.split( 

2451 self.bias, num_or_size_splits=4, axis=0) 

2452 x_i = backend.bias_add(x_i, b_i) 

2453 x_f = backend.bias_add(x_f, b_f) 

2454 x_c = backend.bias_add(x_c, b_c) 

2455 x_o = backend.bias_add(x_o, b_o) 

2456 

2457 if 0 < self.recurrent_dropout < 1.: 

2458 h_tm1_i = h_tm1 * rec_dp_mask[0] 

2459 h_tm1_f = h_tm1 * rec_dp_mask[1] 

2460 h_tm1_c = h_tm1 * rec_dp_mask[2] 

2461 h_tm1_o = h_tm1 * rec_dp_mask[3] 

2462 else: 

2463 h_tm1_i = h_tm1 

2464 h_tm1_f = h_tm1 

2465 h_tm1_c = h_tm1 

2466 h_tm1_o = h_tm1 

2467 x = (x_i, x_f, x_c, x_o) 

2468 h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o) 

2469 c, o = self._compute_carry_and_output(x, h_tm1, c_tm1) 

2470 else: 

2471 if 0. < self.dropout < 1.: 

2472 inputs = inputs * dp_mask[0] 

2473 z = backend.dot(inputs, self.kernel) 

2474 z += backend.dot(h_tm1, self.recurrent_kernel) 

2475 if self.use_bias: 

2476 z = backend.bias_add(z, self.bias) 

2477 

2478 z = array_ops.split(z, num_or_size_splits=4, axis=1) 

2479 c, o = self._compute_carry_and_output_fused(z, c_tm1) 

2480 

2481 h = o * self.activation(c) 

2482 return h, [h, c] 

2483 

2484 def get_config(self): 

2485 config = { 

2486 'units': 

2487 self.units, 

2488 'activation': 

2489 activations.serialize(self.activation), 

2490 'recurrent_activation': 

2491 activations.serialize(self.recurrent_activation), 

2492 'use_bias': 

2493 self.use_bias, 

2494 'kernel_initializer': 

2495 initializers.serialize(self.kernel_initializer), 

2496 'recurrent_initializer': 

2497 initializers.serialize(self.recurrent_initializer), 

2498 'bias_initializer': 

2499 initializers.serialize(self.bias_initializer), 

2500 'unit_forget_bias': 

2501 self.unit_forget_bias, 

2502 'kernel_regularizer': 

2503 regularizers.serialize(self.kernel_regularizer), 

2504 'recurrent_regularizer': 

2505 regularizers.serialize(self.recurrent_regularizer), 

2506 'bias_regularizer': 

2507 regularizers.serialize(self.bias_regularizer), 

2508 'kernel_constraint': 

2509 constraints.serialize(self.kernel_constraint), 

2510 'recurrent_constraint': 

2511 constraints.serialize(self.recurrent_constraint), 

2512 'bias_constraint': 

2513 constraints.serialize(self.bias_constraint), 

2514 'dropout': 

2515 self.dropout, 

2516 'recurrent_dropout': 

2517 self.recurrent_dropout, 

2518 'implementation': 

2519 self.implementation 

2520 } 

2521 config.update(_config_for_enable_caching_device(self)) 

2522 base_config = super(LSTMCell, self).get_config() 

2523 return dict(list(base_config.items()) + list(config.items())) 

2524 

2525 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

2526 return list(_generate_zero_filled_state_for_cell( 

2527 self, inputs, batch_size, dtype)) 

2528 

2529 

2530@keras_export('keras.experimental.PeepholeLSTMCell') 

2531class PeepholeLSTMCell(LSTMCell): 

2532 """Equivalent to LSTMCell class but adds peephole connections. 

2533 

2534 Peephole connections allow the gates to utilize the previous internal state as 

2535 well as the previous hidden state (which is what LSTMCell is limited to). 

2536 This allows PeepholeLSTMCell to better learn precise timings over LSTMCell. 

2537 

2538 From [Gers et al., 2002]( 

2539 http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf): 

2540 

2541 "We find that LSTM augmented by 'peephole connections' from its internal 

2542 cells to its multiplicative gates can learn the fine distinction between 

2543 sequences of spikes spaced either 50 or 49 time steps apart without the help 

2544 of any short training exemplars." 

2545 

2546 The peephole implementation is based on: 

2547 

2548 [Sak et al., 2014](https://research.google.com/pubs/archive/43905.pdf) 

2549 

2550 Example: 

2551 

2552 ```python 

2553 # Create 2 PeepholeLSTMCells 

2554 peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]] 

2555 # Create a layer composed sequentially of the peephole LSTM cells. 

2556 layer = RNN(peephole_lstm_cells) 

2557 input = keras.Input((timesteps, input_dim)) 

2558 output = layer(input) 

2559 ``` 

2560 """ 

2561 

2562 def __init__(self, 

2563 units, 

2564 activation='tanh', 

2565 recurrent_activation='hard_sigmoid', 

2566 use_bias=True, 

2567 kernel_initializer='glorot_uniform', 

2568 recurrent_initializer='orthogonal', 

2569 bias_initializer='zeros', 

2570 unit_forget_bias=True, 

2571 kernel_regularizer=None, 

2572 recurrent_regularizer=None, 

2573 bias_regularizer=None, 

2574 kernel_constraint=None, 

2575 recurrent_constraint=None, 

2576 bias_constraint=None, 

2577 dropout=0., 

2578 recurrent_dropout=0., 

2579 **kwargs): 

2580 warnings.warn('`tf.keras.experimental.PeepholeLSTMCell` is deprecated ' 

2581 'and will be removed in a future version. ' 

2582 'Please use tensorflow_addons.rnn.PeepholeLSTMCell ' 

2583 'instead.') 

2584 super(PeepholeLSTMCell, self).__init__( 

2585 units=units, 

2586 activation=activation, 

2587 recurrent_activation=recurrent_activation, 

2588 use_bias=use_bias, 

2589 kernel_initializer=kernel_initializer, 

2590 recurrent_initializer=recurrent_initializer, 

2591 bias_initializer=bias_initializer, 

2592 unit_forget_bias=unit_forget_bias, 

2593 kernel_regularizer=kernel_regularizer, 

2594 recurrent_regularizer=recurrent_regularizer, 

2595 bias_regularizer=bias_regularizer, 

2596 kernel_constraint=kernel_constraint, 

2597 recurrent_constraint=recurrent_constraint, 

2598 bias_constraint=bias_constraint, 

2599 dropout=dropout, 

2600 recurrent_dropout=recurrent_dropout, 

2601 implementation=kwargs.pop('implementation', 1), 

2602 **kwargs) 

2603 

2604 def build(self, input_shape): 

2605 super(PeepholeLSTMCell, self).build(input_shape) 

2606 # The following are the weight matrices for the peephole connections. These 

2607 # are multiplied with the previous internal state during the computation of 

2608 # carry and output. 

2609 self.input_gate_peephole_weights = self.add_weight( 

2610 shape=(self.units,), 

2611 name='input_gate_peephole_weights', 

2612 initializer=self.kernel_initializer) 

2613 self.forget_gate_peephole_weights = self.add_weight( 

2614 shape=(self.units,), 

2615 name='forget_gate_peephole_weights', 

2616 initializer=self.kernel_initializer) 

2617 self.output_gate_peephole_weights = self.add_weight( 

2618 shape=(self.units,), 

2619 name='output_gate_peephole_weights', 

2620 initializer=self.kernel_initializer) 

2621 

2622 def _compute_carry_and_output(self, x, h_tm1, c_tm1): 

2623 x_i, x_f, x_c, x_o = x 

2624 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 

2625 i = self.recurrent_activation( 

2626 x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) + 

2627 self.input_gate_peephole_weights * c_tm1) 

2628 f = self.recurrent_activation(x_f + backend.dot( 

2629 h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) + 

2630 self.forget_gate_peephole_weights * c_tm1) 

2631 c = f * c_tm1 + i * self.activation(x_c + backend.dot( 

2632 h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) 

2633 o = self.recurrent_activation( 

2634 x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) + 

2635 self.output_gate_peephole_weights * c) 

2636 return c, o 

2637 

2638 def _compute_carry_and_output_fused(self, z, c_tm1): 

2639 z0, z1, z2, z3 = z 

2640 i = self.recurrent_activation(z0 + 

2641 self.input_gate_peephole_weights * c_tm1) 

2642 f = self.recurrent_activation(z1 + 

2643 self.forget_gate_peephole_weights * c_tm1) 

2644 c = f * c_tm1 + i * self.activation(z2) 

2645 o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c) 

2646 return c, o 

2647 

2648 

2649@keras_export(v1=['keras.layers.LSTM']) 

2650class LSTM(RNN): 

2651 """Long Short-Term Memory layer - Hochreiter 1997. 

2652 

2653 Note that this cell is not optimized for performance on GPU. Please use 

2654 `tf.compat.v1.keras.layers.CuDNNLSTM` for better performance on GPU. 

2655 

2656 Args: 

2657 units: Positive integer, dimensionality of the output space. 

2658 activation: Activation function to use. 

2659 Default: hyperbolic tangent (`tanh`). 

2660 If you pass `None`, no activation is applied 

2661 (ie. "linear" activation: `a(x) = x`). 

2662 recurrent_activation: Activation function to use 

2663 for the recurrent step. 

2664 Default: hard sigmoid (`hard_sigmoid`). 

2665 If you pass `None`, no activation is applied 

2666 (ie. "linear" activation: `a(x) = x`). 

2667 use_bias: Boolean, whether the layer uses a bias vector. 

2668 kernel_initializer: Initializer for the `kernel` weights matrix, 

2669 used for the linear transformation of the inputs.. 

2670 recurrent_initializer: Initializer for the `recurrent_kernel` 

2671 weights matrix, 

2672 used for the linear transformation of the recurrent state. 

2673 bias_initializer: Initializer for the bias vector. 

2674 unit_forget_bias: Boolean. 

2675 If True, add 1 to the bias of the forget gate at initialization. 

2676 Setting it to true will also force `bias_initializer="zeros"`. 

2677 This is recommended in [Jozefowicz et al., 2015]( 

2678 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf). 

2679 kernel_regularizer: Regularizer function applied to 

2680 the `kernel` weights matrix. 

2681 recurrent_regularizer: Regularizer function applied to 

2682 the `recurrent_kernel` weights matrix. 

2683 bias_regularizer: Regularizer function applied to the bias vector. 

2684 activity_regularizer: Regularizer function applied to 

2685 the output of the layer (its "activation"). 

2686 kernel_constraint: Constraint function applied to 

2687 the `kernel` weights matrix. 

2688 recurrent_constraint: Constraint function applied to 

2689 the `recurrent_kernel` weights matrix. 

2690 bias_constraint: Constraint function applied to the bias vector. 

2691 dropout: Float between 0 and 1. 

2692 Fraction of the units to drop for 

2693 the linear transformation of the inputs. 

2694 recurrent_dropout: Float between 0 and 1. 

2695 Fraction of the units to drop for 

2696 the linear transformation of the recurrent state. 

2697 return_sequences: Boolean. Whether to return the last output. 

2698 in the output sequence, or the full sequence. 

2699 return_state: Boolean. Whether to return the last state 

2700 in addition to the output. 

2701 go_backwards: Boolean (default False). 

2702 If True, process the input sequence backwards and return the 

2703 reversed sequence. 

2704 stateful: Boolean (default False). If True, the last state 

2705 for each sample at index i in a batch will be used as initial 

2706 state for the sample of index i in the following batch. 

2707 unroll: Boolean (default False). 

2708 If True, the network will be unrolled, 

2709 else a symbolic loop will be used. 

2710 Unrolling can speed-up a RNN, 

2711 although it tends to be more memory-intensive. 

2712 Unrolling is only suitable for short sequences. 

2713 time_major: The shape format of the `inputs` and `outputs` tensors. 

2714 If True, the inputs and outputs will be in shape 

2715 `(timesteps, batch, ...)`, whereas in the False case, it will be 

2716 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 

2717 efficient because it avoids transposes at the beginning and end of the 

2718 RNN calculation. However, most TensorFlow data is batch-major, so by 

2719 default this function accepts input and emits output in batch-major 

2720 form. 

2721 

2722 Call arguments: 

2723 inputs: A 3D tensor. 

2724 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 

2725 a given timestep should be masked. An individual `True` entry indicates 

2726 that the corresponding timestep should be utilized, while a `False` 

2727 entry indicates that the corresponding timestep should be ignored. 

2728 training: Python boolean indicating whether the layer should behave in 

2729 training mode or in inference mode. This argument is passed to the cell 

2730 when calling it. This is only relevant if `dropout` or 

2731 `recurrent_dropout` is used. 

2732 initial_state: List of initial state tensors to be passed to the first 

2733 call of the cell. 

2734 """ 

2735 

2736 def __init__(self, 

2737 units, 

2738 activation='tanh', 

2739 recurrent_activation='hard_sigmoid', 

2740 use_bias=True, 

2741 kernel_initializer='glorot_uniform', 

2742 recurrent_initializer='orthogonal', 

2743 bias_initializer='zeros', 

2744 unit_forget_bias=True, 

2745 kernel_regularizer=None, 

2746 recurrent_regularizer=None, 

2747 bias_regularizer=None, 

2748 activity_regularizer=None, 

2749 kernel_constraint=None, 

2750 recurrent_constraint=None, 

2751 bias_constraint=None, 

2752 dropout=0., 

2753 recurrent_dropout=0., 

2754 return_sequences=False, 

2755 return_state=False, 

2756 go_backwards=False, 

2757 stateful=False, 

2758 unroll=False, 

2759 **kwargs): 

2760 implementation = kwargs.pop('implementation', 1) 

2761 if implementation == 0: 

2762 logging.warning('`implementation=0` has been deprecated, ' 

2763 'and now defaults to `implementation=1`.' 

2764 'Please update your layer call.') 

2765 if 'enable_caching_device' in kwargs: 

2766 cell_kwargs = {'enable_caching_device': 

2767 kwargs.pop('enable_caching_device')} 

2768 else: 

2769 cell_kwargs = {} 

2770 cell = LSTMCell( 

2771 units, 

2772 activation=activation, 

2773 recurrent_activation=recurrent_activation, 

2774 use_bias=use_bias, 

2775 kernel_initializer=kernel_initializer, 

2776 recurrent_initializer=recurrent_initializer, 

2777 unit_forget_bias=unit_forget_bias, 

2778 bias_initializer=bias_initializer, 

2779 kernel_regularizer=kernel_regularizer, 

2780 recurrent_regularizer=recurrent_regularizer, 

2781 bias_regularizer=bias_regularizer, 

2782 kernel_constraint=kernel_constraint, 

2783 recurrent_constraint=recurrent_constraint, 

2784 bias_constraint=bias_constraint, 

2785 dropout=dropout, 

2786 recurrent_dropout=recurrent_dropout, 

2787 implementation=implementation, 

2788 dtype=kwargs.get('dtype'), 

2789 trainable=kwargs.get('trainable', True), 

2790 **cell_kwargs) 

2791 super(LSTM, self).__init__( 

2792 cell, 

2793 return_sequences=return_sequences, 

2794 return_state=return_state, 

2795 go_backwards=go_backwards, 

2796 stateful=stateful, 

2797 unroll=unroll, 

2798 **kwargs) 

2799 self.activity_regularizer = regularizers.get(activity_regularizer) 

2800 self.input_spec = [InputSpec(ndim=3)] 

2801 

2802 def call(self, inputs, mask=None, training=None, initial_state=None): 

2803 return super(LSTM, self).call( 

2804 inputs, mask=mask, training=training, initial_state=initial_state) 

2805 

2806 @property 

2807 def units(self): 

2808 return self.cell.units 

2809 

2810 @property 

2811 def activation(self): 

2812 return self.cell.activation 

2813 

2814 @property 

2815 def recurrent_activation(self): 

2816 return self.cell.recurrent_activation 

2817 

2818 @property 

2819 def use_bias(self): 

2820 return self.cell.use_bias 

2821 

2822 @property 

2823 def kernel_initializer(self): 

2824 return self.cell.kernel_initializer 

2825 

2826 @property 

2827 def recurrent_initializer(self): 

2828 return self.cell.recurrent_initializer 

2829 

2830 @property 

2831 def bias_initializer(self): 

2832 return self.cell.bias_initializer 

2833 

2834 @property 

2835 def unit_forget_bias(self): 

2836 return self.cell.unit_forget_bias 

2837 

2838 @property 

2839 def kernel_regularizer(self): 

2840 return self.cell.kernel_regularizer 

2841 

2842 @property 

2843 def recurrent_regularizer(self): 

2844 return self.cell.recurrent_regularizer 

2845 

2846 @property 

2847 def bias_regularizer(self): 

2848 return self.cell.bias_regularizer 

2849 

2850 @property 

2851 def kernel_constraint(self): 

2852 return self.cell.kernel_constraint 

2853 

2854 @property 

2855 def recurrent_constraint(self): 

2856 return self.cell.recurrent_constraint 

2857 

2858 @property 

2859 def bias_constraint(self): 

2860 return self.cell.bias_constraint 

2861 

2862 @property 

2863 def dropout(self): 

2864 return self.cell.dropout 

2865 

2866 @property 

2867 def recurrent_dropout(self): 

2868 return self.cell.recurrent_dropout 

2869 

2870 @property 

2871 def implementation(self): 

2872 return self.cell.implementation 

2873 

2874 def get_config(self): 

2875 config = { 

2876 'units': 

2877 self.units, 

2878 'activation': 

2879 activations.serialize(self.activation), 

2880 'recurrent_activation': 

2881 activations.serialize(self.recurrent_activation), 

2882 'use_bias': 

2883 self.use_bias, 

2884 'kernel_initializer': 

2885 initializers.serialize(self.kernel_initializer), 

2886 'recurrent_initializer': 

2887 initializers.serialize(self.recurrent_initializer), 

2888 'bias_initializer': 

2889 initializers.serialize(self.bias_initializer), 

2890 'unit_forget_bias': 

2891 self.unit_forget_bias, 

2892 'kernel_regularizer': 

2893 regularizers.serialize(self.kernel_regularizer), 

2894 'recurrent_regularizer': 

2895 regularizers.serialize(self.recurrent_regularizer), 

2896 'bias_regularizer': 

2897 regularizers.serialize(self.bias_regularizer), 

2898 'activity_regularizer': 

2899 regularizers.serialize(self.activity_regularizer), 

2900 'kernel_constraint': 

2901 constraints.serialize(self.kernel_constraint), 

2902 'recurrent_constraint': 

2903 constraints.serialize(self.recurrent_constraint), 

2904 'bias_constraint': 

2905 constraints.serialize(self.bias_constraint), 

2906 'dropout': 

2907 self.dropout, 

2908 'recurrent_dropout': 

2909 self.recurrent_dropout, 

2910 'implementation': 

2911 self.implementation 

2912 } 

2913 config.update(_config_for_enable_caching_device(self.cell)) 

2914 base_config = super(LSTM, self).get_config() 

2915 del base_config['cell'] 

2916 return dict(list(base_config.items()) + list(config.items())) 

2917 

2918 @classmethod 

2919 def from_config(cls, config): 

2920 if 'implementation' in config and config['implementation'] == 0: 

2921 config['implementation'] = 1 

2922 return cls(**config) 

2923 

2924 

2925def _generate_dropout_mask(ones, rate, training=None, count=1): 

2926 def dropped_inputs(): 

2927 return backend.dropout(ones, rate) 

2928 

2929 if count > 1: 

2930 return [ 

2931 backend.in_train_phase(dropped_inputs, ones, training=training) 

2932 for _ in range(count) 

2933 ] 

2934 return backend.in_train_phase(dropped_inputs, ones, training=training) 

2935 

2936 

2937def _standardize_args(inputs, initial_state, constants, num_constants): 

2938 """Standardizes `__call__` to a single list of tensor inputs. 

2939 

2940 When running a model loaded from a file, the input tensors 

2941 `initial_state` and `constants` can be passed to `RNN.__call__()` as part 

2942 of `inputs` instead of by the dedicated keyword arguments. This method 

2943 makes sure the arguments are separated and that `initial_state` and 

2944 `constants` are lists of tensors (or None). 

2945 

2946 Args: 

2947 inputs: Tensor or list/tuple of tensors. which may include constants 

2948 and initial states. In that case `num_constant` must be specified. 

2949 initial_state: Tensor or list of tensors or None, initial states. 

2950 constants: Tensor or list of tensors or None, constant tensors. 

2951 num_constants: Expected number of constants (if constants are passed as 

2952 part of the `inputs` list. 

2953 

2954 Returns: 

2955 inputs: Single tensor or tuple of tensors. 

2956 initial_state: List of tensors or None. 

2957 constants: List of tensors or None. 

2958 """ 

2959 if isinstance(inputs, list): 

2960 # There are several situations here: 

2961 # In the graph mode, __call__ will be only called once. The initial_state 

2962 # and constants could be in inputs (from file loading). 

2963 # In the eager mode, __call__ will be called twice, once during 

2964 # rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be 

2965 # model.fit/train_on_batch/predict with real np data. In the second case, 

2966 # the inputs will contain initial_state and constants as eager tensor. 

2967 # 

2968 # For either case, the real input is the first item in the list, which 

2969 # could be a nested structure itself. Then followed by initial_states, which 

2970 # could be a list of items, or list of list if the initial_state is complex 

2971 # structure, and finally followed by constants which is a flat list. 

2972 assert initial_state is None and constants is None 

2973 if num_constants: 

2974 constants = inputs[-num_constants:] 

2975 inputs = inputs[:-num_constants] 

2976 if len(inputs) > 1: 

2977 initial_state = inputs[1:] 

2978 inputs = inputs[:1] 

2979 

2980 if len(inputs) > 1: 

2981 inputs = tuple(inputs) 

2982 else: 

2983 inputs = inputs[0] 

2984 

2985 def to_list_or_none(x): 

2986 if x is None or isinstance(x, list): 

2987 return x 

2988 if isinstance(x, tuple): 

2989 return list(x) 

2990 return [x] 

2991 

2992 initial_state = to_list_or_none(initial_state) 

2993 constants = to_list_or_none(constants) 

2994 

2995 return inputs, initial_state, constants 

2996 

2997 

2998def _is_multiple_state(state_size): 

2999 """Check whether the state_size contains multiple states.""" 

3000 return (hasattr(state_size, '__len__') and 

3001 not isinstance(state_size, tensor_shape.TensorShape)) 

3002 

3003 

3004def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype): 

3005 if inputs is not None: 

3006 batch_size = array_ops.shape(inputs)[0] 

3007 dtype = inputs.dtype 

3008 return _generate_zero_filled_state(batch_size, cell.state_size, dtype) 

3009 

3010 

3011def _generate_zero_filled_state(batch_size_tensor, state_size, dtype): 

3012 """Generate a zero filled tensor with shape [batch_size, state_size].""" 

3013 if batch_size_tensor is None or dtype is None: 

3014 raise ValueError( 

3015 'batch_size and dtype cannot be None while constructing initial state: ' 

3016 'batch_size={}, dtype={}'.format(batch_size_tensor, dtype)) 

3017 

3018 def create_zeros(unnested_state_size): 

3019 flat_dims = tensor_shape.TensorShape(unnested_state_size).as_list() 

3020 init_state_size = [batch_size_tensor] + flat_dims 

3021 return array_ops.zeros(init_state_size, dtype=dtype) 

3022 

3023 if nest.is_nested(state_size): 

3024 return nest.map_structure(create_zeros, state_size) 

3025 else: 

3026 return create_zeros(state_size) 

3027 

3028 

3029def _caching_device(rnn_cell): 

3030 """Returns the caching device for the RNN variable. 

3031 

3032 This is useful for distributed training, when variable is not located as same 

3033 device as the training worker. By enabling the device cache, this allows 

3034 worker to read the variable once and cache locally, rather than read it every 

3035 time step from remote when it is needed. 

3036 

3037 Note that this is assuming the variable that cell needs for each time step is 

3038 having the same value in the forward path, and only gets updated in the 

3039 backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If the 

3040 cell body relies on any variable that gets updated every time step, then 

3041 caching device will cause it to read the stall value. 

3042 

3043 Args: 

3044 rnn_cell: the rnn cell instance. 

3045 """ 

3046 if context.executing_eagerly(): 

3047 # caching_device is not supported in eager mode. 

3048 return None 

3049 if not getattr(rnn_cell, '_enable_caching_device', False): 

3050 return None 

3051 # Don't set a caching device when running in a loop, since it is possible that 

3052 # train steps could be wrapped in a tf.while_loop. In that scenario caching 

3053 # prevents forward computations in loop iterations from re-reading the 

3054 # updated weights. 

3055 if control_flow_util.IsInWhileLoop(ops.get_default_graph()): 

3056 logging.warning( 

3057 'Variable read device caching has been disabled because the ' 

3058 'RNN is in tf.while_loop loop context, which will cause ' 

3059 'reading stalled value in forward path. This could slow down ' 

3060 'the training due to duplicated variable reads. Please ' 

3061 'consider updating your code to remove tf.while_loop if possible.') 

3062 return None 

3063 if (rnn_cell._dtype_policy.compute_dtype != 

3064 rnn_cell._dtype_policy.variable_dtype): 

3065 logging.warning( 

3066 'Variable read device caching has been disabled since it ' 

3067 'doesn\'t work with the mixed precision API. This is ' 

3068 'likely to cause a slowdown for RNN training due to ' 

3069 'duplicated read of variable for each timestep, which ' 

3070 'will be significant in a multi remote worker setting. ' 

3071 'Please consider disabling mixed precision API if ' 

3072 'the performance has been affected.') 

3073 return None 

3074 # Cache the value on the device that access the variable. 

3075 return lambda op: op.device 

3076 

3077 

3078def _config_for_enable_caching_device(rnn_cell): 

3079 """Return the dict config for RNN cell wrt to enable_caching_device field. 

3080 

3081 Since enable_caching_device is a internal implementation detail for speed up 

3082 the RNN variable read when running on the multi remote worker setting, we 

3083 don't want this config to be serialized constantly in the JSON. We will only 

3084 serialize this field when a none default value is used to create the cell. 

3085 Args: 

3086 rnn_cell: the RNN cell for serialize. 

3087 

3088 Returns: 

3089 A dict which contains the JSON config for enable_caching_device value or 

3090 empty dict if the enable_caching_device value is same as the default value. 

3091 """ 

3092 default_enable_caching_device = ops.executing_eagerly_outside_functions() 

3093 if rnn_cell._enable_caching_device != default_enable_caching_device: 

3094 return {'enable_caching_device': rnn_cell._enable_caching_device} 

3095 return {}