Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/base_rnn.py: 13%
342 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for recurrent layers."""
18import collections
20import numpy as np
21import tensorflow.compat.v2 as tf
23from keras.src import backend
24from keras.src.engine import base_layer
25from keras.src.engine.input_spec import InputSpec
26from keras.src.layers.rnn import rnn_utils
27from keras.src.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin
28from keras.src.layers.rnn.stacked_rnn_cells import StackedRNNCells
29from keras.src.saving import serialization_lib
30from keras.src.saving.legacy.saved_model import layer_serialization
31from keras.src.utils import generic_utils
33# isort: off
34from tensorflow.python.util.tf_export import keras_export
35from tensorflow.tools.docs import doc_controls
38@keras_export("keras.layers.RNN")
39class RNN(base_layer.Layer):
40 """Base class for recurrent layers.
42 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
43 for details about the usage of RNN API.
45 Args:
46 cell: A RNN cell instance or a list of RNN cell instances.
47 A RNN cell is a class that has:
48 - A `call(input_at_t, states_at_t)` method, returning
49 `(output_at_t, states_at_t_plus_1)`. The call method of the
50 cell can also take the optional argument `constants`, see
51 section "Note on passing external constants" below.
52 - A `state_size` attribute. This can be a single integer
53 (single state) in which case it is the size of the recurrent
54 state. This can also be a list/tuple of integers (one size per state).
55 The `state_size` can also be TensorShape or tuple/list of
56 TensorShape, to represent high dimension state.
57 - A `output_size` attribute. This can be a single integer or a
58 TensorShape, which represent the shape of the output. For backward
59 compatible reason, if this attribute is not available for the
60 cell, the value will be inferred by the first element of the
61 `state_size`.
62 - A `get_initial_state(inputs=None, batch_size=None, dtype=None)`
63 method that creates a tensor meant to be fed to `call()` as the
64 initial state, if the user didn't specify any initial state via other
65 means. The returned initial state should have a shape of
66 [batch_size, cell.state_size]. The cell might choose to create a
67 tensor full of zeros, or full of other values based on the cell's
68 implementation.
69 `inputs` is the input tensor to the RNN layer, which should
70 contain the batch size as its shape[0], and also dtype. Note that
71 the shape[0] might be `None` during the graph construction. Either
72 the `inputs` or the pair of `batch_size` and `dtype` are provided.
73 `batch_size` is a scalar tensor that represents the batch size
74 of the inputs. `dtype` is `tf.DType` that represents the dtype of
75 the inputs.
76 For backward compatibility, if this method is not implemented
77 by the cell, the RNN layer will create a zero filled tensor with the
78 size of [batch_size, cell.state_size].
79 In the case that `cell` is a list of RNN cell instances, the cells
80 will be stacked on top of each other in the RNN, resulting in an
81 efficient stacked RNN.
82 return_sequences: Boolean (default `False`). Whether to return the last
83 output in the output sequence, or the full sequence.
84 return_state: Boolean (default `False`). Whether to return the last state
85 in addition to the output.
86 go_backwards: Boolean (default `False`).
87 If True, process the input sequence backwards and return the
88 reversed sequence.
89 stateful: Boolean (default `False`). If True, the last state
90 for each sample at index i in a batch will be used as initial
91 state for the sample of index i in the following batch.
92 unroll: Boolean (default `False`).
93 If True, the network will be unrolled, else a symbolic loop will be
94 used. Unrolling can speed-up a RNN, although it tends to be more
95 memory-intensive. Unrolling is only suitable for short sequences.
96 time_major: The shape format of the `inputs` and `outputs` tensors.
97 If True, the inputs and outputs will be in shape
98 `(timesteps, batch, ...)`, whereas in the False case, it will be
99 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
100 efficient because it avoids transposes at the beginning and end of the
101 RNN calculation. However, most TensorFlow data is batch-major, so by
102 default this function accepts input and emits output in batch-major
103 form.
104 zero_output_for_mask: Boolean (default `False`).
105 Whether the output should use zeros for the masked timesteps. Note that
106 this field is only used when `return_sequences` is True and mask is
107 provided. It can useful if you want to reuse the raw output sequence of
108 the RNN without interference from the masked timesteps, eg, merging
109 bidirectional RNNs.
111 Call arguments:
112 inputs: Input tensor.
113 mask: Binary tensor of shape `[batch_size, timesteps]` indicating whether
114 a given timestep should be masked. An individual `True` entry indicates
115 that the corresponding timestep should be utilized, while a `False`
116 entry indicates that the corresponding timestep should be ignored.
117 training: Python boolean indicating whether the layer should behave in
118 training mode or in inference mode. This argument is passed to the cell
119 when calling it. This is for use with cells that use dropout.
120 initial_state: List of initial state tensors to be passed to the first
121 call of the cell.
122 constants: List of constant tensors to be passed to the cell at each
123 timestep.
125 Input shape:
126 N-D tensor with shape `[batch_size, timesteps, ...]` or
127 `[timesteps, batch_size, ...]` when time_major is True.
129 Output shape:
130 - If `return_state`: a list of tensors. The first tensor is
131 the output. The remaining tensors are the last states,
132 each with shape `[batch_size, state_size]`, where `state_size` could
133 be a high dimension tensor shape.
134 - If `return_sequences`: N-D tensor with shape
135 `[batch_size, timesteps, output_size]`, where `output_size` could
136 be a high dimension tensor shape, or
137 `[timesteps, batch_size, output_size]` when `time_major` is True.
138 - Else, N-D tensor with shape `[batch_size, output_size]`, where
139 `output_size` could be a high dimension tensor shape.
141 Masking:
142 This layer supports masking for input data with a variable number
143 of timesteps. To introduce masks to your data,
144 use an [tf.keras.layers.Embedding] layer with the `mask_zero` parameter
145 set to `True`.
147 Note on using statefulness in RNNs:
148 You can set RNN layers to be 'stateful', which means that the states
149 computed for the samples in one batch will be reused as initial states
150 for the samples in the next batch. This assumes a one-to-one mapping
151 between samples in different successive batches.
153 To enable statefulness:
154 - Specify `stateful=True` in the layer constructor.
155 - Specify a fixed batch size for your model, by passing
156 If sequential model:
157 `batch_input_shape=(...)` to the first layer in your model.
158 Else for functional model with 1 or more Input layers:
159 `batch_shape=(...)` to all the first layers in your model.
160 This is the expected shape of your inputs
161 *including the batch size*.
162 It should be a tuple of integers, e.g. `(32, 10, 100)`.
163 - Specify `shuffle=False` when calling `fit()`.
165 To reset the states of your model, call `.reset_states()` on either
166 a specific layer, or on your entire model.
168 Note on specifying the initial state of RNNs:
169 You can specify the initial state of RNN layers symbolically by
170 calling them with the keyword argument `initial_state`. The value of
171 `initial_state` should be a tensor or list of tensors representing
172 the initial state of the RNN layer.
174 You can specify the initial state of RNN layers numerically by
175 calling `reset_states` with the keyword argument `states`. The value of
176 `states` should be a numpy array or list of numpy arrays representing
177 the initial state of the RNN layer.
179 Note on passing external constants to RNNs:
180 You can pass "external" constants to the cell using the `constants`
181 keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
182 requires that the `cell.call` method accepts the same keyword argument
183 `constants`. Such constants can be used to condition the cell
184 transformation on additional static inputs (not changing over time),
185 a.k.a. an attention mechanism.
187 Examples:
189 ```python
190 from keras.src.layers import RNN
191 from keras.src import backend
193 # First, let's define a RNN Cell, as a layer subclass.
194 class MinimalRNNCell(keras.layers.Layer):
196 def __init__(self, units, **kwargs):
197 self.units = units
198 self.state_size = units
199 super(MinimalRNNCell, self).__init__(**kwargs)
201 def build(self, input_shape):
202 self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
203 initializer='uniform',
204 name='kernel')
205 self.recurrent_kernel = self.add_weight(
206 shape=(self.units, self.units),
207 initializer='uniform',
208 name='recurrent_kernel')
209 self.built = True
211 def call(self, inputs, states):
212 prev_output = states[0]
213 h = backend.dot(inputs, self.kernel)
214 output = h + backend.dot(prev_output, self.recurrent_kernel)
215 return output, [output]
217 # Let's use this cell in a RNN layer:
219 cell = MinimalRNNCell(32)
220 x = keras.Input((None, 5))
221 layer = RNN(cell)
222 y = layer(x)
224 # Here's how to use the cell to build a stacked RNN:
226 cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
227 x = keras.Input((None, 5))
228 layer = RNN(cells)
229 y = layer(x)
230 ```
231 """
233 def __init__(
234 self,
235 cell,
236 return_sequences=False,
237 return_state=False,
238 go_backwards=False,
239 stateful=False,
240 unroll=False,
241 time_major=False,
242 **kwargs,
243 ):
244 if isinstance(cell, (list, tuple)):
245 cell = StackedRNNCells(cell)
246 if "call" not in dir(cell):
247 raise ValueError(
248 "Argument `cell` should have a `call` method. "
249 f"The RNN was passed: cell={cell}"
250 )
251 if "state_size" not in dir(cell):
252 raise ValueError(
253 "The RNN cell should have a `state_size` attribute "
254 "(tuple of integers, one integer per RNN state). "
255 f"Received: cell={cell}"
256 )
257 # If True, the output for masked timestep will be zeros, whereas in the
258 # False case, output from previous timestep is returned for masked
259 # timestep.
260 self.zero_output_for_mask = kwargs.pop("zero_output_for_mask", False)
262 if "input_shape" not in kwargs and (
263 "input_dim" in kwargs or "input_length" in kwargs
264 ):
265 input_shape = (
266 kwargs.pop("input_length", None),
267 kwargs.pop("input_dim", None),
268 )
269 kwargs["input_shape"] = input_shape
271 super().__init__(**kwargs)
272 self.cell = cell
273 self.return_sequences = return_sequences
274 self.return_state = return_state
275 self.go_backwards = go_backwards
276 self.stateful = stateful
277 self.unroll = unroll
278 self.time_major = time_major
280 self.supports_masking = True
281 # The input shape is unknown yet, it could have nested tensor inputs,
282 # and the input spec will be the list of specs for nested inputs, the
283 # structure of the input_spec will be the same as the input.
284 self.input_spec = None
285 self.state_spec = None
286 self._states = None
287 self.constants_spec = None
288 self._num_constants = 0
290 if stateful:
291 if tf.distribute.has_strategy():
292 raise ValueError(
293 "Stateful RNNs (created with `stateful=True`) "
294 "are not yet supported with tf.distribute.Strategy."
295 )
297 @property
298 def _use_input_spec_as_call_signature(self):
299 if self.unroll:
300 # When the RNN layer is unrolled, the time step shape cannot be
301 # unknown. The input spec does not define the time step (because
302 # this layer can be called with any time step value, as long as it
303 # is not None), so it cannot be used as the call function signature
304 # when saving to SavedModel.
305 return False
306 return super()._use_input_spec_as_call_signature
308 @property
309 def states(self):
310 if self._states is None:
311 state = tf.nest.map_structure(lambda _: None, self.cell.state_size)
312 return state if tf.nest.is_nested(self.cell.state_size) else [state]
313 return self._states
315 @states.setter
316 # Automatic tracking catches "self._states" which adds an extra weight and
317 # breaks HDF5 checkpoints.
318 @tf.__internal__.tracking.no_automatic_dependency_tracking
319 def states(self, states):
320 self._states = states
322 def compute_output_shape(self, input_shape):
323 if isinstance(input_shape, list):
324 input_shape = input_shape[0]
325 # Check whether the input shape contains any nested shapes. It could be
326 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
327 # numpy inputs.
328 try:
329 input_shape = tf.TensorShape(input_shape)
330 except (ValueError, TypeError):
331 # A nested tensor input
332 input_shape = tf.nest.flatten(input_shape)[0]
334 batch = input_shape[0]
335 time_step = input_shape[1]
336 if self.time_major:
337 batch, time_step = time_step, batch
339 if rnn_utils.is_multiple_state(self.cell.state_size):
340 state_size = self.cell.state_size
341 else:
342 state_size = [self.cell.state_size]
344 def _get_output_shape(flat_output_size):
345 output_dim = tf.TensorShape(flat_output_size).as_list()
346 if self.return_sequences:
347 if self.time_major:
348 output_shape = tf.TensorShape(
349 [time_step, batch] + output_dim
350 )
351 else:
352 output_shape = tf.TensorShape(
353 [batch, time_step] + output_dim
354 )
355 else:
356 output_shape = tf.TensorShape([batch] + output_dim)
357 return output_shape
359 if getattr(self.cell, "output_size", None) is not None:
360 # cell.output_size could be nested structure.
361 output_shape = tf.nest.flatten(
362 tf.nest.map_structure(_get_output_shape, self.cell.output_size)
363 )
364 output_shape = (
365 output_shape[0] if len(output_shape) == 1 else output_shape
366 )
367 else:
368 # Note that state_size[0] could be a tensor_shape or int.
369 output_shape = _get_output_shape(state_size[0])
371 if self.return_state:
373 def _get_state_shape(flat_state):
374 state_shape = [batch] + tf.TensorShape(flat_state).as_list()
375 return tf.TensorShape(state_shape)
377 state_shape = tf.nest.map_structure(_get_state_shape, state_size)
378 return generic_utils.to_list(output_shape) + tf.nest.flatten(
379 state_shape
380 )
381 else:
382 return output_shape
384 def compute_mask(self, inputs, mask):
385 # Time step masks must be the same for each input.
386 # This is because the mask for an RNN is of size [batch, time_steps, 1],
387 # and specifies which time steps should be skipped, and a time step
388 # must be skipped for all inputs.
389 # TODO(scottzhu): Should we accept multiple different masks?
390 mask = tf.nest.flatten(mask)[0]
391 output_mask = mask if self.return_sequences else None
392 if self.return_state:
393 state_mask = [None for _ in self.states]
394 return [output_mask] + state_mask
395 else:
396 return output_mask
398 def build(self, input_shape):
399 if isinstance(input_shape, list):
400 input_shape = input_shape[0]
401 # The input_shape here could be a nest structure.
403 # do the tensor_shape to shapes here. The input could be single tensor,
404 # or a nested structure of tensors.
405 def get_input_spec(shape):
406 """Convert input shape to InputSpec."""
407 if isinstance(shape, tf.TensorShape):
408 input_spec_shape = shape.as_list()
409 else:
410 input_spec_shape = list(shape)
411 batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
412 if not self.stateful:
413 input_spec_shape[batch_index] = None
414 input_spec_shape[time_step_index] = None
415 return InputSpec(shape=tuple(input_spec_shape))
417 def get_step_input_shape(shape):
418 if isinstance(shape, tf.TensorShape):
419 shape = tuple(shape.as_list())
420 # remove the timestep from the input_shape
421 return shape[1:] if self.time_major else (shape[0],) + shape[2:]
423 def get_state_spec(shape):
424 state_spec_shape = tf.TensorShape(shape).as_list()
425 # append batch dim
426 state_spec_shape = [None] + state_spec_shape
427 return InputSpec(shape=tuple(state_spec_shape))
429 # Check whether the input shape contains any nested shapes. It could be
430 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
431 # numpy inputs.
432 try:
433 input_shape = tf.TensorShape(input_shape)
434 except (ValueError, TypeError):
435 # A nested tensor input
436 pass
438 if not tf.nest.is_nested(input_shape):
439 # This indicates the there is only one input.
440 if self.input_spec is not None:
441 self.input_spec[0] = get_input_spec(input_shape)
442 else:
443 self.input_spec = [get_input_spec(input_shape)]
444 step_input_shape = get_step_input_shape(input_shape)
445 else:
446 if self.input_spec is not None:
447 self.input_spec[0] = tf.nest.map_structure(
448 get_input_spec, input_shape
449 )
450 else:
451 self.input_spec = generic_utils.to_list(
452 tf.nest.map_structure(get_input_spec, input_shape)
453 )
454 step_input_shape = tf.nest.map_structure(
455 get_step_input_shape, input_shape
456 )
458 # allow cell (if layer) to build before we set or validate state_spec.
459 if isinstance(self.cell, base_layer.Layer) and not self.cell.built:
460 with backend.name_scope(self.cell.name):
461 self.cell.build(step_input_shape)
462 self.cell.built = True
464 # set or validate state_spec
465 if rnn_utils.is_multiple_state(self.cell.state_size):
466 state_size = list(self.cell.state_size)
467 else:
468 state_size = [self.cell.state_size]
470 if self.state_spec is not None:
471 # initial_state was passed in call, check compatibility
472 self._validate_state_spec(state_size, self.state_spec)
473 else:
474 if tf.nest.is_nested(state_size):
475 self.state_spec = tf.nest.map_structure(
476 get_state_spec, state_size
477 )
478 else:
479 self.state_spec = [
480 InputSpec(shape=[None] + tf.TensorShape(dim).as_list())
481 for dim in state_size
482 ]
483 # ensure the generated state_spec is correct.
484 self._validate_state_spec(state_size, self.state_spec)
485 if self.stateful:
486 self.reset_states()
487 super().build(input_shape)
489 @staticmethod
490 def _validate_state_spec(cell_state_sizes, init_state_specs):
491 """Validate the state spec between the initial_state and the state_size.
493 Args:
494 cell_state_sizes: list, the `state_size` attribute from the cell.
495 init_state_specs: list, the `state_spec` from the initial_state that
496 is passed in `call()`.
498 Raises:
499 ValueError: When initial state spec is not compatible with the state
500 size.
501 """
502 validation_error = ValueError(
503 "An `initial_state` was passed that is not compatible with "
504 "`cell.state_size`. Received `state_spec`={}; "
505 "however `cell.state_size` is "
506 "{}".format(init_state_specs, cell_state_sizes)
507 )
508 flat_cell_state_sizes = tf.nest.flatten(cell_state_sizes)
509 flat_state_specs = tf.nest.flatten(init_state_specs)
511 if len(flat_cell_state_sizes) != len(flat_state_specs):
512 raise validation_error
513 for cell_state_spec, cell_state_size in zip(
514 flat_state_specs, flat_cell_state_sizes
515 ):
516 if not tf.TensorShape(
517 # Ignore the first axis for init_state which is for batch
518 cell_state_spec.shape[1:]
519 ).is_compatible_with(tf.TensorShape(cell_state_size)):
520 raise validation_error
522 @doc_controls.do_not_doc_inheritable
523 def get_initial_state(self, inputs):
524 get_initial_state_fn = getattr(self.cell, "get_initial_state", None)
526 if tf.nest.is_nested(inputs):
527 # The input are nested sequences. Use the first element in the seq
528 # to get batch size and dtype.
529 inputs = tf.nest.flatten(inputs)[0]
531 input_shape = tf.shape(inputs)
532 batch_size = input_shape[1] if self.time_major else input_shape[0]
533 dtype = inputs.dtype
534 if get_initial_state_fn:
535 init_state = get_initial_state_fn(
536 inputs=None, batch_size=batch_size, dtype=dtype
537 )
538 else:
539 init_state = rnn_utils.generate_zero_filled_state(
540 batch_size, self.cell.state_size, dtype
541 )
542 # Keras RNN expect the states in a list, even if it's a single state
543 # tensor.
544 if not tf.nest.is_nested(init_state):
545 init_state = [init_state]
546 # Force the state to be a list in case it is a namedtuple eg
547 # LSTMStateTuple.
548 return list(init_state)
550 def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
551 inputs, initial_state, constants = rnn_utils.standardize_args(
552 inputs, initial_state, constants, self._num_constants
553 )
555 if initial_state is None and constants is None:
556 return super().__call__(inputs, **kwargs)
558 # If any of `initial_state` or `constants` are specified and are Keras
559 # tensors, then add them to the inputs and temporarily modify the
560 # input_spec to include them.
562 additional_inputs = []
563 additional_specs = []
564 if initial_state is not None:
565 additional_inputs += initial_state
566 self.state_spec = tf.nest.map_structure(
567 lambda s: InputSpec(shape=backend.int_shape(s)), initial_state
568 )
569 additional_specs += self.state_spec
570 if constants is not None:
571 additional_inputs += constants
572 self.constants_spec = [
573 InputSpec(shape=backend.int_shape(constant))
574 for constant in constants
575 ]
576 self._num_constants = len(constants)
577 additional_specs += self.constants_spec
578 # additional_inputs can be empty if initial_state or constants are
579 # provided but empty (e.g. the cell is stateless).
580 flat_additional_inputs = tf.nest.flatten(additional_inputs)
581 is_keras_tensor = (
582 backend.is_keras_tensor(flat_additional_inputs[0])
583 if flat_additional_inputs
584 else True
585 )
586 for tensor in flat_additional_inputs:
587 if backend.is_keras_tensor(tensor) != is_keras_tensor:
588 raise ValueError(
589 "The initial state or constants of an RNN layer cannot be "
590 "specified via a mix of Keras tensors and non-Keras "
591 'tensors (a "Keras tensor" is a tensor that was returned '
592 "by a Keras layer or by `Input` during Functional "
593 "model construction). Received: "
594 f"initial_state={initial_state}, constants={constants}"
595 )
597 if is_keras_tensor:
598 # Compute the full input spec, including state and constants
599 full_input = [inputs] + additional_inputs
600 if self.built:
601 # Keep the input_spec since it has been populated in build()
602 # method.
603 full_input_spec = self.input_spec + additional_specs
604 else:
605 # The original input_spec is None since there could be a nested
606 # tensor input. Update the input_spec to match the inputs.
607 full_input_spec = (
608 generic_utils.to_list(
609 tf.nest.map_structure(lambda _: None, inputs)
610 )
611 + additional_specs
612 )
613 # Perform the call with temporarily replaced input_spec
614 self.input_spec = full_input_spec
615 output = super().__call__(full_input, **kwargs)
616 # Remove the additional_specs from input spec and keep the rest. It
617 # is important to keep since the input spec was populated by
618 # build(), and will be reused in the stateful=True.
619 self.input_spec = self.input_spec[: -len(additional_specs)]
620 return output
621 else:
622 if initial_state is not None:
623 kwargs["initial_state"] = initial_state
624 if constants is not None:
625 kwargs["constants"] = constants
626 return super().__call__(inputs, **kwargs)
628 def call(
629 self,
630 inputs,
631 mask=None,
632 training=None,
633 initial_state=None,
634 constants=None,
635 ):
636 # The input should be dense, padded with zeros. If a ragged input is fed
637 # into the layer, it is padded and the row lengths are used for masking.
638 inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
639 is_ragged_input = row_lengths is not None
640 self._validate_args_if_ragged(is_ragged_input, mask)
642 inputs, initial_state, constants = self._process_inputs(
643 inputs, initial_state, constants
644 )
646 self._maybe_reset_cell_dropout_mask(self.cell)
647 if isinstance(self.cell, StackedRNNCells):
648 for cell in self.cell.cells:
649 self._maybe_reset_cell_dropout_mask(cell)
651 if mask is not None:
652 # Time step masks must be the same for each input.
653 # TODO(scottzhu): Should we accept multiple different masks?
654 mask = tf.nest.flatten(mask)[0]
656 if tf.nest.is_nested(inputs):
657 # In the case of nested input, use the first element for shape
658 # check.
659 input_shape = backend.int_shape(tf.nest.flatten(inputs)[0])
660 else:
661 input_shape = backend.int_shape(inputs)
662 timesteps = input_shape[0] if self.time_major else input_shape[1]
663 if self.unroll and timesteps is None:
664 raise ValueError(
665 "Cannot unroll a RNN if the "
666 "time dimension is undefined. \n"
667 "- If using a Sequential model, "
668 "specify the time dimension by passing "
669 "an `input_shape` or `batch_input_shape` "
670 "argument to your first layer. If your "
671 "first layer is an Embedding, you can "
672 "also use the `input_length` argument.\n"
673 "- If using the functional API, specify "
674 "the time dimension by passing a `shape` "
675 "or `batch_shape` argument to your Input layer."
676 )
678 kwargs = {}
679 if generic_utils.has_arg(self.cell.call, "training"):
680 kwargs["training"] = training
682 # TF RNN cells expect single tensor as state instead of list wrapped
683 # tensor.
684 is_tf_rnn_cell = getattr(self.cell, "_is_tf_rnn_cell", None) is not None
685 # Use the __call__ function for callable objects, eg layers, so that it
686 # will have the proper name scopes for the ops, etc.
687 cell_call_fn = (
688 self.cell.__call__ if callable(self.cell) else self.cell.call
689 )
690 if constants:
691 if not generic_utils.has_arg(self.cell.call, "constants"):
692 raise ValueError(
693 f"RNN cell {self.cell} does not support constants. "
694 f"Received: constants={constants}"
695 )
697 def step(inputs, states):
698 constants = states[-self._num_constants :]
699 states = states[: -self._num_constants]
701 states = (
702 states[0] if len(states) == 1 and is_tf_rnn_cell else states
703 )
704 output, new_states = cell_call_fn(
705 inputs, states, constants=constants, **kwargs
706 )
707 if not tf.nest.is_nested(new_states):
708 new_states = [new_states]
709 return output, new_states
711 else:
713 def step(inputs, states):
714 states = (
715 states[0] if len(states) == 1 and is_tf_rnn_cell else states
716 )
717 output, new_states = cell_call_fn(inputs, states, **kwargs)
718 if not tf.nest.is_nested(new_states):
719 new_states = [new_states]
720 return output, new_states
722 last_output, outputs, states = backend.rnn(
723 step,
724 inputs,
725 initial_state,
726 constants=constants,
727 go_backwards=self.go_backwards,
728 mask=mask,
729 unroll=self.unroll,
730 input_length=row_lengths if row_lengths is not None else timesteps,
731 time_major=self.time_major,
732 zero_output_for_mask=self.zero_output_for_mask,
733 return_all_outputs=self.return_sequences,
734 )
736 if self.stateful:
737 updates = [
738 tf.compat.v1.assign(
739 self_state, tf.cast(state, self_state.dtype)
740 )
741 for self_state, state in zip(
742 tf.nest.flatten(self.states), tf.nest.flatten(states)
743 )
744 ]
745 self.add_update(updates)
747 if self.return_sequences:
748 output = backend.maybe_convert_to_ragged(
749 is_ragged_input,
750 outputs,
751 row_lengths,
752 go_backwards=self.go_backwards,
753 )
754 else:
755 output = last_output
757 if self.return_state:
758 if not isinstance(states, (list, tuple)):
759 states = [states]
760 else:
761 states = list(states)
762 return generic_utils.to_list(output) + states
763 else:
764 return output
766 def _process_inputs(self, inputs, initial_state, constants):
767 # input shape: `(samples, time (padded with zeros), input_dim)`
768 # note that the .build() method of subclasses MUST define
769 # self.input_spec and self.state_spec with complete input shapes.
770 if isinstance(inputs, collections.abc.Sequence) and not isinstance(
771 inputs, tuple
772 ):
773 # get initial_state from full input spec
774 # as they could be copied to multiple GPU.
775 if not self._num_constants:
776 initial_state = inputs[1:]
777 else:
778 initial_state = inputs[1 : -self._num_constants]
779 constants = inputs[-self._num_constants :]
780 if len(initial_state) == 0:
781 initial_state = None
782 inputs = inputs[0]
784 if self.stateful:
785 if initial_state is not None:
786 # When layer is stateful and initial_state is provided, check if
787 # the recorded state is same as the default value (zeros). Use
788 # the recorded state if it is not same as the default.
789 non_zero_count = tf.add_n(
790 [
791 tf.math.count_nonzero(s)
792 for s in tf.nest.flatten(self.states)
793 ]
794 )
795 # Set strict = True to keep the original structure of the state.
796 initial_state = tf.compat.v1.cond(
797 non_zero_count > 0,
798 true_fn=lambda: self.states,
799 false_fn=lambda: initial_state,
800 strict=True,
801 )
802 else:
803 initial_state = self.states
804 initial_state = tf.nest.map_structure(
805 # When the layer has a inferred dtype, use the dtype from the
806 # cell.
807 lambda v: tf.cast(
808 v, self.compute_dtype or self.cell.compute_dtype
809 ),
810 initial_state,
811 )
812 elif initial_state is None:
813 initial_state = self.get_initial_state(inputs)
815 if len(initial_state) != len(self.states):
816 raise ValueError(
817 f"Layer has {len(self.states)} "
818 f"states but was passed {len(initial_state)} initial "
819 f"states. Received: initial_state={initial_state}"
820 )
821 return inputs, initial_state, constants
823 def _validate_args_if_ragged(self, is_ragged_input, mask):
824 if not is_ragged_input:
825 return
827 if mask is not None:
828 raise ValueError(
829 f"The mask that was passed in was {mask}, which "
830 "cannot be applied to RaggedTensor inputs. Please "
831 "make sure that there is no mask injected by upstream "
832 "layers."
833 )
834 if self.unroll:
835 raise ValueError(
836 "The input received contains RaggedTensors and does "
837 "not support unrolling. Disable unrolling by passing "
838 "`unroll=False` in the RNN Layer constructor."
839 )
841 def _maybe_reset_cell_dropout_mask(self, cell):
842 if isinstance(cell, DropoutRNNCellMixin):
843 cell.reset_dropout_mask()
844 cell.reset_recurrent_dropout_mask()
846 def reset_states(self, states=None):
847 """Reset the recorded states for the stateful RNN layer.
849 Can only be used when RNN layer is constructed with `stateful` = `True`.
850 Args:
851 states: Numpy arrays that contains the value for the initial state,
852 which will be feed to cell at the first time step. When the value is
853 None, zero filled numpy array will be created based on the cell
854 state size.
856 Raises:
857 AttributeError: When the RNN layer is not stateful.
858 ValueError: When the batch size of the RNN layer is unknown.
859 ValueError: When the input numpy array is not compatible with the RNN
860 layer state, either size wise or dtype wise.
861 """
862 if not self.stateful:
863 raise AttributeError("Layer must be stateful.")
864 spec_shape = None
865 if self.input_spec is not None:
866 spec_shape = tf.nest.flatten(self.input_spec[0])[0].shape
867 if spec_shape is None:
868 # It is possible to have spec shape to be None, eg when construct a
869 # RNN with a custom cell, or standard RNN layers (LSTM/GRU) which we
870 # only know it has 3 dim input, but not its full shape spec before
871 # build().
872 batch_size = None
873 else:
874 batch_size = spec_shape[1] if self.time_major else spec_shape[0]
875 if not batch_size:
876 raise ValueError(
877 "If a RNN is stateful, it needs to know "
878 "its batch size. Specify the batch size "
879 "of your input tensors: \n"
880 "- If using a Sequential model, "
881 "specify the batch size by passing "
882 "a `batch_input_shape` "
883 "argument to your first layer.\n"
884 "- If using the functional API, specify "
885 "the batch size by passing a "
886 "`batch_shape` argument to your Input layer."
887 )
888 # initialize state if None
889 if tf.nest.flatten(self.states)[0] is None:
890 if getattr(self.cell, "get_initial_state", None):
891 flat_init_state_values = tf.nest.flatten(
892 self.cell.get_initial_state(
893 inputs=None,
894 batch_size=batch_size,
895 # Use variable_dtype instead of compute_dtype, since the
896 # state is stored in a variable
897 dtype=self.variable_dtype or backend.floatx(),
898 )
899 )
900 else:
901 flat_init_state_values = tf.nest.flatten(
902 rnn_utils.generate_zero_filled_state(
903 batch_size,
904 self.cell.state_size,
905 self.variable_dtype or backend.floatx(),
906 )
907 )
908 flat_states_variables = tf.nest.map_structure(
909 backend.variable, flat_init_state_values
910 )
911 self.states = tf.nest.pack_sequence_as(
912 self.cell.state_size, flat_states_variables
913 )
914 if not tf.nest.is_nested(self.states):
915 self.states = [self.states]
916 elif states is None:
917 for state, size in zip(
918 tf.nest.flatten(self.states),
919 tf.nest.flatten(self.cell.state_size),
920 ):
921 backend.set_value(
922 state,
923 np.zeros([batch_size] + tf.TensorShape(size).as_list()),
924 )
925 else:
926 flat_states = tf.nest.flatten(self.states)
927 flat_input_states = tf.nest.flatten(states)
928 if len(flat_input_states) != len(flat_states):
929 raise ValueError(
930 f"Layer {self.name} expects {len(flat_states)} "
931 f"states, but it received {len(flat_input_states)} "
932 f"state values. States received: {states}"
933 )
934 set_value_tuples = []
935 for i, (value, state) in enumerate(
936 zip(flat_input_states, flat_states)
937 ):
938 if value.shape != state.shape:
939 raise ValueError(
940 f"State {i} is incompatible with layer {self.name}: "
941 f"expected shape={(batch_size, state)} "
942 f"but found shape={value.shape}"
943 )
944 set_value_tuples.append((state, value))
945 backend.batch_set_value(set_value_tuples)
947 def get_config(self):
948 config = {
949 "return_sequences": self.return_sequences,
950 "return_state": self.return_state,
951 "go_backwards": self.go_backwards,
952 "stateful": self.stateful,
953 "unroll": self.unroll,
954 "time_major": self.time_major,
955 }
956 if self._num_constants:
957 config["num_constants"] = self._num_constants
958 if self.zero_output_for_mask:
959 config["zero_output_for_mask"] = self.zero_output_for_mask
961 config["cell"] = serialization_lib.serialize_keras_object(self.cell)
962 base_config = super().get_config()
963 return dict(list(base_config.items()) + list(config.items()))
965 @classmethod
966 def from_config(cls, config, custom_objects=None):
967 from keras.src.layers import deserialize as deserialize_layer
969 cell = deserialize_layer(
970 config.pop("cell"), custom_objects=custom_objects
971 )
972 num_constants = config.pop("num_constants", 0)
973 layer = cls(cell, **config)
974 layer._num_constants = num_constants
975 return layer
977 @property
978 def _trackable_saved_model_saver(self):
979 return layer_serialization.RNNSavedModelSaver(self)