Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/rnn/esn_cell.py: 26%

62 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implements ESN Cell.""" 

16 

17import tensorflow as tf 

18from typeguard import typechecked 

19 

20from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell 

21from tensorflow_addons.utils.types import ( 

22 Activation, 

23 Initializer, 

24) 

25 

26 

27@tf.keras.utils.register_keras_serializable(package="Addons") 

28class ESNCell(AbstractRNNCell): 

29 """Echo State recurrent Network (ESN) cell. 

30 This implements the recurrent cell from the paper: 

31 H. Jaeger 

32 "The "echo state" approach to analysing and training recurrent neural networks". 

33 GMD Report148, German National Research Center for Information Technology, 2001. 

34 https://www.researchgate.net/publication/215385037 

35 

36 Example: 

37 

38 >>> inputs = np.random.random([30,23,9]).astype(np.float32) 

39 >>> ESNCell = tfa.rnn.ESNCell(4) 

40 >>> rnn = tf.keras.layers.RNN(ESNCell, return_sequences=True, return_state=True) 

41 >>> outputs, memory_state = rnn(inputs) 

42 >>> outputs.shape 

43 TensorShape([30, 23, 4]) 

44 >>> memory_state.shape 

45 TensorShape([30, 4]) 

46 

47 Args: 

48 units: Positive integer, dimensionality in the reservoir. 

49 connectivity: Float between 0 and 1. 

50 Connection probability between two reservoir units. 

51 Default: 0.1. 

52 leaky: Float between 0 and 1. 

53 Leaking rate of the reservoir. 

54 If you pass 1, it is the special case the model does not have leaky 

55 integration. 

56 Default: 1. 

57 spectral_radius: Float between 0 and 1. 

58 Desired spectral radius of recurrent weight matrix. 

59 Default: 0.9. 

60 use_norm2: Boolean, whether to use the p-norm function (with p=2) as an upper 

61 bound of the spectral radius so that the echo state property is satisfied. 

62 It avoids to compute the eigenvalues which has an exponential complexity. 

63 Default: False. 

64 use_bias: Boolean, whether the layer uses a bias vector. 

65 Default: True. 

66 activation: Activation function to use. 

67 Default: hyperbolic tangent (`tanh`). 

68 kernel_initializer: Initializer for the `kernel` weights matrix, 

69 used for the linear transformation of the inputs. 

70 Default: `glorot_uniform`. 

71 recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, 

72 used for the linear transformation of the recurrent state. 

73 Default: `glorot_uniform`. 

74 bias_initializer: Initializer for the bias vector. 

75 Default: `zeros`. 

76 Call arguments: 

77 inputs: A 2D tensor (batch x num_units). 

78 states: List of state tensors corresponding to the previous timestep. 

