Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/gru.py: 21%

339 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Gated Recurrent Unit layer.""" 

16 

17 

18import uuid 

19 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src import activations 

23from keras.src import backend 

24from keras.src import constraints 

25from keras.src import initializers 

26from keras.src import regularizers 

27from keras.src.engine import base_layer 

28from keras.src.engine.input_spec import InputSpec 

29from keras.src.layers.rnn import gru_lstm_utils 

30from keras.src.layers.rnn import rnn_utils 

31from keras.src.layers.rnn.base_rnn import RNN 

32from keras.src.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin 

33from keras.src.utils import tf_utils 

34 

35# isort: off 

36from tensorflow.python.platform import tf_logging as logging 

37from tensorflow.python.util.tf_export import keras_export 

38 

39RECURRENT_DROPOUT_WARNING_MSG = ( 

40 "RNN `implementation=2` is not supported when `recurrent_dropout` is set. " 

41 "Using `implementation=1`." 

42) 

43 

44 

45@keras_export("keras.layers.GRUCell", v1=[]) 

46class GRUCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer): 

47 """Cell class for the GRU layer. 

48 

49 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 

50 for details about the usage of RNN API. 

51 

52 This class processes one step within the whole time sequence input, whereas 

53 `tf.keras.layer.GRU` processes the whole sequence. 

54 

55 For example: 

56 

57 >>> inputs = tf.random.normal([32, 10, 8]) 

58 >>> rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(4)) 

59 >>> output = rnn(inputs) 

60 >>> print(output.shape) 

61 (32, 4) 

62 >>> rnn = tf.keras.layers.RNN( 

63 ... tf.keras.layers.GRUCell(4), 

64 ... return_sequences=True, 

65 ... return_state=True) 

66 >>> whole_sequence_output, final_state = rnn(inputs) 

67 >>> print(whole_sequence_output.shape) 

68 (32, 10, 4) 

69 >>> print(final_state.shape) 

70 (32, 4) 

71 

72 Args: 

73 units: Positive integer, dimensionality of the output space. 

74 activation: Activation function to use. Default: hyperbolic tangent 

75 (`tanh`). If you pass None, no activation is applied 

76 (ie. "linear" activation: `a(x) = x`). 

77 recurrent_activation: Activation function to use for the recurrent step. 

78 Default: sigmoid (`sigmoid`). If you pass `None`, no activation is 

79 applied (ie. "linear" activation: `a(x) = x`). 

80 use_bias: Boolean, (default `True`), whether the layer uses a bias vector. 

81 kernel_initializer: Initializer for the `kernel` weights matrix, 

82 used for the linear transformation of the inputs. Default: 

83 `glorot_uniform`. 

84 recurrent_initializer: Initializer for the `recurrent_kernel` 

85 weights matrix, used for the linear transformation of the recurrent 

86 state. Default: `orthogonal`. 

87 bias_initializer: Initializer for the bias vector. Default: `zeros`. 

88 kernel_regularizer: Regularizer function applied to the `kernel` weights 

89 matrix. Default: `None`. 

90 recurrent_regularizer: Regularizer function applied to the 

91 `recurrent_kernel` weights matrix. Default: `None`. 

92 bias_regularizer: Regularizer function applied to the bias vector. 

93 Default: `None`. 

94 kernel_constraint: Constraint function applied to the `kernel` weights 

95 matrix. Default: `None`. 

96 recurrent_constraint: Constraint function applied to the 

97 `recurrent_kernel` weights matrix. Default: `None`. 

98 bias_constraint: Constraint function applied to the bias vector. Default: 

99 `None`. 

100 dropout: Float between 0 and 1. Fraction of the units to drop for the 

101 linear transformation of the inputs. Default: 0. 

102 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop 

103 for the linear transformation of the recurrent state. Default: 0. 

104 reset_after: GRU convention (whether to apply reset gate after or 

105 before matrix multiplication). False = "before", 

106 True = "after" (default and cuDNN compatible). 

107 

108 Call arguments: 

109 inputs: A 2D tensor, with shape of `[batch, feature]`. 

110 states: A 2D tensor with shape of `[batch, units]`, which is the state 

111 from the previous time step. For timestep 0, the initial state provided 

112 by user will be feed to cell. 

113 training: Python boolean indicating whether the layer should behave in 

114 training mode or in inference mode. Only relevant when `dropout` or 

115 `recurrent_dropout` is used. 

