Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/gru_v1.py: 60%

92 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Gated Recurrent Unit V1 layer.""" 

16 

17 

18from keras.src import activations 

19from keras.src import constraints 

20from keras.src import initializers 

21from keras.src import regularizers 

22from keras.src.engine.input_spec import InputSpec 

23from keras.src.layers.rnn import gru 

24from keras.src.layers.rnn import rnn_utils 

25from keras.src.layers.rnn.base_rnn import RNN 

26 

27# isort: off 

28from tensorflow.python.platform import tf_logging as logging 

29from tensorflow.python.util.tf_export import keras_export 

30 

31 

32@keras_export(v1=["keras.layers.GRUCell"]) 

33class GRUCell(gru.GRUCell): 

34 """Cell class for the GRU layer. 

35 

36 Args: 

37 units: Positive integer, dimensionality of the output space. 

38 activation: Activation function to use. 

39 Default: hyperbolic tangent (`tanh`). 

40 If you pass None, no activation is applied 

41 (ie. "linear" activation: `a(x) = x`). 

42 recurrent_activation: Activation function to use 

43 for the recurrent step. 

44 Default: hard sigmoid (`hard_sigmoid`). 

45 If you pass `None`, no activation is applied 

46 (ie. "linear" activation: `a(x) = x`). 

47 use_bias: Boolean, whether the layer uses a bias vector. 

48 kernel_initializer: Initializer for the `kernel` weights matrix, 

49 used for the linear transformation of the inputs. 

50 recurrent_initializer: Initializer for the `recurrent_kernel` 

51 weights matrix, 

52 used for the linear transformation of the recurrent state. 

53 bias_initializer: Initializer for the bias vector. 

54 kernel_regularizer: Regularizer function applied to 

55 the `kernel` weights matrix. 

56 recurrent_regularizer: Regularizer function applied to 

57 the `recurrent_kernel` weights matrix. 

58 bias_regularizer: Regularizer function applied to the bias vector. 

59 kernel_constraint: Constraint function applied to 

60 the `kernel` weights matrix. 

61 recurrent_constraint: Constraint function applied to 

62 the `recurrent_kernel` weights matrix. 

63 bias_constraint: Constraint function applied to the bias vector. 

64 dropout: Float between 0 and 1. Fraction of the units to drop for the 

65 linear transformation of the inputs. 

66 recurrent_dropout: Float between 0 and 1. 

67 Fraction of the units to drop for 

68 the linear transformation of the recurrent state. 

69 reset_after: GRU convention (whether to apply reset gate after or 

70 before matrix multiplication). False = "before" (default), 

71 True = "after" (cuDNN compatible). 

72 

73 Call arguments: 

74 inputs: A 2D tensor. 

75 states: List of state tensors corresponding to the previous timestep. 

76 training: Python boolean indicating whether the layer should behave in 

77 training mode or in inference mode. Only relevant when `dropout` or 

78 `recurrent_dropout` is used. 

79 """ 

80 

81 def __init__( 

82 self, 

83 units, 

84 activation="tanh", 

85 recurrent_activation="hard_sigmoid", 

86 use_bias=True, 

87 kernel_initializer="glorot_uniform", 

88 recurrent_initializer="orthogonal", 

89 bias_initializer="zeros", 

90 kernel_regularizer=None, 

91 recurrent_regularizer=None, 

92 bias_regularizer=None, 

93 kernel_constraint=None, 

94 recurrent_constraint=None, 

95 bias_constraint=None, 

96 dropout=0.0, 

97 recurrent_dropout=0.0, 

98 reset_after=False, 

99 **kwargs 

100 ): 

101 super().__init__( 

102 units, 

103 activation=activation, 

104 recurrent_activation=recurrent_activation, 

105 use_bias=use_bias, 

106 kernel_initializer=kernel_initializer, 

107 recurrent_initializer=recurrent_initializer, 

108 bias_initializer=bias_initializer, 

109 kernel_regularizer=kernel_regularizer, 

110 recurrent_regularizer=recurrent_regularizer, 

111 bias_regularizer=bias_regularizer, 

112 kernel_constraint=kernel_constraint, 

113 recurrent_constraint=recurrent_constraint, 

114 bias_constraint=bias_constraint, 

115 dropout=dropout, 

116 recurrent_dropout=recurrent_dropout, 

117 implementation=kwargs.pop("implementation", 1), 

118 reset_after=reset_after, 

119 **kwargs 

120 ) 

121 

122 

123@keras_export(v1=["keras.layers.GRU"]) 

124class GRU(RNN): 

125 """Gated Recurrent Unit - Cho et al. 2014. 

126 

127 There are two variants. The default one is based on 1406.1078v3 and 

128 has reset gate applied to hidden state before matrix multiplication. The 

129 other one is based on original 1406.1078v1 and has the order reversed. 

130 

131 The second variant is compatible with CuDNNGRU (GPU-only) and allows 

132 inference on CPU. Thus it has separate biases for `kernel` and 

