Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/cudnn_gru.py: 29%

58 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Fast GRU layer backed by cuDNN.""" 

16 

17 

18import collections 

19 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src import constraints 

23from keras.src import initializers 

24from keras.src import regularizers 

25from keras.src.layers.rnn import gru_lstm_utils 

26from keras.src.layers.rnn.base_cudnn_rnn import _CuDNNRNN 

27 

28# isort: off 

29from tensorflow.python.util.tf_export import keras_export 

30 

31 

32@keras_export(v1=["keras.layers.CuDNNGRU"]) 

33class CuDNNGRU(_CuDNNRNN): 

34 """Fast GRU implementation backed by cuDNN. 

35 

36 More information about cuDNN can be found on the [NVIDIA 

37 developer website](https://developer.nvidia.com/cudnn). 

38 Can only be run on GPU. 

39 

40 Args: 

41 units: Positive integer, dimensionality of the output space. 

42 kernel_initializer: Initializer for the `kernel` weights matrix, used 

43 for the linear transformation of the inputs. 

44 recurrent_initializer: Initializer for the `recurrent_kernel` weights 

45 matrix, used for the linear transformation of the recurrent state. 

46 bias_initializer: Initializer for the bias vector. 

47 kernel_regularizer: Regularizer function applied to the `kernel` weights 

48 matrix. 

49 recurrent_regularizer: Regularizer function applied to the 

50 `recurrent_kernel` weights matrix. 

51 bias_regularizer: Regularizer function applied to the bias vector. 

52 activity_regularizer: Regularizer function applied to the output of the 

53 layer (its "activation"). 

54 kernel_constraint: Constraint function applied to the `kernel` weights 

55 matrix. 

56 recurrent_constraint: Constraint function applied to the 

57 `recurrent_kernel` weights matrix. 

58 bias_constraint: Constraint function applied to the bias vector. 

59 return_sequences: Boolean. Whether to return the last output in the 

60 output sequence, or the full sequence. 

61 return_state: Boolean. Whether to return the last state in addition to 

62 the output. 

63 go_backwards: Boolean (default False). If True, process the input 

64 sequence backwards and return the reversed sequence. 

65 stateful: Boolean (default False). If True, the last state for each 

66 sample at index i in a batch will be used as initial state for the 

67 sample of index i in the following batch. 

68 """ 

69 

70 def __init__( 

71 self, 

72 units, 

73 kernel_initializer="glorot_uniform", 

74 recurrent_initializer="orthogonal", 

75 bias_initializer="zeros", 

76 kernel_regularizer=None, 

77 recurrent_regularizer=None, 

78 bias_regularizer=None, 

79 activity_regularizer=None, 

80 kernel_constraint=None, 

81 recurrent_constraint=None, 

82 bias_constraint=None, 

83 return_sequences=False, 

84 return_state=False, 

85 go_backwards=False, 

86 stateful=False, 

87 **kwargs 

88 ): 

89 self.units = units 

90 cell_spec = collections.namedtuple("cell", "state_size") 

91 self._cell = cell_spec(state_size=self.units) 

92 super().__init__( 

93 return_sequences=return_sequences, 

94 return_state=return_state, 

95 go_backwards=go_backwards, 

96 stateful=stateful, 

97 **kwargs 

98 ) 

99 

100 self.kernel_initializer = initializers.get(kernel_initializer) 

101 self.recurrent_initializer = initializers.get(recurrent_initializer) 

102 self.bias_initializer = initializers.get(bias_initializer) 

103 

104 self.kernel_regularizer = regularizers.get(kernel_regularizer) 

105 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 

106 self.bias_regularizer = regularizers.get(bias_regularizer) 

107 self.activity_regularizer = regularizers.get(activity_regularizer) 

108 

109 self.kernel_constraint = constraints.get(kernel_constraint) 

110 self.recurrent_constraint = constraints.get(recurrent_constraint) 

111 self.bias_constraint = constraints.get(bias_constraint) 

112 

113 @property 

114 def cell(self): 

115 return self._cell 

116 

117 def build(self, input_shape): 

118 super().build(input_shape) 

119 if isinstance(input_shape, list): 

120 input_shape = input_shape[0] 

121 input_dim = int(input_shape[-1]) 

122 

123 self.kernel = self.add_weight( 

124 shape=(input_dim, self.units * 3), 

125 name="kernel", 

126 initializer=self.kernel_initializer, 

127 regularizer=self.kernel_regularizer, 

128 constraint=self.kernel_constraint, 

129 ) 

130 

131 self.recurrent_kernel = self.add_weight( 

132 shape=(self.units, self.units * 3), 

133 name="recurrent_kernel", 

134 initializer=self.recurrent_initializer, 

135 regularizer=self.recurrent_regularizer, 

136 constraint=self.recurrent_constraint, 

137 ) 

138 

139 self.bias = self.add_weight( 

140 shape=(self.units * 6,), 

141 name="bias", 

142 initializer=self.bias_initializer, 

143 regularizer=self.bias_regularizer, 

144 constraint=self.bias_constraint, 

145 ) 

146 

147 self.built = True 

148 

149 def _process_batch(self, inputs, initial_state): 

150 if not self.time_major: 

151 inputs = tf.transpose(inputs, perm=(1, 0, 2)) 

152 input_h = initial_state[0] 

153 input_h = tf.expand_dims(input_h, axis=0) 

154 

155 params = gru_lstm_utils.canonical_to_params( 

156 weights=[ 

157 self.kernel[:, self.units : self.units * 2], 

158 self.kernel[:, : self.units], 

159 self.kernel[:, self.units * 2 :], 

160 self.recurrent_kernel[:, self.units : self.units * 2], 

161 self.recurrent_kernel[:, : self.units], 

162 self.recurrent_kernel[:, self.units * 2 :], 

163 ], 

164 biases=[ 

165 self.bias[self.units : self.units * 2], 

166 self.bias[: self.units], 

167 self.bias[self.units * 2 : self.units * 3], 

168 self.bias[self.units * 4 : self.units * 5], 

169 self.bias[self.units * 3 : self.units * 4], 

170 self.bias[self.units * 5 :], 

171 ], 

172 shape=self._vector_shape, 

173 ) 

174 

175 args = { 

176 "input": inputs, 

177 "input_h": input_h, 

178 "input_c": 0, 

179 "params": params, 

180 "is_training": True, 

181 "rnn_mode": "gru", 

182 } 

183 

184 outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV2(**args) 

185 

186 if self.stateful or self.return_state: 

187 h = h[0] 

188 if self.return_sequences: 

189 if self.time_major: 

190 output = outputs 

191 else: 

192 output = tf.transpose(outputs, perm=(1, 0, 2)) 

193 else: 

194 output = outputs[-1] 

195 return output, [h] 

196 

197 def get_config(self): 

198 config = { 

199 "units": self.units, 

200 "kernel_initializer": initializers.serialize( 

201 self.kernel_initializer 

202 ), 

203 "recurrent_initializer": initializers.serialize( 

204 self.recurrent_initializer 

205 ), 

206 "bias_initializer": initializers.serialize(self.bias_initializer), 

207 "kernel_regularizer": regularizers.serialize( 

208 self.kernel_regularizer 

209 ), 

210 "recurrent_regularizer": regularizers.serialize( 

211 self.recurrent_regularizer 

212 ), 

213 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

214 "activity_regularizer": regularizers.serialize( 

215 self.activity_regularizer 

216 ), 

217 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

218 "recurrent_constraint": constraints.serialize( 

219 self.recurrent_constraint 

220 ), 

221 "bias_constraint": constraints.serialize(self.bias_constraint), 

222 } 

223 base_config = super().get_config() 

224 return dict(list(base_config.items()) + list(config.items())) 

225