Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py: 29%
524 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-classes-have-attributes
16"""Module implementing RNN Cells.
18This module provides a number of basic commonly used RNN cells, such as LSTM
19(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of
20operators that allow adding dropouts, projections, or embeddings for inputs.
21Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by
22calling the `rnn` ops several times.
23"""
24import collections
25import warnings
27from tensorflow.python.eager import context
28from tensorflow.python.framework import config as tf_config
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_conversion
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_util
35from tensorflow.python.keras import activations
36from tensorflow.python.keras import backend
37from tensorflow.python.keras import initializers
38from tensorflow.python.keras.engine import base_layer_utils
39from tensorflow.python.keras.engine import input_spec
40from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl
41from tensorflow.python.keras.legacy_tf_layers import base as base_layer
42from tensorflow.python.keras.utils import tf_utils
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import clip_ops
45from tensorflow.python.ops import init_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import nn_ops
48from tensorflow.python.ops import partitioned_variables
49from tensorflow.python.ops import variable_scope as vs
50from tensorflow.python.ops import variables as tf_variables
51from tensorflow.python.platform import tf_logging as logging
52from tensorflow.python.trackable import base as trackable
53from tensorflow.python.util import nest
54from tensorflow.python.util.tf_export import keras_export
55from tensorflow.python.util.tf_export import tf_export
57_BIAS_VARIABLE_NAME = "bias"
58_WEIGHTS_VARIABLE_NAME = "kernel"
60# This can be used with self.assertRaisesRegexp for assert_like_rnncell.
61ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell"
64def _hasattr(obj, attr_name):
65 try:
66 getattr(obj, attr_name)
67 except AttributeError:
68 return False
69 else:
70 return True
73def assert_like_rnncell(cell_name, cell):
74 """Raises a TypeError if cell is not like an RNNCell.
76 NOTE: Do not rely on the error message (in particular in tests) which can be
77 subject to change to increase readability. Use
78 ASSERT_LIKE_RNNCELL_ERROR_REGEXP.
80 Args:
81 cell_name: A string to give a meaningful error referencing to the name of
82 the functionargument.
83 cell: The object which should behave like an RNNCell.
85 Raises:
86 TypeError: A human-friendly exception.
87 """
88 conditions = [
89 _hasattr(cell, "output_size"),
90 _hasattr(cell, "state_size"),
91 _hasattr(cell, "get_initial_state") or _hasattr(cell, "zero_state"),
92 callable(cell),
93 ]
94 errors = [
95 "'output_size' property is missing", "'state_size' property is missing",
96 "either 'zero_state' or 'get_initial_state' method is required",
97 "is not callable"
98 ]
100 if not all(conditions):
102 errors = [error for error, cond in zip(errors, conditions) if not cond]
103 raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format(
104 cell_name, cell, ", ".join(errors)))
107def _concat(prefix, suffix, static=False):
108 """Concat that enables int, Tensor, or TensorShape values.
110 This function takes a size specification, which can be an integer, a
111 TensorShape, or a Tensor, and converts it into a concatenated Tensor
112 (if static = False) or a list of integers (if static = True).
114 Args:
115 prefix: The prefix; usually the batch size (and/or time step size).
116 (TensorShape, int, or Tensor.)
117 suffix: TensorShape, int, or Tensor.
118 static: If `True`, return a python list with possibly unknown dimensions.
119 Otherwise return a `Tensor`.
121 Returns:
122 shape: the concatenation of prefix and suffix.
124 Raises:
125 ValueError: if `suffix` is not a scalar or vector (or TensorShape).
126 ValueError: if prefix or suffix was `None` and asked for dynamic
127 Tensors out.
128 """
129 if isinstance(prefix, ops.Tensor):
130 p = prefix
131 p_static = tensor_util.constant_value(prefix)
132 if p.shape.ndims == 0:
133 p = array_ops.expand_dims(p, 0)
134 elif p.shape.ndims != 1:
135 raise ValueError("prefix tensor must be either a scalar or vector, "
136 "but saw tensor: %s" % p)
137 else:
138 p = tensor_shape.TensorShape(prefix)
139 p_static = p.as_list() if p.ndims is not None else None
140 p = (
141 constant_op.constant(p.as_list(), dtype=dtypes.int32)
142 if p.is_fully_defined() else None)
143 if isinstance(suffix, ops.Tensor):
144 s = suffix
145 s_static = tensor_util.constant_value(suffix)
146 if s.shape.ndims == 0:
147 s = array_ops.expand_dims(s, 0)
148 elif s.shape.ndims != 1:
149 raise ValueError("suffix tensor must be either a scalar or vector, "
150 "but saw tensor: %s" % s)
151 else:
152 s = tensor_shape.TensorShape(suffix)
153 s_static = s.as_list() if s.ndims is not None else None
154 s = (
155 constant_op.constant(s.as_list(), dtype=dtypes.int32)
156 if s.is_fully_defined() else None)
158 if static:
159 shape = tensor_shape.TensorShape(p_static).concatenate(s_static)
160 shape = shape.as_list() if shape.ndims is not None else None
161 else:
162 if p is None or s is None:
163 raise ValueError("Provided a prefix or suffix of None: %s and %s" %
164 (prefix, suffix))
165 shape = array_ops.concat((p, s), 0)
166 return shape
169def _zero_state_tensors(state_size, batch_size, dtype):
170 """Create tensors of zeros based on state_size, batch_size, and dtype."""
172 def get_state_shape(s):
173 """Combine s with batch_size to get a proper tensor shape."""
174 c = _concat(batch_size, s)
175 size = array_ops.zeros(c, dtype=dtype)
176 if not context.executing_eagerly():
177 c_static = _concat(batch_size, s, static=True)
178 size.set_shape(c_static)
179 return size
181 return nest.map_structure(get_state_shape, state_size)
184@keras_export(v1=["keras.__internal__.legacy.rnn_cell.RNNCell"])
185@tf_export(v1=["nn.rnn_cell.RNNCell"])
186class RNNCell(base_layer.Layer):
187 """Abstract object representing an RNN cell.
189 Every `RNNCell` must have the properties below and implement `call` with
190 the signature `(output, next_state) = call(input, state)`. The optional
191 third input argument, `scope`, is allowed for backwards compatibility
192 purposes; but should be left off for new subclasses.
194 This definition of cell differs from the definition used in the literature.
195 In the literature, 'cell' refers to an object with a single scalar output.
196 This definition refers to a horizontal array of such units.
198 An RNN cell, in the most abstract setting, is anything that has
199 a state and performs some operation that takes a matrix of inputs.
200 This operation results in an output matrix with `self.output_size` columns.
201 If `self.state_size` is an integer, this operation also results in a new
202 state matrix with `self.state_size` columns. If `self.state_size` is a
203 (possibly nested tuple of) TensorShape object(s), then it should return a
204 matching structure of Tensors having shape `[batch_size].concatenate(s)`
205 for each `s` in `self.batch_size`.
206 """
208 def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
209 super(RNNCell, self).__init__(
210 trainable=trainable, name=name, dtype=dtype, **kwargs)
211 # Attribute that indicates whether the cell is a TF RNN cell, due the slight
212 # difference between TF and Keras RNN cell. Notably the state is not wrapped
213 # in a list for TF cell where they are single tensor state, whereas keras
214 # cell will wrap the state into a list, and call() will have to unwrap them.
215 self._is_tf_rnn_cell = True
217 def __call__(self, inputs, state, scope=None):
218 """Run this RNN cell on inputs, starting from the given state.
220 Args:
221 inputs: `2-D` tensor with shape `[batch_size, input_size]`.
222 state: if `self.state_size` is an integer, this should be a `2-D Tensor`
223 with shape `[batch_size, self.state_size]`. Otherwise, if
224 `self.state_size` is a tuple of integers, this should be a tuple with
225 shapes `[batch_size, s] for s in self.state_size`.
226 scope: VariableScope for the created subgraph; defaults to class name.
228 Returns:
229 A pair containing:
231 - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`.
232 - New state: Either a single `2-D` tensor, or a tuple of tensors matching
233 the arity and shapes of `state`.
234 """
235 if scope is not None:
236 with vs.variable_scope(
237 scope, custom_getter=self._rnn_get_variable) as scope:
238 return super(RNNCell, self).__call__(inputs, state, scope=scope)
239 else:
240 scope_attrname = "rnncell_scope"
241 scope = getattr(self, scope_attrname, None)
242 if scope is None:
243 scope = vs.variable_scope(
244 vs.get_variable_scope(), custom_getter=self._rnn_get_variable)
245 setattr(self, scope_attrname, scope)
246 with scope:
247 return super(RNNCell, self).__call__(inputs, state)
249 def _rnn_get_variable(self, getter, *args, **kwargs):
250 variable = getter(*args, **kwargs)
251 if ops.executing_eagerly_outside_functions():
252 trainable = variable.trainable
253 else:
254 trainable = (
255 variable in tf_variables.trainable_variables() or
256 (base_layer_utils.is_split_variable(variable) and
257 list(variable)[0] in tf_variables.trainable_variables()))
258 if trainable and all(variable is not v for v in self._trainable_weights):
259 self._trainable_weights.append(variable)
260 elif not trainable and all(
261 variable is not v for v in self._non_trainable_weights):
262 self._non_trainable_weights.append(variable)
263 return variable
265 @property
266 def state_size(self):
267 """size(s) of state(s) used by this cell.
269 It can be represented by an Integer, a TensorShape or a tuple of Integers
270 or TensorShapes.
271 """
272 raise NotImplementedError("Abstract method")
274 @property
275 def output_size(self):
276 """Integer or TensorShape: size of outputs produced by this cell."""
277 raise NotImplementedError("Abstract method")
279 def build(self, _):
280 # This tells the parent Layer object that it's OK to call
281 # self.add_variable() inside the call() method.
282 pass
284 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
285 if inputs is not None:
286 # Validate the given batch_size and dtype against inputs if provided.
287 inputs = tensor_conversion.convert_to_tensor_v2_with_dispatch(
288 inputs, name="inputs"
289 )
290 if batch_size is not None:
291 if tensor_util.is_tf_type(batch_size):
292 static_batch_size = tensor_util.constant_value(
293 batch_size, partial=True)
294 else:
295 static_batch_size = batch_size
296 if inputs.shape.dims[0].value != static_batch_size:
297 raise ValueError(
298 "batch size from input tensor is different from the "
299 "input param. Input tensor batch: {}, batch_size: {}".format(
300 inputs.shape.dims[0].value, batch_size))
302 if dtype is not None and inputs.dtype != dtype:
303 raise ValueError(
304 "dtype from input tensor is different from the "
305 "input param. Input tensor dtype: {}, dtype: {}".format(
306 inputs.dtype, dtype))
308 batch_size = inputs.shape.dims[0].value or array_ops.shape(inputs)[0]
309 dtype = inputs.dtype
310 if batch_size is None or dtype is None:
311 raise ValueError(
312 "batch_size and dtype cannot be None while constructing initial "
313 "state: batch_size={}, dtype={}".format(batch_size, dtype))
314 return self.zero_state(batch_size, dtype)
316 def zero_state(self, batch_size, dtype):
317 """Return zero-filled state tensor(s).
319 Args:
320 batch_size: int, float, or unit Tensor representing the batch size.
321 dtype: the data type to use for the state.
323 Returns:
324 If `state_size` is an int or TensorShape, then the return value is a
325 `N-D` tensor of shape `[batch_size, state_size]` filled with zeros.
327 If `state_size` is a nested list or tuple, then the return value is
328 a nested list or tuple (of the same structure) of `2-D` tensors with
329 the shapes `[batch_size, s]` for each s in `state_size`.
330 """
331 # Try to use the last cached zero_state. This is done to avoid recreating
332 # zeros, especially when eager execution is enabled.
333 state_size = self.state_size
334 is_eager = context.executing_eagerly()
335 if is_eager and _hasattr(self, "_last_zero_state"):
336 (last_state_size, last_batch_size, last_dtype,
337 last_output) = getattr(self, "_last_zero_state")
338 if (last_batch_size == batch_size and last_dtype == dtype and
339 last_state_size == state_size):
340 return last_output
341 with backend.name_scope(type(self).__name__ + "ZeroState"):
342 output = _zero_state_tensors(state_size, batch_size, dtype)
343 if is_eager:
344 self._last_zero_state = (state_size, batch_size, dtype, output)
345 return output
347 # TODO(b/134773139): Remove when contrib RNN cells implement `get_config`
348 def get_config(self): # pylint: disable=useless-super-delegation
349 return super(RNNCell, self).get_config()
351 @property
352 def _use_input_spec_as_call_signature(self):
353 # We do not store the shape information for the state argument in the call
354 # function for legacy RNN cells, so do not generate an input signature.
355 return False
358class LayerRNNCell(RNNCell):
359 """Subclass of RNNCells that act like proper `tf.Layer` objects.
361 For backwards compatibility purposes, most `RNNCell` instances allow their
362 `call` methods to instantiate variables via `tf.compat.v1.get_variable`. The
363 underlying
364 variable scope thus keeps track of any variables, and returning cached
365 versions. This is atypical of `tf.layer` objects, which separate this
366 part of layer building into a `build` method that is only called once.
368 Here we provide a subclass for `RNNCell` objects that act exactly as
369 `Layer` objects do. They must provide a `build` method and their
370 `call` methods do not access Variables `tf.compat.v1.get_variable`.
371 """
373 def __call__(self, inputs, state, scope=None, *args, **kwargs):
374 """Run this RNN cell on inputs, starting from the given state.
376 Args:
377 inputs: `2-D` tensor with shape `[batch_size, input_size]`.
378 state: if `self.state_size` is an integer, this should be a `2-D Tensor`
379 with shape `[batch_size, self.state_size]`. Otherwise, if
380 `self.state_size` is a tuple of integers, this should be a tuple with
381 shapes `[batch_size, s] for s in self.state_size`.
382 scope: optional cell scope.
383 *args: Additional positional arguments.
384 **kwargs: Additional keyword arguments.
386 Returns:
387 A pair containing:
389 - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`.
390 - New state: Either a single `2-D` tensor, or a tuple of tensors matching
391 the arity and shapes of `state`.
392 """
393 # Bypass RNNCell's variable capturing semantics for LayerRNNCell.
394 # Instead, it is up to subclasses to provide a proper build
395 # method. See the class docstring for more details.
396 return base_layer.Layer.__call__(
397 self, inputs, state, scope=scope, *args, **kwargs)
400@keras_export(v1=["keras.__internal__.legacy.rnn_cell.BasicRNNCell"])
401@tf_export(v1=["nn.rnn_cell.BasicRNNCell"])
402class BasicRNNCell(LayerRNNCell):
403 """The most basic RNN cell.
405 Note that this cell is not optimized for performance. Please use
406 `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU.
408 Args:
409 num_units: int, The number of units in the RNN cell.
410 activation: Nonlinearity to use. Default: `tanh`. It could also be string
411 that is within Keras activation function names.
412 reuse: (optional) Python boolean describing whether to reuse variables in an
413 existing scope. If not `True`, and the existing scope already has the
414 given variables, an error is raised.
415 name: String, the name of the layer. Layers with the same name will share
416 weights, but to avoid mistakes we require reuse=True in such cases.
417 dtype: Default dtype of the layer (default of `None` means use the type of
418 the first input). Required when `build` is called before `call`.
419 **kwargs: Dict, keyword named properties for common layer attributes, like
420 `trainable` etc when constructing the cell from configs of get_config().
421 """
423 def __init__(self,
424 num_units,
425 activation=None,
426 reuse=None,
427 name=None,
428 dtype=None,
429 **kwargs):
430 warnings.warn("`tf.nn.rnn_cell.BasicRNNCell` is deprecated and will be "
431 "removed in a future version. This class "
432 "is equivalent as `tf.keras.layers.SimpleRNNCell`, "
433 "and will be replaced by that in Tensorflow 2.0.")
434 super(BasicRNNCell, self).__init__(
435 _reuse=reuse, name=name, dtype=dtype, **kwargs)
436 _check_supported_dtypes(self.dtype)
437 if context.executing_eagerly() and tf_config.list_logical_devices("GPU"):
438 logging.warning(
439 "%s: Note that this cell is not optimized for performance. "
440 "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better "
441 "performance on GPU.", self)
443 # Inputs must be 2-dimensional.
444 self.input_spec = input_spec.InputSpec(ndim=2)
446 self._num_units = num_units
447 if activation:
448 self._activation = activations.get(activation)
449 else:
450 self._activation = math_ops.tanh
452 @property
453 def state_size(self):
454 return self._num_units
456 @property
457 def output_size(self):
458 return self._num_units
460 @tf_utils.shape_type_conversion
461 def build(self, inputs_shape):
462 if inputs_shape[-1] is None:
463 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" %
464 str(inputs_shape))
465 _check_supported_dtypes(self.dtype)
467 input_depth = inputs_shape[-1]
468 self._kernel = self.add_variable(
469 _WEIGHTS_VARIABLE_NAME,
470 shape=[input_depth + self._num_units, self._num_units])
471 self._bias = self.add_variable(
472 _BIAS_VARIABLE_NAME,
473 shape=[self._num_units],
474 initializer=init_ops.zeros_initializer(dtype=self.dtype))
476 self.built = True
478 def call(self, inputs, state):
479 """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
480 _check_rnn_cell_input_dtypes([inputs, state])
481 gate_inputs = math_ops.matmul(
482 array_ops.concat([inputs, state], 1), self._kernel)
483 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
484 output = self._activation(gate_inputs)
485 return output, output
487 def get_config(self):
488 config = {
489 "num_units": self._num_units,
490 "activation": activations.serialize(self._activation),
491 "reuse": self._reuse,
492 }
493 base_config = super(BasicRNNCell, self).get_config()
494 return dict(list(base_config.items()) + list(config.items()))
497@keras_export(v1=["keras.__internal__.legacy.rnn_cell.GRUCell"])
498@tf_export(v1=["nn.rnn_cell.GRUCell"])
499class GRUCell(LayerRNNCell):
500 """Gated Recurrent Unit cell.
502 Note that this cell is not optimized for performance. Please use
503 `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or
504 `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU.
506 Args:
507 num_units: int, The number of units in the GRU cell.
508 activation: Nonlinearity to use. Default: `tanh`.
509 reuse: (optional) Python boolean describing whether to reuse variables in an
510 existing scope. If not `True`, and the existing scope already has the
511 given variables, an error is raised.
512 kernel_initializer: (optional) The initializer to use for the weight and
513 projection matrices.
514 bias_initializer: (optional) The initializer to use for the bias.
515 name: String, the name of the layer. Layers with the same name will share
516 weights, but to avoid mistakes we require reuse=True in such cases.
517 dtype: Default dtype of the layer (default of `None` means use the type of
518 the first input). Required when `build` is called before `call`.
519 **kwargs: Dict, keyword named properties for common layer attributes, like
520 `trainable` etc when constructing the cell from configs of get_config().
522 References:
523 Learning Phrase Representations using RNN Encoder Decoder for Statistical
524 Machine Translation:
525 [Cho et al., 2014]
526 (https://aclanthology.coli.uni-saarland.de/papers/D14-1179/d14-1179)
527 ([pdf](http://emnlp2014.org/papers/pdf/EMNLP2014179.pdf))
528 """
530 def __init__(self,
531 num_units,
532 activation=None,
533 reuse=None,
534 kernel_initializer=None,
535 bias_initializer=None,
536 name=None,
537 dtype=None,
538 **kwargs):
539 warnings.warn("`tf.nn.rnn_cell.GRUCell` is deprecated and will be removed "
540 "in a future version. This class "
541 "is equivalent as `tf.keras.layers.GRUCell`, "
542 "and will be replaced by that in Tensorflow 2.0.")
543 super(GRUCell, self).__init__(
544 _reuse=reuse, name=name, dtype=dtype, **kwargs)
545 _check_supported_dtypes(self.dtype)
547 if context.executing_eagerly() and tf_config.list_logical_devices("GPU"):
548 logging.warning(
549 "%s: Note that this cell is not optimized for performance. "
550 "Please use tf.contrib.cudnn_rnn.CudnnGRU for better "
551 "performance on GPU.", self)
552 # Inputs must be 2-dimensional.
553 self.input_spec = input_spec.InputSpec(ndim=2)
555 self._num_units = num_units
556 if activation:
557 self._activation = activations.get(activation)
558 else:
559 self._activation = math_ops.tanh
560 self._kernel_initializer = initializers.get(kernel_initializer)
561 self._bias_initializer = initializers.get(bias_initializer)
563 @property
564 def state_size(self):
565 return self._num_units
567 @property
568 def output_size(self):
569 return self._num_units
571 @tf_utils.shape_type_conversion
572 def build(self, inputs_shape):
573 if inputs_shape[-1] is None:
574 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" %
575 str(inputs_shape))
576 _check_supported_dtypes(self.dtype)
577 input_depth = inputs_shape[-1]
578 self._gate_kernel = self.add_variable(
579 "gates/%s" % _WEIGHTS_VARIABLE_NAME,
580 shape=[input_depth + self._num_units, 2 * self._num_units],
581 initializer=self._kernel_initializer)
582 self._gate_bias = self.add_variable(
583 "gates/%s" % _BIAS_VARIABLE_NAME,
584 shape=[2 * self._num_units],
585 initializer=(self._bias_initializer
586 if self._bias_initializer is not None else
587 init_ops.constant_initializer(1.0, dtype=self.dtype)))
588 self._candidate_kernel = self.add_variable(
589 "candidate/%s" % _WEIGHTS_VARIABLE_NAME,
590 shape=[input_depth + self._num_units, self._num_units],
591 initializer=self._kernel_initializer)
592 self._candidate_bias = self.add_variable(
593 "candidate/%s" % _BIAS_VARIABLE_NAME,
594 shape=[self._num_units],
595 initializer=(self._bias_initializer
596 if self._bias_initializer is not None else
597 init_ops.zeros_initializer(dtype=self.dtype)))
599 self.built = True
601 def call(self, inputs, state):
602 """Gated recurrent unit (GRU) with nunits cells."""
603 _check_rnn_cell_input_dtypes([inputs, state])
605 gate_inputs = math_ops.matmul(
606 array_ops.concat([inputs, state], 1), self._gate_kernel)
607 gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
609 value = math_ops.sigmoid(gate_inputs)
610 r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
612 r_state = r * state
614 candidate = math_ops.matmul(
615 array_ops.concat([inputs, r_state], 1), self._candidate_kernel)
616 candidate = nn_ops.bias_add(candidate, self._candidate_bias)
618 c = self._activation(candidate)
619 new_h = u * state + (1 - u) * c
620 return new_h, new_h
622 def get_config(self):
623 config = {
624 "num_units": self._num_units,
625 "kernel_initializer": initializers.serialize(self._kernel_initializer),
626 "bias_initializer": initializers.serialize(self._bias_initializer),
627 "activation": activations.serialize(self._activation),
628 "reuse": self._reuse,
629 }
630 base_config = super(GRUCell, self).get_config()
631 return dict(list(base_config.items()) + list(config.items()))
634_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
637@keras_export(v1=["keras.__internal__.legacy.rnn_cell.LSTMStateTuple"])
638@tf_export(v1=["nn.rnn_cell.LSTMStateTuple"])
639class LSTMStateTuple(_LSTMStateTuple):
640 """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
642 Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state
643 and `h` is the output.
645 Only used when `state_is_tuple=True`.
646 """
647 __slots__ = ()
649 @property
650 def dtype(self):
651 (c, h) = self
652 if c.dtype != h.dtype:
653 raise TypeError("Inconsistent internal state: %s vs %s" %
654 (str(c.dtype), str(h.dtype)))
655 return c.dtype
658@keras_export(v1=["keras.__internal__.legacy.rnn_cell.BasicLSTMCell"])
659@tf_export(v1=["nn.rnn_cell.BasicLSTMCell"])
660class BasicLSTMCell(LayerRNNCell):
661 """DEPRECATED: Please use `tf.compat.v1.nn.rnn_cell.LSTMCell` instead.
663 Basic LSTM recurrent network cell.
665 The implementation is based on
667 We add forget_bias (default: 1) to the biases of the forget gate in order to
668 reduce the scale of forgetting in the beginning of the training.
670 It does not allow cell clipping, a projection layer, and does not
671 use peep-hole connections: it is the basic baseline.
673 For advanced models, please use the full `tf.compat.v1.nn.rnn_cell.LSTMCell`
674 that follows.
676 Note that this cell is not optimized for performance. Please use
677 `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
678 `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
679 better performance on CPU.
680 """
682 def __init__(self,
683 num_units,
684 forget_bias=1.0,
685 state_is_tuple=True,
686 activation=None,
687 reuse=None,
688 name=None,
689 dtype=None,
690 **kwargs):
691 """Initialize the basic LSTM cell.
693 Args:
694 num_units: int, The number of units in the LSTM cell.
695 forget_bias: float, The bias added to forget gates (see above). Must set
696 to `0.0` manually when restoring from CudnnLSTM-trained checkpoints.
697 state_is_tuple: If True, accepted and returned states are 2-tuples of the
698 `c_state` and `m_state`. If False, they are concatenated along the
699 column axis. The latter behavior will soon be deprecated.
700 activation: Activation function of the inner states. Default: `tanh`. It
701 could also be string that is within Keras activation function names.
702 reuse: (optional) Python boolean describing whether to reuse variables in
703 an existing scope. If not `True`, and the existing scope already has
704 the given variables, an error is raised.
705 name: String, the name of the layer. Layers with the same name will share
706 weights, but to avoid mistakes we require reuse=True in such cases.
707 dtype: Default dtype of the layer (default of `None` means use the type of
708 the first input). Required when `build` is called before `call`.
709 **kwargs: Dict, keyword named properties for common layer attributes, like
710 `trainable` etc when constructing the cell from configs of get_config().
711 When restoring from CudnnLSTM-trained checkpoints, must use
712 `CudnnCompatibleLSTMCell` instead.
713 """
714 warnings.warn("`tf.nn.rnn_cell.BasicLSTMCell` is deprecated and will be "
715 "removed in a future version. This class "
716 "is equivalent as `tf.keras.layers.LSTMCell`, "
717 "and will be replaced by that in Tensorflow 2.0.")
718 super(BasicLSTMCell, self).__init__(
719 _reuse=reuse, name=name, dtype=dtype, **kwargs)
720 _check_supported_dtypes(self.dtype)
721 if not state_is_tuple:
722 logging.warning(
723 "%s: Using a concatenated state is slower and will soon be "
724 "deprecated. Use state_is_tuple=True.", self)
725 if context.executing_eagerly() and tf_config.list_logical_devices("GPU"):
726 logging.warning(
727 "%s: Note that this cell is not optimized for performance. "
728 "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better "
729 "performance on GPU.", self)
731 # Inputs must be 2-dimensional.
732 self.input_spec = input_spec.InputSpec(ndim=2)
734 self._num_units = num_units
735 self._forget_bias = forget_bias
736 self._state_is_tuple = state_is_tuple
737 if activation:
738 self._activation = activations.get(activation)
739 else:
740 self._activation = math_ops.tanh
742 @property
743 def state_size(self):
744 return (LSTMStateTuple(self._num_units, self._num_units)
745 if self._state_is_tuple else 2 * self._num_units)
747 @property
748 def output_size(self):
749 return self._num_units
751 @tf_utils.shape_type_conversion
752 def build(self, inputs_shape):
753 if inputs_shape[-1] is None:
754 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" %
755 str(inputs_shape))
756 _check_supported_dtypes(self.dtype)
757 input_depth = inputs_shape[-1]
758 h_depth = self._num_units
759 self._kernel = self.add_variable(
760 _WEIGHTS_VARIABLE_NAME,
761 shape=[input_depth + h_depth, 4 * self._num_units])
762 self._bias = self.add_variable(
763 _BIAS_VARIABLE_NAME,
764 shape=[4 * self._num_units],
765 initializer=init_ops.zeros_initializer(dtype=self.dtype))
767 self.built = True
769 def call(self, inputs, state):
770 """Long short-term memory cell (LSTM).
772 Args:
773 inputs: `2-D` tensor with shape `[batch_size, input_size]`.
774 state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size,
775 num_units]`, if `state_is_tuple` has been set to `True`. Otherwise, a
776 `Tensor` shaped `[batch_size, 2 * num_units]`.
778 Returns:
779 A pair containing the new hidden state, and the new state (either a
780 `LSTMStateTuple` or a concatenated state, depending on
781 `state_is_tuple`).
782 """
783 _check_rnn_cell_input_dtypes([inputs, state])
785 sigmoid = math_ops.sigmoid
786 one = constant_op.constant(1, dtype=dtypes.int32)
787 # Parameters of gates are concatenated into one multiply for efficiency.
788 if self._state_is_tuple:
789 c, h = state
790 else:
791 c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
793 gate_inputs = math_ops.matmul(
794 array_ops.concat([inputs, h], 1), self._kernel)
795 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
797 # i = input_gate, j = new_input, f = forget_gate, o = output_gate
798 i, j, f, o = array_ops.split(
799 value=gate_inputs, num_or_size_splits=4, axis=one)
801 forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
802 # Note that using `add` and `multiply` instead of `+` and `*` gives a
803 # performance improvement. So using those at the cost of readability.
804 add = math_ops.add
805 multiply = math_ops.multiply
806 new_c = add(
807 multiply(c, sigmoid(add(f, forget_bias_tensor))),
808 multiply(sigmoid(i), self._activation(j)))
809 new_h = multiply(self._activation(new_c), sigmoid(o))
811 if self._state_is_tuple:
812 new_state = LSTMStateTuple(new_c, new_h)
813 else:
814 new_state = array_ops.concat([new_c, new_h], 1)
815 return new_h, new_state
817 def get_config(self):
818 config = {
819 "num_units": self._num_units,
820 "forget_bias": self._forget_bias,
821 "state_is_tuple": self._state_is_tuple,
822 "activation": activations.serialize(self._activation),
823 "reuse": self._reuse,
824 }
825 base_config = super(BasicLSTMCell, self).get_config()
826 return dict(list(base_config.items()) + list(config.items()))
829@keras_export(v1=["keras.__internal__.legacy.rnn_cell.LSTMCell"])
830@tf_export(v1=["nn.rnn_cell.LSTMCell"])
831class LSTMCell(LayerRNNCell):
832 """Long short-term memory unit (LSTM) recurrent network cell.
834 The default non-peephole implementation is based on (Gers et al., 1999).
835 The peephole implementation is based on (Sak et al., 2014).
837 The class uses optional peep-hole connections, optional cell clipping, and
838 an optional projection layer.
840 Note that this cell is not optimized for performance. Please use
841 `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
842 `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
843 better performance on CPU.
844 References:
845 Long short-term memory recurrent neural network architectures for large
846 scale acoustic modeling:
847 [Sak et al., 2014]
848 (https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html)
849 ([pdf]
850 (https://www.isca-speech.org/archive/archive_papers/interspeech_2014/i14_0338.pdf))
851 Learning to forget:
852 [Gers et al., 1999]
853 (http://digital-library.theiet.org/content/conferences/10.1049/cp_19991218)
854 ([pdf](https://arxiv.org/pdf/1409.2329.pdf))
855 Long Short-Term Memory:
856 [Hochreiter et al., 1997]
857 (https://www.mitpressjournals.org/doi/abs/10.1162/neco.1997.9.8.1735)
858 ([pdf](http://ml.jku.at/publications/older/3504.pdf))
859 """
861 def __init__(self,
862 num_units,
863 use_peepholes=False,
864 cell_clip=None,
865 initializer=None,
866 num_proj=None,
867 proj_clip=None,
868 num_unit_shards=None,
869 num_proj_shards=None,
870 forget_bias=1.0,
871 state_is_tuple=True,
872 activation=None,
873 reuse=None,
874 name=None,
875 dtype=None,
876 **kwargs):
877 """Initialize the parameters for an LSTM cell.
879 Args:
880 num_units: int, The number of units in the LSTM cell.
881 use_peepholes: bool, set True to enable diagonal/peephole connections.
882 cell_clip: (optional) A float value, if provided the cell state is clipped
883 by this value prior to the cell output activation.
884 initializer: (optional) The initializer to use for the weight and
885 projection matrices.
886 num_proj: (optional) int, The output dimensionality for the projection
887 matrices. If None, no projection is performed.
888 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
889 provided, then the projected values are clipped elementwise to within
890 `[-proj_clip, proj_clip]`.
891 num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a
892 variable_scope partitioner instead.
893 num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a
894 variable_scope partitioner instead.
895 forget_bias: Biases of the forget gate are initialized by default to 1 in
896 order to reduce the scale of forgetting at the beginning of the
897 training. Must set it manually to `0.0` when restoring from CudnnLSTM
898 trained checkpoints.
899 state_is_tuple: If True, accepted and returned states are 2-tuples of the
900 `c_state` and `m_state`. If False, they are concatenated along the
901 column axis. This latter behavior will soon be deprecated.
902 activation: Activation function of the inner states. Default: `tanh`. It
903 could also be string that is within Keras activation function names.
904 reuse: (optional) Python boolean describing whether to reuse variables in
905 an existing scope. If not `True`, and the existing scope already has
906 the given variables, an error is raised.
907 name: String, the name of the layer. Layers with the same name will share
908 weights, but to avoid mistakes we require reuse=True in such cases.
909 dtype: Default dtype of the layer (default of `None` means use the type of
910 the first input). Required when `build` is called before `call`.
911 **kwargs: Dict, keyword named properties for common layer attributes, like
912 `trainable` etc when constructing the cell from configs of get_config().
913 When restoring from CudnnLSTM-trained checkpoints, use
914 `CudnnCompatibleLSTMCell` instead.
915 """
916 warnings.warn("`tf.nn.rnn_cell.LSTMCell` is deprecated and will be "
917 "removed in a future version. This class "
918 "is equivalent as `tf.keras.layers.LSTMCell`, "
919 "and will be replaced by that in Tensorflow 2.0.")
920 super(LSTMCell, self).__init__(
921 _reuse=reuse, name=name, dtype=dtype, **kwargs)
922 _check_supported_dtypes(self.dtype)
923 if not state_is_tuple:
924 logging.warning(
925 "%s: Using a concatenated state is slower and will soon be "
926 "deprecated. Use state_is_tuple=True.", self)
927 if num_unit_shards is not None or num_proj_shards is not None:
928 logging.warning(
929 "%s: The num_unit_shards and proj_unit_shards parameters are "
930 "deprecated and will be removed in Jan 2017. "
931 "Use a variable scope with a partitioner instead.", self)
932 if context.executing_eagerly() and tf_config.list_logical_devices("GPU"):
933 logging.warning(
934 "%s: Note that this cell is not optimized for performance. "
935 "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better "
936 "performance on GPU.", self)
938 # Inputs must be 2-dimensional.
939 self.input_spec = input_spec.InputSpec(ndim=2)
941 self._num_units = num_units
942 self._use_peepholes = use_peepholes
943 self._cell_clip = cell_clip
944 self._initializer = initializers.get(initializer)
945 self._num_proj = num_proj
946 self._proj_clip = proj_clip
947 self._num_unit_shards = num_unit_shards
948 self._num_proj_shards = num_proj_shards
949 self._forget_bias = forget_bias
950 self._state_is_tuple = state_is_tuple
951 if activation:
952 self._activation = activations.get(activation)
953 else:
954 self._activation = math_ops.tanh
956 if num_proj:
957 self._state_size = (
958 LSTMStateTuple(num_units, num_proj) if state_is_tuple else num_units +
959 num_proj)
960 self._output_size = num_proj
961 else:
962 self._state_size = (
963 LSTMStateTuple(num_units, num_units) if state_is_tuple else 2 *
964 num_units)
965 self._output_size = num_units
967 @property
968 def state_size(self):
969 return self._state_size
971 @property
972 def output_size(self):
973 return self._output_size
975 @tf_utils.shape_type_conversion
976 def build(self, inputs_shape):
977 if inputs_shape[-1] is None:
978 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" %
979 str(inputs_shape))
980 _check_supported_dtypes(self.dtype)
981 input_depth = inputs_shape[-1]
982 h_depth = self._num_units if self._num_proj is None else self._num_proj
983 maybe_partitioner = (
984 partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
985 if self._num_unit_shards is not None else None)
986 self._kernel = self.add_variable(
987 _WEIGHTS_VARIABLE_NAME,
988 shape=[input_depth + h_depth, 4 * self._num_units],
989 initializer=self._initializer,
990 partitioner=maybe_partitioner)
991 if self.dtype is None:
992 initializer = init_ops.zeros_initializer
993 else:
994 initializer = init_ops.zeros_initializer(dtype=self.dtype)
995 self._bias = self.add_variable(
996 _BIAS_VARIABLE_NAME,
997 shape=[4 * self._num_units],
998 initializer=initializer)
999 if self._use_peepholes:
1000 self._w_f_diag = self.add_variable(
1001 "w_f_diag", shape=[self._num_units], initializer=self._initializer)
1002 self._w_i_diag = self.add_variable(
1003 "w_i_diag", shape=[self._num_units], initializer=self._initializer)
1004 self._w_o_diag = self.add_variable(
1005 "w_o_diag", shape=[self._num_units], initializer=self._initializer)
1007 if self._num_proj is not None:
1008 maybe_proj_partitioner = (
1009 partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
1010 if self._num_proj_shards is not None else None)
1011 self._proj_kernel = self.add_variable(
1012 "projection/%s" % _WEIGHTS_VARIABLE_NAME,
1013 shape=[self._num_units, self._num_proj],
1014 initializer=self._initializer,
1015 partitioner=maybe_proj_partitioner)
1017 self.built = True
1019 def call(self, inputs, state):
1020 """Run one step of LSTM.
1022 Args:
1023 inputs: input Tensor, must be 2-D, `[batch, input_size]`.
1024 state: if `state_is_tuple` is False, this must be a state Tensor, `2-D,
1025 [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple
1026 of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`.
1028 Returns:
1029 A tuple containing:
1031 - A `2-D, [batch, output_dim]`, Tensor representing the output of the
1032 LSTM after reading `inputs` when previous state was `state`.
1033 Here output_dim is:
1034 num_proj if num_proj was set,
1035 num_units otherwise.
1036 - Tensor(s) representing the new state of LSTM after reading `inputs` when
1037 the previous state was `state`. Same type and shape(s) as `state`.
1039 Raises:
1040 ValueError: If input size cannot be inferred from inputs via
1041 static shape inference.
1042 """
1043 _check_rnn_cell_input_dtypes([inputs, state])
1045 num_proj = self._num_units if self._num_proj is None else self._num_proj
1046 sigmoid = math_ops.sigmoid
1048 if self._state_is_tuple:
1049 (c_prev, m_prev) = state
1050 else:
1051 c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
1052 m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
1054 input_size = inputs.get_shape().with_rank(2).dims[1].value
1055 if input_size is None:
1056 raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1058 # i = input_gate, j = new_input, f = forget_gate, o = output_gate
1059 lstm_matrix = math_ops.matmul(
1060 array_ops.concat([inputs, m_prev], 1), self._kernel)
1061 lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias)
1063 i, j, f, o = array_ops.split(
1064 value=lstm_matrix, num_or_size_splits=4, axis=1)
1065 # Diagonal connections
1066 if self._use_peepholes:
1067 c = (
1068 sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
1069 sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
1070 else:
1071 c = (
1072 sigmoid(f + self._forget_bias) * c_prev +
1073 sigmoid(i) * self._activation(j))
1075 if self._cell_clip is not None:
1076 # pylint: disable=invalid-unary-operand-type
1077 c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
1078 # pylint: enable=invalid-unary-operand-type
1079 if self._use_peepholes:
1080 m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
1081 else:
1082 m = sigmoid(o) * self._activation(c)
1084 if self._num_proj is not None:
1085 m = math_ops.matmul(m, self._proj_kernel)
1087 if self._proj_clip is not None:
1088 # pylint: disable=invalid-unary-operand-type
1089 m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
1090 # pylint: enable=invalid-unary-operand-type
1092 new_state = (
1093 LSTMStateTuple(c, m)
1094 if self._state_is_tuple else array_ops.concat([c, m], 1))
1095 return m, new_state
1097 def get_config(self):
1098 config = {
1099 "num_units": self._num_units,
1100 "use_peepholes": self._use_peepholes,
1101 "cell_clip": self._cell_clip,
1102 "initializer": initializers.serialize(self._initializer),
1103 "num_proj": self._num_proj,
1104 "proj_clip": self._proj_clip,
1105 "num_unit_shards": self._num_unit_shards,
1106 "num_proj_shards": self._num_proj_shards,
1107 "forget_bias": self._forget_bias,
1108 "state_is_tuple": self._state_is_tuple,
1109 "activation": activations.serialize(self._activation),
1110 "reuse": self._reuse,
1111 }
1112 base_config = super(LSTMCell, self).get_config()
1113 return dict(list(base_config.items()) + list(config.items()))
1116class _RNNCellWrapperV1(RNNCell):
1117 """Base class for cells wrappers V1 compatibility.
1119 This class along with `_RNNCellWrapperV2` allows to define cells wrappers that
1120 are compatible with V1 and V2, and defines helper methods for this purpose.
1121 """
1123 def __init__(self, cell, *args, **kwargs):
1124 super(_RNNCellWrapperV1, self).__init__(*args, **kwargs)
1125 assert_like_rnncell("cell", cell)
1126 self.cell = cell
1127 if isinstance(cell, trackable.Trackable):
1128 self._track_trackable(self.cell, name="cell")
1130 def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
1131 """Calls the wrapped cell and performs the wrapping logic.
1133 This method is called from the wrapper's `call` or `__call__` methods.
1135 Args:
1136 inputs: A tensor with wrapped cell's input.
1137 state: A tensor or tuple of tensors with wrapped cell's state.
1138 cell_call_fn: Wrapped cell's method to use for step computation (cell's
1139 `__call__` or 'call' method).
1140 **kwargs: Additional arguments.
1142 Returns:
1143 A pair containing:
1144 - Output: A tensor with cell's output.
1145 - New state: A tensor or tuple of tensors with new wrapped cell's state.
1146 """
1147 raise NotImplementedError
1149 def __call__(self, inputs, state, scope=None):
1150 """Runs the RNN cell step computation.
1152 We assume that the wrapped RNNCell is being built within its `__call__`
1153 method. We directly use the wrapped cell's `__call__` in the overridden
1154 wrapper `__call__` method.
1156 This allows to use the wrapped cell and the non-wrapped cell equivalently
1157 when using `__call__`.
1159 Args:
1160 inputs: A tensor with wrapped cell's input.
1161 state: A tensor or tuple of tensors with wrapped cell's state.
1162 scope: VariableScope for the subgraph created in the wrapped cells'
1163 `__call__`.
1165 Returns:
1166 A pair containing:
1168 - Output: A tensor with cell's output.
1169 - New state: A tensor or tuple of tensors with new wrapped cell's state.
1170 """
1171 return self._call_wrapped_cell(
1172 inputs, state, cell_call_fn=self.cell.__call__, scope=scope)
1174 def get_config(self):
1175 config = {
1176 "cell": {
1177 "class_name": self.cell.__class__.__name__,
1178 "config": self.cell.get_config()
1179 },
1180 }
1181 base_config = super(_RNNCellWrapperV1, self).get_config()
1182 return dict(list(base_config.items()) + list(config.items()))
1184 @classmethod
1185 def from_config(cls, config, custom_objects=None):
1186 config = config.copy()
1187 cell = config.pop("cell")
1188 try:
1189 assert_like_rnncell("cell", cell)
1190 return cls(cell, **config)
1191 except TypeError:
1192 raise ValueError("RNNCellWrapper cannot reconstruct the wrapped cell. "
1193 "Please overwrite the cell in the config with a RNNCell "
1194 "instance.")
1197@keras_export(v1=["keras.__internal__.legacy.rnn_cell.DropoutWrapper"])
1198@tf_export(v1=["nn.rnn_cell.DropoutWrapper"])
1199class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase,
1200 _RNNCellWrapperV1):
1201 """Operator adding dropout to inputs and outputs of the given cell."""
1203 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
1204 super(DropoutWrapper, self).__init__(*args, **kwargs)
1206 __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__
1209@keras_export(v1=["keras.__internal__.legacy.rnn_cell.ResidualWrapper"])
1210@tf_export(v1=["nn.rnn_cell.ResidualWrapper"])
1211class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase,
1212 _RNNCellWrapperV1):
1213 """RNNCell wrapper that ensures cell inputs are added to the outputs."""
1215 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
1216 super(ResidualWrapper, self).__init__(*args, **kwargs)
1218 __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__
1221@keras_export(v1=["keras.__internal__.legacy.rnn_cell.DeviceWrapper"])
1222@tf_export(v1=["nn.rnn_cell.DeviceWrapper"])
1223class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase,
1224 _RNNCellWrapperV1):
1226 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
1227 super(DeviceWrapper, self).__init__(*args, **kwargs)
1229 __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__
1232@keras_export(v1=["keras.__internal__.legacy.rnn_cell.MultiRNNCell"])
1233@tf_export(v1=["nn.rnn_cell.MultiRNNCell"])
1234class MultiRNNCell(RNNCell):
1235 """RNN cell composed sequentially of multiple simple cells.
1237 Example:
1239 ```python
1240 num_units = [128, 64]
1241 cells = [BasicLSTMCell(num_units=n) for n in num_units]
1242 stacked_rnn_cell = MultiRNNCell(cells)
1243 ```
1244 """
1246 def __init__(self, cells, state_is_tuple=True):
1247 """Create a RNN cell composed sequentially of a number of RNNCells.
1249 Args:
1250 cells: list of RNNCells that will be composed in this order.
1251 state_is_tuple: If True, accepted and returned states are n-tuples, where
1252 `n = len(cells)`. If False, the states are all concatenated along the
1253 column axis. This latter behavior will soon be deprecated.
1255 Raises:
1256 ValueError: if cells is empty (not allowed), or at least one of the cells
1257 returns a state tuple but the flag `state_is_tuple` is `False`.
1258 """
1259 logging.warning("`tf.nn.rnn_cell.MultiRNNCell` is deprecated. This class "
1260 "is equivalent as `tf.keras.layers.StackedRNNCells`, "
1261 "and will be replaced by that in Tensorflow 2.0.")
1262 super(MultiRNNCell, self).__init__()
1263 if not cells:
1264 raise ValueError("Must specify at least one cell for MultiRNNCell.")
1265 if not nest.is_nested(cells):
1266 raise TypeError("cells must be a list or tuple, but saw: %s." % cells)
1268 if len(set(id(cell) for cell in cells)) < len(cells):
1269 logging.log_first_n(
1270 logging.WARN, "At least two cells provided to MultiRNNCell "
1271 "are the same object and will share weights.", 1)
1273 self._cells = cells
1274 for cell_number, cell in enumerate(self._cells):
1275 # Add Trackable dependencies on these cells so their variables get
1276 # saved with this object when using object-based saving.
1277 if isinstance(cell, trackable.Trackable):
1278 # TODO(allenl): Track down non-Trackable callers.
1279 self._track_trackable(cell, name="cell-%d" % (cell_number,))
1280 self._state_is_tuple = state_is_tuple
1281 if not state_is_tuple:
1282 if any(nest.is_nested(c.state_size) for c in self._cells):
1283 raise ValueError("Some cells return tuples of states, but the flag "
1284 "state_is_tuple is not set. State sizes are: %s" %
1285 str([c.state_size for c in self._cells]))
1287 @property
1288 def state_size(self):
1289 if self._state_is_tuple:
1290 return tuple(cell.state_size for cell in self._cells)
1291 else:
1292 return sum(cell.state_size for cell in self._cells)
1294 @property
1295 def output_size(self):
1296 return self._cells[-1].output_size
1298 def zero_state(self, batch_size, dtype):
1299 with backend.name_scope(type(self).__name__ + "ZeroState"):
1300 if self._state_is_tuple:
1301 return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells)
1302 else:
1303 # We know here that state_size of each cell is not a tuple and
1304 # presumably does not contain TensorArrays or anything else fancy
1305 return super(MultiRNNCell, self).zero_state(batch_size, dtype)
1307 @property
1308 def trainable_weights(self):
1309 if not self.trainable:
1310 return []
1311 weights = []
1312 for cell in self._cells:
1313 if isinstance(cell, base_layer.Layer):
1314 weights += cell.trainable_weights
1315 return weights
1317 @property
1318 def non_trainable_weights(self):
1319 weights = []
1320 for cell in self._cells:
1321 if isinstance(cell, base_layer.Layer):
1322 weights += cell.non_trainable_weights
1323 if not self.trainable:
1324 trainable_weights = []
1325 for cell in self._cells:
1326 if isinstance(cell, base_layer.Layer):
1327 trainable_weights += cell.trainable_weights
1328 return trainable_weights + weights
1329 return weights
1331 def call(self, inputs, state):
1332 """Run this multi-layer cell on inputs, starting from state."""
1333 cur_state_pos = 0
1334 cur_inp = inputs
1335 new_states = []
1336 for i, cell in enumerate(self._cells):
1337 with vs.variable_scope("cell_%d" % i):
1338 if self._state_is_tuple:
1339 if not nest.is_nested(state):
1340 raise ValueError(
1341 "Expected state to be a tuple of length %d, but received: %s" %
1342 (len(self.state_size), state))
1343 cur_state = state[i]
1344 else:
1345 cur_state = array_ops.slice(state, [0, cur_state_pos],
1346 [-1, cell.state_size])
1347 cur_state_pos += cell.state_size
1348 cur_inp, new_state = cell(cur_inp, cur_state)
1349 new_states.append(new_state)
1351 new_states = (
1352 tuple(new_states) if self._state_is_tuple else array_ops.concat(
1353 new_states, 1))
1355 return cur_inp, new_states
1358def _check_rnn_cell_input_dtypes(inputs):
1359 """Check whether the input tensors are with supported dtypes.
1361 Default RNN cells only support floats and complex as its dtypes since the
1362 activation function (tanh and sigmoid) only allow those types. This function
1363 will throw a proper error message if the inputs is not in a supported type.
1365 Args:
1366 inputs: tensor or nested structure of tensors that are feed to RNN cell as
1367 input or state.
1369 Raises:
1370 ValueError: if any of the input tensor are not having dtypes of float or
1371 complex.
1372 """
1373 for t in nest.flatten(inputs):
1374 _check_supported_dtypes(t.dtype)
1377def _check_supported_dtypes(dtype):
1378 if dtype is None:
1379 return
1380 dtype = dtypes.as_dtype(dtype)
1381 if not (dtype.is_floating or dtype.is_complex):
1382 raise ValueError("RNN cell only supports floating point inputs, "
1383 "but saw dtype: %s" % dtype)