Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py: 60%
47 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module implementing for RNN wrappers for TF v2."""
17# Note that all the APIs under this module are exported as tf.nn.*. This is due
18# to the fact that those APIs were from tf.nn.rnn_cell_impl. They are ported
19# here to avoid the cyclic dependency issue for serialization. These APIs will
20# probably be deprecated and removed in future since similar API is available in
21# existing Keras RNN API.
24from tensorflow.python.keras.layers import recurrent
25from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl
26from tensorflow.python.keras.utils import tf_inspect
27from tensorflow.python.util.deprecation import deprecated
28from tensorflow.python.util.tf_export import tf_export
31class _RNNCellWrapperV2(recurrent.AbstractRNNCell):
32 """Base class for cells wrappers V2 compatibility.
34 This class along with `rnn_cell_impl._RNNCellWrapperV1` allows to define
35 wrappers that are compatible with V1 and V2, and defines helper methods for
36 this purpose.
37 """
39 def __init__(self, cell, *args, **kwargs):
40 super(_RNNCellWrapperV2, self).__init__(*args, **kwargs)
41 self.cell = cell
42 cell_call_spec = tf_inspect.getfullargspec(cell.call)
43 self._expects_training_arg = ("training" in cell_call_spec.args) or (
44 cell_call_spec.varkw is not None
45 )
47 def call(self, inputs, state, **kwargs):
48 """Runs the RNN cell step computation.
50 When `call` is being used, we assume that the wrapper object has been built,
51 and therefore the wrapped cells has been built via its `build` method and
52 its `call` method can be used directly.
54 This allows to use the wrapped cell and the non-wrapped cell equivalently
55 when using `call` and `build`.
57 Args:
58 inputs: A tensor with wrapped cell's input.
59 state: A tensor or tuple of tensors with wrapped cell's state.
60 **kwargs: Additional arguments passed to the wrapped cell's `call`.
62 Returns:
63 A pair containing:
65 - Output: A tensor with cell's output.
66 - New state: A tensor or tuple of tensors with new wrapped cell's state.
67 """
68 return self._call_wrapped_cell(
69 inputs, state, cell_call_fn=self.cell.call, **kwargs)
71 def build(self, inputs_shape):
72 """Builds the wrapped cell."""
73 self.cell.build(inputs_shape)
74 self.built = True
76 def get_config(self):
77 config = {
78 "cell": {
79 "class_name": self.cell.__class__.__name__,
80 "config": self.cell.get_config()
81 },
82 }
83 base_config = super(_RNNCellWrapperV2, self).get_config()
84 return dict(list(base_config.items()) + list(config.items()))
86 @classmethod
87 def from_config(cls, config, custom_objects=None):
88 config = config.copy()
89 from tensorflow.python.keras.layers.serialization import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
90 cell = deserialize_layer(config.pop("cell"), custom_objects=custom_objects)
91 return cls(cell, **config)
94@deprecated(None, "Please use tf.keras.layers.RNN instead.")
95@tf_export("nn.RNNCellDropoutWrapper", v1=[])
96class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase,
97 _RNNCellWrapperV2):
98 """Operator adding dropout to inputs and outputs of the given cell."""
100 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
101 super(DropoutWrapper, self).__init__(*args, **kwargs)
102 if isinstance(self.cell, recurrent.LSTMCell):
103 raise ValueError("keras LSTM cell does not work with DropoutWrapper. "
104 "Please use LSTMCell(dropout=x, recurrent_dropout=y) "
105 "instead.")
107 __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__
110@deprecated(None, "Please use tf.keras.layers.RNN instead.")
111@tf_export("nn.RNNCellResidualWrapper", v1=[])
112class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase,
113 _RNNCellWrapperV2):
114 """RNNCell wrapper that ensures cell inputs are added to the outputs."""
116 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
117 super(ResidualWrapper, self).__init__(*args, **kwargs)
119 __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__
122@deprecated(None, "Please use tf.keras.layers.RNN instead.")
123@tf_export("nn.RNNCellDeviceWrapper", v1=[])
124class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase,
125 _RNNCellWrapperV2):
126 """Operator that ensures an RNNCell runs on a particular device."""
128 def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
129 super(DeviceWrapper, self).__init__(*args, **kwargs)
131 __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__