Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/gru_v1.py: 60%
92 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gated Recurrent Unit V1 layer."""
18from keras.src import activations
19from keras.src import constraints
20from keras.src import initializers
21from keras.src import regularizers
22from keras.src.engine.input_spec import InputSpec
23from keras.src.layers.rnn import gru
24from keras.src.layers.rnn import rnn_utils
25from keras.src.layers.rnn.base_rnn import RNN
27# isort: off
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.util.tf_export import keras_export
32@keras_export(v1=["keras.layers.GRUCell"])
33class GRUCell(gru.GRUCell):
34 """Cell class for the GRU layer.
36 Args:
37 units: Positive integer, dimensionality of the output space.
38 activation: Activation function to use.
39 Default: hyperbolic tangent (`tanh`).
40 If you pass None, no activation is applied
41 (ie. "linear" activation: `a(x) = x`).
42 recurrent_activation: Activation function to use
43 for the recurrent step.
44 Default: hard sigmoid (`hard_sigmoid`).
45 If you pass `None`, no activation is applied
46 (ie. "linear" activation: `a(x) = x`).
47 use_bias: Boolean, whether the layer uses a bias vector.
48 kernel_initializer: Initializer for the `kernel` weights matrix,
49 used for the linear transformation of the inputs.
50 recurrent_initializer: Initializer for the `recurrent_kernel`
51 weights matrix,
52 used for the linear transformation of the recurrent state.
53 bias_initializer: Initializer for the bias vector.
54 kernel_regularizer: Regularizer function applied to
55 the `kernel` weights matrix.
56 recurrent_regularizer: Regularizer function applied to
57 the `recurrent_kernel` weights matrix.
58 bias_regularizer: Regularizer function applied to the bias vector.
59 kernel_constraint: Constraint function applied to
60 the `kernel` weights matrix.
61 recurrent_constraint: Constraint function applied to
62 the `recurrent_kernel` weights matrix.
63 bias_constraint: Constraint function applied to the bias vector.
64 dropout: Float between 0 and 1. Fraction of the units to drop for the
65 linear transformation of the inputs.
66 recurrent_dropout: Float between 0 and 1.
67 Fraction of the units to drop for
68 the linear transformation of the recurrent state.
69 reset_after: GRU convention (whether to apply reset gate after or
70 before matrix multiplication). False = "before" (default),
71 True = "after" (cuDNN compatible).
73 Call arguments:
74 inputs: A 2D tensor.
75 states: List of state tensors corresponding to the previous timestep.
76 training: Python boolean indicating whether the layer should behave in
77 training mode or in inference mode. Only relevant when `dropout` or
78 `recurrent_dropout` is used.
79 """
81 def __init__(
82 self,
83 units,
84 activation="tanh",
85 recurrent_activation="hard_sigmoid",
86 use_bias=True,
87 kernel_initializer="glorot_uniform",
88 recurrent_initializer="orthogonal",
89 bias_initializer="zeros",
90 kernel_regularizer=None,
91 recurrent_regularizer=None,
92 bias_regularizer=None,
93 kernel_constraint=None,
94 recurrent_constraint=None,
95 bias_constraint=None,
96 dropout=0.0,
97 recurrent_dropout=0.0,
98 reset_after=False,
99 **kwargs
100 ):
101 super().__init__(
102 units,
103 activation=activation,
104 recurrent_activation=recurrent_activation,
105 use_bias=use_bias,
106 kernel_initializer=kernel_initializer,
107 recurrent_initializer=recurrent_initializer,
108 bias_initializer=bias_initializer,
109 kernel_regularizer=kernel_regularizer,
110 recurrent_regularizer=recurrent_regularizer,
111 bias_regularizer=bias_regularizer,
112 kernel_constraint=kernel_constraint,
113 recurrent_constraint=recurrent_constraint,
114 bias_constraint=bias_constraint,
115 dropout=dropout,
116 recurrent_dropout=recurrent_dropout,
117 implementation=kwargs.pop("implementation", 1),
118 reset_after=reset_after,
119 **kwargs
120 )
123@keras_export(v1=["keras.layers.GRU"])
124class GRU(RNN):
125 """Gated Recurrent Unit - Cho et al. 2014.
127 There are two variants. The default one is based on 1406.1078v3 and
128 has reset gate applied to hidden state before matrix multiplication. The
129 other one is based on original 1406.1078v1 and has the order reversed.
131 The second variant is compatible with CuDNNGRU (GPU-only) and allows
132 inference on CPU. Thus it has separate biases for `kernel` and
133 `recurrent_kernel`. Use `'reset_after'=True` and
134 `recurrent_activation='sigmoid'`.
136 Args:
137 units: Positive integer, dimensionality of the output space.
138 activation: Activation function to use.
139 Default: hyperbolic tangent (`tanh`).
140 If you pass `None`, no activation is applied
141 (ie. "linear" activation: `a(x) = x`).
142 recurrent_activation: Activation function to use
143 for the recurrent step.
144 Default: hard sigmoid (`hard_sigmoid`).
145 If you pass `None`, no activation is applied
146 (ie. "linear" activation: `a(x) = x`).
147 use_bias: Boolean, whether the layer uses a bias vector.
148 kernel_initializer: Initializer for the `kernel` weights matrix,
149 used for the linear transformation of the inputs.
150 recurrent_initializer: Initializer for the `recurrent_kernel` weights
151 matrix, used for the linear transformation of the recurrent state.
152 bias_initializer: Initializer for the bias vector.
153 kernel_regularizer: Regularizer function applied to
154 the `kernel` weights matrix.
155 recurrent_regularizer: Regularizer function applied to
156 the `recurrent_kernel` weights matrix.
157 bias_regularizer: Regularizer function applied to the bias vector.
158 activity_regularizer: Regularizer function applied to
159 the output of the layer (its "activation")..
160 kernel_constraint: Constraint function applied to
161 the `kernel` weights matrix.
162 recurrent_constraint: Constraint function applied to
163 the `recurrent_kernel` weights matrix.
164 bias_constraint: Constraint function applied to the bias vector.
165 dropout: Float between 0 and 1.
166 Fraction of the units to drop for
167 the linear transformation of the inputs.
168 recurrent_dropout: Float between 0 and 1.
169 Fraction of the units to drop for
170 the linear transformation of the recurrent state.
171 return_sequences: Boolean. Whether to return the last output
172 in the output sequence, or the full sequence.
173 return_state: Boolean. Whether to return the last state
174 in addition to the output.
175 go_backwards: Boolean (default False).
176 If True, process the input sequence backwards and return the
177 reversed sequence.
178 stateful: Boolean (default False). If True, the last state
179 for each sample at index i in a batch will be used as initial
180 state for the sample of index i in the following batch.
181 unroll: Boolean (default False).
182 If True, the network will be unrolled,
183 else a symbolic loop will be used.
184 Unrolling can speed-up a RNN,
185 although it tends to be more memory-intensive.
186 Unrolling is only suitable for short sequences.
187 time_major: The shape format of the `inputs` and `outputs` tensors.
188 If True, the inputs and outputs will be in shape
189 `(timesteps, batch, ...)`, whereas in the False case, it will be
190 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
191 efficient because it avoids transposes at the beginning and end of the
192 RNN calculation. However, most TensorFlow data is batch-major, so by
193 default this function accepts input and emits output in batch-major
194 form.
195 reset_after: GRU convention (whether to apply reset gate after or
196 before matrix multiplication). False = "before" (default),
197 True = "after" (cuDNN compatible).
199 Call arguments:
200 inputs: A 3D tensor.
201 mask: Binary tensor of shape `(samples, timesteps)` indicating whether
202 a given timestep should be masked. An individual `True` entry indicates
203 that the corresponding timestep should be utilized, while a `False`
204 entry indicates that the corresponding timestep should be ignored.
205 training: Python boolean indicating whether the layer should behave in
206 training mode or in inference mode. This argument is passed to the cell
207 when calling it. This is only relevant if `dropout` or
208 `recurrent_dropout` is used.
209 initial_state: List of initial state tensors to be passed to the first
210 call of the cell.
211 """
213 def __init__(
214 self,
215 units,
216 activation="tanh",
217 recurrent_activation="hard_sigmoid",
218 use_bias=True,
219 kernel_initializer="glorot_uniform",
220 recurrent_initializer="orthogonal",
221 bias_initializer="zeros",
222 kernel_regularizer=None,
223 recurrent_regularizer=None,
224 bias_regularizer=None,
225 activity_regularizer=None,
226 kernel_constraint=None,
227 recurrent_constraint=None,
228 bias_constraint=None,
229 dropout=0.0,
230 recurrent_dropout=0.0,
231 return_sequences=False,
232 return_state=False,
233 go_backwards=False,
234 stateful=False,
235 unroll=False,
236 reset_after=False,
237 **kwargs
238 ):
239 implementation = kwargs.pop("implementation", 1)
240 if implementation == 0:
241 logging.warning(
242 "`implementation=0` has been deprecated, "
243 "and now defaults to `implementation=1`."
244 "Please update your layer call."
245 )
246 if "enable_caching_device" in kwargs:
247 cell_kwargs = {
248 "enable_caching_device": kwargs.pop("enable_caching_device")
249 }
250 else:
251 cell_kwargs = {}
252 cell = GRUCell(
253 units,
254 activation=activation,
255 recurrent_activation=recurrent_activation,
256 use_bias=use_bias,
257 kernel_initializer=kernel_initializer,
258 recurrent_initializer=recurrent_initializer,
259 bias_initializer=bias_initializer,
260 kernel_regularizer=kernel_regularizer,
261 recurrent_regularizer=recurrent_regularizer,
262 bias_regularizer=bias_regularizer,
263 kernel_constraint=kernel_constraint,
264 recurrent_constraint=recurrent_constraint,
265 bias_constraint=bias_constraint,
266 dropout=dropout,
267 recurrent_dropout=recurrent_dropout,
268 implementation=implementation,
269 reset_after=reset_after,
270 dtype=kwargs.get("dtype"),
271 trainable=kwargs.get("trainable", True),
272 name="gru_cell",
273 **cell_kwargs
274 )
275 super().__init__(
276 cell,
277 return_sequences=return_sequences,
278 return_state=return_state,
279 go_backwards=go_backwards,
280 stateful=stateful,
281 unroll=unroll,
282 **kwargs
283 )
284 self.activity_regularizer = regularizers.get(activity_regularizer)
285 self.input_spec = [InputSpec(ndim=3)]
287 def call(self, inputs, mask=None, training=None, initial_state=None):
288 return super().call(
289 inputs, mask=mask, training=training, initial_state=initial_state
290 )
292 @property
293 def units(self):
294 return self.cell.units
296 @property
297 def activation(self):
298 return self.cell.activation
300 @property
301 def recurrent_activation(self):
302 return self.cell.recurrent_activation
304 @property
305 def use_bias(self):
306 return self.cell.use_bias
308 @property
309 def kernel_initializer(self):
310 return self.cell.kernel_initializer
312 @property
313 def recurrent_initializer(self):
314 return self.cell.recurrent_initializer
316 @property
317 def bias_initializer(self):
318 return self.cell.bias_initializer
320 @property
321 def kernel_regularizer(self):
322 return self.cell.kernel_regularizer
324 @property
325 def recurrent_regularizer(self):
326 return self.cell.recurrent_regularizer
328 @property
329 def bias_regularizer(self):
330 return self.cell.bias_regularizer
332 @property
333 def kernel_constraint(self):
334 return self.cell.kernel_constraint
336 @property
337 def recurrent_constraint(self):
338 return self.cell.recurrent_constraint
340 @property
341 def bias_constraint(self):
342 return self.cell.bias_constraint
344 @property
345 def dropout(self):
346 return self.cell.dropout
348 @property
349 def recurrent_dropout(self):
350 return self.cell.recurrent_dropout
352 @property
353 def implementation(self):
354 return self.cell.implementation
356 @property
357 def reset_after(self):
358 return self.cell.reset_after
360 def get_config(self):
361 config = {
362 "units": self.units,
363 "activation": activations.serialize(self.activation),
364 "recurrent_activation": activations.serialize(
365 self.recurrent_activation
366 ),
367 "use_bias": self.use_bias,
368 "kernel_initializer": initializers.serialize(
369 self.kernel_initializer
370 ),
371 "recurrent_initializer": initializers.serialize(
372 self.recurrent_initializer
373 ),
374 "bias_initializer": initializers.serialize(self.bias_initializer),
375 "kernel_regularizer": regularizers.serialize(
376 self.kernel_regularizer
377 ),
378 "recurrent_regularizer": regularizers.serialize(
379 self.recurrent_regularizer
380 ),
381 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
382 "activity_regularizer": regularizers.serialize(
383 self.activity_regularizer
384 ),
385 "kernel_constraint": constraints.serialize(self.kernel_constraint),
386 "recurrent_constraint": constraints.serialize(
387 self.recurrent_constraint
388 ),
389 "bias_constraint": constraints.serialize(self.bias_constraint),
390 "dropout": self.dropout,
391 "recurrent_dropout": self.recurrent_dropout,
392 "implementation": self.implementation,
393 "reset_after": self.reset_after,
394 }
395 config.update(rnn_utils.config_for_enable_caching_device(self.cell))
396 base_config = super().get_config()
397 del base_config["cell"]
398 return dict(list(base_config.items()) + list(config.items()))
400 @classmethod
401 def from_config(cls, config):
402 if "implementation" in config and config["implementation"] == 0:
403 config["implementation"] = 1
404 return cls(**config)