116 """ 

117 

118 def __init__( 

119 self, 

120 units, 

121 activation="tanh", 

122 recurrent_activation="sigmoid", 

123 use_bias=True, 

124 kernel_initializer="glorot_uniform", 

125 recurrent_initializer="orthogonal", 

126 bias_initializer="zeros", 

127 kernel_regularizer=None, 

128 recurrent_regularizer=None, 

129 bias_regularizer=None, 

130 kernel_constraint=None, 

131 recurrent_constraint=None, 

132 bias_constraint=None, 

133 dropout=0.0, 

134 recurrent_dropout=0.0, 

135 reset_after=True, 

136 **kwargs, 

137 ): 

138 if units <= 0: 

139 raise ValueError( 

140 "Received an invalid value for argument `units`, " 

141 f"expected a positive integer, got {units}." 

142 ) 

143 # By default use cached variable under v2 mode, see b/143699808. 

144 if tf.compat.v1.executing_eagerly_outside_functions(): 

145 self._enable_caching_device = kwargs.pop( 

146 "enable_caching_device", True 

147 ) 

148 else: 

149 self._enable_caching_device = kwargs.pop( 

150 "enable_caching_device", False 

151 ) 

152 super().__init__(**kwargs) 

153 self.units = units 

154 self.activation = activations.get(activation) 

155 self.recurrent_activation = activations.get(recurrent_activation) 

156 self.use_bias = use_bias 

157 

158 self.kernel_initializer = initializers.get(kernel_initializer) 

159 self.recurrent_initializer = initializers.get(recurrent_initializer) 

160 self.bias_initializer = initializers.get(bias_initializer) 

161 

162 self.kernel_regularizer = regularizers.get(kernel_regularizer) 

163 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 

164 self.bias_regularizer = regularizers.get(bias_regularizer) 

165 

166 self.kernel_constraint = constraints.get(kernel_constraint) 

167 self.recurrent_constraint = constraints.get(recurrent_constraint) 

168 self.bias_constraint = constraints.get(bias_constraint) 

169 

170 self.dropout = min(1.0, max(0.0, dropout)) 

171 self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout)) 

172 

173 implementation = kwargs.pop("implementation", 2) 

174 if self.recurrent_dropout != 0 and implementation != 1: 

175 logging.debug(RECURRENT_DROPOUT_WARNING_MSG) 

176 self.implementation = 1 

177 else: 

178 self.implementation = implementation 

179 self.reset_after = reset_after 

180 self.state_size = self.units 

181 self.output_size = self.units 

182 

183 @tf_utils.shape_type_conversion 

184 def build(self, input_shape): 

185 super().build(input_shape) 

186 input_dim = input_shape[-1] 

187 default_caching_device = rnn_utils.caching_device(self) 

188 self.kernel = self.add_weight( 

189 shape=(input_dim, self.units * 3), 

190 name="kernel", 

191 initializer=self.kernel_initializer, 

192 regularizer=self.kernel_regularizer, 

193 constraint=self.kernel_constraint, 

194 caching_device=default_caching_device, 

195 ) 

196 self.recurrent_kernel = self.add_weight( 

197 shape=(self.units, self.units * 3), 

198 name="recurrent_kernel", 

199 initializer=self.recurrent_initializer, 

200 regularizer=self.recurrent_regularizer, 

201 constraint=self.recurrent_constraint, 

202 caching_device=default_caching_device, 

203 ) 

204 

205 if self.use_bias: 

206 if not self.reset_after: 

207 bias_shape = (3 * self.units,) 

208 else: 

209 # separate biases for input and recurrent kernels 

210 # Note: the shape is intentionally different from CuDNNGRU 

211 # biases `(2 * 3 * self.units,)`, so that we can distinguish the 

212 # classes when loading and converting saved weights. 

213 bias_shape = (2, 3 * self.units) 

214 self.bias = self.add_weight( 

215 shape=bias_shape, 

216 name="bias", 

217 initializer=self.bias_initializer, 

218 regularizer=self.bias_regularizer, 

219 constraint=self.bias_constraint, 

220 caching_device=default_caching_device, 

221 ) 

222 else: 

223 self.bias = None 

224 self.built = True 

225 

226 def call(self, inputs, states, training=None): 

227 h_tm1 = ( 

228 states[0] if tf.nest.is_nested(states) else states 

229 ) # previous memory 

230 

231 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3) 

232 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 

233 h_tm1, training, count=3 

234 ) 

235 

236 if self.use_bias: 

237 if not self.reset_after: 

238 input_bias, recurrent_bias = self.bias, None 

239 else: 

240 input_bias, recurrent_bias = tf.unstack(self.bias) 

241 

242 if self.implementation == 1: 

243 if 0.0 < self.dropout < 1.0: 

244 inputs_z = inputs * dp_mask[0] 

245 inputs_r = inputs * dp_mask[1] 

246 inputs_h = inputs * dp_mask[2] 

247 else: 

248 inputs_z = inputs 

249 inputs_r = inputs 

250 inputs_h = inputs 

251 

252 x_z = backend.dot(inputs_z, self.kernel[:, : self.units]) 

253 x_r = backend.dot( 

254 inputs_r, self.kernel[:, self.units : self.units * 2] 

255 ) 

256 x_h = backend.dot(inputs_h, self.kernel[:, self.units * 2 :]) 

257 

258 if self.use_bias: 

259 x_z = backend.bias_add(x_z, input_bias[: self.units]) 

260 x_r = backend.bias_add( 

261 x_r, input_bias[self.units : self.units * 2] 

262 ) 

263 x_h = backend.bias_add(x_h, input_bias[self.units * 2 :]) 

264 

265 if 0.0 < self.recurrent_dropout < 1.0: 

266 h_tm1_z = h_tm1 * rec_dp_mask[0] 

267 h_tm1_r = h_tm1 * rec_dp_mask[1] 

268 h_tm1_h = h_tm1 * rec_dp_mask[2] 

269 else: 

270 h_tm1_z = h_tm1 

271 h_tm1_r = h_tm1 

272 h_tm1_h = h_tm1 

273 

274 recurrent_z = backend.dot( 

275 h_tm1_z, self.recurrent_kernel[:, : self.units] 

276 ) 

277 recurrent_r = backend.dot( 

278 h_tm1_r, self.recurrent_kernel[:, self.units : self.units * 2] 

279 ) 

280 if self.reset_after and self.use_bias: 

281 recurrent_z = backend.bias_add( 

282 recurrent_z, recurrent_bias[: self.units] 

283 ) 

284 recurrent_r = backend.bias_add( 

285 recurrent_r, recurrent_bias[self.units : self.units * 2] 

286 ) 

287 

288 z = self.recurrent_activation(x_z + recurrent_z) 

289 r = self.recurrent_activation(x_r + recurrent_r) 

290 

291 # reset gate applied after/before matrix multiplication 

292 if self.reset_after: 

293 recurrent_h = backend.dot( 

294 h_tm1_h, self.recurrent_kernel[:, self.units * 2 :] 

295 ) 

296 if self.use_bias: 

297 recurrent_h = backend.bias_add( 

298 recurrent_h, recurrent_bias[self.units * 2 :] 

299 ) 

300 recurrent_h = r * recurrent_h 

301 else: 

302 recurrent_h = backend.dot( 

303 r * h_tm1_h, self.recurrent_kernel[:, self.units * 2 :] 

304 ) 

305 

306 hh = self.activation(x_h + recurrent_h) 

307 else: 

308 if 0.0 < self.dropout < 1.0: 

309 inputs = inputs * dp_mask[0] 

310 

311 # inputs projected by all gate matrices at once 

312 matrix_x = backend.dot(inputs, self.kernel) 

313 if self.use_bias: 

314 # biases: bias_z_i, bias_r_i, bias_h_i 

315 matrix_x = backend.bias_add(matrix_x, input_bias) 

316 

317 x_z, x_r, x_h = tf.split(matrix_x, 3, axis=-1) 

318 

319 if self.reset_after: 

320 # hidden state projected by all gate matrices at once 

321 matrix_inner = backend.dot(h_tm1, self.recurrent_kernel) 

322 if self.use_bias: 

323 matrix_inner = backend.bias_add( 

324 matrix_inner, recurrent_bias 

325 ) 

326 else: 

327 # hidden state projected separately for update/reset and new 

328 matrix_inner = backend.dot( 

329 h_tm1, self.recurrent_kernel[:, : 2 * self.units] 

330 ) 

331 

332 recurrent_z, recurrent_r, recurrent_h = tf.split( 

333 matrix_inner, [self.units, self.units, -1], axis=-1 

334 ) 

335 

336 z = self.recurrent_activation(x_z + recurrent_z) 

337 r = self.recurrent_activation(x_r + recurrent_r) 

338 

339 if self.reset_after: 

340 recurrent_h = r * recurrent_h 

341 else: 

342 recurrent_h = backend.dot( 

343 r * h_tm1, self.recurrent_kernel[:, 2 * self.units :] 

344 ) 

345 

346 hh = self.activation(x_h + recurrent_h) 

347 # previous and candidate state mixed by update gate 

348 h = z * h_tm1 + (1 - z) * hh 

349 new_state = [h] if tf.nest.is_nested(states) else h 

350 return h, new_state 

351 

352 def get_config(self): 

353 config = { 

354 "units": self.units, 

355 "activation": activations.serialize(self.activation), 

356 "recurrent_activation": activations.serialize( 

357 self.recurrent_activation 

358 ), 

359 "use_bias": self.use_bias, 

360 "kernel_initializer": initializers.serialize( 

361 self.kernel_initializer 

362 ), 

363 "recurrent_initializer": initializers.serialize( 

364 self.recurrent_initializer 

365 ), 

366 "bias_initializer": initializers.serialize(self.bias_initializer), 

367 "kernel_regularizer": regularizers.serialize( 

368 self.kernel_regularizer 

369 ), 

370 "recurrent_regularizer": regularizers.serialize( 

371 self.recurrent_regularizer 

372 ), 

373 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

374 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

375 "recurrent_constraint": constraints.serialize( 

376 self.recurrent_constraint 

377 ), 

378 "bias_constraint": constraints.serialize(self.bias_constraint), 

379 "dropout": self.dropout, 

380 "recurrent_dropout": self.recurrent_dropout, 

381 "implementation": self.implementation, 

382 "reset_after": self.reset_after, 

383 } 

384 config.update(rnn_utils.config_for_enable_caching_device(self)) 

385 base_config = super().get_config() 

386 return dict(list(base_config.items()) + list(config.items())) 

387 

388 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

389 return rnn_utils.generate_zero_filled_state_for_cell( 

390 self, inputs, batch_size, dtype 

391 ) 

392 

393 

394@keras_export("keras.layers.GRU", v1=[]) 

395class GRU(DropoutRNNCellMixin, RNN, base_layer.BaseRandomLayer): 

396 """Gated Recurrent Unit - Cho et al. 2014. 

