Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/stacked_rnn_cells.py: 26%

96 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Wrapper allowing a stack of RNN cells to behave as a single cell.""" 

16 

17 

18import functools 

19 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src import backend 

23from keras.src.engine import base_layer 

24from keras.src.layers.rnn import rnn_utils 

25from keras.src.saving import serialization_lib 

26from keras.src.utils import generic_utils 

27from keras.src.utils import tf_utils 

28 

29# isort: off 

30from tensorflow.python.platform import tf_logging as logging 

31from tensorflow.python.util.tf_export import keras_export 

32 

33 

34@keras_export("keras.layers.StackedRNNCells") 

35class StackedRNNCells(base_layer.Layer): 

36 """Wrapper allowing a stack of RNN cells to behave as a single cell. 

37 

38 Used to implement efficient stacked RNNs. 

39 

40 Args: 

41 cells: List of RNN cell instances. 

42 

43 Examples: 

44 

45 ```python 

46 batch_size = 3 

47 sentence_max_length = 5 

48 n_features = 2 

49 new_shape = (batch_size, sentence_max_length, n_features) 

50 x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32) 

51 

52 rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)] 

53 stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells) 

54 lstm_layer = tf.keras.layers.RNN(stacked_lstm) 

55 

56 result = lstm_layer(x) 

57 ``` 

58 """ 

59 

60 def __init__(self, cells, **kwargs): 

61 for cell in cells: 

62 if "call" not in dir(cell): 

63 raise ValueError( 

64 "All cells must have a `call` method. " 

65 f"Received cell without a `call` method: {cell}" 

66 ) 

67 if "state_size" not in dir(cell): 

68 raise ValueError( 

69 "All cells must have a `state_size` attribute. " 

70 f"Received cell without a `state_size`: {cell}" 

71 ) 

72 self.cells = cells 

73 # reverse_state_order determines whether the state size will be in a 

74 # reverse order of the cells' state. User might want to set this to True 

75 # to keep the existing behavior. This is only useful when use 

76 # RNN(return_state=True) since the state will be returned as the same 

77 # order of state_size. 

78 self.reverse_state_order = kwargs.pop("reverse_state_order", False) 

79 if self.reverse_state_order: 

80 logging.warning( 

81 "reverse_state_order=True in StackedRNNCells will soon " 

82 "be deprecated. Please update the code to work with the " 

83 "natural order of states if you rely on the RNN states, " 

84 "eg RNN(return_state=True)." 

85 ) 

86 super().__init__(**kwargs) 

87 

88 @property 

89 def state_size(self): 

90 return tuple( 

91 c.state_size 

92 for c in ( 

93 self.cells[::-1] if self.reverse_state_order else self.cells 

94 ) 

95 ) 

96 

97 @property 

98 def output_size(self): 

99 if getattr(self.cells[-1], "output_size", None) is not None: 

100 return self.cells[-1].output_size 

101 elif rnn_utils.is_multiple_state(self.cells[-1].state_size): 

102 return self.cells[-1].state_size[0] 

103 else: 

104 return self.cells[-1].state_size 

105 

106 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

107 initial_states = [] 

108 for cell in ( 

109 self.cells[::-1] if self.reverse_state_order else self.cells 

110 ): 

111 get_initial_state_fn = getattr(cell, "get_initial_state", None) 

112 if get_initial_state_fn: 

113 initial_states.append( 

114 get_initial_state_fn( 

115 inputs=inputs, batch_size=batch_size, dtype=dtype 

116 ) 

117 ) 

118 else: 

119 initial_states.append( 

120 rnn_utils.generate_zero_filled_state_for_cell( 

121 cell, inputs, batch_size, dtype 

122 ) 

123 ) 

124 

125 return tuple(initial_states) 

126 

127 def call(self, inputs, states, constants=None, training=None, **kwargs): 

128 # Recover per-cell states. 

129 state_size = ( 

130 self.state_size[::-1] 

131 if self.reverse_state_order 

132 else self.state_size 

133 ) 

134 nested_states = tf.nest.pack_sequence_as( 

135 state_size, tf.nest.flatten(states) 

136 ) 

137 

138 # Call the cells in order and store the returned states. 

139 new_nested_states = [] 

140 for cell, states in zip(self.cells, nested_states): 

141 states = states if tf.nest.is_nested(states) else [states] 

142 # TF cell does not wrap the state into list when there is only one 

143 # state. 

144 is_tf_rnn_cell = getattr(cell, "_is_tf_rnn_cell", None) is not None 

145 states = ( 

146 states[0] if len(states) == 1 and is_tf_rnn_cell else states 

147 ) 

148 if generic_utils.has_arg(cell.call, "training"): 

149 kwargs["training"] = training 

150 else: 

151 kwargs.pop("training", None) 

152 # Use the __call__ function for callable objects, eg layers, so that 

153 # it will have the proper name scopes for the ops, etc. 

154 cell_call_fn = cell.__call__ if callable(cell) else cell.call 

155 if generic_utils.has_arg(cell.call, "constants"): 

156 inputs, states = cell_call_fn( 

157 inputs, states, constants=constants, **kwargs 

158 ) 

159 else: 

160 inputs, states = cell_call_fn(inputs, states, **kwargs) 

161 new_nested_states.append(states) 

162 

163 return inputs, tf.nest.pack_sequence_as( 

164 state_size, tf.nest.flatten(new_nested_states) 

165 ) 

166 

167 @tf_utils.shape_type_conversion 

168 def build(self, input_shape): 

169 if isinstance(input_shape, list): 

170 input_shape = input_shape[0] 

171 

172 def get_batch_input_shape(batch_size, dim): 

173 shape = tf.TensorShape(dim).as_list() 

174 return tuple([batch_size] + shape) 

175 

176 for cell in self.cells: 

177 if isinstance(cell, base_layer.Layer) and not cell.built: 

178 with backend.name_scope(cell.name): 

179 cell.build(input_shape) 

180 cell.built = True 

181 if getattr(cell, "output_size", None) is not None: 

182 output_dim = cell.output_size 

183 elif rnn_utils.is_multiple_state(cell.state_size): 

184 output_dim = cell.state_size[0] 

185 else: 

186 output_dim = cell.state_size 

187 batch_size = tf.nest.flatten(input_shape)[0] 

188 if tf.nest.is_nested(output_dim): 

189 input_shape = tf.nest.map_structure( 

190 functools.partial(get_batch_input_shape, batch_size), 

191 output_dim, 

192 ) 

193 input_shape = tuple(input_shape) 

194 else: 

195 input_shape = tuple( 

196 [batch_size] + tf.TensorShape(output_dim).as_list() 

197 ) 

198 self.built = True 

199 

200 def get_config(self): 

201 cells = [] 

202 for cell in self.cells: 

203 cells.append(serialization_lib.serialize_keras_object(cell)) 

204 config = {"cells": cells} 

205 base_config = super().get_config() 

206 return dict(list(base_config.items()) + list(config.items())) 

207 

208 @classmethod 

209 def from_config(cls, config, custom_objects=None): 

210 from keras.src.layers import deserialize as deserialize_layer 

211 

212 cells = [] 

213 for cell_config in config.pop("cells"): 

214 cells.append( 

215 deserialize_layer(cell_config, custom_objects=custom_objects) 

216 ) 

217 return cls(cells, **config) 

218