Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/lstm.py: 21%

336 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Long Short-Term Memory layer.""" 

16 

17 

18import uuid 

19 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src import activations 

23from keras.src import backend 

24from keras.src import constraints 

25from keras.src import initializers 

26from keras.src import regularizers 

27from keras.src.engine import base_layer 

28from keras.src.engine.input_spec import InputSpec 

29from keras.src.layers.rnn import gru_lstm_utils 

30from keras.src.layers.rnn import rnn_utils 

31from keras.src.layers.rnn.base_rnn import RNN 

32from keras.src.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin 

33from keras.src.utils import tf_utils 

34 

35# isort: off 

36from tensorflow.python.platform import tf_logging as logging 

37from tensorflow.python.util.tf_export import keras_export 

38 

39RECURRENT_DROPOUT_WARNING_MSG = ( 

40 "RNN `implementation=2` is not supported when `recurrent_dropout` is set. " 

41 "Using `implementation=1`." 

42) 

43 

44 

45@keras_export("keras.layers.LSTMCell", v1=[]) 

46class LSTMCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer): 

47 """Cell class for the LSTM layer. 

48 

49 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 

50 for details about the usage of RNN API. 

51 

52 This class processes one step within the whole time sequence input, whereas 

53 `tf.keras.layer.LSTM` processes the whole sequence. 

54 

55 For example: 

56 

57 >>> inputs = tf.random.normal([32, 10, 8]) 

58 >>> rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4)) 

59 >>> output = rnn(inputs) 

60 >>> print(output.shape) 

61 (32, 4) 

62 >>> rnn = tf.keras.layers.RNN( 

63 ... tf.keras.layers.LSTMCell(4), 

64 ... return_sequences=True, 

65 ... return_state=True) 

66 >>> whole_seq_output, final_memory_state, final_carry_state = rnn(inputs) 

67 >>> print(whole_seq_output.shape) 

68 (32, 10, 4) 

69 >>> print(final_memory_state.shape) 

70 (32, 4) 

71 >>> print(final_carry_state.shape) 

72 (32, 4) 

73 

74 Args: 

75 units: Positive integer, dimensionality of the output space. 

76 activation: Activation function to use. Default: hyperbolic tangent 

77 (`tanh`). If you pass `None`, no activation is applied (ie. "linear" 

78 activation: `a(x) = x`). 

79 recurrent_activation: Activation function to use for the recurrent step. 

80 Default: sigmoid (`sigmoid`). If you pass `None`, no activation is 

81 applied (ie. "linear" activation: `a(x) = x`). 

82 use_bias: Boolean, (default `True`), whether the layer uses a bias vector. 

83 kernel_initializer: Initializer for the `kernel` weights matrix, used for 

84 the linear transformation of the inputs. Default: `glorot_uniform`. 

85 recurrent_initializer: Initializer for the `recurrent_kernel` weights 

86 matrix, used for the linear transformation of the recurrent state. 

87 Default: `orthogonal`. 

88 bias_initializer: Initializer for the bias vector. Default: `zeros`. 

89 unit_forget_bias: Boolean (default `True`). If True, add 1 to the bias of 

90 the forget gate at initialization. Setting it to true will also force 