79 """ 

80 

81 @typechecked 

82 def __init__( 

83 self, 

84 units: int, 

85 connectivity: float = 0.1, 

86 leaky: float = 1, 

87 spectral_radius: float = 0.9, 

88 use_norm2: bool = False, 

89 use_bias: bool = True, 

90 activation: Activation = "tanh", 

91 kernel_initializer: Initializer = "glorot_uniform", 

92 recurrent_initializer: Initializer = "glorot_uniform", 

93 bias_initializer: Initializer = "zeros", 

94 **kwargs, 

95 ): 

96 super().__init__(**kwargs) 

97 self.units = units 

98 self.connectivity = connectivity 

99 self.leaky = leaky 

100 self.spectral_radius = spectral_radius 

101 self.use_norm2 = use_norm2 

102 self.use_bias = use_bias 

103 self.activation = tf.keras.activations.get(activation) 

104 self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) 

105 self.recurrent_initializer = tf.keras.initializers.get(recurrent_initializer) 

106 self.bias_initializer = tf.keras.initializers.get(bias_initializer) 

107 

108 self._state_size = units 

109 self._output_size = units 

110 

111 @property 

112 def state_size(self): 

113 return self._state_size 

114 

115 @property 

116 def output_size(self): 

117 return self._output_size 

118 

119 def build(self, inputs_shape): 

120 input_size = tf.compat.dimension_value(tf.TensorShape(inputs_shape)[-1]) 

121 if input_size is None: 

122 raise ValueError( 

123 "Could not infer input size from inputs.get_shape()[-1]. Shape received is %s" 

124 % inputs_shape 

125 ) 

126 

127 def _esn_recurrent_initializer(shape, dtype, partition_info=None): 

128 recurrent_weights = tf.keras.initializers.get(self.recurrent_initializer)( 

129 shape, dtype 

130 ) 

131 

132 connectivity_mask = tf.cast( 

133 tf.math.less_equal(tf.random.uniform(shape), self.connectivity), 

134 dtype, 

135 ) 

136 recurrent_weights = tf.math.multiply(recurrent_weights, connectivity_mask) 

137 

138 # Satisfy the necessary condition for the echo state property `max(eig(W)) < 1` 

139 if self.use_norm2: 

140 # This condition is approximated scaling the norm 2 of the reservoir matrix 

141 # which is an upper bound of the spectral radius. 

142 recurrent_norm2 = tf.math.sqrt( 

143 tf.math.reduce_sum(tf.math.square(recurrent_weights)) 

144 ) 

145 is_norm2_0 = tf.cast(tf.math.equal(recurrent_norm2, 0), dtype) 

146 scaling_factor = tf.cast(self.spectral_radius, dtype) / ( 

147 recurrent_norm2 + 1 * is_norm2_0 

148 ) 

149 else: 

150 abs_eig_values = tf.abs(tf.linalg.eig(recurrent_weights)[0]) 

151 scaling_factor = tf.math.divide_no_nan( 

152 tf.cast(self.spectral_radius, dtype), tf.reduce_max(abs_eig_values) 

153 ) 

154 

155 recurrent_weights = tf.multiply(recurrent_weights, scaling_factor) 

156 

157 return recurrent_weights 

158 

159 self.recurrent_kernel = self.add_weight( 

160 name="recurrent_kernel", 

161 shape=[self.units, self.units], 

162 initializer=_esn_recurrent_initializer, 

163 trainable=False, 

164 dtype=self.dtype, 

165 ) 

166 self.kernel = self.add_weight( 

167 name="kernel", 

168 shape=[input_size, self.units], 

169 initializer=self.kernel_initializer, 

170 trainable=False, 

171 dtype=self.dtype, 

172 ) 

173 

174 if self.use_bias: 

175 self.bias = self.add_weight( 

176 name="bias", 

177 shape=[self.units], 

178 initializer=self.bias_initializer, 

179 trainable=False, 

180 dtype=self.dtype, 

181 ) 

182 

183 self.built = True 

184 

185 def call(self, inputs, state): 

186 in_matrix = tf.concat([inputs, state[0]], axis=1) 

187 weights_matrix = tf.concat([self.kernel, self.recurrent_kernel], axis=0) 

188 

189 output = tf.linalg.matmul(in_matrix, weights_matrix) 

190 if self.use_bias: 

191 output = output + self.bias 

192 output = self.activation(output) 

193 output = (1 - self.leaky) * state[0] + self.leaky * output 

194 

195 return output, output 

196 

197 def get_config(self): 

198 config = { 

199 "units": self.units, 

200 "connectivity": self.connectivity, 

201 "leaky": self.leaky, 

202 "spectral_radius": self.spectral_radius, 

203 "use_norm2": self.use_norm2, 

204 "use_bias": self.use_bias, 

205 "activation": tf.keras.activations.serialize(self.activation), 

206 "kernel_initializer": tf.keras.initializers.serialize( 

207 self.kernel_initializer 

208 ), 

209 "recurrent_initializer": tf.keras.initializers.serialize( 

210 self.recurrent_initializer 

211 ), 

212 "bias_initializer": tf.keras.initializers.serialize(self.bias_initializer), 

213 } 

214 base_config = super().get_config() 

215 return {**base_config, **config}