Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/base_rnn.py: 13%

342 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base class for recurrent layers.""" 

16 

17 

18import collections 

19 

20import numpy as np 

21import tensorflow.compat.v2 as tf 

22 

23from keras.src import backend 

24from keras.src.engine import base_layer 

25from keras.src.engine.input_spec import InputSpec 

26from keras.src.layers.rnn import rnn_utils 

27from keras.src.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin 

28from keras.src.layers.rnn.stacked_rnn_cells import StackedRNNCells 

29from keras.src.saving import serialization_lib 

30from keras.src.saving.legacy.saved_model import layer_serialization 

31from keras.src.utils import generic_utils 

32 

33# isort: off 

34from tensorflow.python.util.tf_export import keras_export 

35from tensorflow.tools.docs import doc_controls 

36 

37 

38@keras_export("keras.layers.RNN") 

39class RNN(base_layer.Layer): 

40 """Base class for recurrent layers. 

41 

42 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 

43 for details about the usage of RNN API. 

44 

45 Args: 

46 cell: A RNN cell instance or a list of RNN cell instances. 

47 A RNN cell is a class that has: 

48 - A `call(input_at_t, states_at_t)` method, returning 

49 `(output_at_t, states_at_t_plus_1)`. The call method of the 

50 cell can also take the optional argument `constants`, see 

51 section "Note on passing external constants" below. 

52 - A `state_size` attribute. This can be a single integer 

53 (single state) in which case it is the size of the recurrent 

54 state. This can also be a list/tuple of integers (one size per state). 

55 The `state_size` can also be TensorShape or tuple/list of 

56 TensorShape, to represent high dimension state. 

57 - A `output_size` attribute. This can be a single integer or a 

58 TensorShape, which represent the shape of the output. For backward 

59 compatible reason, if this attribute is not available for the 

60 cell, the value will be inferred by the first element of the 

61 `state_size`. 

62 - A `get_initial_state(inputs=None, batch_size=None, dtype=None)` 

63 method that creates a tensor meant to be fed to `call()` as the 

64 initial state, if the user didn't specify any initial state via other 

65 means. The returned initial state should have a shape of 

66 [batch_size, cell.state_size]. The cell might choose to create a 

67 tensor full of zeros, or full of other values based on the cell's 

68 implementation. 

69 `inputs` is the input tensor to the RNN layer, which should 

70 contain the batch size as its shape[0], and also dtype. Note that 

71 the shape[0] might be `None` during the graph construction. Either 

72 the `inputs` or the pair of `batch_size` and `dtype` are provided. 

73 `batch_size` is a scalar tensor that represents the batch size 

74 of the inputs. `dtype` is `tf.DType` that represents the dtype of 

75 the inputs. 

76 For backward compatibility, if this method is not implemented 

77 by the cell, the RNN layer will create a zero filled tensor with the 

78 size of [batch_size, cell.state_size]. 

79 In the case that `cell` is a list of RNN cell instances, the cells 

80 will be stacked on top of each other in the RNN, resulting in an 

81 efficient stacked RNN. 

82 return_sequences: Boolean (default `False`). Whether to return the last 

83 output in the output sequence, or the full sequence. 

84 return_state: Boolean (default `False`). Whether to return the last state 

85 in addition to the output. 

86 go_backwards: Boolean (default `False`). 

87 If True, process the input sequence backwards and return the 

88 reversed sequence. 

89 stateful: Boolean (default `False`). If True, the last state 

90 for each sample at index i in a batch will be used as initial 

91 state for the sample of index i in the following batch. 

92 unroll: Boolean (default `False`). 

93 If True, the network will be unrolled, else a symbolic loop will be 

94 used. Unrolling can speed-up a RNN, although it tends to be more 

95 memory-intensive. Unrolling is only suitable for short sequences. 

96 time_major: The shape format of the `inputs` and `outputs` tensors. 

97 If True, the inputs and outputs will be in shape 

98 `(timesteps, batch, ...)`, whereas in the False case, it will be 

99 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 

100 efficient because it avoids transposes at the beginning and end of the 

101 RNN calculation. However, most TensorFlow data is batch-major, so by 

102 default this function accepts input and emits output in batch-major 

103 form. 

104 zero_output_for_mask: Boolean (default `False`). 

105 Whether the output should use zeros for the masked timesteps. Note that 

106 this field is only used when `return_sequences` is True and mask is 

107 provided. It can useful if you want to reuse the raw output sequence of 

108 the RNN without interference from the masked timesteps, eg, merging 

109 bidirectional RNNs. 

110 

111 Call arguments: 

112 inputs: Input tensor. 

113 mask: Binary tensor of shape `[batch_size, timesteps]` indicating whether 

114 a given timestep should be masked. An individual `True` entry indicates 

115 that the corresponding timestep should be utilized, while a `False` 