91 `bias_initializer="zeros"`. This is recommended in [Jozefowicz et 

92 al.](https://github.com/mlresearch/v37/blob/gh-pages/jozefowicz15.pdf) 

93 kernel_regularizer: Regularizer function applied to the `kernel` weights 

94 matrix. Default: `None`. 

95 recurrent_regularizer: Regularizer function applied to 

96 the `recurrent_kernel` weights matrix. Default: `None`. 

97 bias_regularizer: Regularizer function applied to the bias vector. 

98 Default: `None`. 

99 kernel_constraint: Constraint function applied to the `kernel` weights 

100 matrix. Default: `None`. 

101 recurrent_constraint: Constraint function applied to the 

102 `recurrent_kernel` weights matrix. Default: `None`. 

103 bias_constraint: Constraint function applied to the bias vector. Default: 

104 `None`. 

105 dropout: Float between 0 and 1. Fraction of the units to drop for the 

106 linear transformation of the inputs. Default: 0. 

107 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop 

108 for the linear transformation of the recurrent state. Default: 0. 

109 

110 Call arguments: 

111 inputs: A 2D tensor, with shape of `[batch, feature]`. 

112 states: List of 2 tensors that corresponding to the cell's units. Both of 

113 them have shape `[batch, units]`, the first tensor is the memory state 

114 from previous time step, the second tensor is the carry state from 

115 previous time step. For timestep 0, the initial state provided by user 

116 will be feed to cell. 

117 training: Python boolean indicating whether the layer should behave in 

118 training mode or in inference mode. Only relevant when `dropout` or 

119 `recurrent_dropout` is used. 

120 """ 

121 

122 def __init__( 

123 self, 

124 units, 

125 activation="tanh", 

126 recurrent_activation="sigmoid", 

127 use_bias=True, 

128 kernel_initializer="glorot_uniform", 

129 recurrent_initializer="orthogonal", 

130 bias_initializer="zeros", 

131 unit_forget_bias=True, 

132 kernel_regularizer=None, 

133 recurrent_regularizer=None, 

134 bias_regularizer=None, 

135 kernel_constraint=None, 

136 recurrent_constraint=None, 

137 bias_constraint=None, 

138 dropout=0.0, 

139 recurrent_dropout=0.0, 

140 **kwargs, 

141 ): 

142 if units <= 0: 

143 raise ValueError( 

144 "Received an invalid value for argument `units`, " 

145 f"expected a positive integer, got {units}." 

146 ) 

147 # By default use cached variable under v2 mode, see b/143699808. 

148 if tf.compat.v1.executing_eagerly_outside_functions(): 

149 self._enable_caching_device = kwargs.pop( 

150 "enable_caching_device", True 

151 ) 

152 else: 

153 self._enable_caching_device = kwargs.pop( 

154 "enable_caching_device", False 

155 ) 

156 super().__init__(**kwargs) 

157 self.units = units 

158 self.activation = activations.get(activation) 

159 self.recurrent_activation = activations.get(recurrent_activation) 

160 self.use_bias = use_bias 

161 

162 self.kernel_initializer = initializers.get(kernel_initializer) 

163 self.recurrent_initializer = initializers.get(recurrent_initializer) 

164 self.bias_initializer = initializers.get(bias_initializer) 

165 self.unit_forget_bias = unit_forget_bias 

166 

167 self.kernel_regularizer = regularizers.get(kernel_regularizer) 

168 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 

169 self.bias_regularizer = regularizers.get(bias_regularizer) 

170 

171 self.kernel_constraint = constraints.get(kernel_constraint) 

172 self.recurrent_constraint = constraints.get(recurrent_constraint) 

173 self.bias_constraint = constraints.get(bias_constraint) 

174 

175 self.dropout = min(1.0, max(0.0, dropout)) 

176 self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout)) 

177 implementation = kwargs.pop("implementation", 2) 

178 if self.recurrent_dropout != 0 and implementation != 1: 

179 logging.debug(RECURRENT_DROPOUT_WARNING_MSG) 

180 self.implementation = 1 

181 else: 

182 self.implementation = implementation 

183 self.state_size = [self.units, self.units] 

184 self.output_size = self.units 

185 

186 @tf_utils.shape_type_conversion 

187 def build(self, input_shape): 

188 super().build(input_shape) 

189 default_caching_device = rnn_utils.caching_device(self) 

190 input_dim = input_shape[-1] 

191 self.kernel = self.add_weight( 

192 shape=(input_dim, self.units * 4), 

193 name="kernel", 

194 initializer=self.kernel_initializer, 

195 regularizer=self.kernel_regularizer, 

196 constraint=self.kernel_constraint, 

197 caching_device=default_caching_device, 

198 ) 

199 self.recurrent_kernel = self.add_weight( 

200 shape=(self.units, self.units * 4), 

201 name="recurrent_kernel", 

202 initializer=self.recurrent_initializer, 

203 regularizer=self.recurrent_regularizer, 

204 constraint=self.recurrent_constraint, 

205 caching_device=default_caching_device, 

206 ) 

207 

208 if self.use_bias: 

209 if self.unit_forget_bias: 

210 

211 def bias_initializer(_, *args, **kwargs): 

212 return backend.concatenate( 

213 [ 

214 self.bias_initializer( 

215 (self.units,), *args, **kwargs 

216 ), 

217 initializers.get("ones")( 

218 (self.units,), *args, **kwargs 

219 ), 

220 self.bias_initializer( 

221 (self.units * 2,), *args, **kwargs 

222 ), 

223 ] 

224 ) 

225 

226 else: 

227 bias_initializer = self.bias_initializer 

228 self.bias = self.add_weight( 

229 shape=(self.units * 4,), 

230 name="bias", 

231 initializer=bias_initializer, 

232 regularizer=self.bias_regularizer, 

233 constraint=self.bias_constraint, 

234 caching_device=default_caching_device, 

235 ) 

236 else: 

237 self.bias = None 

238 self.built = True 

239 

240 def _compute_carry_and_output(self, x, h_tm1, c_tm1): 

241 """Computes carry and output using split kernels.""" 

242 x_i, x_f, x_c, x_o = x 

243 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 

244 i = self.recurrent_activation( 

245 x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, : self.units]) 

246 ) 

247 f = self.recurrent_activation( 

248 x_f 

249 + backend.dot( 

250 h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2] 

251 ) 

252 ) 

253 c = f * c_tm1 + i * self.activation( 

254 x_c 

255 + backend.dot( 

256 h_tm1_c, 

257 self.recurrent_kernel[:, self.units * 2 : self.units * 3], 

258 ) 

259 ) 

260 o = self.recurrent_activation( 

261 x_o 

262 + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :]) 

263 ) 

264 return c, o 

265 

266 def _compute_carry_and_output_fused(self, z, c_tm1): 

267 """Computes carry and output using fused kernels.""" 

268 z0, z1, z2, z3 = z 

269 i = self.recurrent_activation(z0) 

270 f = self.recurrent_activation(z1) 

271 c = f * c_tm1 + i * self.activation(z2) 

272 o = self.recurrent_activation(z3) 

273 return c, o 

274 

275 def call(self, inputs, states, training=None): 

276 h_tm1 = states[0] # previous memory state 

277 c_tm1 = states[1] # previous carry state 

278 

279 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) 

280 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 

281 h_tm1, training, count=4 

282 ) 

283 

284 if self.implementation == 1: 

285 if 0 < self.dropout < 1.0: 

286 inputs_i = inputs * dp_mask[0] 

287 inputs_f = inputs * dp_mask[1] 

288 inputs_c = inputs * dp_mask[2] 

289 inputs_o = inputs * dp_mask[3] 

290 else: 

291 inputs_i = inputs 

292 inputs_f = inputs 

293 inputs_c = inputs 

294 inputs_o = inputs 

295 k_i, k_f, k_c, k_o = tf.split( 

296 self.kernel, num_or_size_splits=4, axis=1 

297 ) 

298 x_i = backend.dot(inputs_i, k_i) 

299 x_f = backend.dot(inputs_f, k_f) 

300 x_c = backend.dot(inputs_c, k_c) 

301 x_o = backend.dot(inputs_o, k_o) 

302 if self.use_bias: 

303 b_i, b_f, b_c, b_o = tf.split( 

304 self.bias, num_or_size_splits=4, axis=0 

305 ) 

306 x_i = backend.bias_add(x_i, b_i) 

307 x_f = backend.bias_add(x_f, b_f) 

308 x_c = backend.bias_add(x_c, b_c) 

309 x_o = backend.bias_add(x_o, b_o) 

310 

311 if 0 < self.recurrent_dropout < 1.0: 

312 h_tm1_i = h_tm1 * rec_dp_mask[0] 

313 h_tm1_f = h_tm1 * rec_dp_mask[1] 

314 h_tm1_c = h_tm1 * rec_dp_mask[2] 

315 h_tm1_o = h_tm1 * rec_dp_mask[3] 

316 else: 

317 h_tm1_i = h_tm1 

318 h_tm1_f = h_tm1 

319 h_tm1_c = h_tm1 

320 h_tm1_o = h_tm1 

321 x = (x_i, x_f, x_c, x_o) 

322 h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o) 

323 c, o = self._compute_carry_and_output(x, h_tm1, c_tm1) 

324 else: 

325 if 0.0 < self.dropout < 1.0: 

326 inputs = inputs * dp_mask[0] 

327 z = backend.dot(inputs, self.kernel) 

328 z += backend.dot(h_tm1, self.recurrent_kernel) 

329 if self.use_bias: 

330 z = backend.bias_add(z, self.bias) 

331 

332 z = tf.split(z, num_or_size_splits=4, axis=1) 

333 c, o = self._compute_carry_and_output_fused(z, c_tm1) 

334 

335 h = o * self.activation(c) 

336 return h, [h, c] 

337 

338 def get_config(self): 

339 config = { 

340 "units": self.units, 

341 "activation": activations.serialize(self.activation), 

342 "recurrent_activation": activations.serialize( 

343 self.recurrent_activation 

344 ), 

345 "use_bias": self.use_bias, 

346 "kernel_initializer": initializers.serialize( 

347 self.kernel_initializer 

348 ), 

349 "recurrent_initializer": initializers.serialize( 

350 self.recurrent_initializer 

351 ), 

352 "bias_initializer": initializers.serialize(self.bias_initializer), 

353 "unit_forget_bias": self.unit_forget_bias, 

354 "kernel_regularizer": regularizers.serialize( 

355 self.kernel_regularizer 

356 ), 

357 "recurrent_regularizer": regularizers.serialize( 

358 self.recurrent_regularizer 

359 ), 

360 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

361 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

362 "recurrent_constraint": constraints.serialize( 

363 self.recurrent_constraint 

364 ), 

365 "bias_constraint": constraints.serialize(self.bias_constraint), 

366 "dropout": self.dropout, 

367 "recurrent_dropout": self.recurrent_dropout, 

368 "implementation": self.implementation, 

369 } 

370 config.update(rnn_utils.config_for_enable_caching_device(self)) 

371 base_config = super().get_config() 

372 return dict(list(base_config.items()) + list(config.items())) 

373 

374 def get_initial_state(self, inputs=None, batch_size=None, dtype=None): 

375 return list( 

376 rnn_utils.generate_zero_filled_state_for_cell( 

377 self, inputs, batch_size, dtype 

378 ) 

379 ) 

380 

381 

382@keras_export("keras.layers.LSTM", v1=[]) 

383class LSTM(DropoutRNNCellMixin, RNN, base_layer.BaseRandomLayer): 

384 """Long Short-Term Memory layer - Hochreiter 1997. 

385 

386 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) 

