Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/lstm_v1.py: 60%
92 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Long Short-Term Memory V1 layer."""
18from keras.src import activations
19from keras.src import constraints
20from keras.src import initializers
21from keras.src import regularizers
22from keras.src.engine.input_spec import InputSpec
23from keras.src.layers.rnn import lstm
24from keras.src.layers.rnn import rnn_utils
25from keras.src.layers.rnn.base_rnn import RNN
27# isort: off
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.util.tf_export import keras_export
32@keras_export(v1=["keras.layers.LSTMCell"])
33class LSTMCell(lstm.LSTMCell):
34 """Cell class for the LSTM layer.
36 Args:
37 units: Positive integer, dimensionality of the output space.
38 activation: Activation function to use.
39 Default: hyperbolic tangent (`tanh`).
40 If you pass `None`, no activation is applied
41 (ie. "linear" activation: `a(x) = x`).
42 recurrent_activation: Activation function to use
43 for the recurrent step.
44 Default: hard sigmoid (`hard_sigmoid`).
45 If you pass `None`, no activation is applied
46 (ie. "linear" activation: `a(x) = x`).
47 use_bias: Boolean, whether the layer uses a bias vector.
48 kernel_initializer: Initializer for the `kernel` weights matrix,
49 used for the linear transformation of the inputs.
50 recurrent_initializer: Initializer for the `recurrent_kernel`
51 weights matrix,
52 used for the linear transformation of the recurrent state.
53 bias_initializer: Initializer for the bias vector.
54 unit_forget_bias: Boolean.
55 If True, add 1 to the bias of the forget gate at initialization.
56 Setting it to true will also force `bias_initializer="zeros"`.
57 This is recommended in [Jozefowicz et al., 2015](
58 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
59 kernel_regularizer: Regularizer function applied to
60 the `kernel` weights matrix.
61 recurrent_regularizer: Regularizer function applied to
62 the `recurrent_kernel` weights matrix.
63 bias_regularizer: Regularizer function applied to the bias vector.
64 kernel_constraint: Constraint function applied to
65 the `kernel` weights matrix.
66 recurrent_constraint: Constraint function applied to
67 the `recurrent_kernel` weights matrix.
68 bias_constraint: Constraint function applied to the bias vector.
69 dropout: Float between 0 and 1.
70 Fraction of the units to drop for
71 the linear transformation of the inputs.
72 recurrent_dropout: Float between 0 and 1.
73 Fraction of the units to drop for
74 the linear transformation of the recurrent state.
76 Call arguments:
77 inputs: A 2D tensor.
78 states: List of state tensors corresponding to the previous timestep.
79 training: Python boolean indicating whether the layer should behave in
80 training mode or in inference mode. Only relevant when `dropout` or
81 `recurrent_dropout` is used.
82 """
84 def __init__(
85 self,
86 units,
87 activation="tanh",
88 recurrent_activation="hard_sigmoid",
89 use_bias=True,
90 kernel_initializer="glorot_uniform",
91 recurrent_initializer="orthogonal",
92 bias_initializer="zeros",
93 unit_forget_bias=True,
94 kernel_regularizer=None,
95 recurrent_regularizer=None,
96 bias_regularizer=None,
97 kernel_constraint=None,
98 recurrent_constraint=None,
99 bias_constraint=None,
100 dropout=0.0,
101 recurrent_dropout=0.0,
102 **kwargs
103 ):
104 super().__init__(
105 units,
106 activation=activation,
107 recurrent_activation=recurrent_activation,
108 use_bias=use_bias,
109 kernel_initializer=kernel_initializer,
110 recurrent_initializer=recurrent_initializer,
111 bias_initializer=bias_initializer,
112 unit_forget_bias=unit_forget_bias,
113 kernel_regularizer=kernel_regularizer,
114 recurrent_regularizer=recurrent_regularizer,
115 bias_regularizer=bias_regularizer,
116 kernel_constraint=kernel_constraint,
117 recurrent_constraint=recurrent_constraint,
118 bias_constraint=bias_constraint,
119 dropout=dropout,
120 recurrent_dropout=recurrent_dropout,
121 implementation=kwargs.pop("implementation", 1),
122 **kwargs
123 )
126@keras_export(v1=["keras.layers.LSTM"])
127class LSTM(RNN):
128 """Long Short-Term Memory layer - Hochreiter 1997.
130 Note that this cell is not optimized for performance on GPU. Please use
131 `tf.compat.v1.keras.layers.CuDNNLSTM` for better performance on GPU.
133 Args:
134 units: Positive integer, dimensionality of the output space.
135 activation: Activation function to use.
136 Default: hyperbolic tangent (`tanh`).
137 If you pass `None`, no activation is applied
138 (ie. "linear" activation: `a(x) = x`).
139 recurrent_activation: Activation function to use
140 for the recurrent step.
141 Default: hard sigmoid (`hard_sigmoid`).
142 If you pass `None`, no activation is applied
143 (ie. "linear" activation: `a(x) = x`).
144 use_bias: Boolean, whether the layer uses a bias vector.
145 kernel_initializer: Initializer for the `kernel` weights matrix,
146 used for the linear transformation of the inputs..
147 recurrent_initializer: Initializer for the `recurrent_kernel`
148 weights matrix,
149 used for the linear transformation of the recurrent state.
150 bias_initializer: Initializer for the bias vector.
151 unit_forget_bias: Boolean.
152 If True, add 1 to the bias of the forget gate at initialization.
153 Setting it to true will also force `bias_initializer="zeros"`.
154 This is recommended in [Jozefowicz et al., 2015](
155 http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
156 kernel_regularizer: Regularizer function applied to
157 the `kernel` weights matrix.
158 recurrent_regularizer: Regularizer function applied to
159 the `recurrent_kernel` weights matrix.
160 bias_regularizer: Regularizer function applied to the bias vector.
161 activity_regularizer: Regularizer function applied to
162 the output of the layer (its "activation").
163 kernel_constraint: Constraint function applied to
164 the `kernel` weights matrix.
165 recurrent_constraint: Constraint function applied to
166 the `recurrent_kernel` weights matrix.
167 bias_constraint: Constraint function applied to the bias vector.
168 dropout: Float between 0 and 1.
169 Fraction of the units to drop for
170 the linear transformation of the inputs.
171 recurrent_dropout: Float between 0 and 1.
172 Fraction of the units to drop for
173 the linear transformation of the recurrent state.
174 return_sequences: Boolean. Whether to return the last output
175 in the output sequence, or the full sequence.
176 return_state: Boolean. Whether to return the last state
177 in addition to the output.
178 go_backwards: Boolean (default False).
179 If True, process the input sequence backwards and return the
180 reversed sequence.
181 stateful: Boolean (default False). If True, the last state
182 for each sample at index i in a batch will be used as initial
183 state for the sample of index i in the following batch.
184 unroll: Boolean (default False).
185 If True, the network will be unrolled,
186 else a symbolic loop will be used.
187 Unrolling can speed-up a RNN,
188 although it tends to be more memory-intensive.
189 Unrolling is only suitable for short sequences.
190 time_major: The shape format of the `inputs` and `outputs` tensors.
191 If True, the inputs and outputs will be in shape
192 `(timesteps, batch, ...)`, whereas in the False case, it will be
193 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
194 efficient because it avoids transposes at the beginning and end of the
195 RNN calculation. However, most TensorFlow data is batch-major, so by
196 default this function accepts input and emits output in batch-major
197 form.
199 Call arguments:
200 inputs: A 3D tensor.
201 mask: Binary tensor of shape `(samples, timesteps)` indicating whether
202 a given timestep should be masked. An individual `True` entry indicates
203 that the corresponding timestep should be utilized, while a `False`
204 entry indicates that the corresponding timestep should be ignored.
205 training: Python boolean indicating whether the layer should behave in
206 training mode or in inference mode. This argument is passed to the cell
207 when calling it. This is only relevant if `dropout` or
208 `recurrent_dropout` is used.
209 initial_state: List of initial state tensors to be passed to the first
210 call of the cell.
211 """
213 def __init__(
214 self,
215 units,
216 activation="tanh",
217 recurrent_activation="hard_sigmoid",
218 use_bias=True,
219 kernel_initializer="glorot_uniform",
220 recurrent_initializer="orthogonal",
221 bias_initializer="zeros",
222 unit_forget_bias=True,
223 kernel_regularizer=None,
224 recurrent_regularizer=None,
225 bias_regularizer=None,
226 activity_regularizer=None,
227 kernel_constraint=None,
228 recurrent_constraint=None,
229 bias_constraint=None,
230 dropout=0.0,
231 recurrent_dropout=0.0,
232 return_sequences=False,
233 return_state=False,
234 go_backwards=False,
235 stateful=False,
236 unroll=False,
237 **kwargs
238 ):
239 implementation = kwargs.pop("implementation", 1)
240 if implementation == 0:
241 logging.warning(
242 "`implementation=0` has been deprecated, "
243 "and now defaults to `implementation=1`."
244 "Please update your layer call."
245 )
246 if "enable_caching_device" in kwargs:
247 cell_kwargs = {
248 "enable_caching_device": kwargs.pop("enable_caching_device")
249 }
250 else:
251 cell_kwargs = {}
252 cell = LSTMCell(
253 units,
254 activation=activation,
255 recurrent_activation=recurrent_activation,
256 use_bias=use_bias,
257 kernel_initializer=kernel_initializer,
258 recurrent_initializer=recurrent_initializer,
259 unit_forget_bias=unit_forget_bias,
260 bias_initializer=bias_initializer,
261 kernel_regularizer=kernel_regularizer,
262 recurrent_regularizer=recurrent_regularizer,
263 bias_regularizer=bias_regularizer,
264 kernel_constraint=kernel_constraint,
265 recurrent_constraint=recurrent_constraint,
266 bias_constraint=bias_constraint,
267 dropout=dropout,
268 recurrent_dropout=recurrent_dropout,
269 implementation=implementation,
270 dtype=kwargs.get("dtype"),
271 trainable=kwargs.get("trainable", True),
272 name="lstm_cell",
273 **cell_kwargs
274 )
275 super().__init__(
276 cell,
277 return_sequences=return_sequences,
278 return_state=return_state,
279 go_backwards=go_backwards,
280 stateful=stateful,
281 unroll=unroll,
282 **kwargs
283 )
284 self.activity_regularizer = regularizers.get(activity_regularizer)
285 self.input_spec = [InputSpec(ndim=3)]
287 def call(self, inputs, mask=None, training=None, initial_state=None):
288 return super().call(
289 inputs, mask=mask, training=training, initial_state=initial_state
290 )
292 @property
293 def units(self):
294 return self.cell.units
296 @property
297 def activation(self):
298 return self.cell.activation
300 @property
301 def recurrent_activation(self):
302 return self.cell.recurrent_activation
304 @property
305 def use_bias(self):
306 return self.cell.use_bias
308 @property
309 def kernel_initializer(self):
310 return self.cell.kernel_initializer
312 @property
313 def recurrent_initializer(self):
314 return self.cell.recurrent_initializer
316 @property
317 def bias_initializer(self):
318 return self.cell.bias_initializer
320 @property
321 def unit_forget_bias(self):
322 return self.cell.unit_forget_bias
324 @property
325 def kernel_regularizer(self):
326 return self.cell.kernel_regularizer
328 @property
329 def recurrent_regularizer(self):
330 return self.cell.recurrent_regularizer
332 @property
333 def bias_regularizer(self):
334 return self.cell.bias_regularizer
336 @property
337 def kernel_constraint(self):
338 return self.cell.kernel_constraint
340 @property
341 def recurrent_constraint(self):
342 return self.cell.recurrent_constraint
344 @property
345 def bias_constraint(self):
346 return self.cell.bias_constraint
348 @property
349 def dropout(self):
350 return self.cell.dropout
352 @property
353 def recurrent_dropout(self):
354 return self.cell.recurrent_dropout
356 @property
357 def implementation(self):
358 return self.cell.implementation
360 def get_config(self):
361 config = {
362 "units": self.units,
363 "activation": activations.serialize(self.activation),
364 "recurrent_activation": activations.serialize(
365 self.recurrent_activation
366 ),
367 "use_bias": self.use_bias,
368 "kernel_initializer": initializers.serialize(
369 self.kernel_initializer
370 ),
371 "recurrent_initializer": initializers.serialize(
372 self.recurrent_initializer
373 ),
374 "bias_initializer": initializers.serialize(self.bias_initializer),
375 "unit_forget_bias": self.unit_forget_bias,
376 "kernel_regularizer": regularizers.serialize(
377 self.kernel_regularizer
378 ),
379 "recurrent_regularizer": regularizers.serialize(
380 self.recurrent_regularizer
381 ),
382 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
383 "activity_regularizer": regularizers.serialize(
384 self.activity_regularizer
385 ),
386 "kernel_constraint": constraints.serialize(self.kernel_constraint),
387 "recurrent_constraint": constraints.serialize(
388 self.recurrent_constraint
389 ),
390 "bias_constraint": constraints.serialize(self.bias_constraint),
391 "dropout": self.dropout,
392 "recurrent_dropout": self.recurrent_dropout,
393 "implementation": self.implementation,
394 }
395 config.update(rnn_utils.config_for_enable_caching_device(self.cell))
396 base_config = super().get_config()
397 del base_config["cell"]
398 return dict(list(base_config.items()) + list(config.items()))
400 @classmethod
401 def from_config(cls, config):
402 if "implementation" in config and config["implementation"] == 0:
403 config["implementation"] = 1
404 return cls(**config)