Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/esn.py: 65%

51 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implements Echo State recurrent Network (ESN) layer.""" 

16 

17import tensorflow as tf 

18from tensorflow_addons.rnn import ESNCell 

19from typeguard import typechecked 

20 

21from tensorflow_addons.utils.types import ( 

22 Activation, 

23 FloatTensorLike, 

24 TensorLike, 

25 Initializer, 

26) 

27 

28 

29@tf.keras.utils.register_keras_serializable(package="Addons") 

30class ESN(tf.keras.layers.RNN): 

31 """Echo State Network layer. 

32 

33 This implements the recurrent layer using the ESNCell. 

34 

35 This is based on the paper 

36 H. Jaeger 

37 ["The "echo state" approach to analysing and training recurrent neural networks"] 

38 (https://www.researchgate.net/publication/215385037). 

39 GMD Report148, German National Research Center for Information Technology, 2001. 

40 

41 Args: 

42 units: Positive integer, dimensionality of the reservoir. 

43 connectivity: Float between 0 and 1. 

44 Connection probability between two reservoir units. 

45 Default: 0.1. 

46 leaky: Float between 0 and 1. 

47 Leaking rate of the reservoir. 

48 If you pass 1, it's the special case the model does not have leaky integration. 

49 Default: 1. 

50 spectral_radius: Float between 0 and 1. 

51 Desired spectral radius of recurrent weight matrix. 

52 Default: 0.9. 

53 use_norm2: Boolean, whether to use the p-norm function (with p=2) as an upper 

54 bound of the spectral radius so that the echo state property is satisfied. 

55 It avoids to compute the eigenvalues which has an exponential complexity. 

56 Default: False. 

57 use_bias: Boolean, whether the layer uses a bias vector. 

58 Default: True. 

59 activation: Activation function to use. 

60 Default: hyperbolic tangent (`tanh`). 

61 If you pass `None`, no activation is applied 

62 (ie. "linear" activation: `a(x) = x`). 

63 kernel_initializer: Initializer for the `kernel` weights matrix, 

64 used for the linear transformation of the inputs. 

65 Default: `glorot_uniform`. 

66 recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, 

67 used for the linear transformation of the recurrent state. 

68 Default: `glorot_uniform`. 

69 bias_initializer: Initializer for the bias vector. 

70 Default: `zeros`. 

71 return_sequences: Boolean. Whether to return the last output. 

72 in the output sequence, or the full sequence. 

73 go_backwards: Boolean (default False). 

74 If True, process the input sequence backwards and return the 

75 reversed sequence. 

76 unroll: Boolean (default False). 

77 If True, the network will be unrolled, 

78 else a symbolic loop will be used. 

79 Unrolling can speed-up a RNN, 

80 although it tends to be more memory-intensive. 

81 Unrolling is only suitable for short sequences. 

82 

83 Call arguments: 

84 inputs: A 3D tensor. 

85 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 

86 a given timestep should be masked. 

87 training: Python boolean indicating whether the layer should behave in 

88 training mode or in inference mode. This argument is passed to the cell 

89 when calling it. This is only relevant if `dropout` or 

90 `recurrent_dropout` is used. 

91 initial_state: List of initial state tensors to be passed to the first 

92 call of the cell. 

93 """ 

94 

95 @typechecked 

96 def __init__( 

97 self, 

98 units: TensorLike, 

99 connectivity: FloatTensorLike = 0.1, 

100 leaky: FloatTensorLike = 1, 

101 spectral_radius: FloatTensorLike = 0.9, 

102 use_norm2: bool = False, 

103 use_bias: bool = True, 

104 activation: Activation = "tanh", 

105 kernel_initializer: Initializer = "glorot_uniform", 

106 recurrent_initializer: Initializer = "glorot_uniform", 

107 bias_initializer: Initializer = "zeros", 

108 return_sequences=False, 

109 go_backwards=False, 

110 unroll=False, 

111 **kwargs, 

112 ): 

113 cell = ESNCell( 

114 units, 

115 connectivity=connectivity, 

116 leaky=leaky, 

117 spectral_radius=spectral_radius, 

118 use_norm2=use_norm2, 

119 use_bias=use_bias, 

120 activation=activation, 

121 kernel_initializer=kernel_initializer, 

122 recurrent_initializer=recurrent_initializer, 

123 bias_initializer=bias_initializer, 

124 dtype=kwargs.get("dtype"), 

125 ) 

126 super().__init__( 

127 cell, 

128 return_sequences=return_sequences, 

129 go_backwards=go_backwards, 

130 unroll=unroll, 

131 **kwargs, 

132 ) 

133 

134 def call(self, inputs, mask=None, training=None, initial_state=None): 

135 return super().call( 

136 inputs, 

137 mask=mask, 

138 training=training, 

139 initial_state=initial_state, 

140 constants=None, 

141 ) 

142 

143 @property 

144 def units(self): 

145 return self.cell.units 

146 

147 @property 

148 def connectivity(self): 

149 return self.cell.connectivity 

150 

151 @property 

152 def leaky(self): 

153 return self.cell.leaky 

154 

155 @property 

156 def spectral_radius(self): 

157 return self.cell.spectral_radius 

158 

159 @property 

160 def use_norm2(self): 

161 return self.cell.use_norm2 

162 

163 @property 

164 def use_bias(self): 

165 return self.cell.use_bias 

166 

167 @property 

168 def activation(self): 

169 return self.cell.activation 

170 

171 @property 

172 def kernel_initializer(self): 

173 return self.cell.kernel_initializer 

174 

175 @property 

176 def recurrent_initializer(self): 

177 return self.cell.recurrent_initializer 

178 

179 @property 

180 def bias_initializer(self): 

181 return self.cell.bias_initializer 

182 

183 def get_config(self): 

184 config = { 

185 "units": self.units, 

186 "connectivity": self.connectivity, 

187 "leaky": self.leaky, 

188 "spectral_radius": self.spectral_radius, 

189 "use_norm2": self.use_norm2, 

190 "use_bias": self.use_bias, 

191 "activation": tf.keras.activations.serialize(self.activation), 

192 "kernel_initializer": tf.keras.initializers.serialize( 

193 self.kernel_initializer 

194 ), 

195 "recurrent_initializer": tf.keras.initializers.serialize( 

196 self.recurrent_initializer 

197 ), 

198 "bias_initializer": tf.keras.initializers.serialize(self.bias_initializer), 

199 } 

200 base_config = super().get_config() 

201 del base_config["cell"] 

202 return {**base_config, **config} 

203 

204 @classmethod 

205 def from_config(cls, config): 

206 return cls(**config)