387 for details about the usage of RNN API. 

388 

389 Based on available runtime hardware and constraints, this layer 

390 will choose different implementations (cuDNN-based or pure-TensorFlow) 

391 to maximize the performance. If a GPU is available and all 

392 the arguments to the layer meet the requirement of the cuDNN kernel 

393 (see below for details), the layer will use a fast cuDNN implementation. 

394 

395 The requirements to use the cuDNN implementation are: 

396 

397 1. `activation` == `tanh` 

398 2. `recurrent_activation` == `sigmoid` 

399 3. `recurrent_dropout` == 0 

400 4. `unroll` is `False` 

401 5. `use_bias` is `True` 

402 6. Inputs, if use masking, are strictly right-padded. 

403 7. Eager execution is enabled in the outermost context. 

404 

405 For example: 

406 

407 >>> inputs = tf.random.normal([32, 10, 8]) 

408 >>> lstm = tf.keras.layers.LSTM(4) 

409 >>> output = lstm(inputs) 

410 >>> print(output.shape) 

411 (32, 4) 

412 >>> lstm = tf.keras.layers.LSTM(4, return_sequences=True, return_state=True) 

413 >>> whole_seq_output, final_memory_state, final_carry_state = lstm(inputs) 

414 >>> print(whole_seq_output.shape) 

415 (32, 10, 4) 

416 >>> print(final_memory_state.shape) 

417 (32, 4) 

418 >>> print(final_carry_state.shape) 

419 (32, 4) 

420 

421 Args: 

422 units: Positive integer, dimensionality of the output space. 

423 activation: Activation function to use. 

424 Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation 

425 is applied (ie. "linear" activation: `a(x) = x`). 

426 recurrent_activation: Activation function to use for the recurrent step. 

427 Default: sigmoid (`sigmoid`). If you pass `None`, no activation is 

428 applied (ie. "linear" activation: `a(x) = x`). 

429 use_bias: Boolean (default `True`), whether the layer uses a bias vector. 

430 kernel_initializer: Initializer for the `kernel` weights matrix, used for 

431 the linear transformation of the inputs. Default: `glorot_uniform`. 

432 recurrent_initializer: Initializer for the `recurrent_kernel` weights 

433 matrix, used for the linear transformation of the recurrent state. 

434 Default: `orthogonal`. 

435 bias_initializer: Initializer for the bias vector. Default: `zeros`. 

436 unit_forget_bias: Boolean (default `True`). If True, add 1 to the bias of 

437 the forget gate at initialization. Setting it to true will also force 

438 `bias_initializer="zeros"`. This is recommended in [Jozefowicz et 

439 al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf). 

440 kernel_regularizer: Regularizer function applied to the `kernel` weights 

441 matrix. Default: `None`. 

442 recurrent_regularizer: Regularizer function applied to the 

443 `recurrent_kernel` weights matrix. Default: `None`. 

444 bias_regularizer: Regularizer function applied to the bias vector. 

445 Default: `None`. 

446 activity_regularizer: Regularizer function applied to the output of the 

447 layer (its "activation"). Default: `None`. 

448 kernel_constraint: Constraint function applied to the `kernel` weights 

449 matrix. Default: `None`. 

450 recurrent_constraint: Constraint function applied to the 

451 `recurrent_kernel` weights matrix. Default: `None`. 

452 bias_constraint: Constraint function applied to the bias vector. Default: 

453 `None`. 

454 dropout: Float between 0 and 1. Fraction of the units to drop for the 

455 linear transformation of the inputs. Default: 0. 

456 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop 

457 for the linear transformation of the recurrent state. Default: 0. 

458 return_sequences: Boolean. Whether to return the last output in the output 

459 sequence, or the full sequence. Default: `False`. 

460 return_state: Boolean. Whether to return the last state in addition to the 

461 output. Default: `False`. 

462 go_backwards: Boolean (default `False`). If True, process the input 

463 sequence backwards and return the reversed sequence. 

464 stateful: Boolean (default `False`). If True, the last state for each 

465 sample at index i in a batch will be used as initial state for the sample 

466 of index i in the following batch. 

467 time_major: The shape format of the `inputs` and `outputs` tensors. 

468 If True, the inputs and outputs will be in shape 

469 `[timesteps, batch, feature]`, whereas in the False case, it will be 

470 `[batch, timesteps, feature]`. Using `time_major = True` is a bit more 

471 efficient because it avoids transposes at the beginning and end of the 

472 RNN calculation. However, most TensorFlow data is batch-major, so by 

473 default this function accepts input and emits output in batch-major 

474 form. 

475 unroll: Boolean (default `False`). If True, the network will be unrolled, 

476 else a symbolic loop will be used. Unrolling can speed-up a RNN, 

477 although it tends to be more memory-intensive. Unrolling is only 

478 suitable for short sequences. 

479 

480 Call arguments: 

481 inputs: A 3D tensor with shape `[batch, timesteps, feature]`. 