116 entry indicates that the corresponding timestep should be ignored. 

117 training: Python boolean indicating whether the layer should behave in 

118 training mode or in inference mode. This argument is passed to the cell 

119 when calling it. This is for use with cells that use dropout. 

120 initial_state: List of initial state tensors to be passed to the first 

121 call of the cell. 

122 constants: List of constant tensors to be passed to the cell at each 

123 timestep. 

124 

125 Input shape: 

126 N-D tensor with shape `[batch_size, timesteps, ...]` or 

127 `[timesteps, batch_size, ...]` when time_major is True. 

128 

129 Output shape: 

130 - If `return_state`: a list of tensors. The first tensor is 

131 the output. The remaining tensors are the last states, 

132 each with shape `[batch_size, state_size]`, where `state_size` could 

133 be a high dimension tensor shape. 

134 - If `return_sequences`: N-D tensor with shape 

135 `[batch_size, timesteps, output_size]`, where `output_size` could 

136 be a high dimension tensor shape, or 

137 `[timesteps, batch_size, output_size]` when `time_major` is True. 

138 - Else, N-D tensor with shape `[batch_size, output_size]`, where 

139 `output_size` could be a high dimension tensor shape. 

140 

141 Masking: 

142 This layer supports masking for input data with a variable number 

143 of timesteps. To introduce masks to your data, 

144 use an [tf.keras.layers.Embedding] layer with the `mask_zero` parameter 

145 set to `True`. 

146 

147 Note on using statefulness in RNNs: 

148 You can set RNN layers to be 'stateful', which means that the states 

149 computed for the samples in one batch will be reused as initial states 

150 for the samples in the next batch. This assumes a one-to-one mapping 

151 between samples in different successive batches. 

152 

153 To enable statefulness: 

154 - Specify `stateful=True` in the layer constructor. 

155 - Specify a fixed batch size for your model, by passing 

156 If sequential model: 

157 `batch_input_shape=(...)` to the first layer in your model. 

158 Else for functional model with 1 or more Input layers: 

159 `batch_shape=(...)` to all the first layers in your model. 

160 This is the expected shape of your inputs 

161 *including the batch size*. 

162 It should be a tuple of integers, e.g. `(32, 10, 100)`. 

163 - Specify `shuffle=False` when calling `fit()`. 

164 

165 To reset the states of your model, call `.reset_states()` on either 

166 a specific layer, or on your entire model. 

167 

168 Note on specifying the initial state of RNNs: 

169 You can specify the initial state of RNN layers symbolically by 

170 calling them with the keyword argument `initial_state`. The value of 

171 `initial_state` should be a tensor or list of tensors representing 

172 the initial state of the RNN layer. 

173 

174 You can specify the initial state of RNN layers numerically by 

175 calling `reset_states` with the keyword argument `states`. The value of 

176 `states` should be a numpy array or list of numpy arrays representing 

177 the initial state of the RNN layer. 

178 

179 Note on passing external constants to RNNs: 

180 You can pass "external" constants to the cell using the `constants` 

181 keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This 

182 requires that the `cell.call` method accepts the same keyword argument 

183 `constants`. Such constants can be used to condition the cell 

184 transformation on additional static inputs (not changing over time), 

185 a.k.a. an attention mechanism. 

186 

187 Examples: 

188 

189 ```python 

190 from keras.src.layers import RNN 

191 from keras.src import backend 

192 

193 # First, let's define a RNN Cell, as a layer subclass. 

194 class MinimalRNNCell(keras.layers.Layer): 

195 

196 def __init__(self, units, **kwargs): 

197 self.units = units 

198 self.state_size = units 

199 super(MinimalRNNCell, self).__init__(**kwargs) 

200 

201 def build(self, input_shape): 

202 self.kernel = self.add_weight(shape=(input_shape[-1], self.units), 

203 initializer='uniform', 

204 name='kernel') 

205 self.recurrent_kernel = self.add_weight( 

206 shape=(self.units, self.units), 

207 initializer='uniform', 

208 name='recurrent_kernel') 

209 self.built = True 

210 

211 def call(self, inputs, states): 

212 prev_output = states[0] 

213 h = backend.dot(inputs, self.kernel) 

214 output = h + backend.dot(prev_output, self.recurrent_kernel) 

215 return output, [output] 

216 

217 # Let's use this cell in a RNN layer: 

218 

219 cell = MinimalRNNCell(32) 

220 x = keras.Input((None, 5)) 

221 layer = RNN(cell) 

222 y = layer(x) 

223 

224 # Here's how to use the cell to build a stacked RNN: 

225 

226 cells = [MinimalRNNCell(32), MinimalRNNCell(64)] 

227 x = keras.Input((None, 5)) 

228 layer = RNN(cells) 

229 y = layer(x) 

230 ``` 

