Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/cudnn_gru.py: 29%
58 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Fast GRU layer backed by cuDNN."""
18import collections
20import tensorflow.compat.v2 as tf
22from keras.src import constraints
23from keras.src import initializers
24from keras.src import regularizers
25from keras.src.layers.rnn import gru_lstm_utils
26from keras.src.layers.rnn.base_cudnn_rnn import _CuDNNRNN
28# isort: off
29from tensorflow.python.util.tf_export import keras_export
32@keras_export(v1=["keras.layers.CuDNNGRU"])
33class CuDNNGRU(_CuDNNRNN):
34 """Fast GRU implementation backed by cuDNN.
36 More information about cuDNN can be found on the [NVIDIA
37 developer website](https://developer.nvidia.com/cudnn).
38 Can only be run on GPU.
40 Args:
41 units: Positive integer, dimensionality of the output space.
42 kernel_initializer: Initializer for the `kernel` weights matrix, used
43 for the linear transformation of the inputs.
44 recurrent_initializer: Initializer for the `recurrent_kernel` weights
45 matrix, used for the linear transformation of the recurrent state.
46 bias_initializer: Initializer for the bias vector.
47 kernel_regularizer: Regularizer function applied to the `kernel` weights
48 matrix.
49 recurrent_regularizer: Regularizer function applied to the
50 `recurrent_kernel` weights matrix.
51 bias_regularizer: Regularizer function applied to the bias vector.
52 activity_regularizer: Regularizer function applied to the output of the
53 layer (its "activation").
54 kernel_constraint: Constraint function applied to the `kernel` weights
55 matrix.
56 recurrent_constraint: Constraint function applied to the
57 `recurrent_kernel` weights matrix.
58 bias_constraint: Constraint function applied to the bias vector.
59 return_sequences: Boolean. Whether to return the last output in the
60 output sequence, or the full sequence.
61 return_state: Boolean. Whether to return the last state in addition to
62 the output.
63 go_backwards: Boolean (default False). If True, process the input
64 sequence backwards and return the reversed sequence.
65 stateful: Boolean (default False). If True, the last state for each
66 sample at index i in a batch will be used as initial state for the
67 sample of index i in the following batch.
68 """
70 def __init__(
71 self,
72 units,
73 kernel_initializer="glorot_uniform",
74 recurrent_initializer="orthogonal",
75 bias_initializer="zeros",
76 kernel_regularizer=None,
77 recurrent_regularizer=None,
78 bias_regularizer=None,
79 activity_regularizer=None,
80 kernel_constraint=None,
81 recurrent_constraint=None,
82 bias_constraint=None,
83 return_sequences=False,
84 return_state=False,
85 go_backwards=False,
86 stateful=False,
87 **kwargs
88 ):
89 self.units = units
90 cell_spec = collections.namedtuple("cell", "state_size")
91 self._cell = cell_spec(state_size=self.units)
92 super().__init__(
93 return_sequences=return_sequences,
94 return_state=return_state,
95 go_backwards=go_backwards,
96 stateful=stateful,
97 **kwargs
98 )
100 self.kernel_initializer = initializers.get(kernel_initializer)
101 self.recurrent_initializer = initializers.get(recurrent_initializer)
102 self.bias_initializer = initializers.get(bias_initializer)
104 self.kernel_regularizer = regularizers.get(kernel_regularizer)
105 self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
106 self.bias_regularizer = regularizers.get(bias_regularizer)
107 self.activity_regularizer = regularizers.get(activity_regularizer)
109 self.kernel_constraint = constraints.get(kernel_constraint)
110 self.recurrent_constraint = constraints.get(recurrent_constraint)
111 self.bias_constraint = constraints.get(bias_constraint)
113 @property
114 def cell(self):
115 return self._cell
117 def build(self, input_shape):
118 super().build(input_shape)
119 if isinstance(input_shape, list):
120 input_shape = input_shape[0]
121 input_dim = int(input_shape[-1])
123 self.kernel = self.add_weight(
124 shape=(input_dim, self.units * 3),
125 name="kernel",
126 initializer=self.kernel_initializer,
127 regularizer=self.kernel_regularizer,
128 constraint=self.kernel_constraint,
129 )
131 self.recurrent_kernel = self.add_weight(
132 shape=(self.units, self.units * 3),
133 name="recurrent_kernel",
134 initializer=self.recurrent_initializer,
135 regularizer=self.recurrent_regularizer,
136 constraint=self.recurrent_constraint,
137 )
139 self.bias = self.add_weight(
140 shape=(self.units * 6,),
141 name="bias",
142 initializer=self.bias_initializer,
143 regularizer=self.bias_regularizer,
144 constraint=self.bias_constraint,
145 )
147 self.built = True
149 def _process_batch(self, inputs, initial_state):
150 if not self.time_major:
151 inputs = tf.transpose(inputs, perm=(1, 0, 2))
152 input_h = initial_state[0]
153 input_h = tf.expand_dims(input_h, axis=0)
155 params = gru_lstm_utils.canonical_to_params(
156 weights=[
157 self.kernel[:, self.units : self.units * 2],
158 self.kernel[:, : self.units],
159 self.kernel[:, self.units * 2 :],
160 self.recurrent_kernel[:, self.units : self.units * 2],
161 self.recurrent_kernel[:, : self.units],
162 self.recurrent_kernel[:, self.units * 2 :],
163 ],
164 biases=[
165 self.bias[self.units : self.units * 2],
166 self.bias[: self.units],
167 self.bias[self.units * 2 : self.units * 3],
168 self.bias[self.units * 4 : self.units * 5],
169 self.bias[self.units * 3 : self.units * 4],
170 self.bias[self.units * 5 :],
171 ],
172 shape=self._vector_shape,
173 )
175 args = {
176 "input": inputs,
177 "input_h": input_h,
178 "input_c": 0,
179 "params": params,
180 "is_training": True,
181 "rnn_mode": "gru",
182 }
184 outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV2(**args)
186 if self.stateful or self.return_state:
187 h = h[0]
188 if self.return_sequences:
189 if self.time_major:
190 output = outputs
191 else:
192 output = tf.transpose(outputs, perm=(1, 0, 2))
193 else:
194 output = outputs[-1]
195 return output, [h]
197 def get_config(self):
198 config = {
199 "units": self.units,
200 "kernel_initializer": initializers.serialize(
201 self.kernel_initializer
202 ),
203 "recurrent_initializer": initializers.serialize(
204 self.recurrent_initializer
205 ),
206 "bias_initializer": initializers.serialize(self.bias_initializer),
207 "kernel_regularizer": regularizers.serialize(
208 self.kernel_regularizer
209 ),
210 "recurrent_regularizer": regularizers.serialize(
211 self.recurrent_regularizer
212 ),
213 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
214 "activity_regularizer": regularizers.serialize(
215 self.activity_regularizer
216 ),
217 "kernel_constraint": constraints.serialize(self.kernel_constraint),
218 "recurrent_constraint": constraints.serialize(
219 self.recurrent_constraint
220 ),
221 "bias_constraint": constraints.serialize(self.bias_constraint),
222 }
223 base_config = super().get_config()
224 return dict(list(base_config.items()) + list(config.items()))