482 mask: Binary tensor of shape `[batch, timesteps]` indicating whether 

483 a given timestep should be masked (optional). 

484 An individual `True` entry indicates that the corresponding timestep 

485 should be utilized, while a `False` entry indicates that the 

486 corresponding timestep should be ignored. Defaults to `None`. 

487 training: Python boolean indicating whether the layer should behave in 

488 training mode or in inference mode. This argument is passed to the cell 

489 when calling it. This is only relevant if `dropout` or 

490 `recurrent_dropout` is used (optional). Defaults to `None`. 

491 initial_state: List of initial state tensors to be passed to the first 

492 call of the cell (optional, `None` causes creation 

493 of zero-filled initial state tensors). Defaults to `None`. 

494 """ 

495 

496 def __init__( 

497 self, 

498 units, 

499 activation="tanh", 

500 recurrent_activation="sigmoid", 

501 use_bias=True, 

502 kernel_initializer="glorot_uniform", 

503 recurrent_initializer="orthogonal", 

504 bias_initializer="zeros", 

505 unit_forget_bias=True, 

506 kernel_regularizer=None, 

507 recurrent_regularizer=None, 

508 bias_regularizer=None, 

509 activity_regularizer=None, 

510 kernel_constraint=None, 

511 recurrent_constraint=None, 

512 bias_constraint=None, 

513 dropout=0.0, 

514 recurrent_dropout=0.0, 

515 return_sequences=False, 

516 return_state=False, 

517 go_backwards=False, 

518 stateful=False, 

519 time_major=False, 

520 unroll=False, 

521 **kwargs, 

522 ): 

523 # return_runtime is a flag for testing, which shows the real backend 

524 # implementation chosen by grappler in graph mode. 

525 self.return_runtime = kwargs.pop("return_runtime", False) 

526 implementation = kwargs.pop("implementation", 2) 

527 if implementation == 0: 

528 logging.warning( 

529 "`implementation=0` has been deprecated, " 

530 "and now defaults to `implementation=1`." 

531 "Please update your layer call." 

532 ) 

533 if "enable_caching_device" in kwargs: 

534 cell_kwargs = { 

535 "enable_caching_device": kwargs.pop("enable_caching_device") 

536 } 

537 else: 

538 cell_kwargs = {} 

539 cell = LSTMCell( 

540 units, 

541 activation=activation, 

542 recurrent_activation=recurrent_activation, 

543 use_bias=use_bias, 

544 kernel_initializer=kernel_initializer, 

545 recurrent_initializer=recurrent_initializer, 

546 unit_forget_bias=unit_forget_bias, 

547 bias_initializer=bias_initializer, 

548 kernel_regularizer=kernel_regularizer, 

549 recurrent_regularizer=recurrent_regularizer, 

550 bias_regularizer=bias_regularizer, 

551 kernel_constraint=kernel_constraint, 

552 recurrent_constraint=recurrent_constraint, 

553 bias_constraint=bias_constraint, 

554 dropout=dropout, 

555 recurrent_dropout=recurrent_dropout, 

556 implementation=implementation, 

557 dtype=kwargs.get("dtype"), 

558 trainable=kwargs.get("trainable", True), 

559 name="lstm_cell", 

560 **cell_kwargs, 

561 ) 

562 super().__init__( 

563 cell, 

564 return_sequences=return_sequences, 

565 return_state=return_state, 

566 go_backwards=go_backwards, 

567 stateful=stateful, 

568 time_major=time_major, 

569 unroll=unroll, 

570 **kwargs, 

571 ) 

572 self.activity_regularizer = regularizers.get(activity_regularizer) 

573 self.input_spec = [InputSpec(ndim=3)] 

574 self.state_spec = [ 

575 InputSpec(shape=(None, dim)) for dim in (self.units, self.units) 

576 ] 

577 self._could_use_gpu_kernel = ( 

578 self.activation in (activations.tanh, tf.tanh) 

579 and self.recurrent_activation in (activations.sigmoid, tf.sigmoid) 

580 and recurrent_dropout == 0 

581 and not unroll 

582 and use_bias 

583 and tf.compat.v1.executing_eagerly_outside_functions() 

584 ) 

585 if tf.config.list_logical_devices("GPU"): 

586 # Only show the message when there is GPU available, user will not 

587 # care about the cuDNN if there isn't any GPU. 

588 if self._could_use_gpu_kernel: 

589 logging.debug(gru_lstm_utils.CUDNN_AVAILABLE_MSG % self.name) 

590 else: 

591 logging.warning( 

592 gru_lstm_utils.CUDNN_NOT_AVAILABLE_MSG % self.name 

593 ) 

594 

595 if gru_lstm_utils.use_new_gru_lstm_impl(): 

596 self._defun_wrapper = gru_lstm_utils.DefunWrapper( 

597 time_major, go_backwards, "lstm" 

598 ) 

599 

600 def call(self, inputs, mask=None, training=None, initial_state=None): 

601 # The input should be dense, padded with zeros. If a ragged input is fed 

602 # into the layer, it is padded and the row lengths are used for masking. 

603 inputs, row_lengths = backend.convert_inputs_if_ragged(inputs) 

604 is_ragged_input = row_lengths is not None 

605 self._validate_args_if_ragged(is_ragged_input, mask) 

606 

607 # LSTM does not support constants. Ignore it during process. 

608 inputs, initial_state, _ = self._process_inputs( 

609 inputs, initial_state, None 

610 ) 

611 

612 if isinstance(mask, list): 

613 mask = mask[0] 

614 

615 input_shape = backend.int_shape(inputs) 

616 timesteps = input_shape[0] if self.time_major else input_shape[1] 

617 

618 if not self._could_use_gpu_kernel: 

619 # Fall back to use the normal LSTM. 

620 kwargs = {"training": training} 

621 self._maybe_reset_cell_dropout_mask(self.cell) 

622 

623 def step(inputs, states): 

624 return self.cell(inputs, states, **kwargs) 

625 

626 last_output, outputs, states = backend.rnn( 

627 step, 

628 inputs, 

629 initial_state, 

630 constants=None, 

631 go_backwards=self.go_backwards, 

632 mask=mask, 

633 unroll=self.unroll, 

634 input_length=row_lengths 

635 if row_lengths is not None 

636 else timesteps, 

637 time_major=self.time_major, 

638 zero_output_for_mask=self.zero_output_for_mask, 

639 return_all_outputs=self.return_sequences, 

640 ) 

641 runtime = gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_UNKNOWN) 

642 else: 

643 # Use the new defun approach for backend implementation swap. 

644 # Note that different implementations need to have same function 

645 # signature, eg, the tensor parameters need to have same shape and 

646 # dtypes. Since the cuDNN has an extra set of bias, those bias will 

647 # be passed to both normal and cuDNN implementations. 

648 self.reset_dropout_mask() 

649 dropout_mask = self.get_dropout_mask_for_cell( 

650 inputs, training, count=4 

651 ) 

652 if dropout_mask is not None: 

653 inputs = inputs * dropout_mask[0] 

654 if gru_lstm_utils.use_new_gru_lstm_impl(): 

655 lstm_kwargs = { 

656 "inputs": inputs, 

657 "init_h": gru_lstm_utils.read_variable_value( 

658 initial_state[0] 

659 ), 

660 "init_c": gru_lstm_utils.read_variable_value( 

661 initial_state[1] 

662 ), 

663 "kernel": gru_lstm_utils.read_variable_value( 

664 self.cell.kernel 

665 ), 

666 "recurrent_kernel": gru_lstm_utils.read_variable_value( 

667 self.cell.recurrent_kernel 

668 ), 

669 "bias": gru_lstm_utils.read_variable_value(self.cell.bias), 

670 "mask": mask, 

671 "time_major": self.time_major, 

672 "go_backwards": self.go_backwards, 

673 "sequence_lengths": row_lengths, 

674 "zero_output_for_mask": self.zero_output_for_mask, 

675 } 

676 ( 

677 last_output, 

678 outputs, 

679 new_h, 

680 new_c, 

681 runtime, 

682 ) = self._defun_wrapper.defun_layer(**lstm_kwargs) 

683 else: 

684 gpu_lstm_kwargs = { 

685 "inputs": inputs, 

686 "init_h": gru_lstm_utils.read_variable_value( 

687 initial_state[0] 

688 ), 

689 "init_c": gru_lstm_utils.read_variable_value( 

690 initial_state[1] 

691 ), 

692 "kernel": gru_lstm_utils.read_variable_value( 

693 self.cell.kernel 

694 ), 

695 "recurrent_kernel": gru_lstm_utils.read_variable_value( 

696 self.cell.recurrent_kernel 

697 ), 

698 "bias": gru_lstm_utils.read_variable_value(self.cell.bias), 

699 "mask": mask, 

700 "time_major": self.time_major, 

701 "go_backwards": self.go_backwards, 

702 "sequence_lengths": row_lengths, 

703 "return_sequences": self.return_sequences, 

704 } 

705 normal_lstm_kwargs = gpu_lstm_kwargs.copy() 

706 normal_lstm_kwargs.update( 

707 { 

708 "zero_output_for_mask": self.zero_output_for_mask, 

709 } 

710 ) 

711 

712 if tf.executing_eagerly(): 

713 device_type = gru_lstm_utils.get_context_device_type() 

714 can_use_gpu = ( 

715 # Either user specified GPU or unspecified but GPU is 

716 # available. 

717 ( 

718 device_type == gru_lstm_utils.GPU_DEVICE_NAME 

719 or ( 

720 device_type is None 

721 and tf.config.list_logical_devices("GPU") 

722 ) 

723 ) 

724 and gru_lstm_utils.is_cudnn_supported_inputs( 

725 mask, self.time_major, row_lengths 

726 ) 

727 ) 

728 # Under eager context, check the device placement and prefer 

729 # the GPU implementation when GPU is available. 

730 if can_use_gpu: 

731 last_output, outputs, new_h, new_c, runtime = gpu_lstm( 

732 **gpu_lstm_kwargs 

733 ) 

734 else: 

735 ( 

736 last_output, 

737 outputs, 

738 new_h, 

739 new_c, 

740 runtime, 

741 ) = standard_lstm(**normal_lstm_kwargs) 

742 else: 

743 ( 

744 last_output, 

745 outputs, 

746 new_h, 

747 new_c, 

748 runtime, 

749 ) = lstm_with_backend_selection(**normal_lstm_kwargs) 

750 

751 states = [new_h, new_c] 

752 

753 if self.stateful: 

754 updates = [ 

755 tf.compat.v1.assign( 

756 self_state, tf.cast(state, self_state.dtype) 

757 ) 

758 for self_state, state in zip(self.states, states) 

759 ] 

760 self.add_update(updates) 

761 

762 if self.return_sequences: 

763 output = backend.maybe_convert_to_ragged( 

764 is_ragged_input, 

765 outputs, 

766 row_lengths, 

767 go_backwards=self.go_backwards, 

768 ) 

769 else: 

770 output = last_output 

771 

772 if self.return_state: 

773 return [output] + list(states) 

774 elif self.return_runtime: 

775 return output, runtime 

776 else: 

777 return output 

778 

779 @property 

780 def units(self): 

781 return self.cell.units 

782 

783 @property 

784 def activation(self): 

785 return self.cell.activation 

786 

787 @property 

788 def recurrent_activation(self): 

789 return self.cell.recurrent_activation 

790 

791 @property 

792 def use_bias(self): 

793 return self.cell.use_bias 

794 

795 @property 

796 def kernel_initializer(self): 

797 return self.cell.kernel_initializer 

798 

799 @property 

800 def recurrent_initializer(self): 

801 return self.cell.recurrent_initializer 

802 

803 @property 

804 def bias_initializer(self): 

805 return self.cell.bias_initializer 

806 

807 @property 

808 def unit_forget_bias(self): 

809 return self.cell.unit_forget_bias 

810 

811 @property 

812 def kernel_regularizer(self): 

813 return self.cell.kernel_regularizer 

814 

815 @property 

816 def recurrent_regularizer(self): 

817 return self.cell.recurrent_regularizer 

818 

819 @property 

820 def bias_regularizer(self): 

821 return self.cell.bias_regularizer 

822 

823 @property 

824 def kernel_constraint(self): 

825 return self.cell.kernel_constraint 

826 

827 @property 

828 def recurrent_constraint(self): 

829 return self.cell.recurrent_constraint 

830 

831 @property 

832 def bias_constraint(self): 

833 return self.cell.bias_constraint 

834 

835 @property 

836 def dropout(self): 

837 return self.cell.dropout 

838 

839 @property 

840 def recurrent_dropout(self): 

841 return self.cell.recurrent_dropout 

842 

843 @property 

844 def implementation(self): 

845 return self.cell.implementation 

846 

847 def get_config(self): 

848 config = { 

849 "units": self.units, 

850 "activation": activations.serialize(self.activation), 

851 "recurrent_activation": activations.serialize( 

852 self.recurrent_activation 

853 ), 

854 "use_bias": self.use_bias, 

855 "kernel_initializer": initializers.serialize( 

856 self.kernel_initializer 

857 ), 

858 "recurrent_initializer": initializers.serialize( 

859 self.recurrent_initializer 

860 ), 

861 "bias_initializer": initializers.serialize(self.bias_initializer), 

862 "unit_forget_bias": self.unit_forget_bias, 

863 "kernel_regularizer": regularizers.serialize( 

864 self.kernel_regularizer 

865 ), 

866 "recurrent_regularizer": regularizers.serialize( 

867 self.recurrent_regularizer 

868 ), 

869 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

870 "activity_regularizer": regularizers.serialize( 

871 self.activity_regularizer 

872 ), 

873 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

874 "recurrent_constraint": constraints.serialize( 

875 self.recurrent_constraint 

876 ), 

877 "bias_constraint": constraints.serialize(self.bias_constraint), 

878 "dropout": self.dropout, 

879 "recurrent_dropout": self.recurrent_dropout, 

880 "implementation": self.implementation, 

881 } 

882 config.update(rnn_utils.config_for_enable_caching_device(self.cell)) 

883 base_config = super().get_config() 

884 del base_config["cell"] 

885 return dict(list(base_config.items()) + list(config.items())) 

886 

887 @classmethod 

888 def from_config(cls, config): 

889 if "implementation" in config and config["implementation"] == 0: 

890 config["implementation"] = 1 

891 return cls(**config) 

892 

893 

894def standard_lstm( 

895 inputs, 

896 init_h, 

897 init_c, 

898 kernel, 

899 recurrent_kernel, 

900 bias, 

901 mask, 

902 time_major, 

903 go_backwards, 

904 sequence_lengths, 

905 zero_output_for_mask, 

906 return_sequences, 

907): 

908 """LSTM with standard kernel implementation. 

