Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/simple_rnn.py: 41%
141 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Fully connected RNN layer."""
18import tensorflow.compat.v2 as tf
20from keras.src import activations
21from keras.src import backend
22from keras.src import constraints
23from keras.src import initializers
24from keras.src import regularizers
25from keras.src.engine import base_layer
26from keras.src.engine.input_spec import InputSpec
27from keras.src.layers.rnn import rnn_utils
28from keras.src.layers.rnn.base_rnn import RNN
29from keras.src.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin
30from keras.src.utils import tf_utils
32# isort: off
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.util.tf_export import keras_export
37@keras_export("keras.layers.SimpleRNNCell")
38class SimpleRNNCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
39 """Cell class for SimpleRNN.
41 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
42 for details about the usage of RNN API.
44 This class processes one step within the whole time sequence input, whereas
45 `tf.keras.layer.SimpleRNN` processes the whole sequence.
47 Args:
48 units: Positive integer, dimensionality of the output space.
49 activation: Activation function to use.
50 Default: hyperbolic tangent (`tanh`).
51 If you pass `None`, no activation is applied
52 (ie. "linear" activation: `a(x) = x`).
53 use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
54 kernel_initializer: Initializer for the `kernel` weights matrix,
55 used for the linear transformation of the inputs. Default:
56 `glorot_uniform`.
57 recurrent_initializer: Initializer for the `recurrent_kernel`
58 weights matrix, used for the linear transformation of the recurrent
59 state. Default: `orthogonal`.
60 bias_initializer: Initializer for the bias vector. Default: `zeros`.
61 kernel_regularizer: Regularizer function applied to the `kernel` weights
62 matrix. Default: `None`.
63 recurrent_regularizer: Regularizer function applied to the
64 `recurrent_kernel` weights matrix. Default: `None`.
65 bias_regularizer: Regularizer function applied to the bias vector.
66 Default: `None`.
67 kernel_constraint: Constraint function applied to the `kernel` weights
68 matrix. Default: `None`.
69 recurrent_constraint: Constraint function applied to the
70 `recurrent_kernel` weights matrix. Default: `None`.
71 bias_constraint: Constraint function applied to the bias vector. Default:
72 `None`.
73 dropout: Float between 0 and 1. Fraction of the units to drop for the
74 linear transformation of the inputs. Default: 0.
75 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
76 for the linear transformation of the recurrent state. Default: 0.
78 Call arguments:
79 inputs: A 2D tensor, with shape of `[batch, feature]`.
80 states: A 2D tensor with shape of `[batch, units]`, which is the state
81 from the previous time step. For timestep 0, the initial state provided
82 by user will be feed to cell.
83 training: Python boolean indicating whether the layer should behave in
84 training mode or in inference mode. Only relevant when `dropout` or
85 `recurrent_dropout` is used.
87 Examples:
89 ```python
90 inputs = np.random.random([32, 10, 8]).astype(np.float32)
91 rnn = tf.keras.layers.RNN(tf.keras.layers.SimpleRNNCell(4))
93 output = rnn(inputs) # The output has shape `[32, 4]`.
95 rnn = tf.keras.layers.RNN(
96 tf.keras.layers.SimpleRNNCell(4),
97 return_sequences=True,
98 return_state=True)
100 # whole_sequence_output has shape `[32, 10, 4]`.
101 # final_state has shape `[32, 4]`.
102 whole_sequence_output, final_state = rnn(inputs)
103 ```
104 """
106 def __init__(
107 self,
108 units,
109 activation="tanh",
110 use_bias=True,
111 kernel_initializer="glorot_uniform",
112 recurrent_initializer="orthogonal",
113 bias_initializer="zeros",
114 kernel_regularizer=None,
115 recurrent_regularizer=None,
116 bias_regularizer=None,
117 kernel_constraint=None,
118 recurrent_constraint=None,
119 bias_constraint=None,
120 dropout=0.0,
121 recurrent_dropout=0.0,
122 **kwargs,
123 ):
124 if units <= 0:
125 raise ValueError(
126 "Received an invalid value for argument `units`, "
127 f"expected a positive integer, got {units}."
128 )
129 # By default use cached variable under v2 mode, see b/143699808.
130 if tf.compat.v1.executing_eagerly_outside_functions():
131 self._enable_caching_device = kwargs.pop(
132 "enable_caching_device", True
133 )
134 else:
135 self._enable_caching_device = kwargs.pop(
136 "enable_caching_device", False
137 )
138 super().__init__(**kwargs)
139 self.units = units
140 self.activation = activations.get(activation)
141 self.use_bias = use_bias
143 self.kernel_initializer = initializers.get(kernel_initializer)
144 self.recurrent_initializer = initializers.get(recurrent_initializer)
145 self.bias_initializer = initializers.get(bias_initializer)
147 self.kernel_regularizer = regularizers.get(kernel_regularizer)
148 self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
149 self.bias_regularizer = regularizers.get(bias_regularizer)
151 self.kernel_constraint = constraints.get(kernel_constraint)
152 self.recurrent_constraint = constraints.get(recurrent_constraint)
153 self.bias_constraint = constraints.get(bias_constraint)
155 self.dropout = min(1.0, max(0.0, dropout))
156 self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
157 self.state_size = self.units
158 self.output_size = self.units
160 @tf_utils.shape_type_conversion
161 def build(self, input_shape):
162 super().build(input_shape)
163 default_caching_device = rnn_utils.caching_device(self)
164 self.kernel = self.add_weight(
165 shape=(input_shape[-1], self.units),
166 name="kernel",
167 initializer=self.kernel_initializer,
168 regularizer=self.kernel_regularizer,
169 constraint=self.kernel_constraint,
170 caching_device=default_caching_device,
171 )
172 self.recurrent_kernel = self.add_weight(
173 shape=(self.units, self.units),
174 name="recurrent_kernel",
175 initializer=self.recurrent_initializer,
176 regularizer=self.recurrent_regularizer,
177 constraint=self.recurrent_constraint,
178 caching_device=default_caching_device,
179 )
180 if self.use_bias:
181 self.bias = self.add_weight(
182 shape=(self.units,),
183 name="bias",
184 initializer=self.bias_initializer,
185 regularizer=self.bias_regularizer,
186 constraint=self.bias_constraint,
187 caching_device=default_caching_device,
188 )
189 else:
190 self.bias = None
191 self.built = True
193 def call(self, inputs, states, training=None):
194 prev_output = states[0] if tf.nest.is_nested(states) else states
195 dp_mask = self.get_dropout_mask_for_cell(inputs, training)
196 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
197 prev_output, training
198 )
200 if dp_mask is not None:
201 h = backend.dot(inputs * dp_mask, self.kernel)
202 else:
203 h = backend.dot(inputs, self.kernel)
204 if self.bias is not None:
205 h = backend.bias_add(h, self.bias)
207 if rec_dp_mask is not None:
208 prev_output = prev_output * rec_dp_mask
209 output = h + backend.dot(prev_output, self.recurrent_kernel)
210 if self.activation is not None:
211 output = self.activation(output)
213 new_state = [output] if tf.nest.is_nested(states) else output
214 return output, new_state
216 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
217 return rnn_utils.generate_zero_filled_state_for_cell(
218 self, inputs, batch_size, dtype
219 )
221 def get_config(self):
222 config = {
223 "units": self.units,
224 "activation": activations.serialize(self.activation),
225 "use_bias": self.use_bias,
226 "kernel_initializer": initializers.serialize(
227 self.kernel_initializer
228 ),
229 "recurrent_initializer": initializers.serialize(
230 self.recurrent_initializer
231 ),
232 "bias_initializer": initializers.serialize(self.bias_initializer),
233 "kernel_regularizer": regularizers.serialize(
234 self.kernel_regularizer
235 ),
236 "recurrent_regularizer": regularizers.serialize(
237 self.recurrent_regularizer
238 ),
239 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
240 "kernel_constraint": constraints.serialize(self.kernel_constraint),
241 "recurrent_constraint": constraints.serialize(
242 self.recurrent_constraint
243 ),
244 "bias_constraint": constraints.serialize(self.bias_constraint),
245 "dropout": self.dropout,
246 "recurrent_dropout": self.recurrent_dropout,
247 }
248 config.update(rnn_utils.config_for_enable_caching_device(self))
249 base_config = super().get_config()
250 return dict(list(base_config.items()) + list(config.items()))
253@keras_export("keras.layers.SimpleRNN")
254class SimpleRNN(RNN):
255 """Fully-connected RNN where the output is to be fed back to input.
257 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
258 for details about the usage of RNN API.
260 Args:
261 units: Positive integer, dimensionality of the output space.
262 activation: Activation function to use.
263 Default: hyperbolic tangent (`tanh`).
264 If you pass None, no activation is applied
265 (ie. "linear" activation: `a(x) = x`).
266 use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
267 kernel_initializer: Initializer for the `kernel` weights matrix,
268 used for the linear transformation of the inputs. Default:
269 `glorot_uniform`.
270 recurrent_initializer: Initializer for the `recurrent_kernel`
271 weights matrix, used for the linear transformation of the recurrent
272 state. Default: `orthogonal`.
273 bias_initializer: Initializer for the bias vector. Default: `zeros`.
274 kernel_regularizer: Regularizer function applied to the `kernel` weights
275 matrix. Default: `None`.
276 recurrent_regularizer: Regularizer function applied to the
277 `recurrent_kernel` weights matrix. Default: `None`.
278 bias_regularizer: Regularizer function applied to the bias vector.
279 Default: `None`.
280 activity_regularizer: Regularizer function applied to the output of the
281 layer (its "activation"). Default: `None`.
282 kernel_constraint: Constraint function applied to the `kernel` weights
283 matrix. Default: `None`.
284 recurrent_constraint: Constraint function applied to the
285 `recurrent_kernel` weights matrix. Default: `None`.
286 bias_constraint: Constraint function applied to the bias vector. Default:
287 `None`.
288 dropout: Float between 0 and 1.
289 Fraction of the units to drop for the linear transformation of the
290 inputs. Default: 0.
291 recurrent_dropout: Float between 0 and 1.
292 Fraction of the units to drop for the linear transformation of the
293 recurrent state. Default: 0.
294 return_sequences: Boolean. Whether to return the last output
295 in the output sequence, or the full sequence. Default: `False`.
296 return_state: Boolean. Whether to return the last state
297 in addition to the output. Default: `False`
298 go_backwards: Boolean (default False).
299 If True, process the input sequence backwards and return the
300 reversed sequence.
301 stateful: Boolean (default False). If True, the last state
302 for each sample at index i in a batch will be used as initial
303 state for the sample of index i in the following batch.
304 unroll: Boolean (default False).
305 If True, the network will be unrolled,
306 else a symbolic loop will be used.
307 Unrolling can speed-up a RNN,
308 although it tends to be more memory-intensive.
309 Unrolling is only suitable for short sequences.
311 Call arguments:
312 inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
313 mask: Binary tensor of shape `[batch, timesteps]` indicating whether
314 a given timestep should be masked. An individual `True` entry indicates
315 that the corresponding timestep should be utilized, while a `False`
316 entry indicates that the corresponding timestep should be ignored.
317 training: Python boolean indicating whether the layer should behave in
318 training mode or in inference mode. This argument is passed to the cell
319 when calling it. This is only relevant if `dropout` or
320 `recurrent_dropout` is used.
321 initial_state: List of initial state tensors to be passed to the first
322 call of the cell.
324 Examples:
326 ```python
327 inputs = np.random.random([32, 10, 8]).astype(np.float32)
328 simple_rnn = tf.keras.layers.SimpleRNN(4)
330 output = simple_rnn(inputs) # The output has shape `[32, 4]`.
332 simple_rnn = tf.keras.layers.SimpleRNN(
333 4, return_sequences=True, return_state=True)
335 # whole_sequence_output has shape `[32, 10, 4]`.
336 # final_state has shape `[32, 4]`.
337 whole_sequence_output, final_state = simple_rnn(inputs)
338 ```
339 """
341 def __init__(
342 self,
343 units,
344 activation="tanh",
345 use_bias=True,
346 kernel_initializer="glorot_uniform",
347 recurrent_initializer="orthogonal",
348 bias_initializer="zeros",
349 kernel_regularizer=None,
350 recurrent_regularizer=None,
351 bias_regularizer=None,
352 activity_regularizer=None,
353 kernel_constraint=None,
354 recurrent_constraint=None,
355 bias_constraint=None,
356 dropout=0.0,
357 recurrent_dropout=0.0,
358 return_sequences=False,
359 return_state=False,
360 go_backwards=False,
361 stateful=False,
362 unroll=False,
363 **kwargs,
364 ):
365 if "implementation" in kwargs:
366 kwargs.pop("implementation")
367 logging.warning(
368 "The `implementation` argument "
369 "in `SimpleRNN` has been deprecated. "
370 "Please remove it from your layer call."
371 )
372 if "enable_caching_device" in kwargs:
373 cell_kwargs = {
374 "enable_caching_device": kwargs.pop("enable_caching_device")
375 }
376 else:
377 cell_kwargs = {}
378 cell = SimpleRNNCell(
379 units,
380 activation=activation,
381 use_bias=use_bias,
382 kernel_initializer=kernel_initializer,
383 recurrent_initializer=recurrent_initializer,
384 bias_initializer=bias_initializer,
385 kernel_regularizer=kernel_regularizer,
386 recurrent_regularizer=recurrent_regularizer,
387 bias_regularizer=bias_regularizer,
388 kernel_constraint=kernel_constraint,
389 recurrent_constraint=recurrent_constraint,
390 bias_constraint=bias_constraint,
391 dropout=dropout,
392 recurrent_dropout=recurrent_dropout,
393 dtype=kwargs.get("dtype"),
394 trainable=kwargs.get("trainable", True),
395 name="simple_rnn_cell",
396 **cell_kwargs,
397 )
398 super().__init__(
399 cell,
400 return_sequences=return_sequences,
401 return_state=return_state,
402 go_backwards=go_backwards,
403 stateful=stateful,
404 unroll=unroll,
405 **kwargs,
406 )
407 self.activity_regularizer = regularizers.get(activity_regularizer)
408 self.input_spec = [InputSpec(ndim=3)]
410 def call(self, inputs, mask=None, training=None, initial_state=None):
411 return super().call(
412 inputs, mask=mask, training=training, initial_state=initial_state
413 )
415 @property
416 def units(self):
417 return self.cell.units
419 @property
420 def activation(self):
421 return self.cell.activation
423 @property
424 def use_bias(self):
425 return self.cell.use_bias
427 @property
428 def kernel_initializer(self):
429 return self.cell.kernel_initializer
431 @property
432 def recurrent_initializer(self):
433 return self.cell.recurrent_initializer
435 @property
436 def bias_initializer(self):
437 return self.cell.bias_initializer
439 @property
440 def kernel_regularizer(self):
441 return self.cell.kernel_regularizer
443 @property
444 def recurrent_regularizer(self):
445 return self.cell.recurrent_regularizer
447 @property
448 def bias_regularizer(self):
449 return self.cell.bias_regularizer
451 @property
452 def kernel_constraint(self):
453 return self.cell.kernel_constraint
455 @property
456 def recurrent_constraint(self):
457 return self.cell.recurrent_constraint
459 @property
460 def bias_constraint(self):
461 return self.cell.bias_constraint
463 @property
464 def dropout(self):
465 return self.cell.dropout
467 @property
468 def recurrent_dropout(self):
469 return self.cell.recurrent_dropout
471 def get_config(self):
472 config = {
473 "units": self.units,
474 "activation": activations.serialize(self.activation),
475 "use_bias": self.use_bias,
476 "kernel_initializer": initializers.serialize(
477 self.kernel_initializer
478 ),
479 "recurrent_initializer": initializers.serialize(
480 self.recurrent_initializer
481 ),
482 "bias_initializer": initializers.serialize(self.bias_initializer),
483 "kernel_regularizer": regularizers.serialize(
484 self.kernel_regularizer
485 ),
486 "recurrent_regularizer": regularizers.serialize(
487 self.recurrent_regularizer
488 ),
489 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
490 "activity_regularizer": regularizers.serialize(
491 self.activity_regularizer
492 ),
493 "kernel_constraint": constraints.serialize(self.kernel_constraint),
494 "recurrent_constraint": constraints.serialize(
495 self.recurrent_constraint
496 ),
497 "bias_constraint": constraints.serialize(self.bias_constraint),
498 "dropout": self.dropout,
499 "recurrent_dropout": self.recurrent_dropout,
500 }
501 base_config = super().get_config()
502 config.update(rnn_utils.config_for_enable_caching_device(self.cell))
503 del base_config["cell"]
504 return dict(list(base_config.items()) + list(config.items()))
506 @classmethod
507 def from_config(cls, config):
508 if "implementation" in config:
509 config.pop("implementation")
510 return cls(**config)