397 

398 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 

399 for details about the usage of RNN API. 

400 

401 Based on available runtime hardware and constraints, this layer 

402 will choose different implementations (cuDNN-based or pure-TensorFlow) 

403 to maximize the performance. If a GPU is available and all 

404 the arguments to the layer meet the requirement of the cuDNN kernel 

405 (see below for details), the layer will use a fast cuDNN implementation. 

406 

407 The requirements to use the cuDNN implementation are: 

408 

409 1. `activation` == `tanh` 

410 2. `recurrent_activation` == `sigmoid` 

411 3. `recurrent_dropout` == 0 

412 4. `unroll` is `False` 

413 5. `use_bias` is `True` 

414 6. `reset_after` is `True` 

415 7. Inputs, if use masking, are strictly right-padded. 

416 8. Eager execution is enabled in the outermost context. 

417 

418 There are two variants of the GRU implementation. The default one is based 

419 on [v3](https://arxiv.org/abs/1406.1078v3) and has reset gate applied to 

420 hidden state before matrix multiplication. The other one is based on 

421 [original](https://arxiv.org/abs/1406.1078v1) and has the order reversed. 

422 

423 The second variant is compatible with CuDNNGRU (GPU-only) and allows 

424 inference on CPU. Thus it has separate biases for `kernel` and 

425 `recurrent_kernel`. To use this variant, set `reset_after=True` and 

426 `recurrent_activation='sigmoid'`. 

427 

428 For example: 

429 

430 >>> inputs = tf.random.normal([32, 10, 8]) 

431 >>> gru = tf.keras.layers.GRU(4) 

432 >>> output = gru(inputs) 

433 >>> print(output.shape) 

434 (32, 4) 

435 >>> gru = tf.keras.layers.GRU(4, return_sequences=True, return_state=True) 

436 >>> whole_sequence_output, final_state = gru(inputs) 

437 >>> print(whole_sequence_output.shape) 

438 (32, 10, 4) 

439 >>> print(final_state.shape) 

440 (32, 4) 

441 

442 Args: 

443 units: Positive integer, dimensionality of the output space. 

444 activation: Activation function to use. 

445 Default: hyperbolic tangent (`tanh`). 

446 If you pass `None`, no activation is applied 

447 (ie. "linear" activation: `a(x) = x`). 

448 recurrent_activation: Activation function to use 

449 for the recurrent step. 

450 Default: sigmoid (`sigmoid`). 

451 If you pass `None`, no activation is applied 

452 (ie. "linear" activation: `a(x) = x`). 

453 use_bias: Boolean, (default `True`), whether the layer uses a bias vector. 

454 kernel_initializer: Initializer for the `kernel` weights matrix, 

455 used for the linear transformation of the inputs. Default: 

456 `glorot_uniform`. 

457 recurrent_initializer: Initializer for the `recurrent_kernel` 

458 weights matrix, used for the linear transformation of the recurrent 

459 state. Default: `orthogonal`. 

460 bias_initializer: Initializer for the bias vector. Default: `zeros`. 

461 kernel_regularizer: Regularizer function applied to the `kernel` weights 

462 matrix. Default: `None`. 

463 recurrent_regularizer: Regularizer function applied to the 

464 `recurrent_kernel` weights matrix. Default: `None`. 

465 bias_regularizer: Regularizer function applied to the bias vector. 

466 Default: `None`. 

467 activity_regularizer: Regularizer function applied to the output of the 

468 layer (its "activation"). Default: `None`. 

469 kernel_constraint: Constraint function applied to the `kernel` weights 

470 matrix. Default: `None`. 

471 recurrent_constraint: Constraint function applied to the 

472 `recurrent_kernel` weights matrix. Default: `None`. 

473 bias_constraint: Constraint function applied to the bias vector. Default: 

474 `None`. 

475 dropout: Float between 0 and 1. Fraction of the units to drop for the 

476 linear transformation of the inputs. Default: 0. 

477 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop 

478 for the linear transformation of the recurrent state. Default: 0. 

479 return_sequences: Boolean. Whether to return the last output 

480 in the output sequence, or the full sequence. Default: `False`. 

481 return_state: Boolean. Whether to return the last state in addition to the 

482 output. Default: `False`. 

483 go_backwards: Boolean (default `False`). 

484 If True, process the input sequence backwards and return the 

485 reversed sequence. 

486 stateful: Boolean (default False). If True, the last state 

487 for each sample at index i in a batch will be used as initial 

488 state for the sample of index i in the following batch. 

489 unroll: Boolean (default False). 

490 If True, the network will be unrolled, 

491 else a symbolic loop will be used. 

492 Unrolling can speed-up a RNN, 

493 although it tends to be more memory-intensive. 

494 Unrolling is only suitable for short sequences. 

495 time_major: The shape format of the `inputs` and `outputs` tensors. 

496 If True, the inputs and outputs will be in shape 

497 `[timesteps, batch, feature]`, whereas in the False case, it will be 

498 `[batch, timesteps, feature]`. Using `time_major = True` is a bit more 

499 efficient because it avoids transposes at the beginning and end of the 

500 RNN calculation. However, most TensorFlow data is batch-major, so by 

501 default this function accepts input and emits output in batch-major 

502 form. 

503 reset_after: GRU convention (whether to apply reset gate after or 

504 before matrix multiplication). False = "before", 

505 True = "after" (default and cuDNN compatible). 

506 

507 Call arguments: 

508 inputs: A 3D tensor, with shape `[batch, timesteps, feature]`. 

509 mask: Binary tensor of shape `[samples, timesteps]` indicating whether 

510 a given timestep should be masked (optional). 

511 An individual `True` entry indicates that the corresponding timestep 

512 should be utilized, while a `False` entry indicates that the 

513 corresponding timestep should be ignored. Defaults to `None`. 

514 training: Python boolean indicating whether the layer should behave in 

515 training mode or in inference mode. This argument is passed to the cell 

516 when calling it. This is only relevant if `dropout` or 

517 `recurrent_dropout` is used (optional). Defaults to `None`. 

518 initial_state: List of initial state tensors to be passed to the first 

519 call of the cell (optional, `None` causes creation 

520 of zero-filled initial state tensors). Defaults to `None`. 

521 """ 

522 

523 def __init__( 

524 self, 

525 units, 

526 activation="tanh", 

527 recurrent_activation="sigmoid", 

528 use_bias=True, 

529 kernel_initializer="glorot_uniform", 

530 recurrent_initializer="orthogonal", 

531 bias_initializer="zeros", 

532 kernel_regularizer=None, 

533 recurrent_regularizer=None, 

534 bias_regularizer=None, 

535 activity_regularizer=None, 

536 kernel_constraint=None, 

537 recurrent_constraint=None, 

538 bias_constraint=None, 

539 dropout=0.0, 

540 recurrent_dropout=0.0, 

541 return_sequences=False, 

542 return_state=False, 

543 go_backwards=False, 

544 stateful=False, 

545 unroll=False, 

546 time_major=False, 

547 reset_after=True, 

548 **kwargs, 

549 ): 

550 # return_runtime is a flag for testing, which shows the real backend 

551 # implementation chosen by grappler in graph mode. 

552 self._return_runtime = kwargs.pop("return_runtime", False) 

553 implementation = kwargs.pop("implementation", 2) 

554 if implementation == 0: 

555 logging.warning( 

556 "`implementation=0` has been deprecated, " 

557 "and now defaults to `implementation=2`." 

558 "Please update your layer call." 

559 ) 

560 if "enable_caching_device" in kwargs: 

561 cell_kwargs = { 

562 "enable_caching_device": kwargs.pop("enable_caching_device") 

563 } 

564 else: 

565 cell_kwargs = {} 

566 cell = GRUCell( 

567 units, 

568 activation=activation, 

569 recurrent_activation=recurrent_activation, 

570 use_bias=use_bias, 

571 kernel_initializer=kernel_initializer, 

572 recurrent_initializer=recurrent_initializer, 

573 bias_initializer=bias_initializer, 

574 kernel_regularizer=kernel_regularizer, 

575 recurrent_regularizer=recurrent_regularizer, 

576 bias_regularizer=bias_regularizer, 

577 kernel_constraint=kernel_constraint, 

578 recurrent_constraint=recurrent_constraint, 

579 bias_constraint=bias_constraint, 

580 dropout=dropout, 

581 recurrent_dropout=recurrent_dropout, 

582 implementation=implementation, 

583 reset_after=reset_after, 

584 dtype=kwargs.get("dtype"), 

585 trainable=kwargs.get("trainable", True), 

586 name="gru_cell", 

587 **cell_kwargs, 

588 ) 

589 super().__init__( 

590 cell, 

591 return_sequences=return_sequences, 

592 return_state=return_state, 

593 go_backwards=go_backwards, 

594 stateful=stateful, 

595 unroll=unroll, 

596 time_major=time_major, 

597 **kwargs, 

598 ) 

599 self.activity_regularizer = regularizers.get(activity_regularizer) 

600 self.input_spec = [InputSpec(ndim=3)] 

601 

602 # GPU kernel uses following setting by default and not configurable. 

603 self._could_use_gpu_kernel = ( 

604 self.activation in (activations.tanh, tf.tanh) 

605 and self.recurrent_activation in (activations.sigmoid, tf.sigmoid) 

606 and recurrent_dropout == 0 

607 and not unroll 

608 and use_bias 

609 and reset_after 

610 and tf.compat.v1.executing_eagerly_outside_functions() 

611 ) 

612 if tf.config.list_logical_devices("GPU"): 

613 # Only show the message when there is GPU available, user will not 

614 # care about the cuDNN if there isn't any GPU. 

615 if self._could_use_gpu_kernel: 

616 logging.debug(gru_lstm_utils.CUDNN_AVAILABLE_MSG % self.name) 

617 else: 

618 logging.warning( 

619 gru_lstm_utils.CUDNN_NOT_AVAILABLE_MSG % self.name 

620 ) 

621 

622 if gru_lstm_utils.use_new_gru_lstm_impl(): 

623 self._defun_wrapper = gru_lstm_utils.DefunWrapper( 

624 time_major, go_backwards, "gru" 

625 ) 

626 

627 def call(self, inputs, mask=None, training=None, initial_state=None): 

628 # The input should be dense, padded with zeros. If a ragged input is fed 

629 # into the layer, it is padded and the row lengths are used for masking. 

630 inputs, row_lengths = backend.convert_inputs_if_ragged(inputs) 

631 is_ragged_input = row_lengths is not None 

632 self._validate_args_if_ragged(is_ragged_input, mask) 

633 

634 # GRU does not support constants. Ignore it during process. 

635 inputs, initial_state, _ = self._process_inputs( 

636 inputs, initial_state, None 

637 ) 

638 

639 if isinstance(mask, list): 

640 mask = mask[0] 

641 

642 input_shape = backend.int_shape(inputs) 

643 timesteps = input_shape[0] if self.time_major else input_shape[1] 

644 

645 if not self._could_use_gpu_kernel: 

646 kwargs = {"training": training} 

647 self._maybe_reset_cell_dropout_mask(self.cell) 

648 

649 def step(cell_inputs, cell_states): 

650 return self.cell(cell_inputs, cell_states, **kwargs) 

651 

652 last_output, outputs, states = backend.rnn( 

653 step, 

654 inputs, 

655 initial_state, 

656 constants=None, 

657 go_backwards=self.go_backwards, 

658 mask=mask, 

659 unroll=self.unroll, 

660 input_length=row_lengths 

661 if row_lengths is not None 

662 else timesteps, 

663 time_major=self.time_major, 

664 zero_output_for_mask=self.zero_output_for_mask, 

665 return_all_outputs=self.return_sequences, 

666 ) 

667 # This is a dummy tensor for testing purpose. 

668 runtime = gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_UNKNOWN) 

669 else: 

670 last_output, outputs, runtime, states = self._defun_gru_call( 

671 inputs, initial_state, training, mask, row_lengths 

672 ) 

673 

674 if self.stateful: 

675 updates = [ 

676 tf.compat.v1.assign( 

677 self.states[0], tf.cast(states[0], self.states[0].dtype) 

678 ) 

679 ] 

680 self.add_update(updates) 

681 

682 if self.return_sequences: 

683 output = backend.maybe_convert_to_ragged( 

684 is_ragged_input, 

685 outputs, 

686 row_lengths, 

687 go_backwards=self.go_backwards, 

688 ) 

689 else: 

690 output = last_output 

691 

692 if self.return_state: 

693 return [output] + list(states) 

694 elif self._return_runtime: 

695 return output, runtime 

696 else: 

697 return output 

698 

699 @property 

700 def units(self): 

701 return self.cell.units 

702 

703 @property 

704 def activation(self): 

705 return self.cell.activation 

706 

707 @property 

708 def recurrent_activation(self): 

709 return self.cell.recurrent_activation 

710 

711 @property 

712 def use_bias(self): 

713 return self.cell.use_bias 

714 

715 @property 

716 def kernel_initializer(self): 

717 return self.cell.kernel_initializer 

718 

719 @property 

720 def recurrent_initializer(self): 

721 return self.cell.recurrent_initializer 

722 

723 @property 

724 def bias_initializer(self): 

725 return self.cell.bias_initializer 

726 

727 @property 

728 def kernel_regularizer(self): 

729 return self.cell.kernel_regularizer 

730 

731 @property 

732 def recurrent_regularizer(self): 

733 return self.cell.recurrent_regularizer 

734 

735 @property 

736 def bias_regularizer(self): 

737 return self.cell.bias_regularizer 

738 

739 @property 

740 def kernel_constraint(self): 

741 return self.cell.kernel_constraint 

742 

743 @property 

744 def recurrent_constraint(self): 

745 return self.cell.recurrent_constraint 

746 

747 @property 

748 def bias_constraint(self): 

749 return self.cell.bias_constraint 

750 

751 @property 

752 def dropout(self): 

753 return self.cell.dropout 

754 

755 @property 

756 def recurrent_dropout(self): 

757 return self.cell.recurrent_dropout 

758 

759 @property 

760 def implementation(self): 

761 return self.cell.implementation 

762 

763 @property 

764 def reset_after(self): 

765 return self.cell.reset_after 

766 

767 def get_config(self): 

768 config = { 

769 "units": self.units, 

770 "activation": activations.serialize(self.activation), 

771 "recurrent_activation": activations.serialize( 

772 self.recurrent_activation 

773 ), 

774 "use_bias": self.use_bias, 

775 "kernel_initializer": initializers.serialize( 

776 self.kernel_initializer 

777 ), 

778 "recurrent_initializer": initializers.serialize( 

779 self.recurrent_initializer 

780 ), 

781 "bias_initializer": initializers.serialize(self.bias_initializer), 

782 "kernel_regularizer": regularizers.serialize( 

783 self.kernel_regularizer 

784 ), 

785 "recurrent_regularizer": regularizers.serialize( 

786 self.recurrent_regularizer 

787 ), 

788 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

789 "activity_regularizer": regularizers.serialize( 

790 self.activity_regularizer 

791 ), 

792 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

793 "recurrent_constraint": constraints.serialize( 

794 self.recurrent_constraint 

795 ), 

796 "bias_constraint": constraints.serialize(self.bias_constraint), 

797 "dropout": self.dropout, 

798 "recurrent_dropout": self.recurrent_dropout, 

799 "implementation": self.implementation, 

800 "reset_after": self.reset_after, 

801 } 

802 config.update(rnn_utils.config_for_enable_caching_device(self.cell)) 

803 base_config = super().get_config() 

804 del base_config["cell"] 

805 return dict(list(base_config.items()) + list(config.items())) 

806 

807 @classmethod 

808 def from_config(cls, config): 

809 if "implementation" in config and config["implementation"] == 0: 

810 config["implementation"] = 1 

811 return cls(**config) 

812 

813 def _defun_gru_call( 

814 self, inputs, initial_state, training, mask, sequence_lengths 

815 ): 

816 # Use the new defun approach for backend implementation swap. 

817 # Note that different implementations need to have same function 

818 # signature, eg, the tensor parameters need to have same shape and 

819 # dtypes. 

820 

821 self.reset_dropout_mask() 

822 dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=3) 

823 if dropout_mask is not None: 

824 inputs = inputs * dropout_mask[0] 

825 

826 if gru_lstm_utils.use_new_gru_lstm_impl(): 

827 gru_kwargs = { 

828 "inputs": inputs, 

829 "init_h": gru_lstm_utils.read_variable_value(initial_state[0]), 

830 "kernel": gru_lstm_utils.read_variable_value(self.cell.kernel), 

831 "recurrent_kernel": gru_lstm_utils.read_variable_value( 

832 self.cell.recurrent_kernel 

833 ), 

834 "bias": gru_lstm_utils.read_variable_value(self.cell.bias), 

835 "mask": mask, 

836 "time_major": self.time_major, 

837 "go_backwards": self.go_backwards, 

838 "sequence_lengths": sequence_lengths, 

839 "zero_output_for_mask": self.zero_output_for_mask, 

840 } 

841 ( 

842 last_output, 

843 outputs, 

844 new_h, 

845 runtime, 

846 ) = self._defun_wrapper.defun_layer(**gru_kwargs) 

847 else: 

848 gpu_gru_kwargs = { 

849 "inputs": inputs, 

850 "init_h": gru_lstm_utils.read_variable_value(initial_state[0]), 

851 "kernel": gru_lstm_utils.read_variable_value(self.cell.kernel), 

852 "recurrent_kernel": gru_lstm_utils.read_variable_value( 

853 self.cell.recurrent_kernel 

854 ), 

855 "bias": gru_lstm_utils.read_variable_value(self.cell.bias), 

856 "mask": mask, 

857 "time_major": self.time_major, 

858 "go_backwards": self.go_backwards, 

859 "sequence_lengths": sequence_lengths, 

860 "return_sequences": self.return_sequences, 

861 } 

862 normal_gru_kwargs = gpu_gru_kwargs.copy() 

863 normal_gru_kwargs.update( 

864 { 

865 "zero_output_for_mask": self.zero_output_for_mask, 

866 } 

867 ) 

868 

869 if tf.executing_eagerly(): 

870 device_type = gru_lstm_utils.get_context_device_type() 

871 can_use_gpu = ( 

872 # Either user specified GPU or unspecified but GPU is 

873 # available. 

874 ( 

875 device_type == gru_lstm_utils.GPU_DEVICE_NAME 

876 or ( 

877 device_type is None 

878 and tf.config.list_logical_devices("GPU") 

879 ) 

880 ) 

881 and ( 

882 gru_lstm_utils.is_cudnn_supported_inputs( 

883 mask, self.time_major, sequence_lengths 

884 ) 

885 ) 

886 ) 

887 # Under eager context, check the device placement and prefer the 

888 if can_use_gpu: 

889 last_output, outputs, new_h, runtime = gpu_gru( 

890 **gpu_gru_kwargs 

891 ) 

892 else: 

893 last_output, outputs, new_h, runtime = standard_gru( 

894 **normal_gru_kwargs 

895 ) 

896 else: 

897 ( 

898 last_output, 

899 outputs, 

900 new_h, 

901 runtime, 

902 ) = gru_with_backend_selection(**normal_gru_kwargs) 

903 

904 states = [new_h] 

905 return last_output, outputs, runtime, states 

906 

907 

908def standard_gru( 

909 inputs, 

910 init_h, 

911 kernel, 

912 recurrent_kernel, 

913 bias, 

914 mask, 

915 time_major, 

916 go_backwards, 

917 sequence_lengths, 

918 zero_output_for_mask, 

919 return_sequences, 

920): 

921 """GRU with standard kernel implementation. 

922 

923 This implementation can be run on all types of hardware. 

924 

925 This implementation lifts out all the layer weights and make them function 

926 parameters. It has same number of tensor input params as the cuDNN 

927 counterpart. The RNN step logic has been simplified, eg dropout and mask is 

928 removed since cuDNN implementation does not support that. 

929 

930 Args: 

931 inputs: Input tensor of GRU layer. 

932 init_h: Initial state tensor for the cell output. 

933 kernel: Weights for cell kernel. 

934 recurrent_kernel: Weights for cell recurrent kernel. 

935 bias: Weights for cell kernel bias and recurrent bias. The bias contains 

936 the combined input_bias and recurrent_bias. 

937 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 

938 a given timestep should be masked. An individual `True` entry indicates 

939 that the corresponding timestep should be utilized, while a `False` 

940 entry indicates that the corresponding timestep should be ignored. 

941 time_major: Boolean, whether the inputs are in the format of 

942 [time, batch, feature] or [batch, time, feature]. 

943 go_backwards: Boolean (default False). If True, process the input sequence 

944 backwards and return the reversed sequence. 

945 sequence_lengths: The lengths of all sequences coming from a variable 

946 length input, such as ragged tensors. If the input has a fixed timestep 

947 size, this should be None. 

948 zero_output_for_mask: Boolean, whether to output zero for masked timestep. 

949 return_sequences: Boolean. If True, return the recurrent outputs for all 

950 timesteps in the sequence. If False, only return the output for the 

951 last timestep (which consumes less memory). 

952 

953 Returns: 

954 last_output: output tensor for the last timestep, which has shape 

955 [batch, units]. 

956 outputs: 

957 - If `return_sequences=True`: output tensor for all timesteps, 

958 which has shape [batch, time, units]. 

959 - Else, a tensor equal to `last_output` with shape [batch, 1, units] 

960 state_0: the cell output, which has same shape as init_h. 

961 runtime: constant string tensor which indicate real runtime hardware. This 

962 value is for testing purpose and should be used by user. 

963 """ 

964 input_shape = backend.int_shape(inputs) 

965 timesteps = input_shape[0] if time_major else input_shape[1] 

966 

967 input_bias, recurrent_bias = tf.unstack(bias) 

968 

969 def step(cell_inputs, cell_states): 

970 """Step function that will be used by Keras RNN backend.""" 

971 h_tm1 = cell_states[0] 

972 

973 # inputs projected by all gate matrices at once 

974 matrix_x = backend.dot(cell_inputs, kernel) 

975 matrix_x = backend.bias_add(matrix_x, input_bias) 

976 

977 x_z, x_r, x_h = tf.split(matrix_x, 3, axis=1) 

978 

979 # hidden state projected by all gate matrices at once 

980 matrix_inner = backend.dot(h_tm1, recurrent_kernel) 

981 matrix_inner = backend.bias_add(matrix_inner, recurrent_bias) 

982 

983 recurrent_z, recurrent_r, recurrent_h = tf.split( 

984 matrix_inner, 3, axis=1 

985 ) 

986 z = tf.sigmoid(x_z + recurrent_z) 

987 r = tf.sigmoid(x_r + recurrent_r) 

988 hh = tf.tanh(x_h + r * recurrent_h) 

989 

990 # previous and candidate state mixed by update gate 

991 h = z * h_tm1 + (1 - z) * hh 

992 return h, [h] 

993 

994 last_output, outputs, new_states = backend.rnn( 

995 step, 

996 inputs, 

997 [init_h], 

998 constants=None, 

999 unroll=False, 

1000 time_major=time_major, 

1001 mask=mask, 

1002 go_backwards=go_backwards, 

1003 input_length=sequence_lengths 

1004 if sequence_lengths is not None 

1005 else timesteps, 

1006 zero_output_for_mask=zero_output_for_mask, 

1007 return_all_outputs=return_sequences, 

1008 ) 

1009 return ( 

1010 last_output, 

1011 outputs, 

1012 new_states[0], 

1013 gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_CPU), 

1014 ) 

1015 

1016 

1017def gpu_gru( 

1018 inputs, 

1019 init_h, 

1020 kernel, 

1021 recurrent_kernel, 

1022 bias, 

1023 mask, 

1024 time_major, 

1025 go_backwards, 

1026 sequence_lengths, 

1027 return_sequences, 

1028): 

1029 """GRU with cuDNN implementation which is only available for GPU.""" 

1030 if mask is not None: 

1031 sequence_lengths = gru_lstm_utils.calculate_sequence_by_mask( 

1032 mask, time_major 

1033 ) 

1034 

1035 if not time_major and sequence_lengths is None: 

1036 inputs = tf.transpose(inputs, perm=(1, 0, 2)) 

1037 seq_axis, batch_axis = (0, 1) 

1038 else: 

1039 seq_axis, batch_axis = (0, 1) if time_major else (1, 0) 

1040 # For init_h, cuDNN expects one more dim of num_layers before or after batch 

1041 # dim for time major or batch major inputs respectively 

1042 init_h = tf.expand_dims(init_h, axis=seq_axis) 

1043 

1044 weights = tf.split(kernel, 3, axis=1) 

1045 weights += tf.split(recurrent_kernel, 3, axis=1) 

1046 # Note that the bias was initialized as shape (2, 3 * units), flat it into 

1047 # (6 * units) 

1048 bias = tf.split(backend.flatten(bias), 6) 

1049 

1050 if tf.sysconfig.get_build_info()["is_cuda_build"]: 

1051 # Note that the gate order for cuDNN is different from the canonical 

1052 # format. canonical format is [z, r, h], whereas cuDNN is [r, z, h]. 

1053 # The swap need to be done for kernel, recurrent_kernel, input_bias, 

1054 # recurrent_bias. 

1055 # z is update gate weights. 

1056 # r is reset gate weights. 

1057 # h is output gate weights. 

1058 weights[0], weights[1] = weights[1], weights[0] 

1059 weights[3], weights[4] = weights[4], weights[3] 

1060 bias[0], bias[1] = bias[1], bias[0] 

1061 bias[3], bias[4] = bias[4], bias[3] 

1062 

1063 params = gru_lstm_utils.canonical_to_params( 

1064 weights=weights, 

1065 biases=bias, 

1066 shape=tf.constant([-1]), 

1067 transpose_weights=True, 

1068 ) 

1069 

1070 if sequence_lengths is not None: 

1071 if go_backwards: 

1072 # Three reversals are required. E.g., 

1073 # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked 

1074 # reversed_input_to_cudnn = [3, 2, 1, 0, 0] 

1075 # output_from_cudnn = [6, 5, 4, 0, 0] 

1076 # expected_output = [0, 0, 6, 5 ,4] 

1077 inputs = tf.reverse_sequence( 

1078 inputs, 

1079 sequence_lengths, 

1080 seq_axis=seq_axis, 

1081 batch_axis=batch_axis, 

1082 ) 

1083 outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3( 

1084 input=inputs, 

1085 input_h=init_h, 

1086 input_c=0, 

1087 params=params, 

1088 is_training=True, 

1089 rnn_mode="gru", 

1090 sequence_lengths=sequence_lengths, 

1091 time_major=time_major, 

1092 ) 

1093 if go_backwards: 

1094 outputs = tf.reverse_sequence( 

1095 outputs, 

1096 sequence_lengths, 

1097 seq_axis=seq_axis, 

1098 batch_axis=batch_axis, 

1099 ) 

1100 outputs = tf.reverse(outputs, axis=[seq_axis]) 

1101 else: 

1102 if go_backwards: 

1103 # Reverse axis 0 since the input is already convert to time major. 

1104 inputs = tf.reverse(inputs, axis=[0]) 

1105 outputs, h, _, _ = tf.raw_ops.CudnnRNN( 

1106 input=inputs, 

1107 input_h=init_h, 

1108 input_c=0, 

1109 params=params, 

1110 is_training=True, 

1111 rnn_mode="gru", 

1112 ) 

1113 

1114 last_output = outputs[-1] 

1115 if not time_major and sequence_lengths is None and return_sequences: 

1116 outputs = tf.transpose(outputs, perm=[1, 0, 2]) 

1117 h = tf.squeeze(h, axis=seq_axis) 

1118 

1119 # In the case of variable length input, the cudnn kernel will fill zeros for 

1120 # the output, whereas the default keras behavior is to bring over the 

1121 # previous output for t-1, so that in the return_sequence=False case, user 

1122 # can quickly get the final effect output instead just 0s at the last 

1123 # timestep. In order to mimic the default keras behavior, we copy the final 

1124 # h state as the last_output, since it is numerically same as the output. 

1125 if sequence_lengths is not None: 

1126 last_output = h 

1127 

1128 # Match CPU return format 

1129 if not return_sequences: 

1130 outputs = tf.expand_dims(last_output, axis=0 if time_major else 1) 

1131 

1132 return ( 

1133 last_output, 

1134 outputs, 

1135 h, 

1136 gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_GPU), 

1137 ) 

1138 

1139 

1140def gru_with_backend_selection( 

1141 inputs, 

1142 init_h, 

1143 kernel, 

1144 recurrent_kernel, 

1145 bias, 

1146 mask, 

1147 time_major, 

1148 go_backwards, 

1149 sequence_lengths, 

1150 zero_output_for_mask, 

1151 return_sequences, 

1152): 

1153 """Call the GRU with optimized backend kernel selection. 

1154 

1155 Under the hood, this function will create two TF function, one with the most 

1156 generic kernel and can run on all device condition, and the second one with 

1157 cuDNN specific kernel, which can only run on GPU. 

1158 

1159 The first function will be called with normal_lstm_params, while the second 

1160 function is not called, but only registered in the graph. The Grappler will 

1161 do the proper graph rewrite and swap the optimized TF function based on the 

1162 device placement. 

1163 

1164 Args: 

1165 inputs: Input tensor of GRU layer. 

1166 init_h: Initial state tensor for the cell output. 

1167 kernel: Weights for cell kernel. 

1168 recurrent_kernel: Weights for cell recurrent kernel. 

1169 bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias 

1170 is used in this case. 

1171 mask: Boolean tensor for mask out the steps within sequence. 

1172 An individual `True` entry indicates that the corresponding timestep 

1173 should be utilized, while a `False` entry indicates that the 

1174 corresponding timestep should be ignored. 

1175 time_major: Boolean, whether the inputs are in the format of 

1176 [time, batch, feature] or [batch, time, feature]. 

1177 go_backwards: Boolean (default False). If True, process the input sequence 

1178 backwards and return the reversed sequence. 

1179 sequence_lengths: The lengths of all sequences coming from a variable 

1180 length input, such as ragged tensors. If the input has a fixed timestep 

1181 size, this should be None. 

1182 zero_output_for_mask: Boolean, whether to output zero for masked timestep. 

1183 return_sequences: Boolean. If True, return the recurrent outputs for all 

1184 timesteps in the sequence. If False, only return the output for the 

1185 last timestep (which consumes less memory). 

1186 

1187 Returns: 

1188 List of output tensors, same as standard_gru. 

1189 """ 

1190 params = { 

1191 "inputs": inputs, 

1192 "init_h": init_h, 

1193 "kernel": kernel, 

1194 "recurrent_kernel": recurrent_kernel, 

1195 "bias": bias, 

1196 "mask": mask, 

1197 "time_major": time_major, 

1198 "go_backwards": go_backwards, 

1199 "sequence_lengths": sequence_lengths, 

1200 "zero_output_for_mask": zero_output_for_mask, 

1201 "return_sequences": return_sequences, 

1202 } 

1203 

1204 def gpu_gru_with_fallback( 

1205 inputs, 

1206 init_h, 

1207 kernel, 

1208 recurrent_kernel, 

1209 bias, 

1210 mask, 

1211 time_major, 

1212 go_backwards, 

1213 sequence_lengths, 

1214 zero_output_for_mask, 

1215 return_sequences, 

1216 ): 

1217 """Use cuDNN kernel when mask is none or strictly right padded.""" 

1218 

1219 def cudnn_gru_fn(): 

1220 return gpu_gru( 

1221 inputs=inputs, 

1222 init_h=init_h, 

1223 kernel=kernel, 

1224 recurrent_kernel=recurrent_kernel, 

1225 bias=bias, 

1226 mask=mask, 

1227 time_major=time_major, 

1228 go_backwards=go_backwards, 

1229 sequence_lengths=sequence_lengths, 

1230 return_sequences=return_sequences, 

1231 ) 

1232 

1233 def standard_gru_fn(): 

1234 return standard_gru( 

1235 inputs=inputs, 

1236 init_h=init_h, 

1237 kernel=kernel, 

1238 recurrent_kernel=recurrent_kernel, 

1239 bias=bias, 

1240 mask=mask, 

1241 time_major=time_major, 

1242 go_backwards=go_backwards, 

1243 sequence_lengths=sequence_lengths, 

1244 zero_output_for_mask=zero_output_for_mask, 

1245 return_sequences=return_sequences, 

1246 ) 

1247 

1248 return tf.__internal__.smart_cond.smart_cond( 

1249 gru_lstm_utils.is_cudnn_supported_inputs( 

1250 mask, time_major, sequence_lengths 

1251 ), 

1252 true_fn=cudnn_gru_fn, 

1253 false_fn=standard_gru_fn, 

1254 ) 

1255 

1256 if gru_lstm_utils.use_new_gru_lstm_impl(): 

1257 # Chooses the implementation dynamically based on the running device. 

1258 ( 

1259 last_output, 

1260 outputs, 

1261 new_h, 

1262 runtime, 

1263 ) = tf.__internal__.execute_fn_for_device( 

1264 { 

1265 gru_lstm_utils.CPU_DEVICE_NAME: lambda: standard_gru(**params), 

1266 gru_lstm_utils.GPU_DEVICE_NAME: lambda: gpu_gru_with_fallback( 

1267 **params 

1268 ), 

1269 }, 

1270 lambda: standard_gru(**params), 

1271 ) 

1272 else: 

1273 # Each time a `tf.function` is called, we will give it a unique 

1274 # identifiable API name, so that Grappler won't get confused when it 

1275 # sees multiple GRU layers added into same graph, and it will be able 

1276 # to pair up the different implementations across them. 

1277 api_name = "gru_" + str(uuid.uuid4()) 

1278 supportive_attribute = { 

1279 "time_major": time_major, 

1280 "go_backwards": go_backwards, 

1281 } 

1282 defun_standard_gru = gru_lstm_utils.generate_defun_backend( 

1283 api_name, 

1284 gru_lstm_utils.CPU_DEVICE_NAME, 

1285 standard_gru, 

1286 supportive_attribute, 

1287 ) 

1288 defun_gpu_gru = gru_lstm_utils.generate_defun_backend( 

1289 api_name, 

1290 gru_lstm_utils.GPU_DEVICE_NAME, 

1291 gpu_gru_with_fallback, 

1292 supportive_attribute, 

1293 ) 

1294 

1295 # Call the normal GRU impl and register the cuDNN impl function. The 

1296 # grappler will kick in during session execution to optimize the graph. 

1297 last_output, outputs, new_h, runtime = defun_standard_gru(**params) 

1298 gru_lstm_utils.function_register(defun_gpu_gru, **params) 

1299 

1300 return last_output, outputs, new_h, runtime 

1301