909 

910 This implementation can be run on all types for hardware. 

911 

912 This implementation lifts out all the layer weights and make them function 

913 parameters. It has same number of tensor input params as the cuDNN 

914 counterpart. The RNN step logic has been simplified, eg dropout and mask is 

915 removed since cuDNN implementation does not support that. 

916 

917 Note that the first half of the bias tensor should be ignored by this impl. 

918 The cuDNN impl need an extra set of input gate bias. In order to make the 

919 both function take same shape of parameter, that extra set of bias is also 

920 feed 

921 here. 

922 

923 Args: 

924 inputs: input tensor of LSTM layer. 

925 init_h: initial state tensor for the cell output. 

926 init_c: initial state tensor for the cell hidden state. 

927 kernel: weights for cell kernel. 

928 recurrent_kernel: weights for cell recurrent kernel. 

929 bias: weights for cell kernel bias and recurrent bias. Only recurrent bias 

930 is used in this case. 

931 mask: Boolean tensor for mask out the steps within sequence. 

932 An individual `True` entry indicates that the corresponding timestep 

933 should be utilized, while a `False` entry indicates that the 

934 corresponding timestep should be ignored. 

935 time_major: boolean, whether the inputs are in the format of 

936 [time, batch, feature] or [batch, time, feature]. 