133 `recurrent_kernel`. Use `'reset_after'=True` and 

134 `recurrent_activation='sigmoid'`. 

135 

136 Args: 

137 units: Positive integer, dimensionality of the output space. 

138 activation: Activation function to use. 

139 Default: hyperbolic tangent (`tanh`). 

140 If you pass `None`, no activation is applied 

141 (ie. "linear" activation: `a(x) = x`). 

142 recurrent_activation: Activation function to use 

143 for the recurrent step. 

144 Default: hard sigmoid (`hard_sigmoid`). 

145 If you pass `None`, no activation is applied 

146 (ie. "linear" activation: `a(x) = x`). 

147 use_bias: Boolean, whether the layer uses a bias vector. 

148 kernel_initializer: Initializer for the `kernel` weights matrix, 

149 used for the linear transformation of the inputs. 

150 recurrent_initializer: Initializer for the `recurrent_kernel` weights 

151 matrix, used for the linear transformation of the recurrent state. 

152 bias_initializer: Initializer for the bias vector. 

153 kernel_regularizer: Regularizer function applied to 

154 the `kernel` weights matrix. 

155 recurrent_regularizer: Regularizer function applied to 

156 the `recurrent_kernel` weights matrix. 

157 bias_regularizer: Regularizer function applied to the bias vector. 

158 activity_regularizer: Regularizer function applied to 

159 the output of the layer (its "activation").. 

160 kernel_constraint: Constraint function applied to 

161 the `kernel` weights matrix. 

162 recurrent_constraint: Constraint function applied to 

163 the `recurrent_kernel` weights matrix. 

164 bias_constraint: Constraint function applied to the bias vector. 

165 dropout: Float between 0 and 1. 

166 Fraction of the units to drop for 

167 the linear transformation of the inputs. 

168 recurrent_dropout: Float between 0 and 1. 

169 Fraction of the units to drop for 

170 the linear transformation of the recurrent state. 

171 return_sequences: Boolean. Whether to return the last output 

172 in the output sequence, or the full sequence. 

173 return_state: Boolean. Whether to return the last state 

174 in addition to the output. 

175 go_backwards: Boolean (default False). 

176 If True, process the input sequence backwards and return the 

177 reversed sequence. 

178 stateful: Boolean (default False). If True, the last state 

179 for each sample at index i in a batch will be used as initial 

180 state for the sample of index i in the following batch. 

181 unroll: Boolean (default False). 

182 If True, the network will be unrolled, 

183 else a symbolic loop will be used. 

184 Unrolling can speed-up a RNN, 

185 although it tends to be more memory-intensive. 

186 Unrolling is only suitable for short sequences. 

187 time_major: The shape format of the `inputs` and `outputs` tensors. 

188 If True, the inputs and outputs will be in shape 

189 `(timesteps, batch, ...)`, whereas in the False case, it will be 

190 `(batch, timesteps, ...)`. Using `time_major = True` is a bit more 

191 efficient because it avoids transposes at the beginning and end of the 

192 RNN calculation. However, most TensorFlow data is batch-major, so by 

193 default this function accepts input and emits output in batch-major 

194 form. 

195 reset_after: GRU convention (whether to apply reset gate after or 

196 before matrix multiplication). False = "before" (default), 

197 True = "after" (cuDNN compatible). 

198 

199 Call arguments: 

200 inputs: A 3D tensor. 

201 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 

202 a given timestep should be masked. An individual `True` entry indicates 

203 that the corresponding timestep should be utilized, while a `False` 

204 entry indicates that the corresponding timestep should be ignored. 

205 training: Python boolean indicating whether the layer should behave in 

206 training mode or in inference mode. This argument is passed to the cell 

207 when calling it. This is only relevant if `dropout` or 

208 `recurrent_dropout` is used. 

209 initial_state: List of initial state tensors to be passed to the first 

210 call of the cell. 

