Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/layers/recurrent.py: 23%
1073 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=g-classes-have-attributes
17"""Recurrent layers and their base classes."""
19import collections
20import warnings
22import numpy as np
24from tensorflow.python.distribute import distribute_lib
25from tensorflow.python.eager import context
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.keras import activations
29from tensorflow.python.keras import backend
30from tensorflow.python.keras import constraints
31from tensorflow.python.keras import initializers
32from tensorflow.python.keras import regularizers
33from tensorflow.python.keras.engine.base_layer import Layer
34from tensorflow.python.keras.engine.input_spec import InputSpec
35from tensorflow.python.keras.saving.saved_model import layer_serialization
36from tensorflow.python.keras.utils import control_flow_util
37from tensorflow.python.keras.utils import generic_utils
38from tensorflow.python.keras.utils import tf_utils
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import array_ops_stack
41from tensorflow.python.ops import cond
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops import state_ops
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.trackable import base as trackable
46from tensorflow.python.util import nest
47from tensorflow.python.util.tf_export import keras_export
48from tensorflow.tools.docs import doc_controls
51RECURRENT_DROPOUT_WARNING_MSG = (
52 'RNN `implementation=2` is not supported when `recurrent_dropout` is set. '
53 'Using `implementation=1`.')
56@keras_export('keras.layers.StackedRNNCells')
57class StackedRNNCells(Layer):
58 """Wrapper allowing a stack of RNN cells to behave as a single cell.
60 Used to implement efficient stacked RNNs.
62 Args:
63 cells: List of RNN cell instances.
65 Examples:
67 ```python
68 batch_size = 3
69 sentence_max_length = 5
70 n_features = 2
71 new_shape = (batch_size, sentence_max_length, n_features)
72 x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32)
74 rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)]
75 stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells)
76 lstm_layer = tf.keras.layers.RNN(stacked_lstm)
78 result = lstm_layer(x)
79 ```
80 """
82 def __init__(self, cells, **kwargs):
83 for cell in cells:
84 if not 'call' in dir(cell):
85 raise ValueError('All cells must have a `call` method. '
86 'received cells:', cells)
87 if not 'state_size' in dir(cell):
88 raise ValueError('All cells must have a '
89 '`state_size` attribute. '
90 'received cells:', cells)
91 self.cells = cells
92 # reverse_state_order determines whether the state size will be in a reverse
93 # order of the cells' state. User might want to set this to True to keep the
94 # existing behavior. This is only useful when use RNN(return_state=True)
95 # since the state will be returned as the same order of state_size.
96 self.reverse_state_order = kwargs.pop('reverse_state_order', False)
97 if self.reverse_state_order:
98 logging.warning('reverse_state_order=True in StackedRNNCells will soon '
99 'be deprecated. Please update the code to work with the '
100 'natural order of states if you rely on the RNN states, '
101 'eg RNN(return_state=True).')
102 super(StackedRNNCells, self).__init__(**kwargs)
104 @property
105 def state_size(self):
106 return tuple(c.state_size for c in
107 (self.cells[::-1] if self.reverse_state_order else self.cells))
109 @property
110 def output_size(self):
111 if getattr(self.cells[-1], 'output_size', None) is not None:
112 return self.cells[-1].output_size
113 elif _is_multiple_state(self.cells[-1].state_size):
114 return self.cells[-1].state_size[0]
115 else:
116 return self.cells[-1].state_size
118 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
119 initial_states = []
120 for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
121 get_initial_state_fn = getattr(cell, 'get_initial_state', None)
122 if get_initial_state_fn:
123 initial_states.append(get_initial_state_fn(
124 inputs=inputs, batch_size=batch_size, dtype=dtype))
125 else:
126 initial_states.append(_generate_zero_filled_state_for_cell(
127 cell, inputs, batch_size, dtype))
129 return tuple(initial_states)
131 def call(self, inputs, states, constants=None, training=None, **kwargs):
132 # Recover per-cell states.
133 state_size = (self.state_size[::-1]
134 if self.reverse_state_order else self.state_size)
135 nested_states = nest.pack_sequence_as(state_size, nest.flatten(states))
137 # Call the cells in order and store the returned states.
138 new_nested_states = []
139 for cell, states in zip(self.cells, nested_states):
140 states = states if nest.is_nested(states) else [states]
141 # TF cell does not wrap the state into list when there is only one state.
142 is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None
143 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
144 if generic_utils.has_arg(cell.call, 'training'):
145 kwargs['training'] = training
146 else:
147 kwargs.pop('training', None)
148 # Use the __call__ function for callable objects, eg layers, so that it
149 # will have the proper name scopes for the ops, etc.
150 cell_call_fn = cell.__call__ if callable(cell) else cell.call
151 if generic_utils.has_arg(cell.call, 'constants'):
152 inputs, states = cell_call_fn(inputs, states,
153 constants=constants, **kwargs)
154 else:
155 inputs, states = cell_call_fn(inputs, states, **kwargs)
156 new_nested_states.append(states)
158 return inputs, nest.pack_sequence_as(state_size,
159 nest.flatten(new_nested_states))
161 @tf_utils.shape_type_conversion
162 def build(self, input_shape):
163 if isinstance(input_shape, list):
164 input_shape = input_shape[0]
165 for cell in self.cells:
166 if isinstance(cell, Layer) and not cell.built:
167 with backend.name_scope(cell.name):
168 cell.build(input_shape)
169 cell.built = True
170 if getattr(cell, 'output_size', None) is not None:
171 output_dim = cell.output_size
172 elif _is_multiple_state(cell.state_size):
173 output_dim = cell.state_size[0]
174 else:
175 output_dim = cell.state_size
176 input_shape = tuple([input_shape[0]] +
177 tensor_shape.TensorShape(output_dim).as_list())
178 self.built = True
180 def get_config(self):
181 cells = []
182 for cell in self.cells:
183 cells.append(generic_utils.serialize_keras_object(cell))
184 config = {'cells': cells}
185 base_config = super(StackedRNNCells, self).get_config()
186 return dict(list(base_config.items()) + list(config.items()))
188 @classmethod
189 def from_config(cls, config, custom_objects=None):
190 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
191 cells = []
192 for cell_config in config.pop('cells'):
193 cells.append(
194 deserialize_layer(cell_config, custom_objects=custom_objects))
195 return cls(cells, **config)
198@keras_export('keras.layers.RNN')
199class RNN(Layer):
200 """Base class for recurrent layers.
202 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
203 for details about the usage of RNN API.
205 Args:
206 cell: A RNN cell instance or a list of RNN cell instances.
207 A RNN cell is a class that has:
208 - A `call(input_at_t, states_at_t)` method, returning
209 `(output_at_t, states_at_t_plus_1)`. The call method of the
210 cell can also take the optional argument `constants`, see
211 section "Note on passing external constants" below.
212 - A `state_size` attribute. This can be a single integer
213 (single state) in which case it is the size of the recurrent
214 state. This can also be a list/tuple of integers (one size per state).
215 The `state_size` can also be TensorShape or tuple/list of
216 TensorShape, to represent high dimension state.
217 - A `output_size` attribute. This can be a single integer or a
218 TensorShape, which represent the shape of the output. For backward
219 compatible reason, if this attribute is not available for the
220 cell, the value will be inferred by the first element of the
221 `state_size`.
222 - A `get_initial_state(inputs=None, batch_size=None, dtype=None)`
223 method that creates a tensor meant to be fed to `call()` as the
224 initial state, if the user didn't specify any initial state via other
225 means. The returned initial state should have a shape of
226 [batch_size, cell.state_size]. The cell might choose to create a
227 tensor full of zeros, or full of other values based on the cell's
228 implementation.
229 `inputs` is the input tensor to the RNN layer, which should
230 contain the batch size as its shape[0], and also dtype. Note that
231 the shape[0] might be `None` during the graph construction. Either
232 the `inputs` or the pair of `batch_size` and `dtype` are provided.
233 `batch_size` is a scalar tensor that represents the batch size
234 of the inputs. `dtype` is `tf.DType` that represents the dtype of
235 the inputs.
236 For backward compatibility, if this method is not implemented
237 by the cell, the RNN layer will create a zero filled tensor with the
238 size of [batch_size, cell.state_size].
239 In the case that `cell` is a list of RNN cell instances, the cells
240 will be stacked on top of each other in the RNN, resulting in an
241 efficient stacked RNN.
242 return_sequences: Boolean (default `False`). Whether to return the last
243 output in the output sequence, or the full sequence.
244 return_state: Boolean (default `False`). Whether to return the last state
245 in addition to the output.
246 go_backwards: Boolean (default `False`).
247 If True, process the input sequence backwards and return the
248 reversed sequence.
249 stateful: Boolean (default `False`). If True, the last state
250 for each sample at index i in a batch will be used as initial
251 state for the sample of index i in the following batch.
252 unroll: Boolean (default `False`).
253 If True, the network will be unrolled, else a symbolic loop will be used.
254 Unrolling can speed-up a RNN, although it tends to be more
255 memory-intensive. Unrolling is only suitable for short sequences.
256 time_major: The shape format of the `inputs` and `outputs` tensors.
257 If True, the inputs and outputs will be in shape
258 `(timesteps, batch, ...)`, whereas in the False case, it will be
259 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
260 efficient because it avoids transposes at the beginning and end of the
261 RNN calculation. However, most TensorFlow data is batch-major, so by
262 default this function accepts input and emits output in batch-major
263 form.
264 zero_output_for_mask: Boolean (default `False`).
265 Whether the output should use zeros for the masked timesteps. Note that
266 this field is only used when `return_sequences` is True and mask is
267 provided. It can useful if you want to reuse the raw output sequence of
268 the RNN without interference from the masked timesteps, eg, merging
269 bidirectional RNNs.
271 Call arguments:
272 inputs: Input tensor.
273 mask: Binary tensor of shape `[batch_size, timesteps]` indicating whether
274 a given timestep should be masked. An individual `True` entry indicates
275 that the corresponding timestep should be utilized, while a `False`
276 entry indicates that the corresponding timestep should be ignored.
277 training: Python boolean indicating whether the layer should behave in
278 training mode or in inference mode. This argument is passed to the cell
279 when calling it. This is for use with cells that use dropout.
280 initial_state: List of initial state tensors to be passed to the first
281 call of the cell.
282 constants: List of constant tensors to be passed to the cell at each
283 timestep.
285 Input shape:
286 N-D tensor with shape `[batch_size, timesteps, ...]` or
287 `[timesteps, batch_size, ...]` when time_major is True.
289 Output shape:
290 - If `return_state`: a list of tensors. The first tensor is
291 the output. The remaining tensors are the last states,
292 each with shape `[batch_size, state_size]`, where `state_size` could
293 be a high dimension tensor shape.
294 - If `return_sequences`: N-D tensor with shape
295 `[batch_size, timesteps, output_size]`, where `output_size` could
296 be a high dimension tensor shape, or
297 `[timesteps, batch_size, output_size]` when `time_major` is True.
298 - Else, N-D tensor with shape `[batch_size, output_size]`, where
299 `output_size` could be a high dimension tensor shape.
301 Masking:
302 This layer supports masking for input data with a variable number
303 of timesteps. To introduce masks to your data,
304 use an [tf.keras.layers.Embedding] layer with the `mask_zero` parameter
305 set to `True`.
307 Note on using statefulness in RNNs:
308 You can set RNN layers to be 'stateful', which means that the states
309 computed for the samples in one batch will be reused as initial states
310 for the samples in the next batch. This assumes a one-to-one mapping
311 between samples in different successive batches.
313 To enable statefulness:
314 - Specify `stateful=True` in the layer constructor.
315 - Specify a fixed batch size for your model, by passing
316 If sequential model:
317 `batch_input_shape=(...)` to the first layer in your model.
318 Else for functional model with 1 or more Input layers:
319 `batch_shape=(...)` to all the first layers in your model.
320 This is the expected shape of your inputs
321 *including the batch size*.
322 It should be a tuple of integers, e.g. `(32, 10, 100)`.
323 - Specify `shuffle=False` when calling `fit()`.
325 To reset the states of your model, call `.reset_states()` on either
326 a specific layer, or on your entire model.
328 Note on specifying the initial state of RNNs:
329 You can specify the initial state of RNN layers symbolically by
330 calling them with the keyword argument `initial_state`. The value of
331 `initial_state` should be a tensor or list of tensors representing
332 the initial state of the RNN layer.
334 You can specify the initial state of RNN layers numerically by
335 calling `reset_states` with the keyword argument `states`. The value of
336 `states` should be a numpy array or list of numpy arrays representing
337 the initial state of the RNN layer.
339 Note on passing external constants to RNNs:
340 You can pass "external" constants to the cell using the `constants`
341 keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
342 requires that the `cell.call` method accepts the same keyword argument
343 `constants`. Such constants can be used to condition the cell
344 transformation on additional static inputs (not changing over time),
345 a.k.a. an attention mechanism.
347 Examples:
349 ```python
350 # First, let's define a RNN Cell, as a layer subclass.
352 class MinimalRNNCell(keras.layers.Layer):
354 def __init__(self, units, **kwargs):
355 self.units = units
356 self.state_size = units
357 super(MinimalRNNCell, self).__init__(**kwargs)
359 def build(self, input_shape):
360 self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
361 initializer='uniform',
362 name='kernel')
363 self.recurrent_kernel = self.add_weight(
364 shape=(self.units, self.units),
365 initializer='uniform',
366 name='recurrent_kernel')
367 self.built = True
369 def call(self, inputs, states):
370 prev_output = states[0]
371 h = backend.dot(inputs, self.kernel)
372 output = h + backend.dot(prev_output, self.recurrent_kernel)
373 return output, [output]
375 # Let's use this cell in a RNN layer:
377 cell = MinimalRNNCell(32)
378 x = keras.Input((None, 5))
379 layer = RNN(cell)
380 y = layer(x)
382 # Here's how to use the cell to build a stacked RNN:
384 cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
385 x = keras.Input((None, 5))
386 layer = RNN(cells)
387 y = layer(x)
388 ```
389 """
391 def __init__(self,
392 cell,
393 return_sequences=False,
394 return_state=False,
395 go_backwards=False,
396 stateful=False,
397 unroll=False,
398 time_major=False,
399 **kwargs):
400 if isinstance(cell, (list, tuple)):
401 cell = StackedRNNCells(cell)
402 if not 'call' in dir(cell):
403 raise ValueError('`cell` should have a `call` method. '
404 'The RNN was passed:', cell)
405 if not 'state_size' in dir(cell):
406 raise ValueError('The RNN cell should have '
407 'an attribute `state_size` '
408 '(tuple of integers, '
409 'one integer per RNN state).')
410 # If True, the output for masked timestep will be zeros, whereas in the
411 # False case, output from previous timestep is returned for masked timestep.
412 self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)
414 if 'input_shape' not in kwargs and (
415 'input_dim' in kwargs or 'input_length' in kwargs):
416 input_shape = (kwargs.pop('input_length', None),
417 kwargs.pop('input_dim', None))
418 kwargs['input_shape'] = input_shape
420 super(RNN, self).__init__(**kwargs)
421 self.cell = cell
422 self.return_sequences = return_sequences
423 self.return_state = return_state
424 self.go_backwards = go_backwards
425 self.stateful = stateful
426 self.unroll = unroll
427 self.time_major = time_major
429 self.supports_masking = True
430 # The input shape is unknown yet, it could have nested tensor inputs, and
431 # the input spec will be the list of specs for nested inputs, the structure
432 # of the input_spec will be the same as the input.
433 self.input_spec = None
434 self.state_spec = None
435 self._states = None
436 self.constants_spec = None
437 self._num_constants = 0
439 if stateful:
440 if distribute_lib.has_strategy():
441 raise ValueError('RNNs with stateful=True not yet supported with '
442 'tf.distribute.Strategy.')
444 @property
445 def _use_input_spec_as_call_signature(self):
446 if self.unroll:
447 # When the RNN layer is unrolled, the time step shape cannot be unknown.
448 # The input spec does not define the time step (because this layer can be
449 # called with any time step value, as long as it is not None), so it
450 # cannot be used as the call function signature when saving to SavedModel.
451 return False
452 return super(RNN, self)._use_input_spec_as_call_signature
454 @property
455 def states(self):
456 if self._states is None:
457 state = nest.map_structure(lambda _: None, self.cell.state_size)
458 return state if nest.is_nested(self.cell.state_size) else [state]
459 return self._states
461 @states.setter
462 # Automatic tracking catches "self._states" which adds an extra weight and
463 # breaks HDF5 checkpoints.
464 @trackable.no_automatic_dependency_tracking
465 def states(self, states):
466 self._states = states
468 def compute_output_shape(self, input_shape):
469 if isinstance(input_shape, list):
470 input_shape = input_shape[0]
471 # Check whether the input shape contains any nested shapes. It could be
472 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy
473 # inputs.
474 try:
475 input_shape = tensor_shape.TensorShape(input_shape)
476 except (ValueError, TypeError):
477 # A nested tensor input
478 input_shape = nest.flatten(input_shape)[0]
480 batch = input_shape[0]
481 time_step = input_shape[1]
482 if self.time_major:
483 batch, time_step = time_step, batch
485 if _is_multiple_state(self.cell.state_size):
486 state_size = self.cell.state_size
487 else:
488 state_size = [self.cell.state_size]
490 def _get_output_shape(flat_output_size):
491 output_dim = tensor_shape.TensorShape(flat_output_size).as_list()
492 if self.return_sequences:
493 if self.time_major:
494 output_shape = tensor_shape.TensorShape(
495 [time_step, batch] + output_dim)
496 else:
497 output_shape = tensor_shape.TensorShape(
498 [batch, time_step] + output_dim)
499 else:
500 output_shape = tensor_shape.TensorShape([batch] + output_dim)
501 return output_shape
503 if getattr(self.cell, 'output_size', None) is not None:
504 # cell.output_size could be nested structure.
505 output_shape = nest.flatten(nest.map_structure(
506 _get_output_shape, self.cell.output_size))
507 output_shape = output_shape[0] if len(output_shape) == 1 else output_shape
508 else:
509 # Note that state_size[0] could be a tensor_shape or int.
510 output_shape = _get_output_shape(state_size[0])
512 if self.return_state:
513 def _get_state_shape(flat_state):
514 state_shape = [batch] + tensor_shape.TensorShape(flat_state).as_list()
515 return tensor_shape.TensorShape(state_shape)
516 state_shape = nest.map_structure(_get_state_shape, state_size)
517 return generic_utils.to_list(output_shape) + nest.flatten(state_shape)
518 else:
519 return output_shape
521 def compute_mask(self, inputs, mask):
522 # Time step masks must be the same for each input.
523 # This is because the mask for an RNN is of size [batch, time_steps, 1],
524 # and specifies which time steps should be skipped, and a time step
525 # must be skipped for all inputs.
526 # TODO(scottzhu): Should we accept multiple different masks?
527 mask = nest.flatten(mask)[0]
528 output_mask = mask if self.return_sequences else None
529 if self.return_state:
530 state_mask = [None for _ in self.states]
531 return [output_mask] + state_mask
532 else:
533 return output_mask
535 def build(self, input_shape):
536 if isinstance(input_shape, list):
537 input_shape = input_shape[0]
538 # The input_shape here could be a nest structure.
540 # do the tensor_shape to shapes here. The input could be single tensor, or a
541 # nested structure of tensors.
542 def get_input_spec(shape):
543 """Convert input shape to InputSpec."""
544 if isinstance(shape, tensor_shape.TensorShape):
545 input_spec_shape = shape.as_list()
546 else:
547 input_spec_shape = list(shape)
548 batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
549 if not self.stateful:
550 input_spec_shape[batch_index] = None
551 input_spec_shape[time_step_index] = None
552 return InputSpec(shape=tuple(input_spec_shape))
554 def get_step_input_shape(shape):
555 if isinstance(shape, tensor_shape.TensorShape):
556 shape = tuple(shape.as_list())
557 # remove the timestep from the input_shape
558 return shape[1:] if self.time_major else (shape[0],) + shape[2:]
560 # Check whether the input shape contains any nested shapes. It could be
561 # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy
562 # inputs.
563 try:
564 input_shape = tensor_shape.TensorShape(input_shape)
565 except (ValueError, TypeError):
566 # A nested tensor input
567 pass
569 if not nest.is_nested(input_shape):
570 # This indicates the there is only one input.
571 if self.input_spec is not None:
572 self.input_spec[0] = get_input_spec(input_shape)
573 else:
574 self.input_spec = [get_input_spec(input_shape)]
575 step_input_shape = get_step_input_shape(input_shape)
576 else:
577 if self.input_spec is not None:
578 self.input_spec[0] = nest.map_structure(get_input_spec, input_shape)
579 else:
580 self.input_spec = generic_utils.to_list(
581 nest.map_structure(get_input_spec, input_shape))
582 step_input_shape = nest.map_structure(get_step_input_shape, input_shape)
584 # allow cell (if layer) to build before we set or validate state_spec.
585 if isinstance(self.cell, Layer) and not self.cell.built:
586 with backend.name_scope(self.cell.name):
587 self.cell.build(step_input_shape)
588 self.cell.built = True
590 # set or validate state_spec
591 if _is_multiple_state(self.cell.state_size):
592 state_size = list(self.cell.state_size)
593 else:
594 state_size = [self.cell.state_size]
596 if self.state_spec is not None:
597 # initial_state was passed in call, check compatibility
598 self._validate_state_spec(state_size, self.state_spec)
599 else:
600 self.state_spec = [
601 InputSpec(shape=[None] + tensor_shape.TensorShape(dim).as_list())
602 for dim in state_size
603 ]
604 if self.stateful:
605 self.reset_states()
606 self.built = True
608 @staticmethod
609 def _validate_state_spec(cell_state_sizes, init_state_specs):
610 """Validate the state spec between the initial_state and the state_size.
612 Args:
613 cell_state_sizes: list, the `state_size` attribute from the cell.
614 init_state_specs: list, the `state_spec` from the initial_state that is
615 passed in `call()`.
617 Raises:
618 ValueError: When initial state spec is not compatible with the state size.
619 """
620 validation_error = ValueError(
621 'An `initial_state` was passed that is not compatible with '
622 '`cell.state_size`. Received `state_spec`={}; '
623 'however `cell.state_size` is '
624 '{}'.format(init_state_specs, cell_state_sizes))
625 flat_cell_state_sizes = nest.flatten(cell_state_sizes)
626 flat_state_specs = nest.flatten(init_state_specs)
628 if len(flat_cell_state_sizes) != len(flat_state_specs):
629 raise validation_error
630 for cell_state_spec, cell_state_size in zip(flat_state_specs,
631 flat_cell_state_sizes):
632 if not tensor_shape.TensorShape(
633 # Ignore the first axis for init_state which is for batch
634 cell_state_spec.shape[1:]).is_compatible_with(
635 tensor_shape.TensorShape(cell_state_size)):
636 raise validation_error
638 @doc_controls.do_not_doc_inheritable
639 def get_initial_state(self, inputs):
640 get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
642 if nest.is_nested(inputs):
643 # The input are nested sequences. Use the first element in the seq to get
644 # batch size and dtype.
645 inputs = nest.flatten(inputs)[0]
647 input_shape = array_ops.shape(inputs)
648 batch_size = input_shape[1] if self.time_major else input_shape[0]
649 dtype = inputs.dtype
650 if get_initial_state_fn:
651 init_state = get_initial_state_fn(
652 inputs=None, batch_size=batch_size, dtype=dtype)
653 else:
654 init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,
655 dtype)
656 # Keras RNN expect the states in a list, even if it's a single state tensor.
657 if not nest.is_nested(init_state):
658 init_state = [init_state]
659 # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple.
660 return list(init_state)
662 def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
663 inputs, initial_state, constants = _standardize_args(inputs,
664 initial_state,
665 constants,
666 self._num_constants)
668 if initial_state is None and constants is None:
669 return super(RNN, self).__call__(inputs, **kwargs)
671 # If any of `initial_state` or `constants` are specified and are Keras
672 # tensors, then add them to the inputs and temporarily modify the
673 # input_spec to include them.
675 additional_inputs = []
676 additional_specs = []
677 if initial_state is not None:
678 additional_inputs += initial_state
679 self.state_spec = nest.map_structure(
680 lambda s: InputSpec(shape=backend.int_shape(s)), initial_state)
681 additional_specs += self.state_spec
682 if constants is not None:
683 additional_inputs += constants
684 self.constants_spec = [
685 InputSpec(shape=backend.int_shape(constant)) for constant in constants
686 ]
687 self._num_constants = len(constants)
688 additional_specs += self.constants_spec
689 # additional_inputs can be empty if initial_state or constants are provided
690 # but empty (e.g. the cell is stateless).
691 flat_additional_inputs = nest.flatten(additional_inputs)
692 is_keras_tensor = backend.is_keras_tensor(
693 flat_additional_inputs[0]) if flat_additional_inputs else True
694 for tensor in flat_additional_inputs:
695 if backend.is_keras_tensor(tensor) != is_keras_tensor:
696 raise ValueError('The initial state or constants of an RNN'
697 ' layer cannot be specified with a mix of'
698 ' Keras tensors and non-Keras tensors'
699 ' (a "Keras tensor" is a tensor that was'
700 ' returned by a Keras layer, or by `Input`)')
702 if is_keras_tensor:
703 # Compute the full input spec, including state and constants
704 full_input = [inputs] + additional_inputs
705 if self.built:
706 # Keep the input_spec since it has been populated in build() method.
707 full_input_spec = self.input_spec + additional_specs
708 else:
709 # The original input_spec is None since there could be a nested tensor
710 # input. Update the input_spec to match the inputs.
711 full_input_spec = generic_utils.to_list(
712 nest.map_structure(lambda _: None, inputs)) + additional_specs
713 # Perform the call with temporarily replaced input_spec
714 self.input_spec = full_input_spec
715 output = super(RNN, self).__call__(full_input, **kwargs)
716 # Remove the additional_specs from input spec and keep the rest. It is
717 # important to keep since the input spec was populated by build(), and
718 # will be reused in the stateful=True.
719 self.input_spec = self.input_spec[:-len(additional_specs)]
720 return output
721 else:
722 if initial_state is not None:
723 kwargs['initial_state'] = initial_state
724 if constants is not None:
725 kwargs['constants'] = constants
726 return super(RNN, self).__call__(inputs, **kwargs)
728 def call(self,
729 inputs,
730 mask=None,
731 training=None,
732 initial_state=None,
733 constants=None):
734 # The input should be dense, padded with zeros. If a ragged input is fed
735 # into the layer, it is padded and the row lengths are used for masking.
736 inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
737 is_ragged_input = (row_lengths is not None)
738 self._validate_args_if_ragged(is_ragged_input, mask)
740 inputs, initial_state, constants = self._process_inputs(
741 inputs, initial_state, constants)
743 self._maybe_reset_cell_dropout_mask(self.cell)
744 if isinstance(self.cell, StackedRNNCells):
745 for cell in self.cell.cells:
746 self._maybe_reset_cell_dropout_mask(cell)
748 if mask is not None:
749 # Time step masks must be the same for each input.
750 # TODO(scottzhu): Should we accept multiple different masks?
751 mask = nest.flatten(mask)[0]
753 if nest.is_nested(inputs):
754 # In the case of nested input, use the first element for shape check.
755 input_shape = backend.int_shape(nest.flatten(inputs)[0])
756 else:
757 input_shape = backend.int_shape(inputs)
758 timesteps = input_shape[0] if self.time_major else input_shape[1]
759 if self.unroll and timesteps is None:
760 raise ValueError('Cannot unroll a RNN if the '
761 'time dimension is undefined. \n'
762 '- If using a Sequential model, '
763 'specify the time dimension by passing '
764 'an `input_shape` or `batch_input_shape` '
765 'argument to your first layer. If your '
766 'first layer is an Embedding, you can '
767 'also use the `input_length` argument.\n'
768 '- If using the functional API, specify '
769 'the time dimension by passing a `shape` '
770 'or `batch_shape` argument to your Input layer.')
772 kwargs = {}
773 if generic_utils.has_arg(self.cell.call, 'training'):
774 kwargs['training'] = training
776 # TF RNN cells expect single tensor as state instead of list wrapped tensor.
777 is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
778 # Use the __call__ function for callable objects, eg layers, so that it
779 # will have the proper name scopes for the ops, etc.
780 cell_call_fn = self.cell.__call__ if callable(self.cell) else self.cell.call
781 if constants:
782 if not generic_utils.has_arg(self.cell.call, 'constants'):
783 raise ValueError('RNN cell does not support constants')
785 def step(inputs, states):
786 constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
787 states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
789 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
790 output, new_states = cell_call_fn(
791 inputs, states, constants=constants, **kwargs)
792 if not nest.is_nested(new_states):
793 new_states = [new_states]
794 return output, new_states
795 else:
797 def step(inputs, states):
798 states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
799 output, new_states = cell_call_fn(inputs, states, **kwargs)
800 if not nest.is_nested(new_states):
801 new_states = [new_states]
802 return output, new_states
803 last_output, outputs, states = backend.rnn(
804 step,
805 inputs,
806 initial_state,
807 constants=constants,
808 go_backwards=self.go_backwards,
809 mask=mask,
810 unroll=self.unroll,
811 input_length=row_lengths if row_lengths is not None else timesteps,
812 time_major=self.time_major,
813 zero_output_for_mask=self.zero_output_for_mask)
815 if self.stateful:
816 updates = [
817 state_ops.assign(self_state, state) for self_state, state in zip(
818 nest.flatten(self.states), nest.flatten(states))
819 ]
820 self.add_update(updates)
822 if self.return_sequences:
823 output = backend.maybe_convert_to_ragged(
824 is_ragged_input, outputs, row_lengths, go_backwards=self.go_backwards)
825 else:
826 output = last_output
828 if self.return_state:
829 if not isinstance(states, (list, tuple)):
830 states = [states]
831 else:
832 states = list(states)
833 return generic_utils.to_list(output) + states
834 else:
835 return output
837 def _process_inputs(self, inputs, initial_state, constants):
838 # input shape: `(samples, time (padded with zeros), input_dim)`
839 # note that the .build() method of subclasses MUST define
840 # self.input_spec and self.state_spec with complete input shapes.
841 if (isinstance(inputs, collections.abc.Sequence)
842 and not isinstance(inputs, tuple)):
843 # get initial_state from full input spec
844 # as they could be copied to multiple GPU.
845 if not self._num_constants:
846 initial_state = inputs[1:]
847 else:
848 initial_state = inputs[1:-self._num_constants]
849 constants = inputs[-self._num_constants:]
850 if len(initial_state) == 0:
851 initial_state = None
852 inputs = inputs[0]
854 if self.stateful:
855 if initial_state is not None:
856 # When layer is stateful and initial_state is provided, check if the
857 # recorded state is same as the default value (zeros). Use the recorded
858 # state if it is not same as the default.
859 non_zero_count = math_ops.add_n([math_ops.count_nonzero_v2(s)
860 for s in nest.flatten(self.states)])
861 # Set strict = True to keep the original structure of the state.
862 initial_state = cond.cond(non_zero_count > 0,
863 true_fn=lambda: self.states,
864 false_fn=lambda: initial_state,
865 strict=True)
866 else:
867 initial_state = self.states
868 elif initial_state is None:
869 initial_state = self.get_initial_state(inputs)
871 if len(initial_state) != len(self.states):
872 raise ValueError('Layer has ' + str(len(self.states)) +
873 ' states but was passed ' + str(len(initial_state)) +
874 ' initial states.')
875 return inputs, initial_state, constants
877 def _validate_args_if_ragged(self, is_ragged_input, mask):
878 if not is_ragged_input:
879 return
881 if mask is not None:
882 raise ValueError('The mask that was passed in was ' + str(mask) +
883 ' and cannot be applied to RaggedTensor inputs. Please '
884 'make sure that there is no mask passed in by upstream '
885 'layers.')
886 if self.unroll:
887 raise ValueError('The input received contains RaggedTensors and does '
888 'not support unrolling. Disable unrolling by passing '
889 '`unroll=False` in the RNN Layer constructor.')
891 def _maybe_reset_cell_dropout_mask(self, cell):
892 if isinstance(cell, DropoutRNNCellMixin):
893 cell.reset_dropout_mask()
894 cell.reset_recurrent_dropout_mask()
896 def reset_states(self, states=None):
897 """Reset the recorded states for the stateful RNN layer.
899 Can only be used when RNN layer is constructed with `stateful` = `True`.
900 Args:
901 states: Numpy arrays that contains the value for the initial state, which
902 will be feed to cell at the first time step. When the value is None,
903 zero filled numpy array will be created based on the cell state size.
905 Raises:
906 AttributeError: When the RNN layer is not stateful.
907 ValueError: When the batch size of the RNN layer is unknown.
908 ValueError: When the input numpy array is not compatible with the RNN
909 layer state, either size wise or dtype wise.
910 """
911 if not self.stateful:
912 raise AttributeError('Layer must be stateful.')
913 spec_shape = None
914 if self.input_spec is not None:
915 spec_shape = nest.flatten(self.input_spec[0])[0].shape
916 if spec_shape is None:
917 # It is possible to have spec shape to be None, eg when construct a RNN
918 # with a custom cell, or standard RNN layers (LSTM/GRU) which we only know
919 # it has 3 dim input, but not its full shape spec before build().
920 batch_size = None
921 else:
922 batch_size = spec_shape[1] if self.time_major else spec_shape[0]
923 if not batch_size:
924 raise ValueError('If a RNN is stateful, it needs to know '
925 'its batch size. Specify the batch size '
926 'of your input tensors: \n'
927 '- If using a Sequential model, '
928 'specify the batch size by passing '
929 'a `batch_input_shape` '
930 'argument to your first layer.\n'
931 '- If using the functional API, specify '
932 'the batch size by passing a '
933 '`batch_shape` argument to your Input layer.')
934 # initialize state if None
935 if nest.flatten(self.states)[0] is None:
936 if getattr(self.cell, 'get_initial_state', None):
937 flat_init_state_values = nest.flatten(self.cell.get_initial_state(
938 inputs=None, batch_size=batch_size,
939 dtype=self.dtype or backend.floatx()))
940 else:
941 flat_init_state_values = nest.flatten(_generate_zero_filled_state(
942 batch_size, self.cell.state_size, self.dtype or backend.floatx()))
943 flat_states_variables = nest.map_structure(
944 backend.variable, flat_init_state_values)
945 self.states = nest.pack_sequence_as(self.cell.state_size,
946 flat_states_variables)
947 if not nest.is_nested(self.states):
948 self.states = [self.states]
949 elif states is None:
950 for state, size in zip(nest.flatten(self.states),
951 nest.flatten(self.cell.state_size)):
952 backend.set_value(
953 state,
954 np.zeros([batch_size] + tensor_shape.TensorShape(size).as_list()))
955 else:
956 flat_states = nest.flatten(self.states)
957 flat_input_states = nest.flatten(states)
958 if len(flat_input_states) != len(flat_states):
959 raise ValueError('Layer ' + self.name + ' expects ' +
960 str(len(flat_states)) + ' states, '
961 'but it received ' + str(len(flat_input_states)) +
962 ' state values. Input received: ' + str(states))
963 set_value_tuples = []
964 for i, (value, state) in enumerate(zip(flat_input_states,
965 flat_states)):
966 if value.shape != state.shape:
967 raise ValueError(
968 'State ' + str(i) + ' is incompatible with layer ' +
969 self.name + ': expected shape=' + str(
970 (batch_size, state)) + ', found shape=' + str(value.shape))
971 set_value_tuples.append((state, value))
972 backend.batch_set_value(set_value_tuples)
974 def get_config(self):
975 config = {
976 'return_sequences': self.return_sequences,
977 'return_state': self.return_state,
978 'go_backwards': self.go_backwards,
979 'stateful': self.stateful,
980 'unroll': self.unroll,
981 'time_major': self.time_major
982 }
983 if self._num_constants:
984 config['num_constants'] = self._num_constants
985 if self.zero_output_for_mask:
986 config['zero_output_for_mask'] = self.zero_output_for_mask
988 config['cell'] = generic_utils.serialize_keras_object(self.cell)
989 base_config = super(RNN, self).get_config()
990 return dict(list(base_config.items()) + list(config.items()))
992 @classmethod
993 def from_config(cls, config, custom_objects=None):
994 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
995 cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects)
996 num_constants = config.pop('num_constants', 0)
997 layer = cls(cell, **config)
998 layer._num_constants = num_constants
999 return layer
1001 @property
1002 def _trackable_saved_model_saver(self):
1003 return layer_serialization.RNNSavedModelSaver(self)
1006@keras_export('keras.layers.AbstractRNNCell')
1007class AbstractRNNCell(Layer):
1008 """Abstract object representing an RNN cell.
1010 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
1011 for details about the usage of RNN API.
1013 This is the base class for implementing RNN cells with custom behavior.
1015 Every `RNNCell` must have the properties below and implement `call` with
1016 the signature `(output, next_state) = call(input, state)`.
1018 Examples:
1020 ```python
1021 class MinimalRNNCell(AbstractRNNCell):
1023 def __init__(self, units, **kwargs):
1024 self.units = units
1025 super(MinimalRNNCell, self).__init__(**kwargs)
1027 @property
1028 def state_size(self):
1029 return self.units
1031 def build(self, input_shape):
1032 self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
1033 initializer='uniform',
1034 name='kernel')
1035 self.recurrent_kernel = self.add_weight(
1036 shape=(self.units, self.units),
1037 initializer='uniform',
1038 name='recurrent_kernel')
1039 self.built = True
1041 def call(self, inputs, states):
1042 prev_output = states[0]
1043 h = backend.dot(inputs, self.kernel)
1044 output = h + backend.dot(prev_output, self.recurrent_kernel)
1045 return output, output
1046 ```
1048 This definition of cell differs from the definition used in the literature.
1049 In the literature, 'cell' refers to an object with a single scalar output.
1050 This definition refers to a horizontal array of such units.
1052 An RNN cell, in the most abstract setting, is anything that has
1053 a state and performs some operation that takes a matrix of inputs.
1054 This operation results in an output matrix with `self.output_size` columns.
1055 If `self.state_size` is an integer, this operation also results in a new
1056 state matrix with `self.state_size` columns. If `self.state_size` is a
1057 (possibly nested tuple of) TensorShape object(s), then it should return a
1058 matching structure of Tensors having shape `[batch_size].concatenate(s)`
1059 for each `s` in `self.batch_size`.
1060 """
1062 def call(self, inputs, states):
1063 """The function that contains the logic for one RNN step calculation.
1065 Args:
1066 inputs: the input tensor, which is a slide from the overall RNN input by
1067 the time dimension (usually the second dimension).
1068 states: the state tensor from previous step, which has the same shape
1069 as `(batch, state_size)`. In the case of timestep 0, it will be the
1070 initial state user specified, or zero filled tensor otherwise.
1072 Returns:
1073 A tuple of two tensors:
1074 1. output tensor for the current timestep, with size `output_size`.
1075 2. state tensor for next step, which has the shape of `state_size`.
1076 """
1077 raise NotImplementedError('Abstract method')
1079 @property
1080 def state_size(self):
1081 """size(s) of state(s) used by this cell.
1083 It can be represented by an Integer, a TensorShape or a tuple of Integers
1084 or TensorShapes.
1085 """
1086 raise NotImplementedError('Abstract method')
1088 @property
1089 def output_size(self):
1090 """Integer or TensorShape: size of outputs produced by this cell."""
1091 raise NotImplementedError('Abstract method')
1093 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
1094 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
1097@doc_controls.do_not_generate_docs
1098class DropoutRNNCellMixin(object):
1099 """Object that hold dropout related fields for RNN Cell.
1101 This class is not a standalone RNN cell. It suppose to be used with a RNN cell
1102 by multiple inheritance. Any cell that mix with class should have following
1103 fields:
1104 dropout: a float number within range [0, 1). The ratio that the input
1105 tensor need to dropout.
1106 recurrent_dropout: a float number within range [0, 1). The ratio that the
1107 recurrent state weights need to dropout.
1108 This object will create and cache created dropout masks, and reuse them for
1109 the incoming data, so that the same mask is used for every batch input.
1110 """
1112 def __init__(self, *args, **kwargs):
1113 self._create_non_trackable_mask_cache()
1114 super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
1116 @trackable.no_automatic_dependency_tracking
1117 def _create_non_trackable_mask_cache(self):
1118 """Create the cache for dropout and recurrent dropout mask.
1120 Note that the following two masks will be used in "graph function" mode,
1121 e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask`
1122 tensors will be generated differently than in the "graph function" case,
1123 and they will be cached.
1125 Also note that in graph mode, we still cache those masks only because the
1126 RNN could be created with `unroll=True`. In that case, the `cell.call()`
1127 function will be invoked multiple times, and we want to ensure same mask
1128 is used every time.
1130 Also the caches are created without tracking. Since they are not picklable
1131 by python when deepcopy, we don't want `layer._obj_reference_counts_dict`
1132 to track it by default.
1133 """
1134 self._dropout_mask_cache = backend.ContextValueCache(
1135 self._create_dropout_mask)
1136 self._recurrent_dropout_mask_cache = backend.ContextValueCache(
1137 self._create_recurrent_dropout_mask)
1139 def reset_dropout_mask(self):
1140 """Reset the cached dropout masks if any.
1142 This is important for the RNN layer to invoke this in it `call()` method so
1143 that the cached mask is cleared before calling the `cell.call()`. The mask
1144 should be cached across the timestep within the same batch, but shouldn't
1145 be cached between batches. Otherwise it will introduce unreasonable bias
1146 against certain index of data within the batch.
1147 """
1148 self._dropout_mask_cache.clear()
1150 def reset_recurrent_dropout_mask(self):
1151 """Reset the cached recurrent dropout masks if any.
1153 This is important for the RNN layer to invoke this in it call() method so
1154 that the cached mask is cleared before calling the cell.call(). The mask
1155 should be cached across the timestep within the same batch, but shouldn't
1156 be cached between batches. Otherwise it will introduce unreasonable bias
1157 against certain index of data within the batch.
1158 """
1159 self._recurrent_dropout_mask_cache.clear()
1161 def _create_dropout_mask(self, inputs, training, count=1):
1162 return _generate_dropout_mask(
1163 array_ops.ones_like(inputs),
1164 self.dropout,
1165 training=training,
1166 count=count)
1168 def _create_recurrent_dropout_mask(self, inputs, training, count=1):
1169 return _generate_dropout_mask(
1170 array_ops.ones_like(inputs),
1171 self.recurrent_dropout,
1172 training=training,
1173 count=count)
1175 def get_dropout_mask_for_cell(self, inputs, training, count=1):
1176 """Get the dropout mask for RNN cell's input.
1178 It will create mask based on context if there isn't any existing cached
1179 mask. If a new mask is generated, it will update the cache in the cell.
1181 Args:
1182 inputs: The input tensor whose shape will be used to generate dropout
1183 mask.
1184 training: Boolean tensor, whether its in training mode, dropout will be
1185 ignored in non-training mode.
1186 count: Int, how many dropout mask will be generated. It is useful for cell
1187 that has internal weights fused together.
1188 Returns:
1189 List of mask tensor, generated or cached mask based on context.
1190 """
1191 if self.dropout == 0:
1192 return None
1193 init_kwargs = dict(inputs=inputs, training=training, count=count)
1194 return self._dropout_mask_cache.setdefault(kwargs=init_kwargs)
1196 def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1):
1197 """Get the recurrent dropout mask for RNN cell.
1199 It will create mask based on context if there isn't any existing cached
1200 mask. If a new mask is generated, it will update the cache in the cell.
1202 Args:
1203 inputs: The input tensor whose shape will be used to generate dropout
1204 mask.
1205 training: Boolean tensor, whether its in training mode, dropout will be
1206 ignored in non-training mode.
1207 count: Int, how many dropout mask will be generated. It is useful for cell
1208 that has internal weights fused together.
1209 Returns:
1210 List of mask tensor, generated or cached mask based on context.
1211 """
1212 if self.recurrent_dropout == 0:
1213 return None
1214 init_kwargs = dict(inputs=inputs, training=training, count=count)
1215 return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs)
1217 def __getstate__(self):
1218 # Used for deepcopy. The caching can't be pickled by python, since it will
1219 # contain tensor and graph.
1220 state = super(DropoutRNNCellMixin, self).__getstate__()
1221 state.pop('_dropout_mask_cache', None)
1222 state.pop('_recurrent_dropout_mask_cache', None)
1223 return state
1225 def __setstate__(self, state):
1226 state['_dropout_mask_cache'] = backend.ContextValueCache(
1227 self._create_dropout_mask)
1228 state['_recurrent_dropout_mask_cache'] = backend.ContextValueCache(
1229 self._create_recurrent_dropout_mask)
1230 super(DropoutRNNCellMixin, self).__setstate__(state)
1233@keras_export('keras.layers.SimpleRNNCell')
1234class SimpleRNNCell(DropoutRNNCellMixin, Layer):
1235 """Cell class for SimpleRNN.
1237 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
1238 for details about the usage of RNN API.
1240 This class processes one step within the whole time sequence input, whereas
1241 `tf.keras.layer.SimpleRNN` processes the whole sequence.
1243 Args:
1244 units: Positive integer, dimensionality of the output space.
1245 activation: Activation function to use.
1246 Default: hyperbolic tangent (`tanh`).
1247 If you pass `None`, no activation is applied
1248 (ie. "linear" activation: `a(x) = x`).
1249 use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
1250 kernel_initializer: Initializer for the `kernel` weights matrix,
1251 used for the linear transformation of the inputs. Default:
1252 `glorot_uniform`.
1253 recurrent_initializer: Initializer for the `recurrent_kernel`
1254 weights matrix, used for the linear transformation of the recurrent state.
1255 Default: `orthogonal`.
1256 bias_initializer: Initializer for the bias vector. Default: `zeros`.
1257 kernel_regularizer: Regularizer function applied to the `kernel` weights
1258 matrix. Default: `None`.
1259 recurrent_regularizer: Regularizer function applied to the
1260 `recurrent_kernel` weights matrix. Default: `None`.
1261 bias_regularizer: Regularizer function applied to the bias vector. Default:
1262 `None`.
1263 kernel_constraint: Constraint function applied to the `kernel` weights
1264 matrix. Default: `None`.
1265 recurrent_constraint: Constraint function applied to the `recurrent_kernel`
1266 weights matrix. Default: `None`.
1267 bias_constraint: Constraint function applied to the bias vector. Default:
1268 `None`.
1269 dropout: Float between 0 and 1. Fraction of the units to drop for the linear
1270 transformation of the inputs. Default: 0.
1271 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
1272 the linear transformation of the recurrent state. Default: 0.
1274 Call arguments:
1275 inputs: A 2D tensor, with shape of `[batch, feature]`.
1276 states: A 2D tensor with shape of `[batch, units]`, which is the state from
1277 the previous time step. For timestep 0, the initial state provided by user
1278 will be feed to cell.
1279 training: Python boolean indicating whether the layer should behave in
1280 training mode or in inference mode. Only relevant when `dropout` or
1281 `recurrent_dropout` is used.
1283 Examples:
1285 ```python
1286 inputs = np.random.random([32, 10, 8]).astype(np.float32)
1287 rnn = tf.keras.layers.RNN(tf.keras.layers.SimpleRNNCell(4))
1289 output = rnn(inputs) # The output has shape `[32, 4]`.
1291 rnn = tf.keras.layers.RNN(
1292 tf.keras.layers.SimpleRNNCell(4),
1293 return_sequences=True,
1294 return_state=True)
1296 # whole_sequence_output has shape `[32, 10, 4]`.
1297 # final_state has shape `[32, 4]`.
1298 whole_sequence_output, final_state = rnn(inputs)
1299 ```
1300 """
1302 def __init__(self,
1303 units,
1304 activation='tanh',
1305 use_bias=True,
1306 kernel_initializer='glorot_uniform',
1307 recurrent_initializer='orthogonal',
1308 bias_initializer='zeros',
1309 kernel_regularizer=None,
1310 recurrent_regularizer=None,
1311 bias_regularizer=None,
1312 kernel_constraint=None,
1313 recurrent_constraint=None,
1314 bias_constraint=None,
1315 dropout=0.,
1316 recurrent_dropout=0.,
1317 **kwargs):
1318 if units < 0:
1319 raise ValueError(f'Received an invalid value for units, expected '
1320 f'a positive integer, got {units}.')
1321 # By default use cached variable under v2 mode, see b/143699808.
1322 if ops.executing_eagerly_outside_functions():
1323 self._enable_caching_device = kwargs.pop('enable_caching_device', True)
1324 else:
1325 self._enable_caching_device = kwargs.pop('enable_caching_device', False)
1326 super(SimpleRNNCell, self).__init__(**kwargs)
1327 self.units = units
1328 self.activation = activations.get(activation)
1329 self.use_bias = use_bias
1331 self.kernel_initializer = initializers.get(kernel_initializer)
1332 self.recurrent_initializer = initializers.get(recurrent_initializer)
1333 self.bias_initializer = initializers.get(bias_initializer)
1335 self.kernel_regularizer = regularizers.get(kernel_regularizer)
1336 self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
1337 self.bias_regularizer = regularizers.get(bias_regularizer)
1339 self.kernel_constraint = constraints.get(kernel_constraint)
1340 self.recurrent_constraint = constraints.get(recurrent_constraint)
1341 self.bias_constraint = constraints.get(bias_constraint)
1343 self.dropout = min(1., max(0., dropout))
1344 self.recurrent_dropout = min(1., max(0., recurrent_dropout))
1345 self.state_size = self.units
1346 self.output_size = self.units
1348 @tf_utils.shape_type_conversion
1349 def build(self, input_shape):
1350 default_caching_device = _caching_device(self)
1351 self.kernel = self.add_weight(
1352 shape=(input_shape[-1], self.units),
1353 name='kernel',
1354 initializer=self.kernel_initializer,
1355 regularizer=self.kernel_regularizer,
1356 constraint=self.kernel_constraint,
1357 caching_device=default_caching_device)
1358 self.recurrent_kernel = self.add_weight(
1359 shape=(self.units, self.units),
1360 name='recurrent_kernel',
1361 initializer=self.recurrent_initializer,
1362 regularizer=self.recurrent_regularizer,
1363 constraint=self.recurrent_constraint,
1364 caching_device=default_caching_device)
1365 if self.use_bias:
1366 self.bias = self.add_weight(
1367 shape=(self.units,),
1368 name='bias',
1369 initializer=self.bias_initializer,
1370 regularizer=self.bias_regularizer,
1371 constraint=self.bias_constraint,
1372 caching_device=default_caching_device)
1373 else:
1374 self.bias = None
1375 self.built = True
1377 def call(self, inputs, states, training=None):
1378 prev_output = states[0] if nest.is_nested(states) else states
1379 dp_mask = self.get_dropout_mask_for_cell(inputs, training)
1380 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
1381 prev_output, training)
1383 if dp_mask is not None:
1384 h = backend.dot(inputs * dp_mask, self.kernel)
1385 else:
1386 h = backend.dot(inputs, self.kernel)
1387 if self.bias is not None:
1388 h = backend.bias_add(h, self.bias)
1390 if rec_dp_mask is not None:
1391 prev_output = prev_output * rec_dp_mask
1392 output = h + backend.dot(prev_output, self.recurrent_kernel)
1393 if self.activation is not None:
1394 output = self.activation(output)
1396 new_state = [output] if nest.is_nested(states) else output
1397 return output, new_state
1399 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
1400 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
1402 def get_config(self):
1403 config = {
1404 'units':
1405 self.units,
1406 'activation':
1407 activations.serialize(self.activation),
1408 'use_bias':
1409 self.use_bias,
1410 'kernel_initializer':
1411 initializers.serialize(self.kernel_initializer),
1412 'recurrent_initializer':
1413 initializers.serialize(self.recurrent_initializer),
1414 'bias_initializer':
1415 initializers.serialize(self.bias_initializer),
1416 'kernel_regularizer':
1417 regularizers.serialize(self.kernel_regularizer),
1418 'recurrent_regularizer':
1419 regularizers.serialize(self.recurrent_regularizer),
1420 'bias_regularizer':
1421 regularizers.serialize(self.bias_regularizer),
1422 'kernel_constraint':
1423 constraints.serialize(self.kernel_constraint),
1424 'recurrent_constraint':
1425 constraints.serialize(self.recurrent_constraint),
1426 'bias_constraint':
1427 constraints.serialize(self.bias_constraint),
1428 'dropout':
1429 self.dropout,
1430 'recurrent_dropout':
1431 self.recurrent_dropout
1432 }
1433 config.update(_config_for_enable_caching_device(self))
1434 base_config = super(SimpleRNNCell, self).get_config()
1435 return dict(list(base_config.items()) + list(config.items()))
1438@keras_export('keras.layers.SimpleRNN')
1439class SimpleRNN(RNN):
1440 """Fully-connected RNN where the output is to be fed back to input.
1442 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
1443 for details about the usage of RNN API.
1445 Args:
1446 units: Positive integer, dimensionality of the output space.
1447 activation: Activation function to use.
1448 Default: hyperbolic tangent (`tanh`).
1449 If you pass None, no activation is applied
1450 (ie. "linear" activation: `a(x) = x`).
1451 use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
1452 kernel_initializer: Initializer for the `kernel` weights matrix,
1453 used for the linear transformation of the inputs. Default:
1454 `glorot_uniform`.
1455 recurrent_initializer: Initializer for the `recurrent_kernel`
1456 weights matrix, used for the linear transformation of the recurrent state.
1457 Default: `orthogonal`.
1458 bias_initializer: Initializer for the bias vector. Default: `zeros`.
1459 kernel_regularizer: Regularizer function applied to the `kernel` weights
1460 matrix. Default: `None`.
1461 recurrent_regularizer: Regularizer function applied to the
1462 `recurrent_kernel` weights matrix. Default: `None`.
1463 bias_regularizer: Regularizer function applied to the bias vector. Default:
1464 `None`.
1465 activity_regularizer: Regularizer function applied to the output of the
1466 layer (its "activation"). Default: `None`.
1467 kernel_constraint: Constraint function applied to the `kernel` weights
1468 matrix. Default: `None`.
1469 recurrent_constraint: Constraint function applied to the `recurrent_kernel`
1470 weights matrix. Default: `None`.
1471 bias_constraint: Constraint function applied to the bias vector. Default:
1472 `None`.
1473 dropout: Float between 0 and 1.
1474 Fraction of the units to drop for the linear transformation of the inputs.
1475 Default: 0.
1476 recurrent_dropout: Float between 0 and 1.
1477 Fraction of the units to drop for the linear transformation of the
1478 recurrent state. Default: 0.
1479 return_sequences: Boolean. Whether to return the last output
1480 in the output sequence, or the full sequence. Default: `False`.
1481 return_state: Boolean. Whether to return the last state
1482 in addition to the output. Default: `False`
1483 go_backwards: Boolean (default False).
1484 If True, process the input sequence backwards and return the
1485 reversed sequence.
1486 stateful: Boolean (default False). If True, the last state
1487 for each sample at index i in a batch will be used as initial
1488 state for the sample of index i in the following batch.
1489 unroll: Boolean (default False).
1490 If True, the network will be unrolled,
1491 else a symbolic loop will be used.
1492 Unrolling can speed-up a RNN,
1493 although it tends to be more memory-intensive.
1494 Unrolling is only suitable for short sequences.
1496 Call arguments:
1497 inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
1498 mask: Binary tensor of shape `[batch, timesteps]` indicating whether
1499 a given timestep should be masked. An individual `True` entry indicates
1500 that the corresponding timestep should be utilized, while a `False` entry
1501 indicates that the corresponding timestep should be ignored.
1502 training: Python boolean indicating whether the layer should behave in
1503 training mode or in inference mode. This argument is passed to the cell
1504 when calling it. This is only relevant if `dropout` or
1505 `recurrent_dropout` is used.
1506 initial_state: List of initial state tensors to be passed to the first
1507 call of the cell.
1509 Examples:
1511 ```python
1512 inputs = np.random.random([32, 10, 8]).astype(np.float32)
1513 simple_rnn = tf.keras.layers.SimpleRNN(4)
1515 output = simple_rnn(inputs) # The output has shape `[32, 4]`.
1517 simple_rnn = tf.keras.layers.SimpleRNN(
1518 4, return_sequences=True, return_state=True)
1520 # whole_sequence_output has shape `[32, 10, 4]`.
1521 # final_state has shape `[32, 4]`.
1522 whole_sequence_output, final_state = simple_rnn(inputs)
1523 ```
1524 """
1526 def __init__(self,
1527 units,
1528 activation='tanh',
1529 use_bias=True,
1530 kernel_initializer='glorot_uniform',
1531 recurrent_initializer='orthogonal',
1532 bias_initializer='zeros',
1533 kernel_regularizer=None,
1534 recurrent_regularizer=None,
1535 bias_regularizer=None,
1536 activity_regularizer=None,
1537 kernel_constraint=None,
1538 recurrent_constraint=None,
1539 bias_constraint=None,
1540 dropout=0.,
1541 recurrent_dropout=0.,
1542 return_sequences=False,
1543 return_state=False,
1544 go_backwards=False,
1545 stateful=False,
1546 unroll=False,
1547 **kwargs):
1548 if 'implementation' in kwargs:
1549 kwargs.pop('implementation')
1550 logging.warning('The `implementation` argument '
1551 'in `SimpleRNN` has been deprecated. '
1552 'Please remove it from your layer call.')
1553 if 'enable_caching_device' in kwargs:
1554 cell_kwargs = {'enable_caching_device':
1555 kwargs.pop('enable_caching_device')}
1556 else:
1557 cell_kwargs = {}
1558 cell = SimpleRNNCell(
1559 units,
1560 activation=activation,
1561 use_bias=use_bias,
1562 kernel_initializer=kernel_initializer,
1563 recurrent_initializer=recurrent_initializer,
1564 bias_initializer=bias_initializer,
1565 kernel_regularizer=kernel_regularizer,
1566 recurrent_regularizer=recurrent_regularizer,
1567 bias_regularizer=bias_regularizer,
1568 kernel_constraint=kernel_constraint,
1569 recurrent_constraint=recurrent_constraint,
1570 bias_constraint=bias_constraint,
1571 dropout=dropout,
1572 recurrent_dropout=recurrent_dropout,
1573 dtype=kwargs.get('dtype'),
1574 trainable=kwargs.get('trainable', True),
1575 **cell_kwargs)
1576 super(SimpleRNN, self).__init__(
1577 cell,
1578 return_sequences=return_sequences,
1579 return_state=return_state,
1580 go_backwards=go_backwards,
1581 stateful=stateful,
1582 unroll=unroll,
1583 **kwargs)
1584 self.activity_regularizer = regularizers.get(activity_regularizer)
1585 self.input_spec = [InputSpec(ndim=3)]
1587 def call(self, inputs, mask=None, training=None, initial_state=None):
1588 return super(SimpleRNN, self).call(
1589 inputs, mask=mask, training=training, initial_state=initial_state)
1591 @property
1592 def units(self):
1593 return self.cell.units
1595 @property
1596 def activation(self):
1597 return self.cell.activation
1599 @property
1600 def use_bias(self):
1601 return self.cell.use_bias
1603 @property
1604 def kernel_initializer(self):
1605 return self.cell.kernel_initializer
1607 @property
1608 def recurrent_initializer(self):
1609 return self.cell.recurrent_initializer
1611 @property
1612 def bias_initializer(self):
1613 return self.cell.bias_initializer
1615 @property
1616 def kernel_regularizer(self):
1617 return self.cell.kernel_regularizer
1619 @property
1620 def recurrent_regularizer(self):
1621 return self.cell.recurrent_regularizer
1623 @property
1624 def bias_regularizer(self):
1625 return self.cell.bias_regularizer
1627 @property
1628 def kernel_constraint(self):
1629 return self.cell.kernel_constraint
1631 @property
1632 def recurrent_constraint(self):
1633 return self.cell.recurrent_constraint
1635 @property
1636 def bias_constraint(self):
1637 return self.cell.bias_constraint
1639 @property
1640 def dropout(self):
1641 return self.cell.dropout
1643 @property
1644 def recurrent_dropout(self):
1645 return self.cell.recurrent_dropout
1647 def get_config(self):
1648 config = {
1649 'units':
1650 self.units,
1651 'activation':
1652 activations.serialize(self.activation),
1653 'use_bias':
1654 self.use_bias,
1655 'kernel_initializer':
1656 initializers.serialize(self.kernel_initializer),
1657 'recurrent_initializer':
1658 initializers.serialize(self.recurrent_initializer),
1659 'bias_initializer':
1660 initializers.serialize(self.bias_initializer),
1661 'kernel_regularizer':
1662 regularizers.serialize(self.kernel_regularizer),
1663 'recurrent_regularizer':
1664 regularizers.serialize(self.recurrent_regularizer),
1665 'bias_regularizer':
1666 regularizers.serialize(self.bias_regularizer),
1667 'activity_regularizer':
1668 regularizers.serialize(self.activity_regularizer),
1669 'kernel_constraint':
1670 constraints.serialize(self.kernel_constraint),
1671 'recurrent_constraint':
1672 constraints.serialize(self.recurrent_constraint),
1673 'bias_constraint':
1674 constraints.serialize(self.bias_constraint),
1675 'dropout':
1676 self.dropout,
1677 'recurrent_dropout':
1678 self.recurrent_dropout
1679 }
1680 base_config = super(SimpleRNN, self).get_config()
1681 config.update(_config_for_enable_caching_device(self.cell))
1682 del base_config['cell']
1683 return dict(list(base_config.items()) + list(config.items()))
1685 @classmethod
1686 def from_config(cls, config):
1687 if 'implementation' in config:
1688 config.pop('implementation')
1689 return cls(**config)
1692@keras_export(v1=['keras.layers.GRUCell'])
1693class GRUCell(DropoutRNNCellMixin, Layer):
1694 """Cell class for the GRU layer.
1696 Args:
1697 units: Positive integer, dimensionality of the output space.
1698 activation: Activation function to use.
1699 Default: hyperbolic tangent (`tanh`).
1700 If you pass None, no activation is applied
1701 (ie. "linear" activation: `a(x) = x`).
1702 recurrent_activation: Activation function to use
1703 for the recurrent step.
1704 Default: hard sigmoid (`hard_sigmoid`).
1705 If you pass `None`, no activation is applied
1706 (ie. "linear" activation: `a(x) = x`).
1707 use_bias: Boolean, whether the layer uses a bias vector.
1708 kernel_initializer: Initializer for the `kernel` weights matrix,
1709 used for the linear transformation of the inputs.
1710 recurrent_initializer: Initializer for the `recurrent_kernel`
1711 weights matrix,
1712 used for the linear transformation of the recurrent state.
1713 bias_initializer: Initializer for the bias vector.
1714 kernel_regularizer: Regularizer function applied to
1715 the `kernel` weights matrix.
1716 recurrent_regularizer: Regularizer function applied to
1717 the `recurrent_kernel` weights matrix.
1718 bias_regularizer: Regularizer function applied to the bias vector.
1719 kernel_constraint: Constraint function applied to
1720 the `kernel` weights matrix.
1721 recurrent_constraint: Constraint function applied to
1722 the `recurrent_kernel` weights matrix.
1723 bias_constraint: Constraint function applied to the bias vector.
1724 dropout: Float between 0 and 1.
1725 Fraction of the units to drop for the linear transformation of the inputs.
1726 recurrent_dropout: Float between 0 and 1.
1727 Fraction of the units to drop for
1728 the linear transformation of the recurrent state.
1729 reset_after: GRU convention (whether to apply reset gate after or
1730 before matrix multiplication). False = "before" (default),
1731 True = "after" (CuDNN compatible).
1733 Call arguments:
1734 inputs: A 2D tensor.
1735 states: List of state tensors corresponding to the previous timestep.
1736 training: Python boolean indicating whether the layer should behave in
1737 training mode or in inference mode. Only relevant when `dropout` or
1738 `recurrent_dropout` is used.
1739 """
1741 def __init__(self,
1742 units,
1743 activation='tanh',
1744 recurrent_activation='hard_sigmoid',
1745 use_bias=True,
1746 kernel_initializer='glorot_uniform',
1747 recurrent_initializer='orthogonal',
1748 bias_initializer='zeros',
1749 kernel_regularizer=None,
1750 recurrent_regularizer=None,
1751 bias_regularizer=None,
1752 kernel_constraint=None,
1753 recurrent_constraint=None,
1754 bias_constraint=None,
1755 dropout=0.,
1756 recurrent_dropout=0.,
1757 reset_after=False,
1758 **kwargs):
1759 if units < 0:
1760 raise ValueError(f'Received an invalid value for units, expected '
1761 f'a positive integer, got {units}.')
1762 # By default use cached variable under v2 mode, see b/143699808.
1763 if ops.executing_eagerly_outside_functions():
1764 self._enable_caching_device = kwargs.pop('enable_caching_device', True)
1765 else:
1766 self._enable_caching_device = kwargs.pop('enable_caching_device', False)
1767 super(GRUCell, self).__init__(**kwargs)
1768 self.units = units
1769 self.activation = activations.get(activation)
1770 self.recurrent_activation = activations.get(recurrent_activation)
1771 self.use_bias = use_bias
1773 self.kernel_initializer = initializers.get(kernel_initializer)
1774 self.recurrent_initializer = initializers.get(recurrent_initializer)
1775 self.bias_initializer = initializers.get(bias_initializer)
1777 self.kernel_regularizer = regularizers.get(kernel_regularizer)
1778 self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
1779 self.bias_regularizer = regularizers.get(bias_regularizer)
1781 self.kernel_constraint = constraints.get(kernel_constraint)
1782 self.recurrent_constraint = constraints.get(recurrent_constraint)
1783 self.bias_constraint = constraints.get(bias_constraint)
1785 self.dropout = min(1., max(0., dropout))
1786 self.recurrent_dropout = min(1., max(0., recurrent_dropout))
1788 implementation = kwargs.pop('implementation', 1)
1789 if self.recurrent_dropout != 0 and implementation != 1:
1790 logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
1791 self.implementation = 1
1792 else:
1793 self.implementation = implementation
1794 self.reset_after = reset_after
1795 self.state_size = self.units
1796 self.output_size = self.units
1798 @tf_utils.shape_type_conversion
1799 def build(self, input_shape):
1800 input_dim = input_shape[-1]
1801 default_caching_device = _caching_device(self)
1802 self.kernel = self.add_weight(
1803 shape=(input_dim, self.units * 3),
1804 name='kernel',
1805 initializer=self.kernel_initializer,
1806 regularizer=self.kernel_regularizer,
1807 constraint=self.kernel_constraint,
1808 caching_device=default_caching_device)
1809 self.recurrent_kernel = self.add_weight(
1810 shape=(self.units, self.units * 3),
1811 name='recurrent_kernel',
1812 initializer=self.recurrent_initializer,
1813 regularizer=self.recurrent_regularizer,
1814 constraint=self.recurrent_constraint,
1815 caching_device=default_caching_device)
1817 if self.use_bias:
1818 if not self.reset_after:
1819 bias_shape = (3 * self.units,)
1820 else:
1821 # separate biases for input and recurrent kernels
1822 # Note: the shape is intentionally different from CuDNNGRU biases
1823 # `(2 * 3 * self.units,)`, so that we can distinguish the classes
1824 # when loading and converting saved weights.
1825 bias_shape = (2, 3 * self.units)
1826 self.bias = self.add_weight(shape=bias_shape,
1827 name='bias',
1828 initializer=self.bias_initializer,
1829 regularizer=self.bias_regularizer,
1830 constraint=self.bias_constraint,
1831 caching_device=default_caching_device)
1832 else:
1833 self.bias = None
1834 self.built = True
1836 def call(self, inputs, states, training=None):
1837 h_tm1 = states[0] if nest.is_nested(states) else states # previous memory
1839 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
1840 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
1841 h_tm1, training, count=3)
1843 if self.use_bias:
1844 if not self.reset_after:
1845 input_bias, recurrent_bias = self.bias, None
1846 else:
1847 input_bias, recurrent_bias = array_ops_stack.unstack(self.bias)
1849 if self.implementation == 1:
1850 if 0. < self.dropout < 1.:
1851 inputs_z = inputs * dp_mask[0]
1852 inputs_r = inputs * dp_mask[1]
1853 inputs_h = inputs * dp_mask[2]
1854 else:
1855 inputs_z = inputs
1856 inputs_r = inputs
1857 inputs_h = inputs
1859 x_z = backend.dot(inputs_z, self.kernel[:, :self.units])
1860 x_r = backend.dot(inputs_r, self.kernel[:, self.units:self.units * 2])
1861 x_h = backend.dot(inputs_h, self.kernel[:, self.units * 2:])
1863 if self.use_bias:
1864 x_z = backend.bias_add(x_z, input_bias[:self.units])
1865 x_r = backend.bias_add(x_r, input_bias[self.units: self.units * 2])
1866 x_h = backend.bias_add(x_h, input_bias[self.units * 2:])
1868 if 0. < self.recurrent_dropout < 1.:
1869 h_tm1_z = h_tm1 * rec_dp_mask[0]
1870 h_tm1_r = h_tm1 * rec_dp_mask[1]
1871 h_tm1_h = h_tm1 * rec_dp_mask[2]
1872 else:
1873 h_tm1_z = h_tm1
1874 h_tm1_r = h_tm1
1875 h_tm1_h = h_tm1
1877 recurrent_z = backend.dot(h_tm1_z, self.recurrent_kernel[:, :self.units])
1878 recurrent_r = backend.dot(
1879 h_tm1_r, self.recurrent_kernel[:, self.units:self.units * 2])
1880 if self.reset_after and self.use_bias:
1881 recurrent_z = backend.bias_add(recurrent_z, recurrent_bias[:self.units])
1882 recurrent_r = backend.bias_add(
1883 recurrent_r, recurrent_bias[self.units:self.units * 2])
1885 z = self.recurrent_activation(x_z + recurrent_z)
1886 r = self.recurrent_activation(x_r + recurrent_r)
1888 # reset gate applied after/before matrix multiplication
1889 if self.reset_after:
1890 recurrent_h = backend.dot(
1891 h_tm1_h, self.recurrent_kernel[:, self.units * 2:])
1892 if self.use_bias:
1893 recurrent_h = backend.bias_add(
1894 recurrent_h, recurrent_bias[self.units * 2:])
1895 recurrent_h = r * recurrent_h
1896 else:
1897 recurrent_h = backend.dot(
1898 r * h_tm1_h, self.recurrent_kernel[:, self.units * 2:])
1900 hh = self.activation(x_h + recurrent_h)
1901 else:
1902 if 0. < self.dropout < 1.:
1903 inputs = inputs * dp_mask[0]
1905 # inputs projected by all gate matrices at once
1906 matrix_x = backend.dot(inputs, self.kernel)
1907 if self.use_bias:
1908 # biases: bias_z_i, bias_r_i, bias_h_i
1909 matrix_x = backend.bias_add(matrix_x, input_bias)
1911 x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=-1)
1913 if self.reset_after:
1914 # hidden state projected by all gate matrices at once
1915 matrix_inner = backend.dot(h_tm1, self.recurrent_kernel)
1916 if self.use_bias:
1917 matrix_inner = backend.bias_add(matrix_inner, recurrent_bias)
1918 else:
1919 # hidden state projected separately for update/reset and new
1920 matrix_inner = backend.dot(
1921 h_tm1, self.recurrent_kernel[:, :2 * self.units])
1923 recurrent_z, recurrent_r, recurrent_h = array_ops.split(
1924 matrix_inner, [self.units, self.units, -1], axis=-1)
1926 z = self.recurrent_activation(x_z + recurrent_z)
1927 r = self.recurrent_activation(x_r + recurrent_r)
1929 if self.reset_after:
1930 recurrent_h = r * recurrent_h
1931 else:
1932 recurrent_h = backend.dot(
1933 r * h_tm1, self.recurrent_kernel[:, 2 * self.units:])
1935 hh = self.activation(x_h + recurrent_h)
1936 # previous and candidate state mixed by update gate
1937 h = z * h_tm1 + (1 - z) * hh
1938 new_state = [h] if nest.is_nested(states) else h
1939 return h, new_state
1941 def get_config(self):
1942 config = {
1943 'units': self.units,
1944 'activation': activations.serialize(self.activation),
1945 'recurrent_activation':
1946 activations.serialize(self.recurrent_activation),
1947 'use_bias': self.use_bias,
1948 'kernel_initializer': initializers.serialize(self.kernel_initializer),
1949 'recurrent_initializer':
1950 initializers.serialize(self.recurrent_initializer),
1951 'bias_initializer': initializers.serialize(self.bias_initializer),
1952 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
1953 'recurrent_regularizer':
1954 regularizers.serialize(self.recurrent_regularizer),
1955 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
1956 'kernel_constraint': constraints.serialize(self.kernel_constraint),
1957 'recurrent_constraint':
1958 constraints.serialize(self.recurrent_constraint),
1959 'bias_constraint': constraints.serialize(self.bias_constraint),
1960 'dropout': self.dropout,
1961 'recurrent_dropout': self.recurrent_dropout,
1962 'implementation': self.implementation,
1963 'reset_after': self.reset_after
1964 }
1965 config.update(_config_for_enable_caching_device(self))
1966 base_config = super(GRUCell, self).get_config()
1967 return dict(list(base_config.items()) + list(config.items()))
1969 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
1970 return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
1973@keras_export(v1=['keras.layers.GRU'])
1974class GRU(RNN):
1975 """Gated Recurrent Unit - Cho et al. 2014.
1977 There are two variants. The default one is based on 1406.1078v3 and
1978 has reset gate applied to hidden state before matrix multiplication. The
1979 other one is based on original 1406.1078v1 and has the order reversed.
1981 The second variant is compatible with CuDNNGRU (GPU-only) and allows
1982 inference on CPU. Thus it has separate biases for `kernel` and
1983 `recurrent_kernel`. Use `'reset_after'=True` and
1984 `recurrent_activation='sigmoid'`.
1986 Args:
1987 units: Positive integer, dimensionality of the output space.
1988 activation: Activation function to use.
1989 Default: hyperbolic tangent (`tanh`).
1990 If you pass `None`, no activation is applied
1991 (ie. "linear" activation: `a(x) = x`).
1992 recurrent_activation: Activation function to use
1993 for the recurrent step.
1994 Default: hard sigmoid (`hard_sigmoid`).
1995 If you pass `None`, no activation is applied
1996 (ie. "linear" activation: `a(x) = x`).
1997 use_bias: Boolean, whether the layer uses a bias vector.
1998 kernel_initializer: Initializer for the `kernel` weights matrix,
1999 used for the linear transformation of the inputs.
2000 recurrent_initializer: Initializer for the `recurrent_kernel`
2001 weights matrix, used for the linear transformation of the recurrent state.
2002 bias_initializer: Initializer for the bias vector.
2003 kernel_regularizer: Regularizer function applied to
2004 the `kernel` weights matrix.
2005 recurrent_regularizer: Regularizer function applied to
2006 the `recurrent_kernel` weights matrix.
2007 bias_regularizer: Regularizer function applied to the bias vector.
2008 activity_regularizer: Regularizer function applied to
2009 the output of the layer (its "activation")..
2010 kernel_constraint: Constraint function applied to
2011 the `kernel` weights matrix.
2012 recurrent_constraint: Constraint function applied to
2013 the `recurrent_kernel` weights matrix.
2014 bias_constraint: Constraint function applied to the bias vector.
2015 dropout: Float between 0 and 1.
2016 Fraction of the units to drop for
2017 the linear transformation of the inputs.
2018 recurrent_dropout: Float between 0 and 1.
2019 Fraction of the units to drop for
2020 the linear transformation of the recurrent state.
2021 return_sequences: Boolean. Whether to return the last output
2022 in the output sequence, or the full sequence.
2023 return_state: Boolean. Whether to return the last state
2024 in addition to the output.
2025 go_backwards: Boolean (default False).
2026 If True, process the input sequence backwards and return the
2027 reversed sequence.
2028 stateful: Boolean (default False). If True, the last state
2029 for each sample at index i in a batch will be used as initial
2030 state for the sample of index i in the following batch.
2031 unroll: Boolean (default False).
2032 If True, the network will be unrolled,
2033 else a symbolic loop will be used.
2034 Unrolling can speed-up a RNN,
2035 although it tends to be more memory-intensive.
2036 Unrolling is only suitable for short sequences.
2037 time_major: The shape format of the `inputs` and `outputs` tensors.
2038 If True, the inputs and outputs will be in shape
2039 `(timesteps, batch, ...)`, whereas in the False case, it will be
2040 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
2041 efficient because it avoids transposes at the beginning and end of the
2042 RNN calculation. However, most TensorFlow data is batch-major, so by
2043 default this function accepts input and emits output in batch-major
2044 form.
2045 reset_after: GRU convention (whether to apply reset gate after or
2046 before matrix multiplication). False = "before" (default),
2047 True = "after" (CuDNN compatible).
2049 Call arguments:
2050 inputs: A 3D tensor.
2051 mask: Binary tensor of shape `(samples, timesteps)` indicating whether
2052 a given timestep should be masked. An individual `True` entry indicates
2053 that the corresponding timestep should be utilized, while a `False`
2054 entry indicates that the corresponding timestep should be ignored.
2055 training: Python boolean indicating whether the layer should behave in
2056 training mode or in inference mode. This argument is passed to the cell
2057 when calling it. This is only relevant if `dropout` or
2058 `recurrent_dropout` is used.
2059 initial_state: List of initial state tensors to be passed to the first
2060 call of the cell.
2061 """
2063 def __init__(self,
2064 units,
2065 activation='tanh',
2066 recurrent_activation='hard_sigmoid',
2067 use_bias=True,
2068 kernel_initializer='glorot_uniform',
2069 recurrent_initializer='orthogonal',
2070 bias_initializer='zeros',
2071 kernel_regularizer=None,
2072 recurrent_regularizer=None,
2073 bias_regularizer=None,
2074 activity_regularizer=None,
2075 kernel_constraint=None,
2076 recurrent_constraint=None,
2077 bias_constraint=None,
2078 dropout=0.,
2079 recurrent_dropout=0.,
2080 return_sequences=False,
2081 return_state=False,
2082 go_backwards=False,
2083 stateful=False,
2084 unroll=False,
2085 reset_after=False,
2086 **kwargs):
2087 implementation = kwargs.pop('implementation', 1)
2088 if implementation == 0:
2089 logging.warning('`implementation=0` has been deprecated, '
2090 'and now defaults to `implementation=1`.'
2091 'Please update your layer call.')
2092 if 'enable_caching_device' in kwargs:
2093 cell_kwargs = {'enable_caching_device':
2094 kwargs.pop('enable_caching_device')}
2095 else:
2096 cell_kwargs = {}
2097 cell = GRUCell(
2098 units,
2099 activation=activation,
2100 recurrent_activation=recurrent_activation,
2101 use_bias=use_bias,
2102 kernel_initializer=kernel_initializer,
2103 recurrent_initializer=recurrent_initializer,
2104 bias_initializer=bias_initializer,
2105 kernel_regularizer=kernel_regularizer,
2106 recurrent_regularizer=recurrent_regularizer,
2107 bias_regularizer=bias_regularizer,
2108 kernel_constraint=kernel_constraint,
2109 recurrent_constraint=recurrent_constraint,
2110 bias_constraint=bias_constraint,
2111 dropout=dropout,
2112 recurrent_dropout=recurrent_dropout,
2113 implementation=implementation,
2114 reset_after=reset_after,
2115 dtype=kwargs.get('dtype'),
2116 trainable=kwargs.get('trainable', True),
2117 **cell_kwargs)
2118 super(GRU, self).__init__(
2119 cell,
2120 return_sequences=return_sequences,
2121 return_state=return_state,
2122 go_backwards=go_backwards,
2123 stateful=stateful,
2124 unroll=unroll,
2125 **kwargs)
2126 self.activity_regularizer = regularizers.get(activity_regularizer)
2127 self.input_spec = [InputSpec(ndim=3)]
2129 def call(self, inputs, mask=None, training=None, initial_state=None):
2130 return super(GRU, self).call(
2131 inputs, mask=mask, training=training, initial_state=initial_state)
2133 @property
2134 def units(self):
2135 return self.cell.units
2137 @property
2138 def activation(self):
2139 return self.cell.activation
2141 @property
2142 def recurrent_activation(self):
2143 return self.cell.recurrent_activation
2145 @property
2146 def use_bias(self):
2147 return self.cell.use_bias
2149 @property
2150 def kernel_initializer(self):
2151 return self.cell.kernel_initializer
2153 @property
2154 def recurrent_initializer(self):
2155 return self.cell.recurrent_initializer
2157 @property
2158 def bias_initializer(self):
2159 return self.cell.bias_initializer
2161 @property
2162 def kernel_regularizer(self):
2163 return self.cell.kernel_regularizer
2165 @property
2166 def recurrent_regularizer(self):
2167 return self.cell.recurrent_regularizer
2169 @property
2170 def bias_regularizer(self):
2171 return self.cell.bias_regularizer
2173 @property
2174 def kernel_constraint(self):
2175 return self.cell.kernel_constraint
2177 @property
2178 def recurrent_constraint(self):
2179 return self.cell.recurrent_constraint
2181 @property
2182 def bias_constraint(self):
2183 return self.cell.bias_constraint
2185 @property
2186 def dropout(self):
2187 return self.cell.dropout
2189 @property
2190 def recurrent_dropout(self):
2191 return self.cell.recurrent_dropout
2193 @property
2194 def implementation(self):
2195 return self.cell.implementation
2197 @property
2198 def reset_after(self):
2199 return self.cell.reset_after
2201 def get_config(self):
2202 config = {
2203 'units':
2204 self.units,
2205 'activation':
2206 activations.serialize(self.activation),
2207 'recurrent_activation':
2208 activations.serialize(self.recurrent_activation),
2209 'use_bias':
2210 self.use_bias,
2211 'kernel_initializer':
2212 initializers.serialize(self.kernel_initializer),
2213 'recurrent_initializer':
2214 initializers.serialize(self.recurrent_initializer),
2215 'bias_initializer':
2216 initializers.serialize(self.bias_initializer),
2217 'kernel_regularizer':
2218 regularizers.serialize(self.kernel_regularizer),
2219 'recurrent_regularizer':
2220 regularizers.serialize(self.recurrent_regularizer),
2221 'bias_regularizer':
2222 regularizers.serialize(self.bias_regularizer),
2223 'activity_regularizer':
2224 regularizers.serialize(self.activity_regularizer),
2225 'kernel_constraint':
2226 constraints.serialize(self.kernel_constraint),
2227 'recurrent_constraint':
2228 constraints.serialize(self.recurrent_constraint),
2229 'bias_constraint':
2230 constraints.serialize(self.bias_constraint),
2231 'dropout':
2232 self.dropout,
2233 'recurrent_dropout':
2234 self.recurrent_dropout,
2235 'implementation':
2236 self.implementation,
2237 'reset_after':
2238 self.reset_after
2239 }
2240 config.update(_config_for_enable_caching_device(self.cell))
2241 base_config = super(GRU, self).get_config()
2242 del base_config['cell']
2243 return dict(list(base_config.items()) + list(config.items()))
2245 @classmethod
2246 def from_config(cls, config):
2247 if 'implementation' in config and config['implementation'] == 0:
2248 config['implementation'] = 1
2249 return cls(**config)
2252@keras_export(v1=['keras.layers.LSTMCell'])
2253class LSTMCell(DropoutRNNCellMixin, Layer):
2254 """Cell class for the LSTM layer.
2256 Args:
2257 units: Positive integer, dimensionality of the output space.
2258 activation: Activation function to use.
2259 Default: hyperbolic tangent (`tanh`).
2260 If you pass `None`, no activation is applied
2261 (ie. "linear" activation: `a(x) = x`).
2262 recurrent_activation: Activation function to use
2263 for the recurrent step.
2264 Default: hard sigmoid (`hard_sigmoid`).
2265 If you pass `None`, no activation is applied
2266 (ie. "linear" activation: `a(x) = x`).
2267 use_bias: Boolean, whether the layer uses a bias vector.
2268 kernel_initializer: Initializer for the `kernel` weights matrix,
2269 used for the linear transformation of the inputs.
2270 recurrent_initializer: Initializer for the `recurrent_kernel`
2271 weights matrix,
2272 used for the linear transformation of the recurrent state.
2273 bias_initializer: Initializer for the bias vector.
2274 unit_forget_bias: Boolean.
2275 If True, add 1 to the bias of the forget gate at initialization.
2276 Setting it to true will also force `bias_initializer="zeros"`.
2277 This is recommended in [Jozefowicz et al., 2015](
2278 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
2279 kernel_regularizer: Regularizer function applied to
2280 the `kernel` weights matrix.
2281 recurrent_regularizer: Regularizer function applied to
2282 the `recurrent_kernel` weights matrix.
2283 bias_regularizer: Regularizer function applied to the bias vector.
2284 kernel_constraint: Constraint function applied to
2285 the `kernel` weights matrix.
2286 recurrent_constraint: Constraint function applied to
2287 the `recurrent_kernel` weights matrix.
2288 bias_constraint: Constraint function applied to the bias vector.
2289 dropout: Float between 0 and 1.
2290 Fraction of the units to drop for
2291 the linear transformation of the inputs.
2292 recurrent_dropout: Float between 0 and 1.
2293 Fraction of the units to drop for
2294 the linear transformation of the recurrent state.
2296 Call arguments:
2297 inputs: A 2D tensor.
2298 states: List of state tensors corresponding to the previous timestep.
2299 training: Python boolean indicating whether the layer should behave in
2300 training mode or in inference mode. Only relevant when `dropout` or
2301 `recurrent_dropout` is used.
2302 """
2304 def __init__(self,
2305 units,
2306 activation='tanh',
2307 recurrent_activation='hard_sigmoid',
2308 use_bias=True,
2309 kernel_initializer='glorot_uniform',
2310 recurrent_initializer='orthogonal',
2311 bias_initializer='zeros',
2312 unit_forget_bias=True,
2313 kernel_regularizer=None,
2314 recurrent_regularizer=None,
2315 bias_regularizer=None,
2316 kernel_constraint=None,
2317 recurrent_constraint=None,
2318 bias_constraint=None,
2319 dropout=0.,
2320 recurrent_dropout=0.,
2321 **kwargs):
2322 if units < 0:
2323 raise ValueError(f'Received an invalid value for units, expected '
2324 f'a positive integer, got {units}.')
2325 # By default use cached variable under v2 mode, see b/143699808.
2326 if ops.executing_eagerly_outside_functions():
2327 self._enable_caching_device = kwargs.pop('enable_caching_device', True)
2328 else:
2329 self._enable_caching_device = kwargs.pop('enable_caching_device', False)
2330 super(LSTMCell, self).__init__(**kwargs)
2331 self.units = units
2332 self.activation = activations.get(activation)
2333 self.recurrent_activation = activations.get(recurrent_activation)
2334 self.use_bias = use_bias
2336 self.kernel_initializer = initializers.get(kernel_initializer)
2337 self.recurrent_initializer = initializers.get(recurrent_initializer)
2338 self.bias_initializer = initializers.get(bias_initializer)
2339 self.unit_forget_bias = unit_forget_bias
2341 self.kernel_regularizer = regularizers.get(kernel_regularizer)
2342 self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
2343 self.bias_regularizer = regularizers.get(bias_regularizer)
2345 self.kernel_constraint = constraints.get(kernel_constraint)
2346 self.recurrent_constraint = constraints.get(recurrent_constraint)
2347 self.bias_constraint = constraints.get(bias_constraint)
2349 self.dropout = min(1., max(0., dropout))
2350 self.recurrent_dropout = min(1., max(0., recurrent_dropout))
2351 implementation = kwargs.pop('implementation', 1)
2352 if self.recurrent_dropout != 0 and implementation != 1:
2353 logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
2354 self.implementation = 1
2355 else:
2356 self.implementation = implementation
2357 self.state_size = [self.units, self.units]
2358 self.output_size = self.units
2360 @tf_utils.shape_type_conversion
2361 def build(self, input_shape):
2362 default_caching_device = _caching_device(self)
2363 input_dim = input_shape[-1]
2364 self.kernel = self.add_weight(
2365 shape=(input_dim, self.units * 4),
2366 name='kernel',
2367 initializer=self.kernel_initializer,
2368 regularizer=self.kernel_regularizer,
2369 constraint=self.kernel_constraint,
2370 caching_device=default_caching_device)
2371 self.recurrent_kernel = self.add_weight(
2372 shape=(self.units, self.units * 4),
2373 name='recurrent_kernel',
2374 initializer=self.recurrent_initializer,
2375 regularizer=self.recurrent_regularizer,
2376 constraint=self.recurrent_constraint,
2377 caching_device=default_caching_device)
2379 if self.use_bias:
2380 if self.unit_forget_bias:
2382 def bias_initializer(_, *args, **kwargs):
2383 return backend.concatenate([
2384 self.bias_initializer((self.units,), *args, **kwargs),
2385 initializers.get('ones')((self.units,), *args, **kwargs),
2386 self.bias_initializer((self.units * 2,), *args, **kwargs),
2387 ])
2388 else:
2389 bias_initializer = self.bias_initializer
2390 self.bias = self.add_weight(
2391 shape=(self.units * 4,),
2392 name='bias',
2393 initializer=bias_initializer,
2394 regularizer=self.bias_regularizer,
2395 constraint=self.bias_constraint,
2396 caching_device=default_caching_device)
2397 else:
2398 self.bias = None
2399 self.built = True
2401 def _compute_carry_and_output(self, x, h_tm1, c_tm1):
2402 """Computes carry and output using split kernels."""
2403 x_i, x_f, x_c, x_o = x
2404 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
2405 i = self.recurrent_activation(
2406 x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]))
2407 f = self.recurrent_activation(x_f + backend.dot(
2408 h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]))
2409 c = f * c_tm1 + i * self.activation(x_c + backend.dot(
2410 h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
2411 o = self.recurrent_activation(
2412 x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]))
2413 return c, o
2415 def _compute_carry_and_output_fused(self, z, c_tm1):
2416 """Computes carry and output using fused kernels."""
2417 z0, z1, z2, z3 = z
2418 i = self.recurrent_activation(z0)
2419 f = self.recurrent_activation(z1)
2420 c = f * c_tm1 + i * self.activation(z2)
2421 o = self.recurrent_activation(z3)
2422 return c, o
2424 def call(self, inputs, states, training=None):
2425 h_tm1 = states[0] # previous memory state
2426 c_tm1 = states[1] # previous carry state
2428 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
2429 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
2430 h_tm1, training, count=4)
2432 if self.implementation == 1:
2433 if 0 < self.dropout < 1.:
2434 inputs_i = inputs * dp_mask[0]
2435 inputs_f = inputs * dp_mask[1]
2436 inputs_c = inputs * dp_mask[2]
2437 inputs_o = inputs * dp_mask[3]
2438 else:
2439 inputs_i = inputs
2440 inputs_f = inputs
2441 inputs_c = inputs
2442 inputs_o = inputs
2443 k_i, k_f, k_c, k_o = array_ops.split(
2444 self.kernel, num_or_size_splits=4, axis=1)
2445 x_i = backend.dot(inputs_i, k_i)
2446 x_f = backend.dot(inputs_f, k_f)
2447 x_c = backend.dot(inputs_c, k_c)
2448 x_o = backend.dot(inputs_o, k_o)
2449 if self.use_bias:
2450 b_i, b_f, b_c, b_o = array_ops.split(
2451 self.bias, num_or_size_splits=4, axis=0)
2452 x_i = backend.bias_add(x_i, b_i)
2453 x_f = backend.bias_add(x_f, b_f)
2454 x_c = backend.bias_add(x_c, b_c)
2455 x_o = backend.bias_add(x_o, b_o)
2457 if 0 < self.recurrent_dropout < 1.:
2458 h_tm1_i = h_tm1 * rec_dp_mask[0]
2459 h_tm1_f = h_tm1 * rec_dp_mask[1]
2460 h_tm1_c = h_tm1 * rec_dp_mask[2]
2461 h_tm1_o = h_tm1 * rec_dp_mask[3]
2462 else:
2463 h_tm1_i = h_tm1
2464 h_tm1_f = h_tm1
2465 h_tm1_c = h_tm1
2466 h_tm1_o = h_tm1
2467 x = (x_i, x_f, x_c, x_o)
2468 h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
2469 c, o = self._compute_carry_and_output(x, h_tm1, c_tm1)
2470 else:
2471 if 0. < self.dropout < 1.:
2472 inputs = inputs * dp_mask[0]
2473 z = backend.dot(inputs, self.kernel)
2474 z += backend.dot(h_tm1, self.recurrent_kernel)
2475 if self.use_bias:
2476 z = backend.bias_add(z, self.bias)
2478 z = array_ops.split(z, num_or_size_splits=4, axis=1)
2479 c, o = self._compute_carry_and_output_fused(z, c_tm1)
2481 h = o * self.activation(c)
2482 return h, [h, c]
2484 def get_config(self):
2485 config = {
2486 'units':
2487 self.units,
2488 'activation':
2489 activations.serialize(self.activation),
2490 'recurrent_activation':
2491 activations.serialize(self.recurrent_activation),
2492 'use_bias':
2493 self.use_bias,
2494 'kernel_initializer':
2495 initializers.serialize(self.kernel_initializer),
2496 'recurrent_initializer':
2497 initializers.serialize(self.recurrent_initializer),
2498 'bias_initializer':
2499 initializers.serialize(self.bias_initializer),
2500 'unit_forget_bias':
2501 self.unit_forget_bias,
2502 'kernel_regularizer':
2503 regularizers.serialize(self.kernel_regularizer),
2504 'recurrent_regularizer':
2505 regularizers.serialize(self.recurrent_regularizer),
2506 'bias_regularizer':
2507 regularizers.serialize(self.bias_regularizer),
2508 'kernel_constraint':
2509 constraints.serialize(self.kernel_constraint),
2510 'recurrent_constraint':
2511 constraints.serialize(self.recurrent_constraint),
2512 'bias_constraint':
2513 constraints.serialize(self.bias_constraint),
2514 'dropout':
2515 self.dropout,
2516 'recurrent_dropout':
2517 self.recurrent_dropout,
2518 'implementation':
2519 self.implementation
2520 }
2521 config.update(_config_for_enable_caching_device(self))
2522 base_config = super(LSTMCell, self).get_config()
2523 return dict(list(base_config.items()) + list(config.items()))
2525 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
2526 return list(_generate_zero_filled_state_for_cell(
2527 self, inputs, batch_size, dtype))
2530@keras_export('keras.experimental.PeepholeLSTMCell')
2531class PeepholeLSTMCell(LSTMCell):
2532 """Equivalent to LSTMCell class but adds peephole connections.
2534 Peephole connections allow the gates to utilize the previous internal state as
2535 well as the previous hidden state (which is what LSTMCell is limited to).
2536 This allows PeepholeLSTMCell to better learn precise timings over LSTMCell.
2538 From [Gers et al., 2002](
2539 http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf):
2541 "We find that LSTM augmented by 'peephole connections' from its internal
2542 cells to its multiplicative gates can learn the fine distinction between
2543 sequences of spikes spaced either 50 or 49 time steps apart without the help
2544 of any short training exemplars."
2546 The peephole implementation is based on:
2548 [Sak et al., 2014](https://research.google.com/pubs/archive/43905.pdf)
2550 Example:
2552 ```python
2553 # Create 2 PeepholeLSTMCells
2554 peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]]
2555 # Create a layer composed sequentially of the peephole LSTM cells.
2556 layer = RNN(peephole_lstm_cells)
2557 input = keras.Input((timesteps, input_dim))
2558 output = layer(input)
2559 ```
2560 """
2562 def __init__(self,
2563 units,
2564 activation='tanh',
2565 recurrent_activation='hard_sigmoid',
2566 use_bias=True,
2567 kernel_initializer='glorot_uniform',
2568 recurrent_initializer='orthogonal',
2569 bias_initializer='zeros',
2570 unit_forget_bias=True,
2571 kernel_regularizer=None,
2572 recurrent_regularizer=None,
2573 bias_regularizer=None,
2574 kernel_constraint=None,
2575 recurrent_constraint=None,
2576 bias_constraint=None,
2577 dropout=0.,
2578 recurrent_dropout=0.,
2579 **kwargs):
2580 warnings.warn('`tf.keras.experimental.PeepholeLSTMCell` is deprecated '
2581 'and will be removed in a future version. '
2582 'Please use tensorflow_addons.rnn.PeepholeLSTMCell '
2583 'instead.')
2584 super(PeepholeLSTMCell, self).__init__(
2585 units=units,
2586 activation=activation,
2587 recurrent_activation=recurrent_activation,
2588 use_bias=use_bias,
2589 kernel_initializer=kernel_initializer,
2590 recurrent_initializer=recurrent_initializer,
2591 bias_initializer=bias_initializer,
2592 unit_forget_bias=unit_forget_bias,
2593 kernel_regularizer=kernel_regularizer,
2594 recurrent_regularizer=recurrent_regularizer,
2595 bias_regularizer=bias_regularizer,
2596 kernel_constraint=kernel_constraint,
2597 recurrent_constraint=recurrent_constraint,
2598 bias_constraint=bias_constraint,
2599 dropout=dropout,
2600 recurrent_dropout=recurrent_dropout,
2601 implementation=kwargs.pop('implementation', 1),
2602 **kwargs)
2604 def build(self, input_shape):
2605 super(PeepholeLSTMCell, self).build(input_shape)
2606 # The following are the weight matrices for the peephole connections. These
2607 # are multiplied with the previous internal state during the computation of
2608 # carry and output.
2609 self.input_gate_peephole_weights = self.add_weight(
2610 shape=(self.units,),
2611 name='input_gate_peephole_weights',
2612 initializer=self.kernel_initializer)
2613 self.forget_gate_peephole_weights = self.add_weight(
2614 shape=(self.units,),
2615 name='forget_gate_peephole_weights',
2616 initializer=self.kernel_initializer)
2617 self.output_gate_peephole_weights = self.add_weight(
2618 shape=(self.units,),
2619 name='output_gate_peephole_weights',
2620 initializer=self.kernel_initializer)
2622 def _compute_carry_and_output(self, x, h_tm1, c_tm1):
2623 x_i, x_f, x_c, x_o = x
2624 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
2625 i = self.recurrent_activation(
2626 x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) +
2627 self.input_gate_peephole_weights * c_tm1)
2628 f = self.recurrent_activation(x_f + backend.dot(
2629 h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) +
2630 self.forget_gate_peephole_weights * c_tm1)
2631 c = f * c_tm1 + i * self.activation(x_c + backend.dot(
2632 h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
2633 o = self.recurrent_activation(
2634 x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) +
2635 self.output_gate_peephole_weights * c)
2636 return c, o
2638 def _compute_carry_and_output_fused(self, z, c_tm1):
2639 z0, z1, z2, z3 = z
2640 i = self.recurrent_activation(z0 +
2641 self.input_gate_peephole_weights * c_tm1)
2642 f = self.recurrent_activation(z1 +
2643 self.forget_gate_peephole_weights * c_tm1)
2644 c = f * c_tm1 + i * self.activation(z2)
2645 o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c)
2646 return c, o
2649@keras_export(v1=['keras.layers.LSTM'])
2650class LSTM(RNN):
2651 """Long Short-Term Memory layer - Hochreiter 1997.
2653 Note that this cell is not optimized for performance on GPU. Please use
2654 `tf.compat.v1.keras.layers.CuDNNLSTM` for better performance on GPU.
2656 Args:
2657 units: Positive integer, dimensionality of the output space.
2658 activation: Activation function to use.
2659 Default: hyperbolic tangent (`tanh`).
2660 If you pass `None`, no activation is applied
2661 (ie. "linear" activation: `a(x) = x`).
2662 recurrent_activation: Activation function to use
2663 for the recurrent step.
2664 Default: hard sigmoid (`hard_sigmoid`).
2665 If you pass `None`, no activation is applied
2666 (ie. "linear" activation: `a(x) = x`).
2667 use_bias: Boolean, whether the layer uses a bias vector.
2668 kernel_initializer: Initializer for the `kernel` weights matrix,
2669 used for the linear transformation of the inputs..
2670 recurrent_initializer: Initializer for the `recurrent_kernel`
2671 weights matrix,
2672 used for the linear transformation of the recurrent state.
2673 bias_initializer: Initializer for the bias vector.
2674 unit_forget_bias: Boolean.
2675 If True, add 1 to the bias of the forget gate at initialization.
2676 Setting it to true will also force `bias_initializer="zeros"`.
2677 This is recommended in [Jozefowicz et al., 2015](
2678 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
2679 kernel_regularizer: Regularizer function applied to
2680 the `kernel` weights matrix.
2681 recurrent_regularizer: Regularizer function applied to
2682 the `recurrent_kernel` weights matrix.
2683 bias_regularizer: Regularizer function applied to the bias vector.
2684 activity_regularizer: Regularizer function applied to
2685 the output of the layer (its "activation").
2686 kernel_constraint: Constraint function applied to
2687 the `kernel` weights matrix.
2688 recurrent_constraint: Constraint function applied to
2689 the `recurrent_kernel` weights matrix.
2690 bias_constraint: Constraint function applied to the bias vector.
2691 dropout: Float between 0 and 1.
2692 Fraction of the units to drop for
2693 the linear transformation of the inputs.
2694 recurrent_dropout: Float between 0 and 1.
2695 Fraction of the units to drop for
2696 the linear transformation of the recurrent state.
2697 return_sequences: Boolean. Whether to return the last output.
2698 in the output sequence, or the full sequence.
2699 return_state: Boolean. Whether to return the last state
2700 in addition to the output.
2701 go_backwards: Boolean (default False).
2702 If True, process the input sequence backwards and return the
2703 reversed sequence.
2704 stateful: Boolean (default False). If True, the last state
2705 for each sample at index i in a batch will be used as initial
2706 state for the sample of index i in the following batch.
2707 unroll: Boolean (default False).
2708 If True, the network will be unrolled,
2709 else a symbolic loop will be used.
2710 Unrolling can speed-up a RNN,
2711 although it tends to be more memory-intensive.
2712 Unrolling is only suitable for short sequences.
2713 time_major: The shape format of the `inputs` and `outputs` tensors.
2714 If True, the inputs and outputs will be in shape
2715 `(timesteps, batch, ...)`, whereas in the False case, it will be
2716 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
2717 efficient because it avoids transposes at the beginning and end of the
2718 RNN calculation. However, most TensorFlow data is batch-major, so by
2719 default this function accepts input and emits output in batch-major
2720 form.
2722 Call arguments:
2723 inputs: A 3D tensor.
2724 mask: Binary tensor of shape `(samples, timesteps)` indicating whether
2725 a given timestep should be masked. An individual `True` entry indicates
2726 that the corresponding timestep should be utilized, while a `False`
2727 entry indicates that the corresponding timestep should be ignored.
2728 training: Python boolean indicating whether the layer should behave in
2729 training mode or in inference mode. This argument is passed to the cell
2730 when calling it. This is only relevant if `dropout` or
2731 `recurrent_dropout` is used.
2732 initial_state: List of initial state tensors to be passed to the first
2733 call of the cell.
2734 """
2736 def __init__(self,
2737 units,
2738 activation='tanh',
2739 recurrent_activation='hard_sigmoid',
2740 use_bias=True,
2741 kernel_initializer='glorot_uniform',
2742 recurrent_initializer='orthogonal',
2743 bias_initializer='zeros',
2744 unit_forget_bias=True,
2745 kernel_regularizer=None,
2746 recurrent_regularizer=None,
2747 bias_regularizer=None,
2748 activity_regularizer=None,
2749 kernel_constraint=None,
2750 recurrent_constraint=None,
2751 bias_constraint=None,
2752 dropout=0.,
2753 recurrent_dropout=0.,
2754 return_sequences=False,
2755 return_state=False,
2756 go_backwards=False,
2757 stateful=False,
2758 unroll=False,
2759 **kwargs):
2760 implementation = kwargs.pop('implementation', 1)
2761 if implementation == 0:
2762 logging.warning('`implementation=0` has been deprecated, '
2763 'and now defaults to `implementation=1`.'
2764 'Please update your layer call.')
2765 if 'enable_caching_device' in kwargs:
2766 cell_kwargs = {'enable_caching_device':
2767 kwargs.pop('enable_caching_device')}
2768 else:
2769 cell_kwargs = {}
2770 cell = LSTMCell(
2771 units,
2772 activation=activation,
2773 recurrent_activation=recurrent_activation,
2774 use_bias=use_bias,
2775 kernel_initializer=kernel_initializer,
2776 recurrent_initializer=recurrent_initializer,
2777 unit_forget_bias=unit_forget_bias,
2778 bias_initializer=bias_initializer,
2779 kernel_regularizer=kernel_regularizer,
2780 recurrent_regularizer=recurrent_regularizer,
2781 bias_regularizer=bias_regularizer,
2782 kernel_constraint=kernel_constraint,
2783 recurrent_constraint=recurrent_constraint,
2784 bias_constraint=bias_constraint,
2785 dropout=dropout,
2786 recurrent_dropout=recurrent_dropout,
2787 implementation=implementation,
2788 dtype=kwargs.get('dtype'),
2789 trainable=kwargs.get('trainable', True),
2790 **cell_kwargs)
2791 super(LSTM, self).__init__(
2792 cell,
2793 return_sequences=return_sequences,
2794 return_state=return_state,
2795 go_backwards=go_backwards,
2796 stateful=stateful,
2797 unroll=unroll,
2798 **kwargs)
2799 self.activity_regularizer = regularizers.get(activity_regularizer)
2800 self.input_spec = [InputSpec(ndim=3)]
2802 def call(self, inputs, mask=None, training=None, initial_state=None):
2803 return super(LSTM, self).call(
2804 inputs, mask=mask, training=training, initial_state=initial_state)
2806 @property
2807 def units(self):
2808 return self.cell.units
2810 @property
2811 def activation(self):
2812 return self.cell.activation
2814 @property
2815 def recurrent_activation(self):
2816 return self.cell.recurrent_activation
2818 @property
2819 def use_bias(self):
2820 return self.cell.use_bias
2822 @property
2823 def kernel_initializer(self):
2824 return self.cell.kernel_initializer
2826 @property
2827 def recurrent_initializer(self):
2828 return self.cell.recurrent_initializer
2830 @property
2831 def bias_initializer(self):
2832 return self.cell.bias_initializer
2834 @property
2835 def unit_forget_bias(self):
2836 return self.cell.unit_forget_bias
2838 @property
2839 def kernel_regularizer(self):
2840 return self.cell.kernel_regularizer
2842 @property
2843 def recurrent_regularizer(self):
2844 return self.cell.recurrent_regularizer
2846 @property
2847 def bias_regularizer(self):
2848 return self.cell.bias_regularizer
2850 @property
2851 def kernel_constraint(self):
2852 return self.cell.kernel_constraint
2854 @property
2855 def recurrent_constraint(self):
2856 return self.cell.recurrent_constraint
2858 @property
2859 def bias_constraint(self):
2860 return self.cell.bias_constraint
2862 @property
2863 def dropout(self):
2864 return self.cell.dropout
2866 @property
2867 def recurrent_dropout(self):
2868 return self.cell.recurrent_dropout
2870 @property
2871 def implementation(self):
2872 return self.cell.implementation
2874 def get_config(self):
2875 config = {
2876 'units':
2877 self.units,
2878 'activation':
2879 activations.serialize(self.activation),
2880 'recurrent_activation':
2881 activations.serialize(self.recurrent_activation),
2882 'use_bias':
2883 self.use_bias,
2884 'kernel_initializer':
2885 initializers.serialize(self.kernel_initializer),
2886 'recurrent_initializer':
2887 initializers.serialize(self.recurrent_initializer),
2888 'bias_initializer':
2889 initializers.serialize(self.bias_initializer),
2890 'unit_forget_bias':
2891 self.unit_forget_bias,
2892 'kernel_regularizer':
2893 regularizers.serialize(self.kernel_regularizer),
2894 'recurrent_regularizer':
2895 regularizers.serialize(self.recurrent_regularizer),
2896 'bias_regularizer':
2897 regularizers.serialize(self.bias_regularizer),
2898 'activity_regularizer':
2899 regularizers.serialize(self.activity_regularizer),
2900 'kernel_constraint':
2901 constraints.serialize(self.kernel_constraint),
2902 'recurrent_constraint':
2903 constraints.serialize(self.recurrent_constraint),
2904 'bias_constraint':
2905 constraints.serialize(self.bias_constraint),
2906 'dropout':
2907 self.dropout,
2908 'recurrent_dropout':
2909 self.recurrent_dropout,
2910 'implementation':
2911 self.implementation
2912 }
2913 config.update(_config_for_enable_caching_device(self.cell))
2914 base_config = super(LSTM, self).get_config()
2915 del base_config['cell']
2916 return dict(list(base_config.items()) + list(config.items()))
2918 @classmethod
2919 def from_config(cls, config):
2920 if 'implementation' in config and config['implementation'] == 0:
2921 config['implementation'] = 1
2922 return cls(**config)
2925def _generate_dropout_mask(ones, rate, training=None, count=1):
2926 def dropped_inputs():
2927 return backend.dropout(ones, rate)
2929 if count > 1:
2930 return [
2931 backend.in_train_phase(dropped_inputs, ones, training=training)
2932 for _ in range(count)
2933 ]
2934 return backend.in_train_phase(dropped_inputs, ones, training=training)
2937def _standardize_args(inputs, initial_state, constants, num_constants):
2938 """Standardizes `__call__` to a single list of tensor inputs.
2940 When running a model loaded from a file, the input tensors
2941 `initial_state` and `constants` can be passed to `RNN.__call__()` as part
2942 of `inputs` instead of by the dedicated keyword arguments. This method
2943 makes sure the arguments are separated and that `initial_state` and
2944 `constants` are lists of tensors (or None).
2946 Args:
2947 inputs: Tensor or list/tuple of tensors. which may include constants
2948 and initial states. In that case `num_constant` must be specified.
2949 initial_state: Tensor or list of tensors or None, initial states.
2950 constants: Tensor or list of tensors or None, constant tensors.
2951 num_constants: Expected number of constants (if constants are passed as
2952 part of the `inputs` list.
2954 Returns:
2955 inputs: Single tensor or tuple of tensors.
2956 initial_state: List of tensors or None.
2957 constants: List of tensors or None.
2958 """
2959 if isinstance(inputs, list):
2960 # There are several situations here:
2961 # In the graph mode, __call__ will be only called once. The initial_state
2962 # and constants could be in inputs (from file loading).
2963 # In the eager mode, __call__ will be called twice, once during
2964 # rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be
2965 # model.fit/train_on_batch/predict with real np data. In the second case,
2966 # the inputs will contain initial_state and constants as eager tensor.
2967 #
2968 # For either case, the real input is the first item in the list, which
2969 # could be a nested structure itself. Then followed by initial_states, which
2970 # could be a list of items, or list of list if the initial_state is complex
2971 # structure, and finally followed by constants which is a flat list.
2972 assert initial_state is None and constants is None
2973 if num_constants:
2974 constants = inputs[-num_constants:]
2975 inputs = inputs[:-num_constants]
2976 if len(inputs) > 1:
2977 initial_state = inputs[1:]
2978 inputs = inputs[:1]
2980 if len(inputs) > 1:
2981 inputs = tuple(inputs)
2982 else:
2983 inputs = inputs[0]
2985 def to_list_or_none(x):
2986 if x is None or isinstance(x, list):
2987 return x
2988 if isinstance(x, tuple):
2989 return list(x)
2990 return [x]
2992 initial_state = to_list_or_none(initial_state)
2993 constants = to_list_or_none(constants)
2995 return inputs, initial_state, constants
2998def _is_multiple_state(state_size):
2999 """Check whether the state_size contains multiple states."""
3000 return (hasattr(state_size, '__len__') and
3001 not isinstance(state_size, tensor_shape.TensorShape))
3004def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
3005 if inputs is not None:
3006 batch_size = array_ops.shape(inputs)[0]
3007 dtype = inputs.dtype
3008 return _generate_zero_filled_state(batch_size, cell.state_size, dtype)
3011def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
3012 """Generate a zero filled tensor with shape [batch_size, state_size]."""
3013 if batch_size_tensor is None or dtype is None:
3014 raise ValueError(
3015 'batch_size and dtype cannot be None while constructing initial state: '
3016 'batch_size={}, dtype={}'.format(batch_size_tensor, dtype))
3018 def create_zeros(unnested_state_size):
3019 flat_dims = tensor_shape.TensorShape(unnested_state_size).as_list()
3020 init_state_size = [batch_size_tensor] + flat_dims
3021 return array_ops.zeros(init_state_size, dtype=dtype)
3023 if nest.is_nested(state_size):
3024 return nest.map_structure(create_zeros, state_size)
3025 else:
3026 return create_zeros(state_size)
3029def _caching_device(rnn_cell):
3030 """Returns the caching device for the RNN variable.
3032 This is useful for distributed training, when variable is not located as same
3033 device as the training worker. By enabling the device cache, this allows
3034 worker to read the variable once and cache locally, rather than read it every
3035 time step from remote when it is needed.
3037 Note that this is assuming the variable that cell needs for each time step is
3038 having the same value in the forward path, and only gets updated in the
3039 backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If the
3040 cell body relies on any variable that gets updated every time step, then
3041 caching device will cause it to read the stall value.
3043 Args:
3044 rnn_cell: the rnn cell instance.
3045 """
3046 if context.executing_eagerly():
3047 # caching_device is not supported in eager mode.
3048 return None
3049 if not getattr(rnn_cell, '_enable_caching_device', False):
3050 return None
3051 # Don't set a caching device when running in a loop, since it is possible that
3052 # train steps could be wrapped in a tf.while_loop. In that scenario caching
3053 # prevents forward computations in loop iterations from re-reading the
3054 # updated weights.
3055 if control_flow_util.IsInWhileLoop(ops.get_default_graph()):
3056 logging.warning(
3057 'Variable read device caching has been disabled because the '
3058 'RNN is in tf.while_loop loop context, which will cause '
3059 'reading stalled value in forward path. This could slow down '
3060 'the training due to duplicated variable reads. Please '
3061 'consider updating your code to remove tf.while_loop if possible.')
3062 return None
3063 if (rnn_cell._dtype_policy.compute_dtype !=
3064 rnn_cell._dtype_policy.variable_dtype):
3065 logging.warning(
3066 'Variable read device caching has been disabled since it '
3067 'doesn\'t work with the mixed precision API. This is '
3068 'likely to cause a slowdown for RNN training due to '
3069 'duplicated read of variable for each timestep, which '
3070 'will be significant in a multi remote worker setting. '
3071 'Please consider disabling mixed precision API if '
3072 'the performance has been affected.')
3073 return None
3074 # Cache the value on the device that access the variable.
3075 return lambda op: op.device
3078def _config_for_enable_caching_device(rnn_cell):
3079 """Return the dict config for RNN cell wrt to enable_caching_device field.
3081 Since enable_caching_device is a internal implementation detail for speed up
3082 the RNN variable read when running on the multi remote worker setting, we
3083 don't want this config to be serialized constantly in the JSON. We will only
3084 serialize this field when a none default value is used to create the cell.
3085 Args:
3086 rnn_cell: the RNN cell for serialize.
3088 Returns:
3089 A dict which contains the JSON config for enable_caching_device value or
3090 empty dict if the enable_caching_device value is same as the default value.
3091 """
3092 default_enable_caching_device = ops.executing_eagerly_outside_functions()
3093 if rnn_cell._enable_caching_device != default_enable_caching_device:
3094 return {'enable_caching_device': rnn_cell._enable_caching_device}
3095 return {}