937 go_backwards: Boolean (default False). If True, process the input sequence 

938 backwards and return the reversed sequence. 

939 sequence_lengths: The lengths of all sequences coming from a variable 

940 length input, such as ragged tensors. If the input has a fixed timestep 

941 size, this should be None. 

942 zero_output_for_mask: Boolean, whether to output zero for masked timestep. 

943 return_sequences: Boolean. If True, return the recurrent outputs for all 

944 timesteps in the sequence. If False, only return the output for the 

945 last timestep (which consumes less memory). 

946 

947 Returns: 

948 last_output: output tensor for the last timestep, which has shape 

949 [batch, units]. 

950 outputs: 

951 - If `return_sequences=True`: output tensor for all timesteps, 

952 which has shape [batch, time, units]. 

953 - Else, a tensor equal to `last_output` with shape [batch, 1, units] 

954 state_0: the cell output, which has same shape as init_h. 

955 state_1: the cell hidden state, which has same shape as init_c. 

956 runtime: constant string tensor which indicate real runtime hardware. This 

957 value is for testing purpose and should be used by user. 

958 """ 

959 input_shape = backend.int_shape(inputs) 

960 timesteps = input_shape[0] if time_major else input_shape[1] 

961 

962 def step(cell_inputs, cell_states): 

963 """Step function that will be used by Keras RNN backend.""" 

964 h_tm1 = cell_states[0] # previous memory state 

965 c_tm1 = cell_states[1] # previous carry state 

966 

967 z = backend.dot(cell_inputs, kernel) 

968 z += backend.dot(h_tm1, recurrent_kernel) 

969 z = backend.bias_add(z, bias) 

970 

971 z0, z1, z2, z3 = tf.split(z, 4, axis=1) 

972 

973 i = tf.sigmoid(z0) 

974 f = tf.sigmoid(z1) 

975 c = f * c_tm1 + i * tf.tanh(z2) 

976 o = tf.sigmoid(z3) 

977 

978 h = o * tf.tanh(c) 

979 return h, [h, c] 

980 

981 last_output, outputs, new_states = backend.rnn( 

982 step, 

983 inputs, 

984 [init_h, init_c], 

985 constants=None, 

986 unroll=False, 

987 time_major=time_major, 

988 mask=mask, 

989 go_backwards=go_backwards, 

990 input_length=( 

991 sequence_lengths if sequence_lengths is not None else timesteps 

992 ), 

993 zero_output_for_mask=zero_output_for_mask, 

994 return_all_outputs=return_sequences, 

995 ) 

996 return ( 

997 last_output, 

998 outputs, 

999 new_states[0], 

1000 new_states[1], 

1001 gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_CPU), 

1002 ) 

1003 

1004 

1005def gpu_lstm( 

1006 inputs, 

1007 init_h, 

1008 init_c, 

1009 kernel, 

1010 recurrent_kernel, 

1011 bias, 

1012 mask, 

1013 time_major, 

1014 go_backwards, 

1015 sequence_lengths, 

1016 return_sequences, 

1017): 

1018 """LSTM with either cuDNN or ROCm implementation which is only available for 

1019 GPU. 

1020 

1021 Note that currently only right padded data is supported, or the result will 

1022 be polluted by the unmasked data which should be filtered. 

1023 

1024 Args: 

1025 inputs: Input tensor of LSTM layer. 

1026 init_h: Initial state tensor for the cell output. 

1027 init_c: Initial state tensor for the cell hidden state. 

1028 kernel: Weights for cell kernel. 

1029 recurrent_kernel: Weights for cell recurrent kernel. 

1030 bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias 

1031 is used in this case. 

