Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/stacked_rnn_cells.py: 26%
96 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Wrapper allowing a stack of RNN cells to behave as a single cell."""
18import functools
20import tensorflow.compat.v2 as tf
22from keras.src import backend
23from keras.src.engine import base_layer
24from keras.src.layers.rnn import rnn_utils
25from keras.src.saving import serialization_lib
26from keras.src.utils import generic_utils
27from keras.src.utils import tf_utils
29# isort: off
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util.tf_export import keras_export
34@keras_export("keras.layers.StackedRNNCells")
35class StackedRNNCells(base_layer.Layer):
36 """Wrapper allowing a stack of RNN cells to behave as a single cell.
38 Used to implement efficient stacked RNNs.
40 Args:
41 cells: List of RNN cell instances.
43 Examples:
45 ```python
46 batch_size = 3
47 sentence_max_length = 5
48 n_features = 2
49 new_shape = (batch_size, sentence_max_length, n_features)
50 x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32)
52 rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)]
53 stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells)
54 lstm_layer = tf.keras.layers.RNN(stacked_lstm)
56 result = lstm_layer(x)
57 ```
58 """
60 def __init__(self, cells, **kwargs):
61 for cell in cells:
62 if "call" not in dir(cell):
63 raise ValueError(
64 "All cells must have a `call` method. "
65 f"Received cell without a `call` method: {cell}"
66 )
67 if "state_size" not in dir(cell):
68 raise ValueError(
69 "All cells must have a `state_size` attribute. "
70 f"Received cell without a `state_size`: {cell}"
71 )
72 self.cells = cells
73 # reverse_state_order determines whether the state size will be in a
74 # reverse order of the cells' state. User might want to set this to True
75 # to keep the existing behavior. This is only useful when use
76 # RNN(return_state=True) since the state will be returned as the same
77 # order of state_size.
78 self.reverse_state_order = kwargs.pop("reverse_state_order", False)
79 if self.reverse_state_order:
80 logging.warning(
81 "reverse_state_order=True in StackedRNNCells will soon "
82 "be deprecated. Please update the code to work with the "
83 "natural order of states if you rely on the RNN states, "
84 "eg RNN(return_state=True)."
85 )
86 super().__init__(**kwargs)
88 @property
89 def state_size(self):
90 return tuple(
91 c.state_size
92 for c in (
93 self.cells[::-1] if self.reverse_state_order else self.cells
94 )
95 )
97 @property
98 def output_size(self):
99 if getattr(self.cells[-1], "output_size", None) is not None:
100 return self.cells[-1].output_size
101 elif rnn_utils.is_multiple_state(self.cells[-1].state_size):
102 return self.cells[-1].state_size[0]
103 else:
104 return self.cells[-1].state_size
106 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
107 initial_states = []
108 for cell in (
109 self.cells[::-1] if self.reverse_state_order else self.cells
110 ):
111 get_initial_state_fn = getattr(cell, "get_initial_state", None)
112 if get_initial_state_fn:
113 initial_states.append(
114 get_initial_state_fn(
115 inputs=inputs, batch_size=batch_size, dtype=dtype
116 )
117 )
118 else:
119 initial_states.append(
120 rnn_utils.generate_zero_filled_state_for_cell(
121 cell, inputs, batch_size, dtype
122 )
123 )
125 return tuple(initial_states)
127 def call(self, inputs, states, constants=None, training=None, **kwargs):
128 # Recover per-cell states.
129 state_size = (
130 self.state_size[::-1]
131 if self.reverse_state_order
132 else self.state_size
133 )
134 nested_states = tf.nest.pack_sequence_as(
135 state_size, tf.nest.flatten(states)
136 )
138 # Call the cells in order and store the returned states.
139 new_nested_states = []
140 for cell, states in zip(self.cells, nested_states):
141 states = states if tf.nest.is_nested(states) else [states]
142 # TF cell does not wrap the state into list when there is only one
143 # state.
144 is_tf_rnn_cell = getattr(cell, "_is_tf_rnn_cell", None) is not None
145 states = (
146 states[0] if len(states) == 1 and is_tf_rnn_cell else states
147 )
148 if generic_utils.has_arg(cell.call, "training"):
149 kwargs["training"] = training
150 else:
151 kwargs.pop("training", None)
152 # Use the __call__ function for callable objects, eg layers, so that
153 # it will have the proper name scopes for the ops, etc.
154 cell_call_fn = cell.__call__ if callable(cell) else cell.call
155 if generic_utils.has_arg(cell.call, "constants"):
156 inputs, states = cell_call_fn(
157 inputs, states, constants=constants, **kwargs
158 )
159 else:
160 inputs, states = cell_call_fn(inputs, states, **kwargs)
161 new_nested_states.append(states)
163 return inputs, tf.nest.pack_sequence_as(
164 state_size, tf.nest.flatten(new_nested_states)
165 )
167 @tf_utils.shape_type_conversion
168 def build(self, input_shape):
169 if isinstance(input_shape, list):
170 input_shape = input_shape[0]
172 def get_batch_input_shape(batch_size, dim):
173 shape = tf.TensorShape(dim).as_list()
174 return tuple([batch_size] + shape)
176 for cell in self.cells:
177 if isinstance(cell, base_layer.Layer) and not cell.built:
178 with backend.name_scope(cell.name):
179 cell.build(input_shape)
180 cell.built = True
181 if getattr(cell, "output_size", None) is not None:
182 output_dim = cell.output_size
183 elif rnn_utils.is_multiple_state(cell.state_size):
184 output_dim = cell.state_size[0]
185 else:
186 output_dim = cell.state_size
187 batch_size = tf.nest.flatten(input_shape)[0]
188 if tf.nest.is_nested(output_dim):
189 input_shape = tf.nest.map_structure(
190 functools.partial(get_batch_input_shape, batch_size),
191 output_dim,
192 )
193 input_shape = tuple(input_shape)
194 else:
195 input_shape = tuple(
196 [batch_size] + tf.TensorShape(output_dim).as_list()
197 )
198 self.built = True
200 def get_config(self):
201 cells = []
202 for cell in self.cells:
203 cells.append(serialization_lib.serialize_keras_object(cell))
204 config = {"cells": cells}
205 base_config = super().get_config()
206 return dict(list(base_config.items()) + list(config.items()))
208 @classmethod
209 def from_config(cls, config, custom_objects=None):
210 from keras.src.layers import deserialize as deserialize_layer
212 cells = []
213 for cell_config in config.pop("cells"):
214 cells.append(
215 deserialize_layer(cell_config, custom_objects=custom_objects)
216 )
217 return cls(cells, **config)