231 """ 

232 

233 def __init__( 

234 self, 

235 cell, 

236 return_sequences=False, 

237 return_state=False, 

238 go_backwards=False, 

239 stateful=False, 

240 unroll=False, 

241 time_major=False, 

242 **kwargs, 

243 ): 

244 if isinstance(cell, (list, tuple)): 

245 cell = StackedRNNCells(cell) 

246 if "call" not in dir(cell): 

247 raise ValueError( 

248 "Argument `cell` should have a `call` method. " 

249 f"The RNN was passed: cell={cell}" 

250 ) 

251 if "state_size" not in dir(cell): 

252 raise ValueError( 

253 "The RNN cell should have a `state_size` attribute " 

254 "(tuple of integers, one integer per RNN state). " 

255 f"Received: cell={cell}" 

256 ) 

257 # If True, the output for masked timestep will be zeros, whereas in the 

258 # False case, output from previous timestep is returned for masked 

259 # timestep. 

260 self.zero_output_for_mask = kwargs.pop("zero_output_for_mask", False) 

261 

262 if "input_shape" not in kwargs and ( 

263 "input_dim" in kwargs or "input_length" in kwargs 

264 ): 

265 input_shape = ( 

266 kwargs.pop("input_length", None), 

267 kwargs.pop("input_dim", None), 

268 ) 

269 kwargs["input_shape"] = input_shape 

270 

271 super().__init__(**kwargs) 

272 self.cell = cell 

273 self.return_sequences = return_sequences 

274 self.return_state = return_state 

275 self.go_backwards = go_backwards 

276 self.stateful = stateful 

277 self.unroll = unroll 

278 self.time_major = time_major 

279 

280 self.supports_masking = True 

281 # The input shape is unknown yet, it could have nested tensor inputs, 

282 # and the input spec will be the list of specs for nested inputs, the 

283 # structure of the input_spec will be the same as the input. 

284 self.input_spec = None 

285 self.state_spec = None 

286 self._states = None 

287 self.constants_spec = None 

288 self._num_constants = 0 

289 

290 if stateful: 

291 if tf.distribute.has_strategy(): 

292 raise ValueError( 

293 "Stateful RNNs (created with `stateful=True`) " 

294 "are not yet supported with tf.distribute.Strategy." 

295 ) 

296 

297 @property 

298 def _use_input_spec_as_call_signature(self): 

299 if self.unroll: 

300 # When the RNN layer is unrolled, the time step shape cannot be 

301 # unknown. The input spec does not define the time step (because 

302 # this layer can be called with any time step value, as long as it 

303 # is not None), so it cannot be used as the call function signature 

304 # when saving to SavedModel. 

305 return False 

306 return super()._use_input_spec_as_call_signature 

307 

308 @property 

309 def states(self): 

310 if self._states is None: 

311 state = tf.nest.map_structure(lambda _: None, self.cell.state_size) 

312 return state if tf.nest.is_nested(self.cell.state_size) else [state] 

313 return self._states 

314 

315 @states.setter 

316 # Automatic tracking catches "self._states" which adds an extra weight and 

317 # breaks HDF5 checkpoints. 

318 @tf.__internal__.tracking.no_automatic_dependency_tracking 

319 def states(self, states): 

320 self._states = states 

321 

322 def compute_output_shape(self, input_shape): 

323 if isinstance(input_shape, list): 

324 input_shape = input_shape[0] 

325 # Check whether the input shape contains any nested shapes. It could be 

326 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from 

327 # numpy inputs. 

328 try: 

329 input_shape = tf.TensorShape(input_shape) 

330 except (ValueError, TypeError): 

331 # A nested tensor input 

332 input_shape = tf.nest.flatten(input_shape)[0] 

333 

334 batch = input_shape[0] 

335 time_step = input_shape[1] 

336 if self.time_major: 

337 batch, time_step = time_step, batch 

338 

339 if rnn_utils.is_multiple_state(self.cell.state_size): 

340 state_size = self.cell.state_size 

341 else: 

342 state_size = [self.cell.state_size] 

343 

344 def _get_output_shape(flat_output_size): 

345 output_dim = tf.TensorShape(flat_output_size).as_list() 

346 if self.return_sequences: 

347 if self.time_major: 

348 output_shape = tf.TensorShape( 

349 [time_step, batch] + output_dim 

350 ) 

351 else: 

352 output_shape = tf.TensorShape( 

353 [batch, time_step] + output_dim 

354 ) 

355 else: 

356 output_shape = tf.TensorShape([batch] + output_dim) 

357 return output_shape 

358 

359 if getattr(self.cell, "output_size", None) is not None: 

360 # cell.output_size could be nested structure. 

361 output_shape = tf.nest.flatten( 

362 tf.nest.map_structure(_get_output_shape, self.cell.output_size) 

363 ) 

364 output_shape = ( 

365 output_shape[0] if len(output_shape) == 1 else output_shape 

366 ) 

367 else: 

368 # Note that state_size[0] could be a tensor_shape or int. 

369 output_shape = _get_output_shape(state_size[0]) 

370 

371 if self.return_state: 

372 

373 def _get_state_shape(flat_state): 

374 state_shape = [batch] + tf.TensorShape(flat_state).as_list() 

375 return tf.TensorShape(state_shape) 

376 

377 state_shape = tf.nest.map_structure(_get_state_shape, state_size) 

378 return generic_utils.to_list(output_shape) + tf.nest.flatten( 

379 state_shape 

380 ) 

381 else: 

382 return output_shape 

383 

384 def compute_mask(self, inputs, mask): 

385 # Time step masks must be the same for each input. 

386 # This is because the mask for an RNN is of size [batch, time_steps, 1], 

387 # and specifies which time steps should be skipped, and a time step 

388 # must be skipped for all inputs. 

389 # TODO(scottzhu): Should we accept multiple different masks? 

390 mask = tf.nest.flatten(mask)[0] 

391 output_mask = mask if self.return_sequences else None 

392 if self.return_state: 

393 state_mask = [None for _ in self.states] 

394 return [output_mask] + state_mask 

395 else: 

396 return output_mask 

397 

398 def build(self, input_shape): 

399 if isinstance(input_shape, list): 

400 input_shape = input_shape[0] 

401 # The input_shape here could be a nest structure. 

402 

403 # do the tensor_shape to shapes here. The input could be single tensor, 

404 # or a nested structure of tensors. 

405 def get_input_spec(shape): 

406 """Convert input shape to InputSpec.""" 

407 if isinstance(shape, tf.TensorShape): 

408 input_spec_shape = shape.as_list() 

409 else: 

410 input_spec_shape = list(shape) 

411 batch_index, time_step_index = (1, 0) if self.time_major else (0, 1) 

412 if not self.stateful: 

413 input_spec_shape[batch_index] = None 

414 input_spec_shape[time_step_index] = None 

415 return InputSpec(shape=tuple(input_spec_shape)) 

416 

417 def get_step_input_shape(shape): 

418 if isinstance(shape, tf.TensorShape): 

419 shape = tuple(shape.as_list()) 

420 # remove the timestep from the input_shape 

421 return shape[1:] if self.time_major else (shape[0],) + shape[2:] 

422 

423 def get_state_spec(shape): 

424 state_spec_shape = tf.TensorShape(shape).as_list() 

425 # append batch dim 

426 state_spec_shape = [None] + state_spec_shape 

427 return InputSpec(shape=tuple(state_spec_shape)) 

428 

429 # Check whether the input shape contains any nested shapes. It could be 

430 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from 

431 # numpy inputs. 

432 try: 

433 input_shape = tf.TensorShape(input_shape) 

434 except (ValueError, TypeError): 

435 # A nested tensor input 

436 pass 

437 

438 if not tf.nest.is_nested(input_shape): 

439 # This indicates the there is only one input. 

440 if self.input_spec is not None: 

441 self.input_spec[0] = get_input_spec(input_shape) 

442 else: 

443 self.input_spec = [get_input_spec(input_shape)] 

444 step_input_shape = get_step_input_shape(input_shape) 

445 else: 

446 if self.input_spec is not None: 

447 self.input_spec[0] = tf.nest.map_structure( 

448 get_input_spec, input_shape 

449 ) 

450 else: 

451 self.input_spec = generic_utils.to_list( 

452 tf.nest.map_structure(get_input_spec, input_shape) 

453 ) 

454 step_input_shape = tf.nest.map_structure( 

455 get_step_input_shape, input_shape 

456 ) 

457 

458 # allow cell (if layer) to build before we set or validate state_spec. 

459 if isinstance(self.cell, base_layer.Layer) and not self.cell.built: 

460 with backend.name_scope(self.cell.name): 

461 self.cell.build(step_input_shape) 

462 self.cell.built = True 

463 

464 # set or validate state_spec 

465 if rnn_utils.is_multiple_state(self.cell.state_size): 

466 state_size = list(self.cell.state_size) 

467 else: 

468 state_size = [self.cell.state_size] 

469 

470 if self.state_spec is not None: 

471 # initial_state was passed in call, check compatibility 

472 self._validate_state_spec(state_size, self.state_spec) 

473 else: 

474 if tf.nest.is_nested(state_size): 

475 self.state_spec = tf.nest.map_structure( 

476 get_state_spec, state_size 

477 ) 

478 else: 

479 self.state_spec = [ 

480 InputSpec(shape=[None] + tf.TensorShape(dim).as_list()) 

481 for dim in state_size 

482 ] 

483 # ensure the generated state_spec is correct. 

484 self._validate_state_spec(state_size, self.state_spec) 

485 if self.stateful: 

486 self.reset_states() 

487 super().build(input_shape) 

488 

489 @staticmethod 

490 def _validate_state_spec(cell_state_sizes, init_state_specs): 

491 """Validate the state spec between the initial_state and the state_size. 