1032 mask: Boolean tensor for mask out the steps within sequence. An individual 

1033 `True` entry indicates that the corresponding timestep should be 

1034 utilized, while a `False` entry indicates that the corresponding 

1035 timestep should be ignored. 

1036 time_major: Boolean, whether the inputs are in the format of [time, batch, 

1037 feature] or [batch, time, feature]. 

1038 go_backwards: Boolean (default False). If True, process the input sequence 

1039 backwards and return the reversed sequence. 

1040 sequence_lengths: The lengths of all sequences coming from a variable 

1041 length input, such as ragged tensors. If the input has a fixed timestep 

1042 size, this should be None. 

1043 return_sequences: Boolean. If True, return the recurrent outputs for all 

1044 timesteps in the sequence. If False, only return the output for the 

1045 last timestep, matching the CPU function output format. 

1046 

1047 Returns: 

1048 last_output: Output tensor for the last timestep, which has shape 

1049 [batch, units]. 

1050 outputs: 

1051 - If `return_sequences=True`: output tensor for all timesteps, 

1052 which has shape [batch, time, units]. 

1053 - Else, a tensor equal to `last_output` with shape [batch, 1, units] 

1054 state_0: The cell output, which has same shape as init_h. 

1055 state_1: The cell hidden state, which has same shape as init_c. 

1056 runtime: Constant string tensor which indicate real runtime hardware. This 

1057 value is for testing purpose and should not be used by user. 

1058 """ 

1059 if mask is not None: 

1060 sequence_lengths = gru_lstm_utils.calculate_sequence_by_mask( 

1061 mask, time_major 

1062 ) 

1063 

1064 if not time_major and sequence_lengths is None: 

1065 inputs = tf.transpose(inputs, perm=(1, 0, 2)) 

1066 seq_axis, batch_axis = (0, 1) 

1067 else: 

1068 seq_axis, batch_axis = (0, 1) if time_major else (1, 0) 

1069 # For init_h and init_c, cuDNN expects one more dim of num_layers before or 

1070 # after batch dim for time major or batch major inputs respectively 

1071 init_h = tf.expand_dims(init_h, axis=seq_axis) 

1072 init_c = tf.expand_dims(init_c, axis=seq_axis) 

1073 

1074 weights = tf.split(kernel, 4, axis=1) 

1075 weights += tf.split(recurrent_kernel, 4, axis=1) 

1076 # cuDNN has an extra set of bias for inputs, we disable them (setting to 0), 

1077 # so that mathematically it is same as the canonical LSTM implementation. 

1078 full_bias = tf.concat((tf.zeros_like(bias), bias), 0) 

1079 

1080 if tf.sysconfig.get_build_info()["is_rocm_build"]: 

1081 # ROCm MIOpen's weight sequence for LSTM is different from both 

1082 # canonical and Cudnn format 

1083 # MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o] 

1084 # i is input gate weights. 

1085 # f is forget gate weights. 

1086 # o is output gate weights. 

1087 # c is cell gate weights. 

1088 weights = [weights[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)] 

1089 # full_bias is a tensor of shape (8*n,) 

1090 full_bias = tf.split(full_bias, 8, axis=0) 

1091 full_bias = [full_bias[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)] 

1092 

1093 params = gru_lstm_utils.canonical_to_params( 

1094 weights=weights, 

1095 biases=tf.split(full_bias, 8), 

1096 shape=tf.constant([-1]), 

1097 transpose_weights=True, 

1098 ) 

1099 

1100 if sequence_lengths is not None: 

1101 if go_backwards: 

1102 # Three reversals are required. E.g., 

1103 # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked 

1104 # reversed_input_to_cudnn = [3, 2, 1, 0, 0] 

1105 # output_from_cudnn = [6, 5, 4, 0, 0] 

1106 # expected_output = [0, 0, 6, 5 ,4] 

1107 inputs = tf.reverse_sequence( 

1108 inputs, 

1109 sequence_lengths, 

1110 seq_axis=seq_axis, 

1111 batch_axis=batch_axis, 

1112 ) 

1113 outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3( 

1114 input=inputs, 

1115 input_h=init_h, 

1116 input_c=init_c, 

1117 params=params, 

1118 is_training=True, 

1119 rnn_mode="lstm", 

1120 sequence_lengths=sequence_lengths, 

1121 time_major=time_major, 

1122 ) 

1123 if go_backwards: 

1124 outputs = tf.reverse_sequence( 

1125 outputs, 

1126 sequence_lengths, 

1127 seq_axis=seq_axis, 

1128 batch_axis=batch_axis, 

1129 ) 

1130 outputs = tf.reverse(outputs, axis=[seq_axis]) 

1131 else: 

1132 # # Fill the array with shape [batch] with value of max timesteps. 

1133 # sequence_length = array_ops.fill([array_ops.shape(inputs)[1]], 

1134 # array_ops.shape(inputs)[0]) 

1135 if go_backwards: 

1136 # Reverse axis 0 since the input is already convert to time major. 

1137 inputs = tf.reverse(inputs, axis=[0]) 

1138 outputs, h, c, _ = tf.raw_ops.CudnnRNN( 

1139 input=inputs, 

1140 input_h=init_h, 

1141 input_c=init_c, 

1142 params=params, 

1143 is_training=True, 

1144 rnn_mode="lstm", 

1145 ) 

1146 

1147 last_output = outputs[-1] 

1148 if not time_major and sequence_lengths is None and return_sequences: 

1149 outputs = tf.transpose(outputs, perm=[1, 0, 2]) 

1150 h = tf.squeeze(h, axis=seq_axis) 

1151 c = tf.squeeze(c, axis=seq_axis) 

1152 

1153 # In the case of variable length input, the cudnn kernel will fill zeros for 

1154 # the output, whereas the default keras behavior is to bring over the 

1155 # previous output for t-1, so that in the return_sequence=False case, user 

1156 # can quickly get the final effect output instead just 0s at the last 

1157 # timestep. In order to mimic the default keras behavior, we copy the final 

1158 # h state as the last_output, since it is numerically same as the output. 

1159 if sequence_lengths is not None: 

1160 last_output = h 

1161 

1162 # Match CPU return format 

1163 if not return_sequences: 

1164 outputs = tf.expand_dims(last_output, axis=0 if time_major else 1) 

1165 

1166 return ( 

1167 last_output, 

1168 outputs, 

1169 h, 

1170 c, 

1171 gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_GPU), 

1172 ) 

1173 

1174 

1175def lstm_with_backend_selection( 

1176 inputs, 

1177 init_h, 

1178 init_c, 

1179 kernel, 

1180 recurrent_kernel, 

1181 bias, 

1182 mask, 

1183 time_major, 

1184 go_backwards, 

1185 sequence_lengths, 

1186 zero_output_for_mask, 

1187 return_sequences, 

1188): 

1189 """Call the LSTM with optimized backend kernel selection. 

