Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/cudnn_lstm.py: 26%

66 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Fast LSTM layer backed by cuDNN.""" 

16 

17 

18import collections 

19 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src import constraints 

23from keras.src import initializers 

24from keras.src import regularizers 

25from keras.src.layers.rnn import gru_lstm_utils 

26from keras.src.layers.rnn.base_cudnn_rnn import _CuDNNRNN 

27 

28# isort: off 

29from tensorflow.python.util.tf_export import keras_export 

30 

31 

32@keras_export(v1=["keras.layers.CuDNNLSTM"]) 

33class CuDNNLSTM(_CuDNNRNN): 

34 """Fast LSTM implementation backed by cuDNN. 

35 

36 More information about cuDNN can be found on the [NVIDIA 

37 developer website](https://developer.nvidia.com/cudnn). 

38 Can only be run on GPU. 

39 

40 Args: 

41 units: Positive integer, dimensionality of the output space. 

42 kernel_initializer: Initializer for the `kernel` weights matrix, used 

43 for the linear transformation of the inputs. 

44 unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate 

45 at initialization. Setting it to true will also force 

46 `bias_initializer="zeros"`. This is recommended in [Jozefowicz et 

47 al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 

48 recurrent_initializer: Initializer for the `recurrent_kernel` weights 

49 matrix, used for the linear transformation of the recurrent state. 

50 bias_initializer: Initializer for the bias vector. 

51 kernel_regularizer: Regularizer function applied to the `kernel` weights 

52 matrix. 

53 recurrent_regularizer: Regularizer function applied to the 

54 `recurrent_kernel` weights matrix. 

55 bias_regularizer: Regularizer function applied to the bias vector. 

56 activity_regularizer: Regularizer function applied to the output of the 

57 layer (its "activation"). 

58 kernel_constraint: Constraint function applied to the `kernel` weights 

59 matrix. 

60 recurrent_constraint: Constraint function applied to the 

61 `recurrent_kernel` weights matrix. 

62 bias_constraint: Constraint function applied to the bias vector. 

63 return_sequences: Boolean. Whether to return the last output. in the 

64 output sequence, or the full sequence. 

65 return_state: Boolean. Whether to return the last state in addition to 

66 the output. 

67 go_backwards: Boolean (default False). If True, process the input 

68 sequence backwards and return the reversed sequence. 

69 stateful: Boolean (default False). If True, the last state for each 

70 sample at index i in a batch will be used as initial state for the 

71 sample of index i in the following batch. 

72 """ 

73 

74 def __init__( 

75 self, 

76 units, 

77 kernel_initializer="glorot_uniform", 

78 recurrent_initializer="orthogonal", 

79 bias_initializer="zeros", 

80 unit_forget_bias=True, 

81 kernel_regularizer=None, 

82 recurrent_regularizer=None, 

83 bias_regularizer=None, 

84 activity_regularizer=None, 

85 kernel_constraint=None, 

86 recurrent_constraint=None, 

87 bias_constraint=None, 

88 return_sequences=False, 

89 return_state=False, 

90 go_backwards=False, 

91 stateful=False, 

92 **kwargs 

93 ): 

94 self.units = units 

95 cell_spec = collections.namedtuple("cell", "state_size") 

96 self._cell = cell_spec(state_size=(self.units, self.units)) 

97 super().__init__( 

98 return_sequences=return_sequences, 

99 return_state=return_state, 

100 go_backwards=go_backwards, 

101 stateful=stateful, 

102 **kwargs 

103 ) 

104 

105 self.kernel_initializer = initializers.get(kernel_initializer) 

106 self.recurrent_initializer = initializers.get(recurrent_initializer) 

107 self.bias_initializer = initializers.get(bias_initializer) 

108 self.unit_forget_bias = unit_forget_bias 

109 

110 self.kernel_regularizer = regularizers.get(kernel_regularizer) 

111 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 

112 self.bias_regularizer = regularizers.get(bias_regularizer) 

113 self.activity_regularizer = regularizers.get(activity_regularizer) 

114 

115 self.kernel_constraint = constraints.get(kernel_constraint) 

116 self.recurrent_constraint = constraints.get(recurrent_constraint) 

117 self.bias_constraint = constraints.get(bias_constraint) 

118 

119 @property 

120 def cell(self): 

121 return self._cell 

122 

123 def build(self, input_shape): 

124 super().build(input_shape) 

125 if isinstance(input_shape, list): 

126 input_shape = input_shape[0] 

127 input_dim = int(input_shape[-1]) 

128 

129 self.kernel = self.add_weight( 

130 shape=(input_dim, self.units * 4), 

131 name="kernel", 

132 initializer=self.kernel_initializer, 

133 regularizer=self.kernel_regularizer, 

134 constraint=self.kernel_constraint, 

135 ) 

136 

137 self.recurrent_kernel = self.add_weight( 

138 shape=(self.units, self.units * 4), 

139 name="recurrent_kernel", 

140 initializer=self.recurrent_initializer, 

141 regularizer=self.recurrent_regularizer, 

142 constraint=self.recurrent_constraint, 

143 ) 

144 

145 if self.unit_forget_bias: 

146 

147 def bias_initializer(_, *args, **kwargs): 

148 return tf.concat( 

149 [ 

150 self.bias_initializer( 

151 (self.units * 5,), *args, **kwargs 

152 ), 

153 tf.compat.v1.ones_initializer()( 

154 (self.units,), *args, **kwargs 

155 ), 

156 self.bias_initializer( 

157 (self.units * 2,), *args, **kwargs 

158 ), 

159 ], 

160 axis=0, 

161 ) 

162 

163 else: 

164 bias_initializer = self.bias_initializer 

165 self.bias = self.add_weight( 

166 shape=(self.units * 8,), 

167 name="bias", 

168 initializer=bias_initializer, 

169 regularizer=self.bias_regularizer, 

170 constraint=self.bias_constraint, 

171 ) 

172 

173 self.built = True 

174 

175 def _process_batch(self, inputs, initial_state): 

176 if not self.time_major: 

177 inputs = tf.transpose(inputs, perm=(1, 0, 2)) 

178 input_h = initial_state[0] 

179 input_c = initial_state[1] 

180 input_h = tf.expand_dims(input_h, axis=0) 

181 input_c = tf.expand_dims(input_c, axis=0) 

182 

183 params = gru_lstm_utils.canonical_to_params( 

184 weights=[ 

185 self.kernel[:, : self.units], 

186 self.kernel[:, self.units : self.units * 2], 

187 self.kernel[:, self.units * 2 : self.units * 3], 

188 self.kernel[:, self.units * 3 :], 

189 self.recurrent_kernel[:, : self.units], 

190 self.recurrent_kernel[:, self.units : self.units * 2], 

191 self.recurrent_kernel[:, self.units * 2 : self.units * 3], 

192 self.recurrent_kernel[:, self.units * 3 :], 

193 ], 

194 biases=[ 

195 self.bias[: self.units], 

196 self.bias[self.units : self.units * 2], 

197 self.bias[self.units * 2 : self.units * 3], 

198 self.bias[self.units * 3 : self.units * 4], 

199 self.bias[self.units * 4 : self.units * 5], 

200 self.bias[self.units * 5 : self.units * 6], 

201 self.bias[self.units * 6 : self.units * 7], 

202 self.bias[self.units * 7 :], 

203 ], 

204 shape=self._vector_shape, 

205 ) 

206 

207 args = { 

208 "input": inputs, 

209 "input_h": input_h, 

210 "input_c": input_c, 

211 "params": params, 

212 "is_training": True, 

213 } 

214 

215 outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV2(**args) 

216 

217 if self.stateful or self.return_state: 

218 h = h[0] 

219 c = c[0] 

220 if self.return_sequences: 

221 if self.time_major: 

222 output = outputs 

223 else: 

224 output = tf.transpose(outputs, perm=(1, 0, 2)) 

225 else: 

226 output = outputs[-1] 

227 return output, [h, c] 

228 

229 def get_config(self): 

230 config = { 

231 "units": self.units, 

232 "kernel_initializer": initializers.serialize( 

233 self.kernel_initializer 

234 ), 

235 "recurrent_initializer": initializers.serialize( 

236 self.recurrent_initializer 

237 ), 

238 "bias_initializer": initializers.serialize(self.bias_initializer), 

239 "unit_forget_bias": self.unit_forget_bias, 

240 "kernel_regularizer": regularizers.serialize( 

241 self.kernel_regularizer 

242 ), 

243 "recurrent_regularizer": regularizers.serialize( 

244 self.recurrent_regularizer 

245 ), 

246 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

247 "activity_regularizer": regularizers.serialize( 

248 self.activity_regularizer 

249 ), 

250 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

251 "recurrent_constraint": constraints.serialize( 

252 self.recurrent_constraint 

253 ), 

254 "bias_constraint": constraints.serialize(self.bias_constraint), 

255 } 

256 base_config = super().get_config() 

257 return dict(list(base_config.items()) + list(config.items())) 

258