Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/cudnn_lstm.py: 26%
66 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Fast LSTM layer backed by cuDNN."""
18import collections
20import tensorflow.compat.v2 as tf
22from keras.src import constraints
23from keras.src import initializers
24from keras.src import regularizers
25from keras.src.layers.rnn import gru_lstm_utils
26from keras.src.layers.rnn.base_cudnn_rnn import _CuDNNRNN
28# isort: off
29from tensorflow.python.util.tf_export import keras_export
32@keras_export(v1=["keras.layers.CuDNNLSTM"])
33class CuDNNLSTM(_CuDNNRNN):
34 """Fast LSTM implementation backed by cuDNN.
36 More information about cuDNN can be found on the [NVIDIA
37 developer website](https://developer.nvidia.com/cudnn).
38 Can only be run on GPU.
40 Args:
41 units: Positive integer, dimensionality of the output space.
42 kernel_initializer: Initializer for the `kernel` weights matrix, used
43 for the linear transformation of the inputs.
44 unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate
45 at initialization. Setting it to true will also force
46 `bias_initializer="zeros"`. This is recommended in [Jozefowicz et
47 al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
48 recurrent_initializer: Initializer for the `recurrent_kernel` weights
49 matrix, used for the linear transformation of the recurrent state.
50 bias_initializer: Initializer for the bias vector.
51 kernel_regularizer: Regularizer function applied to the `kernel` weights
52 matrix.
53 recurrent_regularizer: Regularizer function applied to the
54 `recurrent_kernel` weights matrix.
55 bias_regularizer: Regularizer function applied to the bias vector.
56 activity_regularizer: Regularizer function applied to the output of the
57 layer (its "activation").
58 kernel_constraint: Constraint function applied to the `kernel` weights
59 matrix.
60 recurrent_constraint: Constraint function applied to the
61 `recurrent_kernel` weights matrix.
62 bias_constraint: Constraint function applied to the bias vector.
63 return_sequences: Boolean. Whether to return the last output. in the
64 output sequence, or the full sequence.
65 return_state: Boolean. Whether to return the last state in addition to
66 the output.
67 go_backwards: Boolean (default False). If True, process the input
68 sequence backwards and return the reversed sequence.
69 stateful: Boolean (default False). If True, the last state for each
70 sample at index i in a batch will be used as initial state for the
71 sample of index i in the following batch.
72 """
74 def __init__(
75 self,
76 units,
77 kernel_initializer="glorot_uniform",
78 recurrent_initializer="orthogonal",
79 bias_initializer="zeros",
80 unit_forget_bias=True,
81 kernel_regularizer=None,
82 recurrent_regularizer=None,
83 bias_regularizer=None,
84 activity_regularizer=None,
85 kernel_constraint=None,
86 recurrent_constraint=None,
87 bias_constraint=None,
88 return_sequences=False,
89 return_state=False,
90 go_backwards=False,
91 stateful=False,
92 **kwargs
93 ):
94 self.units = units
95 cell_spec = collections.namedtuple("cell", "state_size")
96 self._cell = cell_spec(state_size=(self.units, self.units))
97 super().__init__(
98 return_sequences=return_sequences,
99 return_state=return_state,
100 go_backwards=go_backwards,
101 stateful=stateful,
102 **kwargs
103 )
105 self.kernel_initializer = initializers.get(kernel_initializer)
106 self.recurrent_initializer = initializers.get(recurrent_initializer)
107 self.bias_initializer = initializers.get(bias_initializer)
108 self.unit_forget_bias = unit_forget_bias
110 self.kernel_regularizer = regularizers.get(kernel_regularizer)
111 self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
112 self.bias_regularizer = regularizers.get(bias_regularizer)
113 self.activity_regularizer = regularizers.get(activity_regularizer)
115 self.kernel_constraint = constraints.get(kernel_constraint)
116 self.recurrent_constraint = constraints.get(recurrent_constraint)
117 self.bias_constraint = constraints.get(bias_constraint)
119 @property
120 def cell(self):
121 return self._cell
123 def build(self, input_shape):
124 super().build(input_shape)
125 if isinstance(input_shape, list):
126 input_shape = input_shape[0]
127 input_dim = int(input_shape[-1])
129 self.kernel = self.add_weight(
130 shape=(input_dim, self.units * 4),
131 name="kernel",
132 initializer=self.kernel_initializer,
133 regularizer=self.kernel_regularizer,
134 constraint=self.kernel_constraint,
135 )
137 self.recurrent_kernel = self.add_weight(
138 shape=(self.units, self.units * 4),
139 name="recurrent_kernel",
140 initializer=self.recurrent_initializer,
141 regularizer=self.recurrent_regularizer,
142 constraint=self.recurrent_constraint,
143 )
145 if self.unit_forget_bias:
147 def bias_initializer(_, *args, **kwargs):
148 return tf.concat(
149 [
150 self.bias_initializer(
151 (self.units * 5,), *args, **kwargs
152 ),
153 tf.compat.v1.ones_initializer()(
154 (self.units,), *args, **kwargs
155 ),
156 self.bias_initializer(
157 (self.units * 2,), *args, **kwargs
158 ),
159 ],
160 axis=0,
161 )
163 else:
164 bias_initializer = self.bias_initializer
165 self.bias = self.add_weight(
166 shape=(self.units * 8,),
167 name="bias",
168 initializer=bias_initializer,
169 regularizer=self.bias_regularizer,
170 constraint=self.bias_constraint,
171 )
173 self.built = True
175 def _process_batch(self, inputs, initial_state):
176 if not self.time_major:
177 inputs = tf.transpose(inputs, perm=(1, 0, 2))
178 input_h = initial_state[0]
179 input_c = initial_state[1]
180 input_h = tf.expand_dims(input_h, axis=0)
181 input_c = tf.expand_dims(input_c, axis=0)
183 params = gru_lstm_utils.canonical_to_params(
184 weights=[
185 self.kernel[:, : self.units],
186 self.kernel[:, self.units : self.units * 2],
187 self.kernel[:, self.units * 2 : self.units * 3],
188 self.kernel[:, self.units * 3 :],
189 self.recurrent_kernel[:, : self.units],
190 self.recurrent_kernel[:, self.units : self.units * 2],
191 self.recurrent_kernel[:, self.units * 2 : self.units * 3],
192 self.recurrent_kernel[:, self.units * 3 :],
193 ],
194 biases=[
195 self.bias[: self.units],
196 self.bias[self.units : self.units * 2],
197 self.bias[self.units * 2 : self.units * 3],
198 self.bias[self.units * 3 : self.units * 4],
199 self.bias[self.units * 4 : self.units * 5],
200 self.bias[self.units * 5 : self.units * 6],
201 self.bias[self.units * 6 : self.units * 7],
202 self.bias[self.units * 7 :],
203 ],
204 shape=self._vector_shape,
205 )
207 args = {
208 "input": inputs,
209 "input_h": input_h,
210 "input_c": input_c,
211 "params": params,
212 "is_training": True,
213 }
215 outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV2(**args)
217 if self.stateful or self.return_state:
218 h = h[0]
219 c = c[0]
220 if self.return_sequences:
221 if self.time_major:
222 output = outputs
223 else:
224 output = tf.transpose(outputs, perm=(1, 0, 2))
225 else:
226 output = outputs[-1]
227 return output, [h, c]
229 def get_config(self):
230 config = {
231 "units": self.units,
232 "kernel_initializer": initializers.serialize(
233 self.kernel_initializer
234 ),
235 "recurrent_initializer": initializers.serialize(
236 self.recurrent_initializer
237 ),
238 "bias_initializer": initializers.serialize(self.bias_initializer),
239 "unit_forget_bias": self.unit_forget_bias,
240 "kernel_regularizer": regularizers.serialize(
241 self.kernel_regularizer
242 ),
243 "recurrent_regularizer": regularizers.serialize(
244 self.recurrent_regularizer
245 ),
246 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
247 "activity_regularizer": regularizers.serialize(
248 self.activity_regularizer
249 ),
250 "kernel_constraint": constraints.serialize(self.kernel_constraint),
251 "recurrent_constraint": constraints.serialize(
252 self.recurrent_constraint
253 ),
254 "bias_constraint": constraints.serialize(self.bias_constraint),
255 }
256 base_config = super().get_config()
257 return dict(list(base_config.items()) + list(config.items()))