1190 

1191 Under the hood, this function will create two TF function, one with the most 

1192 generic kernel and can run on all device condition, and the second one with 

1193 cuDNN specific kernel, which can only run on GPU. 

1194 

1195 The first function will be called with normal_lstm_params, while the second 

1196 function is not called, but only registered in the graph. The Grappler will 

1197 do the proper graph rewrite and swap the optimized TF function based on the 

1198 device placement. 

1199 

1200 Args: 

1201 inputs: Input tensor of LSTM layer. 

1202 init_h: Initial state tensor for the cell output. 

1203 init_c: Initial state tensor for the cell hidden state. 

1204 kernel: Weights for cell kernel. 

1205 recurrent_kernel: Weights for cell recurrent kernel. 

1206 bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias 

1207 is used in this case. 

1208 mask: Boolean tensor for mask out the steps within sequence. 

1209 An individual `True` entry indicates that the corresponding timestep 

1210 should be utilized, while a `False` entry indicates that the 

1211 corresponding timestep should be ignored. 

1212 time_major: Boolean, whether the inputs are in the format of 

1213 [time, batch, feature] or [batch, time, feature]. 

1214 go_backwards: Boolean (default False). If True, process the input sequence 

1215 backwards and return the reversed sequence. 

1216 sequence_lengths: The lengths of all sequences coming from a variable 

1217 length input, such as ragged tensors. If the input has a fixed timestep 

1218 size, this should be None. 

1219 zero_output_for_mask: Boolean, whether to output zero for masked timestep. 

1220 return_sequences: Boolean. If True, return the recurrent outputs for all 

1221 timesteps in the sequence. If False, only return the output for the 

1222 last timestep (which consumes less memory). 

1223 

1224 Returns: 

1225 List of output tensors, same as standard_lstm. 

1226 """ 

1227 params = { 

1228 "inputs": inputs, 

1229 "init_h": init_h, 

1230 "init_c": init_c, 

1231 "kernel": kernel, 

1232 "recurrent_kernel": recurrent_kernel, 

1233 "bias": bias, 

1234 "mask": mask, 

1235 "time_major": time_major, 

1236 "go_backwards": go_backwards, 

1237 "sequence_lengths": sequence_lengths, 

1238 "zero_output_for_mask": zero_output_for_mask, 

1239 "return_sequences": return_sequences, 

1240 } 

1241 

1242 def gpu_lstm_with_fallback( 

1243 inputs, 

1244 init_h, 

1245 init_c, 

1246 kernel, 

1247 recurrent_kernel, 

1248 bias, 

1249 mask, 

1250 time_major, 

1251 go_backwards, 

1252 sequence_lengths, 

1253 zero_output_for_mask, 

1254 return_sequences, 

1255 ): 

1256 """Use cuDNN kernel when mask is none or strictly right padded.""" 

1257 

1258 def cudnn_lstm_fn(): 

1259 return gpu_lstm( 

1260 inputs=inputs, 

1261 init_h=init_h, 

1262 init_c=init_c, 

1263 kernel=kernel, 

1264 recurrent_kernel=recurrent_kernel, 

1265 bias=bias, 

1266 mask=mask, 

1267 time_major=time_major, 

1268 go_backwards=go_backwards, 

1269 sequence_lengths=sequence_lengths, 

1270 return_sequences=return_sequences, 

1271 ) 

1272 

1273 def stardard_lstm_fn(): 

1274 return standard_lstm( 

1275 inputs=inputs, 

1276 init_h=init_h, 

1277 init_c=init_c, 

1278 kernel=kernel, 

1279 recurrent_kernel=recurrent_kernel, 

1280 bias=bias, 

1281 mask=mask, 

1282 time_major=time_major, 

1283 go_backwards=go_backwards, 

1284 sequence_lengths=sequence_lengths, 

1285 zero_output_for_mask=zero_output_for_mask, 

1286 return_sequences=return_sequences, 

1287 ) 

1288 

1289 return tf.__internal__.smart_cond.smart_cond( 

1290 gru_lstm_utils.is_cudnn_supported_inputs( 

1291 mask, time_major, sequence_lengths 

1292 ), 

1293 true_fn=cudnn_lstm_fn, 

1294 false_fn=stardard_lstm_fn, 

1295 ) 

1296 

1297 if gru_lstm_utils.use_new_gru_lstm_impl(): 

1298 # Chooses the implementation dynamically based on the running device. 

1299 ( 

1300 last_output, 

1301 outputs, 

1302 new_h, 

1303 new_c, 

1304 runtime, 

1305 ) = tf.__internal__.execute_fn_for_device( 

1306 { 

1307 gru_lstm_utils.CPU_DEVICE_NAME: lambda: standard_lstm(**params), 

1308 gru_lstm_utils.GPU_DEVICE_NAME: lambda: gpu_lstm_with_fallback( 

1309 **params 

1310 ), 

1311 }, 

1312 lambda: standard_lstm(**params), 

1313 ) 

1314 else: 

1315 # Each time a `tf.function` is called, we will give it a unique 

1316 # identifiable API name, so that Grappler won't get confused when it 

1317 # sees multiple LSTM layers added into same graph, and it will be able 

1318 # to pair up the different implementations across them. 

1319 api_name = "lstm_" + str(uuid.uuid4()) 

1320 supportive_attribute = { 

1321 "time_major": time_major, 

1322 "go_backwards": go_backwards, 

1323 } 

1324 defun_standard_lstm = gru_lstm_utils.generate_defun_backend( 

1325 api_name, 

1326 gru_lstm_utils.CPU_DEVICE_NAME, 

1327 standard_lstm, 

1328 supportive_attribute, 

1329 ) 

1330 defun_gpu_lstm = gru_lstm_utils.generate_defun_backend( 

1331 api_name, 

1332 gru_lstm_utils.GPU_DEVICE_NAME, 

1333 gpu_lstm_with_fallback, 

1334 supportive_attribute, 

1335 ) 

1336 

1337 # Call the normal LSTM impl and register the cuDNN impl function. The 

1338 # grappler will kick in during session execution to optimize the graph. 

1339 last_output, outputs, new_h, new_c, runtime = defun_standard_lstm( 

1340 **params 

1341 ) 

1342 gru_lstm_utils.function_register(defun_gpu_lstm, **params) 

1343 

1344 return last_output, outputs, new_h, new_c, runtime 

1345