492 

493 Args: 

494 cell_state_sizes: list, the `state_size` attribute from the cell. 

495 init_state_specs: list, the `state_spec` from the initial_state that 

496 is passed in `call()`. 

497 

498 Raises: 

499 ValueError: When initial state spec is not compatible with the state 

500 size. 

501 """ 

502 validation_error = ValueError( 

503 "An `initial_state` was passed that is not compatible with " 

504 "`cell.state_size`. Received `state_spec`={}; " 

505 "however `cell.state_size` is " 

506 "{}".format(init_state_specs, cell_state_sizes) 

507 ) 

508 flat_cell_state_sizes = tf.nest.flatten(cell_state_sizes) 

509 flat_state_specs = tf.nest.flatten(init_state_specs) 

510 

511 if len(flat_cell_state_sizes) != len(flat_state_specs): 

512 raise validation_error 

513 for cell_state_spec, cell_state_size in zip( 

514 flat_state_specs, flat_cell_state_sizes 

515 ): 

516 if not tf.TensorShape( 

517 # Ignore the first axis for init_state which is for batch 

518 cell_state_spec.shape[1:] 

519 ).is_compatible_with(tf.TensorShape(cell_state_size)): 

520 raise validation_error 

521 

522 @doc_controls.do_not_doc_inheritable 

523 def get_initial_state(self, inputs): 

524 get_initial_state_fn = getattr(self.cell, "get_initial_state", None) 

525 

526 if tf.nest.is_nested(inputs): 

527 # The input are nested sequences. Use the first element in the seq 

528 # to get batch size and dtype. 

529 inputs = tf.nest.flatten(inputs)[0] 

530 

531 input_shape = tf.shape(inputs) 

532 batch_size = input_shape[1] if self.time_major else input_shape[0] 

533 dtype = inputs.dtype 

534 if get_initial_state_fn: 

535 init_state = get_initial_state_fn( 

536 inputs=None, batch_size=batch_size, dtype=dtype 

537 ) 

538 else: 

539 init_state = rnn_utils.generate_zero_filled_state( 

540 batch_size, self.cell.state_size, dtype 

541 ) 

542 # Keras RNN expect the states in a list, even if it's a single state 

543 # tensor. 

544 if not tf.nest.is_nested(init_state): 

545 init_state = [init_state] 

546 # Force the state to be a list in case it is a namedtuple eg 

547 # LSTMStateTuple. 

548 return list(init_state) 

549 

550 def __call__(self, inputs, initial_state=None, constants=None, **kwargs): 

551 inputs, initial_state, constants = rnn_utils.standardize_args( 

552 inputs, initial_state, constants, self._num_constants 

553 ) 

554 

555 if initial_state is None and constants is None: 

556 return super().__call__(inputs, **kwargs) 

557 

558 # If any of `initial_state` or `constants` are specified and are Keras 

559 # tensors, then add them to the inputs and temporarily modify the 

560 # input_spec to include them. 

561 

562 additional_inputs = [] 

563 additional_specs = [] 

564 if initial_state is not None: 

565 additional_inputs += initial_state 

566 self.state_spec = tf.nest.map_structure( 

567 lambda s: InputSpec(shape=backend.int_shape(s)), initial_state 

568 ) 

569 additional_specs += self.state_spec 

570 if constants is not None: 

571 additional_inputs += constants 

572 self.constants_spec = [ 

573 InputSpec(shape=backend.int_shape(constant)) 

574 for constant in constants 

575 ] 

576 self._num_constants = len(constants) 

577 additional_specs += self.constants_spec 

578 # additional_inputs can be empty if initial_state or constants are 

579 # provided but empty (e.g. the cell is stateless). 

580 flat_additional_inputs = tf.nest.flatten(additional_inputs) 

581 is_keras_tensor = ( 

582 backend.is_keras_tensor(flat_additional_inputs[0]) 

583 if flat_additional_inputs 

584 else True 

585 ) 

586 for tensor in flat_additional_inputs: 

587 if backend.is_keras_tensor(tensor) != is_keras_tensor: 

588 raise ValueError( 

589 "The initial state or constants of an RNN layer cannot be " 

590 "specified via a mix of Keras tensors and non-Keras " 

591 'tensors (a "Keras tensor" is a tensor that was returned ' 

592 "by a Keras layer or by `Input` during Functional " 

593 "model construction). Received: " 

594 f"initial_state={initial_state}, constants={constants}" 

595 ) 

596 

597 if is_keras_tensor: 

598 # Compute the full input spec, including state and constants 

599 full_input = [inputs] + additional_inputs 

600 if self.built: 

601 # Keep the input_spec since it has been populated in build() 

602 # method. 

603 full_input_spec = self.input_spec + additional_specs 

604 else: 

605 # The original input_spec is None since there could be a nested 

606 # tensor input. Update the input_spec to match the inputs. 

607 full_input_spec = ( 

608 generic_utils.to_list( 

609 tf.nest.map_structure(lambda _: None, inputs) 

610 ) 

611 + additional_specs 

612 ) 

613 # Perform the call with temporarily replaced input_spec 

614 self.input_spec = full_input_spec 

615 output = super().__call__(full_input, **kwargs) 

616 # Remove the additional_specs from input spec and keep the rest. It 

617 # is important to keep since the input spec was populated by 

618 # build(), and will be reused in the stateful=True. 

619 self.input_spec = self.input_spec[: -len(additional_specs)] 

620 return output 

621 else: 

622 if initial_state is not None: 

623 kwargs["initial_state"] = initial_state 

624 if constants is not None: 

625 kwargs["constants"] = constants 

626 return super().__call__(inputs, **kwargs) 

627 

628 def call( 

629 self, 

630 inputs, 

631 mask=None, 

632 training=None, 

633 initial_state=None, 

634 constants=None, 

635 ): 

636 # The input should be dense, padded with zeros. If a ragged input is fed 

637 # into the layer, it is padded and the row lengths are used for masking. 

638 inputs, row_lengths = backend.convert_inputs_if_ragged(inputs) 

639 is_ragged_input = row_lengths is not None 

640 self._validate_args_if_ragged(is_ragged_input, mask) 

641 

642 inputs, initial_state, constants = self._process_inputs( 

643 inputs, initial_state, constants 

644 ) 

645 

646 self._maybe_reset_cell_dropout_mask(self.cell) 

647 if isinstance(self.cell, StackedRNNCells): 

648 for cell in self.cell.cells: 

649 self._maybe_reset_cell_dropout_mask(cell) 

650 

651 if mask is not None: 

652 # Time step masks must be the same for each input. 

653 # TODO(scottzhu): Should we accept multiple different masks? 

654 mask = tf.nest.flatten(mask)[0] 

655 

656 if tf.nest.is_nested(inputs): 

657 # In the case of nested input, use the first element for shape 

658 # check. 

659 input_shape = backend.int_shape(tf.nest.flatten(inputs)[0]) 

660 else: 

661 input_shape = backend.int_shape(inputs) 

662 timesteps = input_shape[0] if self.time_major else input_shape[1] 

663 if self.unroll and timesteps is None: 

664 raise ValueError( 

665 "Cannot unroll a RNN if the " 

666 "time dimension is undefined. \n" 

667 "- If using a Sequential model, " 

668 "specify the time dimension by passing " 

669 "an `input_shape` or `batch_input_shape` " 

670 "argument to your first layer. If your " 

671 "first layer is an Embedding, you can " 

672 "also use the `input_length` argument.\n" 

673 "- If using the functional API, specify " 

674 "the time dimension by passing a `shape` " 

675 "or `batch_shape` argument to your Input layer." 

676 ) 

677 

678 kwargs = {} 

679 if generic_utils.has_arg(self.cell.call, "training"): 

680 kwargs["training"] = training 

681 

682 # TF RNN cells expect single tensor as state instead of list wrapped 

683 # tensor. 

684 is_tf_rnn_cell = getattr(self.cell, "_is_tf_rnn_cell", None) is not None 

685 # Use the __call__ function for callable objects, eg layers, so that it 

686 # will have the proper name scopes for the ops, etc. 

687 cell_call_fn = ( 

688 self.cell.__call__ if callable(self.cell) else self.cell.call 

689 ) 

690 if constants: 

691 if not generic_utils.has_arg(self.cell.call, "constants"): 

692 raise ValueError( 

693 f"RNN cell {self.cell} does not support constants. " 

694 f"Received: constants={constants}" 

695 ) 

696 

697 def step(inputs, states): 

698 constants = states[-self._num_constants :] 

699 states = states[: -self._num_constants] 

700 

701 states = ( 

702 states[0] if len(states) == 1 and is_tf_rnn_cell else states 

703 ) 

704 output, new_states = cell_call_fn( 

705 inputs, states, constants=constants, **kwargs 

706 ) 

707 if not tf.nest.is_nested(new_states): 

708 new_states = [new_states] 

709 return output, new_states 

710 

711 else: 

712 

713 def step(inputs, states): 

714 states = ( 

715 states[0] if len(states) == 1 and is_tf_rnn_cell else states 

716 ) 

717 output, new_states = cell_call_fn(inputs, states, **kwargs) 

718 if not tf.nest.is_nested(new_states): 

719 new_states = [new_states] 

720 return output, new_states 

721 

722 last_output, outputs, states = backend.rnn( 

723 step, 

724 inputs, 

725 initial_state, 

726 constants=constants, 

727 go_backwards=self.go_backwards, 

728 mask=mask, 

729 unroll=self.unroll, 

730 input_length=row_lengths if row_lengths is not None else timesteps, 

731 time_major=self.time_major, 

732 zero_output_for_mask=self.zero_output_for_mask, 

733 return_all_outputs=self.return_sequences, 

734 ) 

735 

736 if self.stateful: 

737 updates = [ 

738 tf.compat.v1.assign( 

739 self_state, tf.cast(state, self_state.dtype) 

740 ) 

741 for self_state, state in zip( 

742 tf.nest.flatten(self.states), tf.nest.flatten(states) 

743 ) 

744 ] 

745 self.add_update(updates) 

746 

747 if self.return_sequences: 

748 output = backend.maybe_convert_to_ragged( 

749 is_ragged_input, 

750 outputs, 

751 row_lengths, 

752 go_backwards=self.go_backwards, 

753 ) 

754 else: 

755 output = last_output 

756 

757 if self.return_state: 

758 if not isinstance(states, (list, tuple)): 

759 states = [states] 

760 else: 

761 states = list(states) 

762 return generic_utils.to_list(output) + states 

763 else: 

764 return output 

765 

766 def _process_inputs(self, inputs, initial_state, constants): 

767 # input shape: `(samples, time (padded with zeros), input_dim)` 

768 # note that the .build() method of subclasses MUST define 

769 # self.input_spec and self.state_spec with complete input shapes. 

770 if isinstance(inputs, collections.abc.Sequence) and not isinstance( 

771 inputs, tuple 

772 ): 

773 # get initial_state from full input spec 

774 # as they could be copied to multiple GPU. 

775 if not self._num_constants: 

776 initial_state = inputs[1:] 

777 else: 

778 initial_state = inputs[1 : -self._num_constants] 

779 constants = inputs[-self._num_constants :] 

780 if len(initial_state) == 0: 

781 initial_state = None 

782 inputs = inputs[0] 

783 

784 if self.stateful: 

785 if initial_state is not None: 

786 # When layer is stateful and initial_state is provided, check if 

787 # the recorded state is same as the default value (zeros). Use 

788 # the recorded state if it is not same as the default. 

789 non_zero_count = tf.add_n( 

790 [ 

791 tf.math.count_nonzero(s) 

792 for s in tf.nest.flatten(self.states) 

793 ] 

794 ) 

795 # Set strict = True to keep the original structure of the state. 

796 initial_state = tf.compat.v1.cond( 

797 non_zero_count > 0, 

798 true_fn=lambda: self.states, 

799 false_fn=lambda: initial_state, 

800 strict=True, 

801 ) 

802 else: 

803 initial_state = self.states 

804 initial_state = tf.nest.map_structure( 

805 # When the layer has a inferred dtype, use the dtype from the 

806 # cell. 

807 lambda v: tf.cast( 

808 v, self.compute_dtype or self.cell.compute_dtype 

809 ), 

810 initial_state, 

811 ) 

812 elif initial_state is None: 

813 initial_state = self.get_initial_state(inputs) 

814 

815 if len(initial_state) != len(self.states): 

816 raise ValueError( 

817 f"Layer has {len(self.states)} " 

818 f"states but was passed {len(initial_state)} initial " 

819 f"states. Received: initial_state={initial_state}" 

820 ) 

821 return inputs, initial_state, constants 

822 

823 def _validate_args_if_ragged(self, is_ragged_input, mask): 

824 if not is_ragged_input: 

825 return 

826 

827 if mask is not None: 

828 raise ValueError( 

829 f"The mask that was passed in was {mask}, which " 

830 "cannot be applied to RaggedTensor inputs. Please " 

831 "make sure that there is no mask injected by upstream " 

832 "layers." 

833 ) 

834 if self.unroll: 

835 raise ValueError( 

836 "The input received contains RaggedTensors and does " 

837 "not support unrolling. Disable unrolling by passing " 

838 "`unroll=False` in the RNN Layer constructor." 

839 ) 

840 

841 def _maybe_reset_cell_dropout_mask(self, cell): 

842 if isinstance(cell, DropoutRNNCellMixin): 

843 cell.reset_dropout_mask() 

844 cell.reset_recurrent_dropout_mask() 

845 

846 def reset_states(self, states=None): 

847 """Reset the recorded states for the stateful RNN layer. 

