Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/gru.py: 21%
339 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gated Recurrent Unit layer."""
18import uuid
20import tensorflow.compat.v2 as tf
22from keras.src import activations
23from keras.src import backend
24from keras.src import constraints
25from keras.src import initializers
26from keras.src import regularizers
27from keras.src.engine import base_layer
28from keras.src.engine.input_spec import InputSpec
29from keras.src.layers.rnn import gru_lstm_utils
30from keras.src.layers.rnn import rnn_utils
31from keras.src.layers.rnn.base_rnn import RNN
32from keras.src.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin
33from keras.src.utils import tf_utils
35# isort: off
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.util.tf_export import keras_export
39RECURRENT_DROPOUT_WARNING_MSG = (
40 "RNN `implementation=2` is not supported when `recurrent_dropout` is set. "
41 "Using `implementation=1`."
42)
45@keras_export("keras.layers.GRUCell", v1=[])
46class GRUCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
47 """Cell class for the GRU layer.
49 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
50 for details about the usage of RNN API.
52 This class processes one step within the whole time sequence input, whereas
53 `tf.keras.layer.GRU` processes the whole sequence.
55 For example:
57 >>> inputs = tf.random.normal([32, 10, 8])
58 >>> rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(4))
59 >>> output = rnn(inputs)
60 >>> print(output.shape)
61 (32, 4)
62 >>> rnn = tf.keras.layers.RNN(
63 ... tf.keras.layers.GRUCell(4),
64 ... return_sequences=True,
65 ... return_state=True)
66 >>> whole_sequence_output, final_state = rnn(inputs)
67 >>> print(whole_sequence_output.shape)
68 (32, 10, 4)
69 >>> print(final_state.shape)
70 (32, 4)
72 Args:
73 units: Positive integer, dimensionality of the output space.
74 activation: Activation function to use. Default: hyperbolic tangent
75 (`tanh`). If you pass None, no activation is applied
76 (ie. "linear" activation: `a(x) = x`).
77 recurrent_activation: Activation function to use for the recurrent step.
78 Default: sigmoid (`sigmoid`). If you pass `None`, no activation is
79 applied (ie. "linear" activation: `a(x) = x`).
80 use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
81 kernel_initializer: Initializer for the `kernel` weights matrix,
82 used for the linear transformation of the inputs. Default:
83 `glorot_uniform`.
84 recurrent_initializer: Initializer for the `recurrent_kernel`
85 weights matrix, used for the linear transformation of the recurrent
86 state. Default: `orthogonal`.
87 bias_initializer: Initializer for the bias vector. Default: `zeros`.
88 kernel_regularizer: Regularizer function applied to the `kernel` weights
89 matrix. Default: `None`.
90 recurrent_regularizer: Regularizer function applied to the
91 `recurrent_kernel` weights matrix. Default: `None`.
92 bias_regularizer: Regularizer function applied to the bias vector.
93 Default: `None`.
94 kernel_constraint: Constraint function applied to the `kernel` weights
95 matrix. Default: `None`.
96 recurrent_constraint: Constraint function applied to the
97 `recurrent_kernel` weights matrix. Default: `None`.
98 bias_constraint: Constraint function applied to the bias vector. Default:
99 `None`.
100 dropout: Float between 0 and 1. Fraction of the units to drop for the
101 linear transformation of the inputs. Default: 0.
102 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
103 for the linear transformation of the recurrent state. Default: 0.
104 reset_after: GRU convention (whether to apply reset gate after or
105 before matrix multiplication). False = "before",
106 True = "after" (default and cuDNN compatible).
108 Call arguments:
109 inputs: A 2D tensor, with shape of `[batch, feature]`.
110 states: A 2D tensor with shape of `[batch, units]`, which is the state
111 from the previous time step. For timestep 0, the initial state provided
112 by user will be feed to cell.
113 training: Python boolean indicating whether the layer should behave in
114 training mode or in inference mode. Only relevant when `dropout` or
115 `recurrent_dropout` is used.
116 """
118 def __init__(
119 self,
120 units,
121 activation="tanh",
122 recurrent_activation="sigmoid",
123 use_bias=True,
124 kernel_initializer="glorot_uniform",
125 recurrent_initializer="orthogonal",
126 bias_initializer="zeros",
127 kernel_regularizer=None,
128 recurrent_regularizer=None,
129 bias_regularizer=None,
130 kernel_constraint=None,
131 recurrent_constraint=None,
132 bias_constraint=None,
133 dropout=0.0,
134 recurrent_dropout=0.0,
135 reset_after=True,
136 **kwargs,
137 ):
138 if units <= 0:
139 raise ValueError(
140 "Received an invalid value for argument `units`, "
141 f"expected a positive integer, got {units}."
142 )
143 # By default use cached variable under v2 mode, see b/143699808.
144 if tf.compat.v1.executing_eagerly_outside_functions():
145 self._enable_caching_device = kwargs.pop(
146 "enable_caching_device", True
147 )
148 else:
149 self._enable_caching_device = kwargs.pop(
150 "enable_caching_device", False
151 )
152 super().__init__(**kwargs)
153 self.units = units
154 self.activation = activations.get(activation)
155 self.recurrent_activation = activations.get(recurrent_activation)
156 self.use_bias = use_bias
158 self.kernel_initializer = initializers.get(kernel_initializer)
159 self.recurrent_initializer = initializers.get(recurrent_initializer)
160 self.bias_initializer = initializers.get(bias_initializer)
162 self.kernel_regularizer = regularizers.get(kernel_regularizer)
163 self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
164 self.bias_regularizer = regularizers.get(bias_regularizer)
166 self.kernel_constraint = constraints.get(kernel_constraint)
167 self.recurrent_constraint = constraints.get(recurrent_constraint)
168 self.bias_constraint = constraints.get(bias_constraint)
170 self.dropout = min(1.0, max(0.0, dropout))
171 self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
173 implementation = kwargs.pop("implementation", 2)
174 if self.recurrent_dropout != 0 and implementation != 1:
175 logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
176 self.implementation = 1
177 else:
178 self.implementation = implementation
179 self.reset_after = reset_after
180 self.state_size = self.units
181 self.output_size = self.units
183 @tf_utils.shape_type_conversion
184 def build(self, input_shape):
185 super().build(input_shape)
186 input_dim = input_shape[-1]
187 default_caching_device = rnn_utils.caching_device(self)
188 self.kernel = self.add_weight(
189 shape=(input_dim, self.units * 3),
190 name="kernel",
191 initializer=self.kernel_initializer,
192 regularizer=self.kernel_regularizer,
193 constraint=self.kernel_constraint,
194 caching_device=default_caching_device,
195 )
196 self.recurrent_kernel = self.add_weight(
197 shape=(self.units, self.units * 3),
198 name="recurrent_kernel",
199 initializer=self.recurrent_initializer,
200 regularizer=self.recurrent_regularizer,
201 constraint=self.recurrent_constraint,
202 caching_device=default_caching_device,
203 )
205 if self.use_bias:
206 if not self.reset_after:
207 bias_shape = (3 * self.units,)
208 else:
209 # separate biases for input and recurrent kernels
210 # Note: the shape is intentionally different from CuDNNGRU
211 # biases `(2 * 3 * self.units,)`, so that we can distinguish the
212 # classes when loading and converting saved weights.
213 bias_shape = (2, 3 * self.units)
214 self.bias = self.add_weight(
215 shape=bias_shape,
216 name="bias",
217 initializer=self.bias_initializer,
218 regularizer=self.bias_regularizer,
219 constraint=self.bias_constraint,
220 caching_device=default_caching_device,
221 )
222 else:
223 self.bias = None
224 self.built = True
226 def call(self, inputs, states, training=None):
227 h_tm1 = (
228 states[0] if tf.nest.is_nested(states) else states
229 ) # previous memory
231 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
232 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
233 h_tm1, training, count=3
234 )
236 if self.use_bias:
237 if not self.reset_after:
238 input_bias, recurrent_bias = self.bias, None
239 else:
240 input_bias, recurrent_bias = tf.unstack(self.bias)
242 if self.implementation == 1:
243 if 0.0 < self.dropout < 1.0:
244 inputs_z = inputs * dp_mask[0]
245 inputs_r = inputs * dp_mask[1]
246 inputs_h = inputs * dp_mask[2]
247 else:
248 inputs_z = inputs
249 inputs_r = inputs
250 inputs_h = inputs
252 x_z = backend.dot(inputs_z, self.kernel[:, : self.units])
253 x_r = backend.dot(
254 inputs_r, self.kernel[:, self.units : self.units * 2]
255 )
256 x_h = backend.dot(inputs_h, self.kernel[:, self.units * 2 :])
258 if self.use_bias:
259 x_z = backend.bias_add(x_z, input_bias[: self.units])
260 x_r = backend.bias_add(
261 x_r, input_bias[self.units : self.units * 2]
262 )
263 x_h = backend.bias_add(x_h, input_bias[self.units * 2 :])
265 if 0.0 < self.recurrent_dropout < 1.0:
266 h_tm1_z = h_tm1 * rec_dp_mask[0]
267 h_tm1_r = h_tm1 * rec_dp_mask[1]
268 h_tm1_h = h_tm1 * rec_dp_mask[2]
269 else:
270 h_tm1_z = h_tm1
271 h_tm1_r = h_tm1
272 h_tm1_h = h_tm1
274 recurrent_z = backend.dot(
275 h_tm1_z, self.recurrent_kernel[:, : self.units]
276 )
277 recurrent_r = backend.dot(
278 h_tm1_r, self.recurrent_kernel[:, self.units : self.units * 2]
279 )
280 if self.reset_after and self.use_bias:
281 recurrent_z = backend.bias_add(
282 recurrent_z, recurrent_bias[: self.units]
283 )
284 recurrent_r = backend.bias_add(
285 recurrent_r, recurrent_bias[self.units : self.units * 2]
286 )
288 z = self.recurrent_activation(x_z + recurrent_z)
289 r = self.recurrent_activation(x_r + recurrent_r)
291 # reset gate applied after/before matrix multiplication
292 if self.reset_after:
293 recurrent_h = backend.dot(
294 h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]
295 )
296 if self.use_bias:
297 recurrent_h = backend.bias_add(
298 recurrent_h, recurrent_bias[self.units * 2 :]
299 )
300 recurrent_h = r * recurrent_h
301 else:
302 recurrent_h = backend.dot(
303 r * h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]
304 )
306 hh = self.activation(x_h + recurrent_h)
307 else:
308 if 0.0 < self.dropout < 1.0:
309 inputs = inputs * dp_mask[0]
311 # inputs projected by all gate matrices at once
312 matrix_x = backend.dot(inputs, self.kernel)
313 if self.use_bias:
314 # biases: bias_z_i, bias_r_i, bias_h_i
315 matrix_x = backend.bias_add(matrix_x, input_bias)
317 x_z, x_r, x_h = tf.split(matrix_x, 3, axis=-1)
319 if self.reset_after:
320 # hidden state projected by all gate matrices at once
321 matrix_inner = backend.dot(h_tm1, self.recurrent_kernel)
322 if self.use_bias:
323 matrix_inner = backend.bias_add(
324 matrix_inner, recurrent_bias
325 )
326 else:
327 # hidden state projected separately for update/reset and new
328 matrix_inner = backend.dot(
329 h_tm1, self.recurrent_kernel[:, : 2 * self.units]
330 )
332 recurrent_z, recurrent_r, recurrent_h = tf.split(
333 matrix_inner, [self.units, self.units, -1], axis=-1
334 )
336 z = self.recurrent_activation(x_z + recurrent_z)
337 r = self.recurrent_activation(x_r + recurrent_r)
339 if self.reset_after:
340 recurrent_h = r * recurrent_h
341 else:
342 recurrent_h = backend.dot(
343 r * h_tm1, self.recurrent_kernel[:, 2 * self.units :]
344 )
346 hh = self.activation(x_h + recurrent_h)
347 # previous and candidate state mixed by update gate
348 h = z * h_tm1 + (1 - z) * hh
349 new_state = [h] if tf.nest.is_nested(states) else h
350 return h, new_state
352 def get_config(self):
353 config = {
354 "units": self.units,
355 "activation": activations.serialize(self.activation),
356 "recurrent_activation": activations.serialize(
357 self.recurrent_activation
358 ),
359 "use_bias": self.use_bias,
360 "kernel_initializer": initializers.serialize(
361 self.kernel_initializer
362 ),
363 "recurrent_initializer": initializers.serialize(
364 self.recurrent_initializer
365 ),
366 "bias_initializer": initializers.serialize(self.bias_initializer),
367 "kernel_regularizer": regularizers.serialize(
368 self.kernel_regularizer
369 ),
370 "recurrent_regularizer": regularizers.serialize(
371 self.recurrent_regularizer
372 ),
373 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
374 "kernel_constraint": constraints.serialize(self.kernel_constraint),
375 "recurrent_constraint": constraints.serialize(
376 self.recurrent_constraint
377 ),
378 "bias_constraint": constraints.serialize(self.bias_constraint),
379 "dropout": self.dropout,
380 "recurrent_dropout": self.recurrent_dropout,
381 "implementation": self.implementation,
382 "reset_after": self.reset_after,
383 }
384 config.update(rnn_utils.config_for_enable_caching_device(self))
385 base_config = super().get_config()
386 return dict(list(base_config.items()) + list(config.items()))
388 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
389 return rnn_utils.generate_zero_filled_state_for_cell(
390 self, inputs, batch_size, dtype
391 )
394@keras_export("keras.layers.GRU", v1=[])
395class GRU(DropoutRNNCellMixin, RNN, base_layer.BaseRandomLayer):
396 """Gated Recurrent Unit - Cho et al. 2014.
398 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
399 for details about the usage of RNN API.
401 Based on available runtime hardware and constraints, this layer
402 will choose different implementations (cuDNN-based or pure-TensorFlow)
403 to maximize the performance. If a GPU is available and all
404 the arguments to the layer meet the requirement of the cuDNN kernel
405 (see below for details), the layer will use a fast cuDNN implementation.
407 The requirements to use the cuDNN implementation are:
409 1. `activation` == `tanh`
410 2. `recurrent_activation` == `sigmoid`
411 3. `recurrent_dropout` == 0
412 4. `unroll` is `False`
413 5. `use_bias` is `True`
414 6. `reset_after` is `True`
415 7. Inputs, if use masking, are strictly right-padded.
416 8. Eager execution is enabled in the outermost context.
418 There are two variants of the GRU implementation. The default one is based
419 on [v3](https://arxiv.org/abs/1406.1078v3) and has reset gate applied to
420 hidden state before matrix multiplication. The other one is based on
421 [original](https://arxiv.org/abs/1406.1078v1) and has the order reversed.
423 The second variant is compatible with CuDNNGRU (GPU-only) and allows
424 inference on CPU. Thus it has separate biases for `kernel` and
425 `recurrent_kernel`. To use this variant, set `reset_after=True` and
426 `recurrent_activation='sigmoid'`.
428 For example:
430 >>> inputs = tf.random.normal([32, 10, 8])
431 >>> gru = tf.keras.layers.GRU(4)
432 >>> output = gru(inputs)
433 >>> print(output.shape)
434 (32, 4)
435 >>> gru = tf.keras.layers.GRU(4, return_sequences=True, return_state=True)
436 >>> whole_sequence_output, final_state = gru(inputs)
437 >>> print(whole_sequence_output.shape)
438 (32, 10, 4)
439 >>> print(final_state.shape)
440 (32, 4)
442 Args:
443 units: Positive integer, dimensionality of the output space.
444 activation: Activation function to use.
445 Default: hyperbolic tangent (`tanh`).
446 If you pass `None`, no activation is applied
447 (ie. "linear" activation: `a(x) = x`).
448 recurrent_activation: Activation function to use
449 for the recurrent step.
450 Default: sigmoid (`sigmoid`).
451 If you pass `None`, no activation is applied
452 (ie. "linear" activation: `a(x) = x`).
453 use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
454 kernel_initializer: Initializer for the `kernel` weights matrix,
455 used for the linear transformation of the inputs. Default:
456 `glorot_uniform`.
457 recurrent_initializer: Initializer for the `recurrent_kernel`
458 weights matrix, used for the linear transformation of the recurrent
459 state. Default: `orthogonal`.
460 bias_initializer: Initializer for the bias vector. Default: `zeros`.
461 kernel_regularizer: Regularizer function applied to the `kernel` weights
462 matrix. Default: `None`.
463 recurrent_regularizer: Regularizer function applied to the
464 `recurrent_kernel` weights matrix. Default: `None`.
465 bias_regularizer: Regularizer function applied to the bias vector.
466 Default: `None`.
467 activity_regularizer: Regularizer function applied to the output of the
468 layer (its "activation"). Default: `None`.
469 kernel_constraint: Constraint function applied to the `kernel` weights
470 matrix. Default: `None`.
471 recurrent_constraint: Constraint function applied to the
472 `recurrent_kernel` weights matrix. Default: `None`.
473 bias_constraint: Constraint function applied to the bias vector. Default:
474 `None`.
475 dropout: Float between 0 and 1. Fraction of the units to drop for the
476 linear transformation of the inputs. Default: 0.
477 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
478 for the linear transformation of the recurrent state. Default: 0.
479 return_sequences: Boolean. Whether to return the last output
480 in the output sequence, or the full sequence. Default: `False`.
481 return_state: Boolean. Whether to return the last state in addition to the
482 output. Default: `False`.
483 go_backwards: Boolean (default `False`).
484 If True, process the input sequence backwards and return the
485 reversed sequence.
486 stateful: Boolean (default False). If True, the last state
487 for each sample at index i in a batch will be used as initial
488 state for the sample of index i in the following batch.
489 unroll: Boolean (default False).
490 If True, the network will be unrolled,
491 else a symbolic loop will be used.
492 Unrolling can speed-up a RNN,
493 although it tends to be more memory-intensive.
494 Unrolling is only suitable for short sequences.
495 time_major: The shape format of the `inputs` and `outputs` tensors.
496 If True, the inputs and outputs will be in shape
497 `[timesteps, batch, feature]`, whereas in the False case, it will be
498 `[batch, timesteps, feature]`. Using `time_major = True` is a bit more
499 efficient because it avoids transposes at the beginning and end of the
500 RNN calculation. However, most TensorFlow data is batch-major, so by
501 default this function accepts input and emits output in batch-major
502 form.
503 reset_after: GRU convention (whether to apply reset gate after or
504 before matrix multiplication). False = "before",
505 True = "after" (default and cuDNN compatible).
507 Call arguments:
508 inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
509 mask: Binary tensor of shape `[samples, timesteps]` indicating whether
510 a given timestep should be masked (optional).
511 An individual `True` entry indicates that the corresponding timestep
512 should be utilized, while a `False` entry indicates that the
513 corresponding timestep should be ignored. Defaults to `None`.
514 training: Python boolean indicating whether the layer should behave in
515 training mode or in inference mode. This argument is passed to the cell
516 when calling it. This is only relevant if `dropout` or
517 `recurrent_dropout` is used (optional). Defaults to `None`.
518 initial_state: List of initial state tensors to be passed to the first
519 call of the cell (optional, `None` causes creation
520 of zero-filled initial state tensors). Defaults to `None`.
521 """
523 def __init__(
524 self,
525 units,
526 activation="tanh",
527 recurrent_activation="sigmoid",
528 use_bias=True,
529 kernel_initializer="glorot_uniform",
530 recurrent_initializer="orthogonal",
531 bias_initializer="zeros",
532 kernel_regularizer=None,
533 recurrent_regularizer=None,
534 bias_regularizer=None,
535 activity_regularizer=None,
536 kernel_constraint=None,
537 recurrent_constraint=None,
538 bias_constraint=None,
539 dropout=0.0,
540 recurrent_dropout=0.0,
541 return_sequences=False,
542 return_state=False,
543 go_backwards=False,
544 stateful=False,
545 unroll=False,
546 time_major=False,
547 reset_after=True,
548 **kwargs,
549 ):
550 # return_runtime is a flag for testing, which shows the real backend
551 # implementation chosen by grappler in graph mode.
552 self._return_runtime = kwargs.pop("return_runtime", False)
553 implementation = kwargs.pop("implementation", 2)
554 if implementation == 0:
555 logging.warning(
556 "`implementation=0` has been deprecated, "
557 "and now defaults to `implementation=2`."
558 "Please update your layer call."
559 )
560 if "enable_caching_device" in kwargs:
561 cell_kwargs = {
562 "enable_caching_device": kwargs.pop("enable_caching_device")
563 }
564 else:
565 cell_kwargs = {}
566 cell = GRUCell(
567 units,
568 activation=activation,
569 recurrent_activation=recurrent_activation,
570 use_bias=use_bias,
571 kernel_initializer=kernel_initializer,
572 recurrent_initializer=recurrent_initializer,
573 bias_initializer=bias_initializer,
574 kernel_regularizer=kernel_regularizer,
575 recurrent_regularizer=recurrent_regularizer,
576 bias_regularizer=bias_regularizer,
577 kernel_constraint=kernel_constraint,
578 recurrent_constraint=recurrent_constraint,
579 bias_constraint=bias_constraint,
580 dropout=dropout,
581 recurrent_dropout=recurrent_dropout,
582 implementation=implementation,
583 reset_after=reset_after,
584 dtype=kwargs.get("dtype"),
585 trainable=kwargs.get("trainable", True),
586 name="gru_cell",
587 **cell_kwargs,
588 )
589 super().__init__(
590 cell,
591 return_sequences=return_sequences,
592 return_state=return_state,
593 go_backwards=go_backwards,
594 stateful=stateful,
595 unroll=unroll,
596 time_major=time_major,
597 **kwargs,
598 )
599 self.activity_regularizer = regularizers.get(activity_regularizer)
600 self.input_spec = [InputSpec(ndim=3)]
602 # GPU kernel uses following setting by default and not configurable.
603 self._could_use_gpu_kernel = (
604 self.activation in (activations.tanh, tf.tanh)
605 and self.recurrent_activation in (activations.sigmoid, tf.sigmoid)
606 and recurrent_dropout == 0
607 and not unroll
608 and use_bias
609 and reset_after
610 and tf.compat.v1.executing_eagerly_outside_functions()
611 )
612 if tf.config.list_logical_devices("GPU"):
613 # Only show the message when there is GPU available, user will not
614 # care about the cuDNN if there isn't any GPU.
615 if self._could_use_gpu_kernel:
616 logging.debug(gru_lstm_utils.CUDNN_AVAILABLE_MSG % self.name)
617 else:
618 logging.warning(
619 gru_lstm_utils.CUDNN_NOT_AVAILABLE_MSG % self.name
620 )
622 if gru_lstm_utils.use_new_gru_lstm_impl():
623 self._defun_wrapper = gru_lstm_utils.DefunWrapper(
624 time_major, go_backwards, "gru"
625 )
627 def call(self, inputs, mask=None, training=None, initial_state=None):
628 # The input should be dense, padded with zeros. If a ragged input is fed
629 # into the layer, it is padded and the row lengths are used for masking.
630 inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
631 is_ragged_input = row_lengths is not None
632 self._validate_args_if_ragged(is_ragged_input, mask)
634 # GRU does not support constants. Ignore it during process.
635 inputs, initial_state, _ = self._process_inputs(
636 inputs, initial_state, None
637 )
639 if isinstance(mask, list):
640 mask = mask[0]
642 input_shape = backend.int_shape(inputs)
643 timesteps = input_shape[0] if self.time_major else input_shape[1]
645 if not self._could_use_gpu_kernel:
646 kwargs = {"training": training}
647 self._maybe_reset_cell_dropout_mask(self.cell)
649 def step(cell_inputs, cell_states):
650 return self.cell(cell_inputs, cell_states, **kwargs)
652 last_output, outputs, states = backend.rnn(
653 step,
654 inputs,
655 initial_state,
656 constants=None,
657 go_backwards=self.go_backwards,
658 mask=mask,
659 unroll=self.unroll,
660 input_length=row_lengths
661 if row_lengths is not None
662 else timesteps,
663 time_major=self.time_major,
664 zero_output_for_mask=self.zero_output_for_mask,
665 return_all_outputs=self.return_sequences,
666 )
667 # This is a dummy tensor for testing purpose.
668 runtime = gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_UNKNOWN)
669 else:
670 last_output, outputs, runtime, states = self._defun_gru_call(
671 inputs, initial_state, training, mask, row_lengths
672 )
674 if self.stateful:
675 updates = [
676 tf.compat.v1.assign(
677 self.states[0], tf.cast(states[0], self.states[0].dtype)
678 )
679 ]
680 self.add_update(updates)
682 if self.return_sequences:
683 output = backend.maybe_convert_to_ragged(
684 is_ragged_input,
685 outputs,
686 row_lengths,
687 go_backwards=self.go_backwards,
688 )
689 else:
690 output = last_output
692 if self.return_state:
693 return [output] + list(states)
694 elif self._return_runtime:
695 return output, runtime
696 else:
697 return output
699 @property
700 def units(self):
701 return self.cell.units
703 @property
704 def activation(self):
705 return self.cell.activation
707 @property
708 def recurrent_activation(self):
709 return self.cell.recurrent_activation
711 @property
712 def use_bias(self):
713 return self.cell.use_bias
715 @property
716 def kernel_initializer(self):
717 return self.cell.kernel_initializer
719 @property
720 def recurrent_initializer(self):
721 return self.cell.recurrent_initializer
723 @property
724 def bias_initializer(self):
725 return self.cell.bias_initializer
727 @property
728 def kernel_regularizer(self):
729 return self.cell.kernel_regularizer
731 @property
732 def recurrent_regularizer(self):
733 return self.cell.recurrent_regularizer
735 @property
736 def bias_regularizer(self):
737 return self.cell.bias_regularizer
739 @property
740 def kernel_constraint(self):
741 return self.cell.kernel_constraint
743 @property
744 def recurrent_constraint(self):
745 return self.cell.recurrent_constraint
747 @property
748 def bias_constraint(self):
749 return self.cell.bias_constraint
751 @property
752 def dropout(self):
753 return self.cell.dropout
755 @property
756 def recurrent_dropout(self):
757 return self.cell.recurrent_dropout
759 @property
760 def implementation(self):
761 return self.cell.implementation
763 @property
764 def reset_after(self):
765 return self.cell.reset_after
767 def get_config(self):
768 config = {
769 "units": self.units,
770 "activation": activations.serialize(self.activation),
771 "recurrent_activation": activations.serialize(
772 self.recurrent_activation
773 ),
774 "use_bias": self.use_bias,
775 "kernel_initializer": initializers.serialize(
776 self.kernel_initializer
777 ),
778 "recurrent_initializer": initializers.serialize(
779 self.recurrent_initializer
780 ),
781 "bias_initializer": initializers.serialize(self.bias_initializer),
782 "kernel_regularizer": regularizers.serialize(
783 self.kernel_regularizer
784 ),
785 "recurrent_regularizer": regularizers.serialize(
786 self.recurrent_regularizer
787 ),
788 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
789 "activity_regularizer": regularizers.serialize(
790 self.activity_regularizer
791 ),
792 "kernel_constraint": constraints.serialize(self.kernel_constraint),
793 "recurrent_constraint": constraints.serialize(
794 self.recurrent_constraint
795 ),
796 "bias_constraint": constraints.serialize(self.bias_constraint),
797 "dropout": self.dropout,
798 "recurrent_dropout": self.recurrent_dropout,
799 "implementation": self.implementation,
800 "reset_after": self.reset_after,
801 }
802 config.update(rnn_utils.config_for_enable_caching_device(self.cell))
803 base_config = super().get_config()
804 del base_config["cell"]
805 return dict(list(base_config.items()) + list(config.items()))
807 @classmethod
808 def from_config(cls, config):
809 if "implementation" in config and config["implementation"] == 0:
810 config["implementation"] = 1
811 return cls(**config)
813 def _defun_gru_call(
814 self, inputs, initial_state, training, mask, sequence_lengths
815 ):
816 # Use the new defun approach for backend implementation swap.
817 # Note that different implementations need to have same function
818 # signature, eg, the tensor parameters need to have same shape and
819 # dtypes.
821 self.reset_dropout_mask()
822 dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
823 if dropout_mask is not None:
824 inputs = inputs * dropout_mask[0]
826 if gru_lstm_utils.use_new_gru_lstm_impl():
827 gru_kwargs = {
828 "inputs": inputs,
829 "init_h": gru_lstm_utils.read_variable_value(initial_state[0]),
830 "kernel": gru_lstm_utils.read_variable_value(self.cell.kernel),
831 "recurrent_kernel": gru_lstm_utils.read_variable_value(
832 self.cell.recurrent_kernel
833 ),
834 "bias": gru_lstm_utils.read_variable_value(self.cell.bias),
835 "mask": mask,
836 "time_major": self.time_major,
837 "go_backwards": self.go_backwards,
838 "sequence_lengths": sequence_lengths,
839 "zero_output_for_mask": self.zero_output_for_mask,
840 }
841 (
842 last_output,
843 outputs,
844 new_h,
845 runtime,
846 ) = self._defun_wrapper.defun_layer(**gru_kwargs)
847 else:
848 gpu_gru_kwargs = {
849 "inputs": inputs,
850 "init_h": gru_lstm_utils.read_variable_value(initial_state[0]),
851 "kernel": gru_lstm_utils.read_variable_value(self.cell.kernel),
852 "recurrent_kernel": gru_lstm_utils.read_variable_value(
853 self.cell.recurrent_kernel
854 ),
855 "bias": gru_lstm_utils.read_variable_value(self.cell.bias),
856 "mask": mask,
857 "time_major": self.time_major,
858 "go_backwards": self.go_backwards,
859 "sequence_lengths": sequence_lengths,
860 "return_sequences": self.return_sequences,
861 }
862 normal_gru_kwargs = gpu_gru_kwargs.copy()
863 normal_gru_kwargs.update(
864 {
865 "zero_output_for_mask": self.zero_output_for_mask,
866 }
867 )
869 if tf.executing_eagerly():
870 device_type = gru_lstm_utils.get_context_device_type()
871 can_use_gpu = (
872 # Either user specified GPU or unspecified but GPU is
873 # available.
874 (
875 device_type == gru_lstm_utils.GPU_DEVICE_NAME
876 or (
877 device_type is None
878 and tf.config.list_logical_devices("GPU")
879 )
880 )
881 and (
882 gru_lstm_utils.is_cudnn_supported_inputs(
883 mask, self.time_major, sequence_lengths
884 )
885 )
886 )
887 # Under eager context, check the device placement and prefer the
888 if can_use_gpu:
889 last_output, outputs, new_h, runtime = gpu_gru(
890 **gpu_gru_kwargs
891 )
892 else:
893 last_output, outputs, new_h, runtime = standard_gru(
894 **normal_gru_kwargs
895 )
896 else:
897 (
898 last_output,
899 outputs,
900 new_h,
901 runtime,
902 ) = gru_with_backend_selection(**normal_gru_kwargs)
904 states = [new_h]
905 return last_output, outputs, runtime, states
908def standard_gru(
909 inputs,
910 init_h,
911 kernel,
912 recurrent_kernel,
913 bias,
914 mask,
915 time_major,
916 go_backwards,
917 sequence_lengths,
918 zero_output_for_mask,
919 return_sequences,
920):
921 """GRU with standard kernel implementation.
923 This implementation can be run on all types of hardware.
925 This implementation lifts out all the layer weights and make them function
926 parameters. It has same number of tensor input params as the cuDNN
927 counterpart. The RNN step logic has been simplified, eg dropout and mask is
928 removed since cuDNN implementation does not support that.
930 Args:
931 inputs: Input tensor of GRU layer.
932 init_h: Initial state tensor for the cell output.
933 kernel: Weights for cell kernel.
934 recurrent_kernel: Weights for cell recurrent kernel.
935 bias: Weights for cell kernel bias and recurrent bias. The bias contains
936 the combined input_bias and recurrent_bias.
937 mask: Binary tensor of shape `(samples, timesteps)` indicating whether
938 a given timestep should be masked. An individual `True` entry indicates
939 that the corresponding timestep should be utilized, while a `False`
940 entry indicates that the corresponding timestep should be ignored.
941 time_major: Boolean, whether the inputs are in the format of
942 [time, batch, feature] or [batch, time, feature].
943 go_backwards: Boolean (default False). If True, process the input sequence
944 backwards and return the reversed sequence.
945 sequence_lengths: The lengths of all sequences coming from a variable
946 length input, such as ragged tensors. If the input has a fixed timestep
947 size, this should be None.
948 zero_output_for_mask: Boolean, whether to output zero for masked timestep.
949 return_sequences: Boolean. If True, return the recurrent outputs for all
950 timesteps in the sequence. If False, only return the output for the
951 last timestep (which consumes less memory).
953 Returns:
954 last_output: output tensor for the last timestep, which has shape
955 [batch, units].
956 outputs:
957 - If `return_sequences=True`: output tensor for all timesteps,
958 which has shape [batch, time, units].
959 - Else, a tensor equal to `last_output` with shape [batch, 1, units]
960 state_0: the cell output, which has same shape as init_h.
961 runtime: constant string tensor which indicate real runtime hardware. This
962 value is for testing purpose and should be used by user.
963 """
964 input_shape = backend.int_shape(inputs)
965 timesteps = input_shape[0] if time_major else input_shape[1]
967 input_bias, recurrent_bias = tf.unstack(bias)
969 def step(cell_inputs, cell_states):
970 """Step function that will be used by Keras RNN backend."""
971 h_tm1 = cell_states[0]
973 # inputs projected by all gate matrices at once
974 matrix_x = backend.dot(cell_inputs, kernel)
975 matrix_x = backend.bias_add(matrix_x, input_bias)
977 x_z, x_r, x_h = tf.split(matrix_x, 3, axis=1)
979 # hidden state projected by all gate matrices at once
980 matrix_inner = backend.dot(h_tm1, recurrent_kernel)
981 matrix_inner = backend.bias_add(matrix_inner, recurrent_bias)
983 recurrent_z, recurrent_r, recurrent_h = tf.split(
984 matrix_inner, 3, axis=1
985 )
986 z = tf.sigmoid(x_z + recurrent_z)
987 r = tf.sigmoid(x_r + recurrent_r)
988 hh = tf.tanh(x_h + r * recurrent_h)
990 # previous and candidate state mixed by update gate
991 h = z * h_tm1 + (1 - z) * hh
992 return h, [h]
994 last_output, outputs, new_states = backend.rnn(
995 step,
996 inputs,
997 [init_h],
998 constants=None,
999 unroll=False,
1000 time_major=time_major,
1001 mask=mask,
1002 go_backwards=go_backwards,
1003 input_length=sequence_lengths
1004 if sequence_lengths is not None
1005 else timesteps,
1006 zero_output_for_mask=zero_output_for_mask,
1007 return_all_outputs=return_sequences,
1008 )
1009 return (
1010 last_output,
1011 outputs,
1012 new_states[0],
1013 gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_CPU),
1014 )
1017def gpu_gru(
1018 inputs,
1019 init_h,
1020 kernel,
1021 recurrent_kernel,
1022 bias,
1023 mask,
1024 time_major,
1025 go_backwards,
1026 sequence_lengths,
1027 return_sequences,
1028):
1029 """GRU with cuDNN implementation which is only available for GPU."""
1030 if mask is not None:
1031 sequence_lengths = gru_lstm_utils.calculate_sequence_by_mask(
1032 mask, time_major
1033 )
1035 if not time_major and sequence_lengths is None:
1036 inputs = tf.transpose(inputs, perm=(1, 0, 2))
1037 seq_axis, batch_axis = (0, 1)
1038 else:
1039 seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
1040 # For init_h, cuDNN expects one more dim of num_layers before or after batch
1041 # dim for time major or batch major inputs respectively
1042 init_h = tf.expand_dims(init_h, axis=seq_axis)
1044 weights = tf.split(kernel, 3, axis=1)
1045 weights += tf.split(recurrent_kernel, 3, axis=1)
1046 # Note that the bias was initialized as shape (2, 3 * units), flat it into
1047 # (6 * units)
1048 bias = tf.split(backend.flatten(bias), 6)
1050 if tf.sysconfig.get_build_info()["is_cuda_build"]:
1051 # Note that the gate order for cuDNN is different from the canonical
1052 # format. canonical format is [z, r, h], whereas cuDNN is [r, z, h].
1053 # The swap need to be done for kernel, recurrent_kernel, input_bias,
1054 # recurrent_bias.
1055 # z is update gate weights.
1056 # r is reset gate weights.
1057 # h is output gate weights.
1058 weights[0], weights[1] = weights[1], weights[0]
1059 weights[3], weights[4] = weights[4], weights[3]
1060 bias[0], bias[1] = bias[1], bias[0]
1061 bias[3], bias[4] = bias[4], bias[3]
1063 params = gru_lstm_utils.canonical_to_params(
1064 weights=weights,
1065 biases=bias,
1066 shape=tf.constant([-1]),
1067 transpose_weights=True,
1068 )
1070 if sequence_lengths is not None:
1071 if go_backwards:
1072 # Three reversals are required. E.g.,
1073 # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
1074 # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
1075 # output_from_cudnn = [6, 5, 4, 0, 0]
1076 # expected_output = [0, 0, 6, 5 ,4]
1077 inputs = tf.reverse_sequence(
1078 inputs,
1079 sequence_lengths,
1080 seq_axis=seq_axis,
1081 batch_axis=batch_axis,
1082 )
1083 outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
1084 input=inputs,
1085 input_h=init_h,
1086 input_c=0,
1087 params=params,
1088 is_training=True,
1089 rnn_mode="gru",
1090 sequence_lengths=sequence_lengths,
1091 time_major=time_major,
1092 )
1093 if go_backwards:
1094 outputs = tf.reverse_sequence(
1095 outputs,
1096 sequence_lengths,
1097 seq_axis=seq_axis,
1098 batch_axis=batch_axis,
1099 )
1100 outputs = tf.reverse(outputs, axis=[seq_axis])
1101 else:
1102 if go_backwards:
1103 # Reverse axis 0 since the input is already convert to time major.
1104 inputs = tf.reverse(inputs, axis=[0])
1105 outputs, h, _, _ = tf.raw_ops.CudnnRNN(
1106 input=inputs,
1107 input_h=init_h,
1108 input_c=0,
1109 params=params,
1110 is_training=True,
1111 rnn_mode="gru",
1112 )
1114 last_output = outputs[-1]
1115 if not time_major and sequence_lengths is None and return_sequences:
1116 outputs = tf.transpose(outputs, perm=[1, 0, 2])
1117 h = tf.squeeze(h, axis=seq_axis)
1119 # In the case of variable length input, the cudnn kernel will fill zeros for
1120 # the output, whereas the default keras behavior is to bring over the
1121 # previous output for t-1, so that in the return_sequence=False case, user
1122 # can quickly get the final effect output instead just 0s at the last
1123 # timestep. In order to mimic the default keras behavior, we copy the final
1124 # h state as the last_output, since it is numerically same as the output.
1125 if sequence_lengths is not None:
1126 last_output = h
1128 # Match CPU return format
1129 if not return_sequences:
1130 outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)
1132 return (
1133 last_output,
1134 outputs,
1135 h,
1136 gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_GPU),
1137 )
1140def gru_with_backend_selection(
1141 inputs,
1142 init_h,
1143 kernel,
1144 recurrent_kernel,
1145 bias,
1146 mask,
1147 time_major,
1148 go_backwards,
1149 sequence_lengths,
1150 zero_output_for_mask,
1151 return_sequences,
1152):
1153 """Call the GRU with optimized backend kernel selection.
1155 Under the hood, this function will create two TF function, one with the most
1156 generic kernel and can run on all device condition, and the second one with
1157 cuDNN specific kernel, which can only run on GPU.
1159 The first function will be called with normal_lstm_params, while the second
1160 function is not called, but only registered in the graph. The Grappler will
1161 do the proper graph rewrite and swap the optimized TF function based on the
1162 device placement.
1164 Args:
1165 inputs: Input tensor of GRU layer.
1166 init_h: Initial state tensor for the cell output.
1167 kernel: Weights for cell kernel.
1168 recurrent_kernel: Weights for cell recurrent kernel.
1169 bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
1170 is used in this case.
1171 mask: Boolean tensor for mask out the steps within sequence.
1172 An individual `True` entry indicates that the corresponding timestep
1173 should be utilized, while a `False` entry indicates that the
1174 corresponding timestep should be ignored.
1175 time_major: Boolean, whether the inputs are in the format of
1176 [time, batch, feature] or [batch, time, feature].
1177 go_backwards: Boolean (default False). If True, process the input sequence
1178 backwards and return the reversed sequence.
1179 sequence_lengths: The lengths of all sequences coming from a variable
1180 length input, such as ragged tensors. If the input has a fixed timestep
1181 size, this should be None.
1182 zero_output_for_mask: Boolean, whether to output zero for masked timestep.
1183 return_sequences: Boolean. If True, return the recurrent outputs for all
1184 timesteps in the sequence. If False, only return the output for the
1185 last timestep (which consumes less memory).
1187 Returns:
1188 List of output tensors, same as standard_gru.
1189 """
1190 params = {
1191 "inputs": inputs,
1192 "init_h": init_h,
1193 "kernel": kernel,
1194 "recurrent_kernel": recurrent_kernel,
1195 "bias": bias,
1196 "mask": mask,
1197 "time_major": time_major,
1198 "go_backwards": go_backwards,
1199 "sequence_lengths": sequence_lengths,
1200 "zero_output_for_mask": zero_output_for_mask,
1201 "return_sequences": return_sequences,
1202 }
1204 def gpu_gru_with_fallback(
1205 inputs,
1206 init_h,
1207 kernel,
1208 recurrent_kernel,
1209 bias,
1210 mask,
1211 time_major,
1212 go_backwards,
1213 sequence_lengths,
1214 zero_output_for_mask,
1215 return_sequences,
1216 ):
1217 """Use cuDNN kernel when mask is none or strictly right padded."""
1219 def cudnn_gru_fn():
1220 return gpu_gru(
1221 inputs=inputs,
1222 init_h=init_h,
1223 kernel=kernel,
1224 recurrent_kernel=recurrent_kernel,
1225 bias=bias,
1226 mask=mask,
1227 time_major=time_major,
1228 go_backwards=go_backwards,
1229 sequence_lengths=sequence_lengths,
1230 return_sequences=return_sequences,
1231 )
1233 def standard_gru_fn():
1234 return standard_gru(
1235 inputs=inputs,
1236 init_h=init_h,
1237 kernel=kernel,
1238 recurrent_kernel=recurrent_kernel,
1239 bias=bias,
1240 mask=mask,
1241 time_major=time_major,
1242 go_backwards=go_backwards,
1243 sequence_lengths=sequence_lengths,
1244 zero_output_for_mask=zero_output_for_mask,
1245 return_sequences=return_sequences,
1246 )
1248 return tf.__internal__.smart_cond.smart_cond(
1249 gru_lstm_utils.is_cudnn_supported_inputs(
1250 mask, time_major, sequence_lengths
1251 ),
1252 true_fn=cudnn_gru_fn,
1253 false_fn=standard_gru_fn,
1254 )
1256 if gru_lstm_utils.use_new_gru_lstm_impl():
1257 # Chooses the implementation dynamically based on the running device.
1258 (
1259 last_output,
1260 outputs,
1261 new_h,
1262 runtime,
1263 ) = tf.__internal__.execute_fn_for_device(
1264 {
1265 gru_lstm_utils.CPU_DEVICE_NAME: lambda: standard_gru(**params),
1266 gru_lstm_utils.GPU_DEVICE_NAME: lambda: gpu_gru_with_fallback(
1267 **params
1268 ),
1269 },
1270 lambda: standard_gru(**params),
1271 )
1272 else:
1273 # Each time a `tf.function` is called, we will give it a unique
1274 # identifiable API name, so that Grappler won't get confused when it
1275 # sees multiple GRU layers added into same graph, and it will be able
1276 # to pair up the different implementations across them.
1277 api_name = "gru_" + str(uuid.uuid4())
1278 supportive_attribute = {
1279 "time_major": time_major,
1280 "go_backwards": go_backwards,
1281 }
1282 defun_standard_gru = gru_lstm_utils.generate_defun_backend(
1283 api_name,
1284 gru_lstm_utils.CPU_DEVICE_NAME,
1285 standard_gru,
1286 supportive_attribute,
1287 )
1288 defun_gpu_gru = gru_lstm_utils.generate_defun_backend(
1289 api_name,
1290 gru_lstm_utils.GPU_DEVICE_NAME,
1291 gpu_gru_with_fallback,
1292 supportive_attribute,
1293 )
1295 # Call the normal GRU impl and register the cuDNN impl function. The
1296 # grappler will kick in during session execution to optimize the graph.
1297 last_output, outputs, new_h, runtime = defun_standard_gru(**params)
1298 gru_lstm_utils.function_register(defun_gpu_gru, **params)
1300 return last_output, outputs, new_h, runtime