211 """ 

212 

213 def __init__( 

214 self, 

215 units, 

216 activation="tanh", 

217 recurrent_activation="hard_sigmoid", 

218 use_bias=True, 

219 kernel_initializer="glorot_uniform", 

220 recurrent_initializer="orthogonal", 

221 bias_initializer="zeros", 

222 kernel_regularizer=None, 

223 recurrent_regularizer=None, 

224 bias_regularizer=None, 

225 activity_regularizer=None, 

226 kernel_constraint=None, 

227 recurrent_constraint=None, 

228 bias_constraint=None, 

229 dropout=0.0, 

230 recurrent_dropout=0.0, 

231 return_sequences=False, 

232 return_state=False, 

233 go_backwards=False, 

234 stateful=False, 

235 unroll=False, 

236 reset_after=False, 

237 **kwargs 

238 ): 

239 implementation = kwargs.pop("implementation", 1) 

240 if implementation == 0: 

241 logging.warning( 

242 "`implementation=0` has been deprecated, " 

243 "and now defaults to `implementation=1`." 

244 "Please update your layer call." 

245 ) 

246 if "enable_caching_device" in kwargs: 

247 cell_kwargs = { 

248 "enable_caching_device": kwargs.pop("enable_caching_device") 

249 } 

250 else: 

251 cell_kwargs = {} 

252 cell = GRUCell( 

253 units, 

254 activation=activation, 

255 recurrent_activation=recurrent_activation, 

256 use_bias=use_bias, 

257 kernel_initializer=kernel_initializer, 

258 recurrent_initializer=recurrent_initializer, 

259 bias_initializer=bias_initializer, 

260 kernel_regularizer=kernel_regularizer, 

261 recurrent_regularizer=recurrent_regularizer, 

262 bias_regularizer=bias_regularizer, 

263 kernel_constraint=kernel_constraint, 

264 recurrent_constraint=recurrent_constraint, 

265 bias_constraint=bias_constraint, 

266 dropout=dropout, 

267 recurrent_dropout=recurrent_dropout, 

268 implementation=implementation, 

269 reset_after=reset_after, 

270 dtype=kwargs.get("dtype"), 

271 trainable=kwargs.get("trainable", True), 

272 name="gru_cell", 

273 **cell_kwargs 

274 ) 

275 super().__init__( 

276 cell, 

277 return_sequences=return_sequences, 

278 return_state=return_state, 

279 go_backwards=go_backwards, 

280 stateful=stateful, 

281 unroll=unroll, 

282 **kwargs 

283 ) 

284 self.activity_regularizer = regularizers.get(activity_regularizer) 

285 self.input_spec = [InputSpec(ndim=3)] 

286 

287 def call(self, inputs, mask=None, training=None, initial_state=None): 

288 return super().call( 

289 inputs, mask=mask, training=training, initial_state=initial_state 

290 ) 

291 

292 @property 

293 def units(self): 

294 return self.cell.units 

295 

296 @property 

297 def activation(self): 

298 return self.cell.activation 

299 

300 @property 

301 def recurrent_activation(self): 

302 return self.cell.recurrent_activation 

303 

304 @property 

305 def use_bias(self): 

306 return self.cell.use_bias 

307 

308 @property 

309 def kernel_initializer(self): 

310 return self.cell.kernel_initializer 

311 

312 @property 

313 def recurrent_initializer(self): 

314 return self.cell.recurrent_initializer 

315 

316 @property 

317 def bias_initializer(self): 

318 return self.cell.bias_initializer 

319 

320 @property 

321 def kernel_regularizer(self): 

322 return self.cell.kernel_regularizer 

323 

324 @property 

325 def recurrent_regularizer(self): 

326 return self.cell.recurrent_regularizer 

327 

328 @property 

329 def bias_regularizer(self): 

330 return self.cell.bias_regularizer 

331 

332 @property 

333 def kernel_constraint(self): 

334 return self.cell.kernel_constraint 

335 

336 @property 

337 def recurrent_constraint(self): 

338 return self.cell.recurrent_constraint 

339 

340 @property 

341 def bias_constraint(self): 

342 return self.cell.bias_constraint 

343 

344 @property 

345 def dropout(self): 

346 return self.cell.dropout 

347 

348 @property 

349 def recurrent_dropout(self): 

350 return self.cell.recurrent_dropout 

351 

352 @property 

353 def implementation(self): 

354 return self.cell.implementation 

355 

356 @property 

357 def reset_after(self): 

358 return self.cell.reset_after 

359 

360 def get_config(self): 

361 config = { 

362 "units": self.units, 

363 "activation": activations.serialize(self.activation), 

364 "recurrent_activation": activations.serialize( 

365 self.recurrent_activation 

366 ), 

367 "use_bias": self.use_bias, 

368 "kernel_initializer": initializers.serialize( 

369 self.kernel_initializer 

370 ), 

371 "recurrent_initializer": initializers.serialize( 

372 self.recurrent_initializer 

373 ), 

374 "bias_initializer": initializers.serialize(self.bias_initializer), 

375 "kernel_regularizer": regularizers.serialize( 

376 self.kernel_regularizer 

377 ), 

378 "recurrent_regularizer": regularizers.serialize( 

379 self.recurrent_regularizer 

380 ), 

381 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

382 "activity_regularizer": regularizers.serialize( 

383 self.activity_regularizer 

384 ), 

385 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

386 "recurrent_constraint": constraints.serialize( 

387 self.recurrent_constraint 

388 ), 

389 "bias_constraint": constraints.serialize(self.bias_constraint), 

390 "dropout": self.dropout, 

391 "recurrent_dropout": self.recurrent_dropout, 

392 "implementation": self.implementation, 

393 "reset_after": self.reset_after, 

394 } 

395 config.update(rnn_utils.config_for_enable_caching_device(self.cell)) 

396 base_config = super().get_config() 

397 del base_config["cell"] 

398 return dict(list(base_config.items()) + list(config.items())) 

399 

400 @classmethod 

401 def from_config(cls, config): 

402 if "implementation" in config and config["implementation"] == 0: 

403 config["implementation"] = 1 

404 return cls(**config) 

405