Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/rnn/esn_cell.py: 26%
62 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements ESN Cell."""
17import tensorflow as tf
18from typeguard import typechecked
20from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell
21from tensorflow_addons.utils.types import (
22 Activation,
23 Initializer,
24)
27@tf.keras.utils.register_keras_serializable(package="Addons")
28class ESNCell(AbstractRNNCell):
29 """Echo State recurrent Network (ESN) cell.
30 This implements the recurrent cell from the paper:
31 H. Jaeger
32 "The "echo state" approach to analysing and training recurrent neural networks".
33 GMD Report148, German National Research Center for Information Technology, 2001.
34 https://www.researchgate.net/publication/215385037
36 Example:
38 >>> inputs = np.random.random([30,23,9]).astype(np.float32)
39 >>> ESNCell = tfa.rnn.ESNCell(4)
40 >>> rnn = tf.keras.layers.RNN(ESNCell, return_sequences=True, return_state=True)
41 >>> outputs, memory_state = rnn(inputs)
42 >>> outputs.shape
43 TensorShape([30, 23, 4])
44 >>> memory_state.shape
45 TensorShape([30, 4])
47 Args:
48 units: Positive integer, dimensionality in the reservoir.
49 connectivity: Float between 0 and 1.
50 Connection probability between two reservoir units.
51 Default: 0.1.
52 leaky: Float between 0 and 1.
53 Leaking rate of the reservoir.
54 If you pass 1, it is the special case the model does not have leaky
55 integration.
56 Default: 1.
57 spectral_radius: Float between 0 and 1.
58 Desired spectral radius of recurrent weight matrix.
59 Default: 0.9.
60 use_norm2: Boolean, whether to use the p-norm function (with p=2) as an upper
61 bound of the spectral radius so that the echo state property is satisfied.
62 It avoids to compute the eigenvalues which has an exponential complexity.
63 Default: False.
64 use_bias: Boolean, whether the layer uses a bias vector.
65 Default: True.
66 activation: Activation function to use.
67 Default: hyperbolic tangent (`tanh`).
68 kernel_initializer: Initializer for the `kernel` weights matrix,
69 used for the linear transformation of the inputs.
70 Default: `glorot_uniform`.
71 recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix,
72 used for the linear transformation of the recurrent state.
73 Default: `glorot_uniform`.
74 bias_initializer: Initializer for the bias vector.
75 Default: `zeros`.
76 Call arguments:
77 inputs: A 2D tensor (batch x num_units).
78 states: List of state tensors corresponding to the previous timestep.
79 """
81 @typechecked
82 def __init__(
83 self,
84 units: int,
85 connectivity: float = 0.1,
86 leaky: float = 1,
87 spectral_radius: float = 0.9,
88 use_norm2: bool = False,
89 use_bias: bool = True,
90 activation: Activation = "tanh",
91 kernel_initializer: Initializer = "glorot_uniform",
92 recurrent_initializer: Initializer = "glorot_uniform",
93 bias_initializer: Initializer = "zeros",
94 **kwargs,
95 ):
96 super().__init__(**kwargs)
97 self.units = units
98 self.connectivity = connectivity
99 self.leaky = leaky
100 self.spectral_radius = spectral_radius
101 self.use_norm2 = use_norm2
102 self.use_bias = use_bias
103 self.activation = tf.keras.activations.get(activation)
104 self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
105 self.recurrent_initializer = tf.keras.initializers.get(recurrent_initializer)
106 self.bias_initializer = tf.keras.initializers.get(bias_initializer)
108 self._state_size = units
109 self._output_size = units
111 @property
112 def state_size(self):
113 return self._state_size
115 @property
116 def output_size(self):
117 return self._output_size
119 def build(self, inputs_shape):
120 input_size = tf.compat.dimension_value(tf.TensorShape(inputs_shape)[-1])
121 if input_size is None:
122 raise ValueError(
123 "Could not infer input size from inputs.get_shape()[-1]. Shape received is %s"
124 % inputs_shape
125 )
127 def _esn_recurrent_initializer(shape, dtype, partition_info=None):
128 recurrent_weights = tf.keras.initializers.get(self.recurrent_initializer)(
129 shape, dtype
130 )
132 connectivity_mask = tf.cast(
133 tf.math.less_equal(tf.random.uniform(shape), self.connectivity),
134 dtype,
135 )
136 recurrent_weights = tf.math.multiply(recurrent_weights, connectivity_mask)
138 # Satisfy the necessary condition for the echo state property `max(eig(W)) < 1`
139 if self.use_norm2:
140 # This condition is approximated scaling the norm 2 of the reservoir matrix
141 # which is an upper bound of the spectral radius.
142 recurrent_norm2 = tf.math.sqrt(
143 tf.math.reduce_sum(tf.math.square(recurrent_weights))
144 )
145 is_norm2_0 = tf.cast(tf.math.equal(recurrent_norm2, 0), dtype)
146 scaling_factor = tf.cast(self.spectral_radius, dtype) / (
147 recurrent_norm2 + 1 * is_norm2_0
148 )
149 else:
150 abs_eig_values = tf.abs(tf.linalg.eig(recurrent_weights)[0])
151 scaling_factor = tf.math.divide_no_nan(
152 tf.cast(self.spectral_radius, dtype), tf.reduce_max(abs_eig_values)
153 )
155 recurrent_weights = tf.multiply(recurrent_weights, scaling_factor)
157 return recurrent_weights
159 self.recurrent_kernel = self.add_weight(
160 name="recurrent_kernel",
161 shape=[self.units, self.units],
162 initializer=_esn_recurrent_initializer,
163 trainable=False,
164 dtype=self.dtype,
165 )
166 self.kernel = self.add_weight(
167 name="kernel",
168 shape=[input_size, self.units],
169 initializer=self.kernel_initializer,
170 trainable=False,
171 dtype=self.dtype,
172 )
174 if self.use_bias:
175 self.bias = self.add_weight(
176 name="bias",
177 shape=[self.units],
178 initializer=self.bias_initializer,
179 trainable=False,
180 dtype=self.dtype,
181 )
183 self.built = True
185 def call(self, inputs, state):
186 in_matrix = tf.concat([inputs, state[0]], axis=1)
187 weights_matrix = tf.concat([self.kernel, self.recurrent_kernel], axis=0)
189 output = tf.linalg.matmul(in_matrix, weights_matrix)
190 if self.use_bias:
191 output = output + self.bias
192 output = self.activation(output)
193 output = (1 - self.leaky) * state[0] + self.leaky * output
195 return output, output
197 def get_config(self):
198 config = {
199 "units": self.units,
200 "connectivity": self.connectivity,
201 "leaky": self.leaky,
202 "spectral_radius": self.spectral_radius,
203 "use_norm2": self.use_norm2,
204 "use_bias": self.use_bias,
205 "activation": tf.keras.activations.serialize(self.activation),
206 "kernel_initializer": tf.keras.initializers.serialize(
207 self.kernel_initializer
208 ),
209 "recurrent_initializer": tf.keras.initializers.serialize(
210 self.recurrent_initializer
211 ),
212 "bias_initializer": tf.keras.initializers.serialize(self.bias_initializer),
213 }
214 base_config = super().get_config()
215 return {**base_config, **config}