Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/esn.py: 65%
51 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements Echo State recurrent Network (ESN) layer."""
17import tensorflow as tf
18from tensorflow_addons.rnn import ESNCell
19from typeguard import typechecked
21from tensorflow_addons.utils.types import (
22 Activation,
23 FloatTensorLike,
24 TensorLike,
25 Initializer,
26)
29@tf.keras.utils.register_keras_serializable(package="Addons")
30class ESN(tf.keras.layers.RNN):
31 """Echo State Network layer.
33 This implements the recurrent layer using the ESNCell.
35 This is based on the paper
36 H. Jaeger
37 ["The "echo state" approach to analysing and training recurrent neural networks"]
38 (https://www.researchgate.net/publication/215385037).
39 GMD Report148, German National Research Center for Information Technology, 2001.
41 Args:
42 units: Positive integer, dimensionality of the reservoir.
43 connectivity: Float between 0 and 1.
44 Connection probability between two reservoir units.
45 Default: 0.1.
46 leaky: Float between 0 and 1.
47 Leaking rate of the reservoir.
48 If you pass 1, it's the special case the model does not have leaky integration.
49 Default: 1.
50 spectral_radius: Float between 0 and 1.
51 Desired spectral radius of recurrent weight matrix.
52 Default: 0.9.
53 use_norm2: Boolean, whether to use the p-norm function (with p=2) as an upper
54 bound of the spectral radius so that the echo state property is satisfied.
55 It avoids to compute the eigenvalues which has an exponential complexity.
56 Default: False.
57 use_bias: Boolean, whether the layer uses a bias vector.
58 Default: True.
59 activation: Activation function to use.
60 Default: hyperbolic tangent (`tanh`).
61 If you pass `None`, no activation is applied
62 (ie. "linear" activation: `a(x) = x`).
63 kernel_initializer: Initializer for the `kernel` weights matrix,
64 used for the linear transformation of the inputs.
65 Default: `glorot_uniform`.
66 recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix,
67 used for the linear transformation of the recurrent state.
68 Default: `glorot_uniform`.
69 bias_initializer: Initializer for the bias vector.
70 Default: `zeros`.
71 return_sequences: Boolean. Whether to return the last output.
72 in the output sequence, or the full sequence.
73 go_backwards: Boolean (default False).
74 If True, process the input sequence backwards and return the
75 reversed sequence.
76 unroll: Boolean (default False).
77 If True, the network will be unrolled,
78 else a symbolic loop will be used.
79 Unrolling can speed-up a RNN,
80 although it tends to be more memory-intensive.
81 Unrolling is only suitable for short sequences.
83 Call arguments:
84 inputs: A 3D tensor.
85 mask: Binary tensor of shape `(samples, timesteps)` indicating whether
86 a given timestep should be masked.
87 training: Python boolean indicating whether the layer should behave in
88 training mode or in inference mode. This argument is passed to the cell
89 when calling it. This is only relevant if `dropout` or
90 `recurrent_dropout` is used.
91 initial_state: List of initial state tensors to be passed to the first
92 call of the cell.
93 """
95 @typechecked
96 def __init__(
97 self,
98 units: TensorLike,
99 connectivity: FloatTensorLike = 0.1,
100 leaky: FloatTensorLike = 1,
101 spectral_radius: FloatTensorLike = 0.9,
102 use_norm2: bool = False,
103 use_bias: bool = True,
104 activation: Activation = "tanh",
105 kernel_initializer: Initializer = "glorot_uniform",
106 recurrent_initializer: Initializer = "glorot_uniform",
107 bias_initializer: Initializer = "zeros",
108 return_sequences=False,
109 go_backwards=False,
110 unroll=False,
111 **kwargs,
112 ):
113 cell = ESNCell(
114 units,
115 connectivity=connectivity,
116 leaky=leaky,
117 spectral_radius=spectral_radius,
118 use_norm2=use_norm2,
119 use_bias=use_bias,
120 activation=activation,
121 kernel_initializer=kernel_initializer,
122 recurrent_initializer=recurrent_initializer,
123 bias_initializer=bias_initializer,
124 dtype=kwargs.get("dtype"),
125 )
126 super().__init__(
127 cell,
128 return_sequences=return_sequences,
129 go_backwards=go_backwards,
130 unroll=unroll,
131 **kwargs,
132 )
134 def call(self, inputs, mask=None, training=None, initial_state=None):
135 return super().call(
136 inputs,
137 mask=mask,
138 training=training,
139 initial_state=initial_state,
140 constants=None,
141 )
143 @property
144 def units(self):
145 return self.cell.units
147 @property
148 def connectivity(self):
149 return self.cell.connectivity
151 @property
152 def leaky(self):
153 return self.cell.leaky
155 @property
156 def spectral_radius(self):
157 return self.cell.spectral_radius
159 @property
160 def use_norm2(self):
161 return self.cell.use_norm2
163 @property
164 def use_bias(self):
165 return self.cell.use_bias
167 @property
168 def activation(self):
169 return self.cell.activation
171 @property
172 def kernel_initializer(self):
173 return self.cell.kernel_initializer
175 @property
176 def recurrent_initializer(self):
177 return self.cell.recurrent_initializer
179 @property
180 def bias_initializer(self):
181 return self.cell.bias_initializer
183 def get_config(self):
184 config = {
185 "units": self.units,
186 "connectivity": self.connectivity,
187 "leaky": self.leaky,
188 "spectral_radius": self.spectral_radius,
189 "use_norm2": self.use_norm2,
190 "use_bias": self.use_bias,
191 "activation": tf.keras.activations.serialize(self.activation),
192 "kernel_initializer": tf.keras.initializers.serialize(
193 self.kernel_initializer
194 ),
195 "recurrent_initializer": tf.keras.initializers.serialize(
196 self.recurrent_initializer
197 ),
198 "bias_initializer": tf.keras.initializers.serialize(self.bias_initializer),
199 }
200 base_config = super().get_config()
201 del base_config["cell"]
202 return {**base_config, **config}
204 @classmethod
205 def from_config(cls, config):
206 return cls(**config)