848 

849 Can only be used when RNN layer is constructed with `stateful` = `True`. 

850 Args: 

851 states: Numpy arrays that contains the value for the initial state, 

852 which will be feed to cell at the first time step. When the value is 

853 None, zero filled numpy array will be created based on the cell 

854 state size. 

855 

856 Raises: 

857 AttributeError: When the RNN layer is not stateful. 

858 ValueError: When the batch size of the RNN layer is unknown. 

859 ValueError: When the input numpy array is not compatible with the RNN 

860 layer state, either size wise or dtype wise. 

861 """ 

862 if not self.stateful: 

863 raise AttributeError("Layer must be stateful.") 

864 spec_shape = None 

865 if self.input_spec is not None: 

866 spec_shape = tf.nest.flatten(self.input_spec[0])[0].shape 

867 if spec_shape is None: 

868 # It is possible to have spec shape to be None, eg when construct a 

869 # RNN with a custom cell, or standard RNN layers (LSTM/GRU) which we 

870 # only know it has 3 dim input, but not its full shape spec before 

871 # build(). 

872 batch_size = None 

873 else: 

874 batch_size = spec_shape[1] if self.time_major else spec_shape[0] 

875 if not batch_size: 

876 raise ValueError( 

877 "If a RNN is stateful, it needs to know " 

878 "its batch size. Specify the batch size " 

879 "of your input tensors: \n" 

880 "- If using a Sequential model, " 

881 "specify the batch size by passing " 

882 "a `batch_input_shape` " 

883 "argument to your first layer.\n" 

884 "- If using the functional API, specify " 

885 "the batch size by passing a " 

886 "`batch_shape` argument to your Input layer." 

887 ) 

888 # initialize state if None 

889 if tf.nest.flatten(self.states)[0] is None: 

890 if getattr(self.cell, "get_initial_state", None): 

891 flat_init_state_values = tf.nest.flatten( 

892 self.cell.get_initial_state( 

893 inputs=None, 

894 batch_size=batch_size, 

895 # Use variable_dtype instead of compute_dtype, since the 

896 # state is stored in a variable 

897 dtype=self.variable_dtype or backend.floatx(), 

898 ) 

899 ) 

900 else: 

901 flat_init_state_values = tf.nest.flatten( 

902 rnn_utils.generate_zero_filled_state( 

903 batch_size, 

904 self.cell.state_size, 

905 self.variable_dtype or backend.floatx(), 

906 ) 

907 ) 

908 flat_states_variables = tf.nest.map_structure( 

909 backend.variable, flat_init_state_values 

910 ) 

911 self.states = tf.nest.pack_sequence_as( 

912 self.cell.state_size, flat_states_variables 

913 ) 

914 if not tf.nest.is_nested(self.states): 

915 self.states = [self.states] 

916 elif states is None: 

917 for state, size in zip( 

918 tf.nest.flatten(self.states), 

919 tf.nest.flatten(self.cell.state_size), 

920 ): 

921 backend.set_value( 

922 state, 

923 np.zeros([batch_size] + tf.TensorShape(size).as_list()), 

924 ) 

925 else: 

926 flat_states = tf.nest.flatten(self.states) 

927 flat_input_states = tf.nest.flatten(states) 

928 if len(flat_input_states) != len(flat_states): 

929 raise ValueError( 

930 f"Layer {self.name} expects {len(flat_states)} " 

931 f"states, but it received {len(flat_input_states)} " 

932 f"state values. States received: {states}" 

933 ) 

934 set_value_tuples = [] 

935 for i, (value, state) in enumerate( 

936 zip(flat_input_states, flat_states) 

937 ): 

938 if value.shape != state.shape: 

939 raise ValueError( 

940 f"State {i} is incompatible with layer {self.name}: " 

941 f"expected shape={(batch_size, state)} " 

942 f"but found shape={value.shape}" 

943 ) 

944 set_value_tuples.append((state, value)) 

945 backend.batch_set_value(set_value_tuples) 

946 

947 def get_config(self): 

948 config = { 

949 "return_sequences": self.return_sequences, 

950 "return_state": self.return_state, 

951 "go_backwards": self.go_backwards, 

952 "stateful": self.stateful, 

953 "unroll": self.unroll, 

954 "time_major": self.time_major, 

955 } 

956 if self._num_constants: 

957 config["num_constants"] = self._num_constants 

958 if self.zero_output_for_mask: 

959 config["zero_output_for_mask"] = self.zero_output_for_mask 

960 

961 config["cell"] = serialization_lib.serialize_keras_object(self.cell) 

962 base_config = super().get_config() 

963 return dict(list(base_config.items()) + list(config.items())) 

964 

965 @classmethod 

966 def from_config(cls, config, custom_objects=None): 

967 from keras.src.layers import deserialize as deserialize_layer 

968 

969 cell = deserialize_layer( 

970 config.pop("cell"), custom_objects=custom_objects 

971 ) 

972 num_constants = config.pop("num_constants", 0) 

973 layer = cls(cell, **config) 

974 layer._num_constants = num_constants 

975 return layer 

976 

977 @property 

978 def _trackable_saved_model_saver(self): 

979 return layer_serialization.RNNSavedModelSaver(self) 

980