Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/lstm.py: 21%
336 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Long Short-Term Memory layer."""
18import uuid
20import tensorflow.compat.v2 as tf
22from keras.src import activations
23from keras.src import backend
24from keras.src import constraints
25from keras.src import initializers
26from keras.src import regularizers
27from keras.src.engine import base_layer
28from keras.src.engine.input_spec import InputSpec
29from keras.src.layers.rnn import gru_lstm_utils
30from keras.src.layers.rnn import rnn_utils
31from keras.src.layers.rnn.base_rnn import RNN
32from keras.src.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin
33from keras.src.utils import tf_utils
35# isort: off
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.util.tf_export import keras_export
39RECURRENT_DROPOUT_WARNING_MSG = (
40 "RNN `implementation=2` is not supported when `recurrent_dropout` is set. "
41 "Using `implementation=1`."
42)
45@keras_export("keras.layers.LSTMCell", v1=[])
46class LSTMCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
47 """Cell class for the LSTM layer.
49 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
50 for details about the usage of RNN API.
52 This class processes one step within the whole time sequence input, whereas
53 `tf.keras.layer.LSTM` processes the whole sequence.
55 For example:
57 >>> inputs = tf.random.normal([32, 10, 8])
58 >>> rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4))
59 >>> output = rnn(inputs)
60 >>> print(output.shape)
61 (32, 4)
62 >>> rnn = tf.keras.layers.RNN(
63 ... tf.keras.layers.LSTMCell(4),
64 ... return_sequences=True,
65 ... return_state=True)
66 >>> whole_seq_output, final_memory_state, final_carry_state = rnn(inputs)
67 >>> print(whole_seq_output.shape)
68 (32, 10, 4)
69 >>> print(final_memory_state.shape)
70 (32, 4)
71 >>> print(final_carry_state.shape)
72 (32, 4)
74 Args:
75 units: Positive integer, dimensionality of the output space.
76 activation: Activation function to use. Default: hyperbolic tangent
77 (`tanh`). If you pass `None`, no activation is applied (ie. "linear"
78 activation: `a(x) = x`).
79 recurrent_activation: Activation function to use for the recurrent step.
80 Default: sigmoid (`sigmoid`). If you pass `None`, no activation is
81 applied (ie. "linear" activation: `a(x) = x`).
82 use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
83 kernel_initializer: Initializer for the `kernel` weights matrix, used for
84 the linear transformation of the inputs. Default: `glorot_uniform`.
85 recurrent_initializer: Initializer for the `recurrent_kernel` weights
86 matrix, used for the linear transformation of the recurrent state.
87 Default: `orthogonal`.
88 bias_initializer: Initializer for the bias vector. Default: `zeros`.
89 unit_forget_bias: Boolean (default `True`). If True, add 1 to the bias of
90 the forget gate at initialization. Setting it to true will also force
91 `bias_initializer="zeros"`. This is recommended in [Jozefowicz et
92 al.](https://github.com/mlresearch/v37/blob/gh-pages/jozefowicz15.pdf)
93 kernel_regularizer: Regularizer function applied to the `kernel` weights
94 matrix. Default: `None`.
95 recurrent_regularizer: Regularizer function applied to
96 the `recurrent_kernel` weights matrix. Default: `None`.
97 bias_regularizer: Regularizer function applied to the bias vector.
98 Default: `None`.
99 kernel_constraint: Constraint function applied to the `kernel` weights
100 matrix. Default: `None`.
101 recurrent_constraint: Constraint function applied to the
102 `recurrent_kernel` weights matrix. Default: `None`.
103 bias_constraint: Constraint function applied to the bias vector. Default:
104 `None`.
105 dropout: Float between 0 and 1. Fraction of the units to drop for the
106 linear transformation of the inputs. Default: 0.
107 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
108 for the linear transformation of the recurrent state. Default: 0.
110 Call arguments:
111 inputs: A 2D tensor, with shape of `[batch, feature]`.
112 states: List of 2 tensors that corresponding to the cell's units. Both of
113 them have shape `[batch, units]`, the first tensor is the memory state
114 from previous time step, the second tensor is the carry state from
115 previous time step. For timestep 0, the initial state provided by user
116 will be feed to cell.
117 training: Python boolean indicating whether the layer should behave in
118 training mode or in inference mode. Only relevant when `dropout` or
119 `recurrent_dropout` is used.
120 """
122 def __init__(
123 self,
124 units,
125 activation="tanh",
126 recurrent_activation="sigmoid",
127 use_bias=True,
128 kernel_initializer="glorot_uniform",
129 recurrent_initializer="orthogonal",
130 bias_initializer="zeros",
131 unit_forget_bias=True,
132 kernel_regularizer=None,
133 recurrent_regularizer=None,
134 bias_regularizer=None,
135 kernel_constraint=None,
136 recurrent_constraint=None,
137 bias_constraint=None,
138 dropout=0.0,
139 recurrent_dropout=0.0,
140 **kwargs,
141 ):
142 if units <= 0:
143 raise ValueError(
144 "Received an invalid value for argument `units`, "
145 f"expected a positive integer, got {units}."
146 )
147 # By default use cached variable under v2 mode, see b/143699808.
148 if tf.compat.v1.executing_eagerly_outside_functions():
149 self._enable_caching_device = kwargs.pop(
150 "enable_caching_device", True
151 )
152 else:
153 self._enable_caching_device = kwargs.pop(
154 "enable_caching_device", False
155 )
156 super().__init__(**kwargs)
157 self.units = units
158 self.activation = activations.get(activation)
159 self.recurrent_activation = activations.get(recurrent_activation)
160 self.use_bias = use_bias
162 self.kernel_initializer = initializers.get(kernel_initializer)
163 self.recurrent_initializer = initializers.get(recurrent_initializer)
164 self.bias_initializer = initializers.get(bias_initializer)
165 self.unit_forget_bias = unit_forget_bias
167 self.kernel_regularizer = regularizers.get(kernel_regularizer)
168 self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
169 self.bias_regularizer = regularizers.get(bias_regularizer)
171 self.kernel_constraint = constraints.get(kernel_constraint)
172 self.recurrent_constraint = constraints.get(recurrent_constraint)
173 self.bias_constraint = constraints.get(bias_constraint)
175 self.dropout = min(1.0, max(0.0, dropout))
176 self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
177 implementation = kwargs.pop("implementation", 2)
178 if self.recurrent_dropout != 0 and implementation != 1:
179 logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
180 self.implementation = 1
181 else:
182 self.implementation = implementation
183 self.state_size = [self.units, self.units]
184 self.output_size = self.units
186 @tf_utils.shape_type_conversion
187 def build(self, input_shape):
188 super().build(input_shape)
189 default_caching_device = rnn_utils.caching_device(self)
190 input_dim = input_shape[-1]
191 self.kernel = self.add_weight(
192 shape=(input_dim, self.units * 4),
193 name="kernel",
194 initializer=self.kernel_initializer,
195 regularizer=self.kernel_regularizer,
196 constraint=self.kernel_constraint,
197 caching_device=default_caching_device,
198 )
199 self.recurrent_kernel = self.add_weight(
200 shape=(self.units, self.units * 4),
201 name="recurrent_kernel",
202 initializer=self.recurrent_initializer,
203 regularizer=self.recurrent_regularizer,
204 constraint=self.recurrent_constraint,
205 caching_device=default_caching_device,
206 )
208 if self.use_bias:
209 if self.unit_forget_bias:
211 def bias_initializer(_, *args, **kwargs):
212 return backend.concatenate(
213 [
214 self.bias_initializer(
215 (self.units,), *args, **kwargs
216 ),
217 initializers.get("ones")(
218 (self.units,), *args, **kwargs
219 ),
220 self.bias_initializer(
221 (self.units * 2,), *args, **kwargs
222 ),
223 ]
224 )
226 else:
227 bias_initializer = self.bias_initializer
228 self.bias = self.add_weight(
229 shape=(self.units * 4,),
230 name="bias",
231 initializer=bias_initializer,
232 regularizer=self.bias_regularizer,
233 constraint=self.bias_constraint,
234 caching_device=default_caching_device,
235 )
236 else:
237 self.bias = None
238 self.built = True
240 def _compute_carry_and_output(self, x, h_tm1, c_tm1):
241 """Computes carry and output using split kernels."""
242 x_i, x_f, x_c, x_o = x
243 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
244 i = self.recurrent_activation(
245 x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, : self.units])
246 )
247 f = self.recurrent_activation(
248 x_f
249 + backend.dot(
250 h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2]
251 )
252 )
253 c = f * c_tm1 + i * self.activation(
254 x_c
255 + backend.dot(
256 h_tm1_c,
257 self.recurrent_kernel[:, self.units * 2 : self.units * 3],
258 )
259 )
260 o = self.recurrent_activation(
261 x_o
262 + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :])
263 )
264 return c, o
266 def _compute_carry_and_output_fused(self, z, c_tm1):
267 """Computes carry and output using fused kernels."""
268 z0, z1, z2, z3 = z
269 i = self.recurrent_activation(z0)
270 f = self.recurrent_activation(z1)
271 c = f * c_tm1 + i * self.activation(z2)
272 o = self.recurrent_activation(z3)
273 return c, o
275 def call(self, inputs, states, training=None):
276 h_tm1 = states[0] # previous memory state
277 c_tm1 = states[1] # previous carry state
279 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
280 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
281 h_tm1, training, count=4
282 )
284 if self.implementation == 1:
285 if 0 < self.dropout < 1.0:
286 inputs_i = inputs * dp_mask[0]
287 inputs_f = inputs * dp_mask[1]
288 inputs_c = inputs * dp_mask[2]
289 inputs_o = inputs * dp_mask[3]
290 else:
291 inputs_i = inputs
292 inputs_f = inputs
293 inputs_c = inputs
294 inputs_o = inputs
295 k_i, k_f, k_c, k_o = tf.split(
296 self.kernel, num_or_size_splits=4, axis=1
297 )
298 x_i = backend.dot(inputs_i, k_i)
299 x_f = backend.dot(inputs_f, k_f)
300 x_c = backend.dot(inputs_c, k_c)
301 x_o = backend.dot(inputs_o, k_o)
302 if self.use_bias:
303 b_i, b_f, b_c, b_o = tf.split(
304 self.bias, num_or_size_splits=4, axis=0
305 )
306 x_i = backend.bias_add(x_i, b_i)
307 x_f = backend.bias_add(x_f, b_f)
308 x_c = backend.bias_add(x_c, b_c)
309 x_o = backend.bias_add(x_o, b_o)
311 if 0 < self.recurrent_dropout < 1.0:
312 h_tm1_i = h_tm1 * rec_dp_mask[0]
313 h_tm1_f = h_tm1 * rec_dp_mask[1]
314 h_tm1_c = h_tm1 * rec_dp_mask[2]
315 h_tm1_o = h_tm1 * rec_dp_mask[3]
316 else:
317 h_tm1_i = h_tm1
318 h_tm1_f = h_tm1
319 h_tm1_c = h_tm1
320 h_tm1_o = h_tm1
321 x = (x_i, x_f, x_c, x_o)
322 h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
323 c, o = self._compute_carry_and_output(x, h_tm1, c_tm1)
324 else:
325 if 0.0 < self.dropout < 1.0:
326 inputs = inputs * dp_mask[0]
327 z = backend.dot(inputs, self.kernel)
328 z += backend.dot(h_tm1, self.recurrent_kernel)
329 if self.use_bias:
330 z = backend.bias_add(z, self.bias)
332 z = tf.split(z, num_or_size_splits=4, axis=1)
333 c, o = self._compute_carry_and_output_fused(z, c_tm1)
335 h = o * self.activation(c)
336 return h, [h, c]
338 def get_config(self):
339 config = {
340 "units": self.units,
341 "activation": activations.serialize(self.activation),
342 "recurrent_activation": activations.serialize(
343 self.recurrent_activation
344 ),
345 "use_bias": self.use_bias,
346 "kernel_initializer": initializers.serialize(
347 self.kernel_initializer
348 ),
349 "recurrent_initializer": initializers.serialize(
350 self.recurrent_initializer
351 ),
352 "bias_initializer": initializers.serialize(self.bias_initializer),
353 "unit_forget_bias": self.unit_forget_bias,
354 "kernel_regularizer": regularizers.serialize(
355 self.kernel_regularizer
356 ),
357 "recurrent_regularizer": regularizers.serialize(
358 self.recurrent_regularizer
359 ),
360 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
361 "kernel_constraint": constraints.serialize(self.kernel_constraint),
362 "recurrent_constraint": constraints.serialize(
363 self.recurrent_constraint
364 ),
365 "bias_constraint": constraints.serialize(self.bias_constraint),
366 "dropout": self.dropout,
367 "recurrent_dropout": self.recurrent_dropout,
368 "implementation": self.implementation,
369 }
370 config.update(rnn_utils.config_for_enable_caching_device(self))
371 base_config = super().get_config()
372 return dict(list(base_config.items()) + list(config.items()))
374 def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
375 return list(
376 rnn_utils.generate_zero_filled_state_for_cell(
377 self, inputs, batch_size, dtype
378 )
379 )
382@keras_export("keras.layers.LSTM", v1=[])
383class LSTM(DropoutRNNCellMixin, RNN, base_layer.BaseRandomLayer):
384 """Long Short-Term Memory layer - Hochreiter 1997.
386 See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
387 for details about the usage of RNN API.
389 Based on available runtime hardware and constraints, this layer
390 will choose different implementations (cuDNN-based or pure-TensorFlow)
391 to maximize the performance. If a GPU is available and all
392 the arguments to the layer meet the requirement of the cuDNN kernel
393 (see below for details), the layer will use a fast cuDNN implementation.
395 The requirements to use the cuDNN implementation are:
397 1. `activation` == `tanh`
398 2. `recurrent_activation` == `sigmoid`
399 3. `recurrent_dropout` == 0
400 4. `unroll` is `False`
401 5. `use_bias` is `True`
402 6. Inputs, if use masking, are strictly right-padded.
403 7. Eager execution is enabled in the outermost context.
405 For example:
407 >>> inputs = tf.random.normal([32, 10, 8])
408 >>> lstm = tf.keras.layers.LSTM(4)
409 >>> output = lstm(inputs)
410 >>> print(output.shape)
411 (32, 4)
412 >>> lstm = tf.keras.layers.LSTM(4, return_sequences=True, return_state=True)
413 >>> whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)
414 >>> print(whole_seq_output.shape)
415 (32, 10, 4)
416 >>> print(final_memory_state.shape)
417 (32, 4)
418 >>> print(final_carry_state.shape)
419 (32, 4)
421 Args:
422 units: Positive integer, dimensionality of the output space.
423 activation: Activation function to use.
424 Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation
425 is applied (ie. "linear" activation: `a(x) = x`).
426 recurrent_activation: Activation function to use for the recurrent step.
427 Default: sigmoid (`sigmoid`). If you pass `None`, no activation is
428 applied (ie. "linear" activation: `a(x) = x`).
429 use_bias: Boolean (default `True`), whether the layer uses a bias vector.
430 kernel_initializer: Initializer for the `kernel` weights matrix, used for
431 the linear transformation of the inputs. Default: `glorot_uniform`.
432 recurrent_initializer: Initializer for the `recurrent_kernel` weights
433 matrix, used for the linear transformation of the recurrent state.
434 Default: `orthogonal`.
435 bias_initializer: Initializer for the bias vector. Default: `zeros`.
436 unit_forget_bias: Boolean (default `True`). If True, add 1 to the bias of
437 the forget gate at initialization. Setting it to true will also force
438 `bias_initializer="zeros"`. This is recommended in [Jozefowicz et
439 al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
440 kernel_regularizer: Regularizer function applied to the `kernel` weights
441 matrix. Default: `None`.
442 recurrent_regularizer: Regularizer function applied to the
443 `recurrent_kernel` weights matrix. Default: `None`.
444 bias_regularizer: Regularizer function applied to the bias vector.
445 Default: `None`.
446 activity_regularizer: Regularizer function applied to the output of the
447 layer (its "activation"). Default: `None`.
448 kernel_constraint: Constraint function applied to the `kernel` weights
449 matrix. Default: `None`.
450 recurrent_constraint: Constraint function applied to the
451 `recurrent_kernel` weights matrix. Default: `None`.
452 bias_constraint: Constraint function applied to the bias vector. Default:
453 `None`.
454 dropout: Float between 0 and 1. Fraction of the units to drop for the
455 linear transformation of the inputs. Default: 0.
456 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
457 for the linear transformation of the recurrent state. Default: 0.
458 return_sequences: Boolean. Whether to return the last output in the output
459 sequence, or the full sequence. Default: `False`.
460 return_state: Boolean. Whether to return the last state in addition to the
461 output. Default: `False`.
462 go_backwards: Boolean (default `False`). If True, process the input
463 sequence backwards and return the reversed sequence.
464 stateful: Boolean (default `False`). If True, the last state for each
465 sample at index i in a batch will be used as initial state for the sample
466 of index i in the following batch.
467 time_major: The shape format of the `inputs` and `outputs` tensors.
468 If True, the inputs and outputs will be in shape
469 `[timesteps, batch, feature]`, whereas in the False case, it will be
470 `[batch, timesteps, feature]`. Using `time_major = True` is a bit more
471 efficient because it avoids transposes at the beginning and end of the
472 RNN calculation. However, most TensorFlow data is batch-major, so by
473 default this function accepts input and emits output in batch-major
474 form.
475 unroll: Boolean (default `False`). If True, the network will be unrolled,
476 else a symbolic loop will be used. Unrolling can speed-up a RNN,
477 although it tends to be more memory-intensive. Unrolling is only
478 suitable for short sequences.
480 Call arguments:
481 inputs: A 3D tensor with shape `[batch, timesteps, feature]`.
482 mask: Binary tensor of shape `[batch, timesteps]` indicating whether
483 a given timestep should be masked (optional).
484 An individual `True` entry indicates that the corresponding timestep
485 should be utilized, while a `False` entry indicates that the
486 corresponding timestep should be ignored. Defaults to `None`.
487 training: Python boolean indicating whether the layer should behave in
488 training mode or in inference mode. This argument is passed to the cell
489 when calling it. This is only relevant if `dropout` or
490 `recurrent_dropout` is used (optional). Defaults to `None`.
491 initial_state: List of initial state tensors to be passed to the first
492 call of the cell (optional, `None` causes creation
493 of zero-filled initial state tensors). Defaults to `None`.
494 """
496 def __init__(
497 self,
498 units,
499 activation="tanh",
500 recurrent_activation="sigmoid",
501 use_bias=True,
502 kernel_initializer="glorot_uniform",
503 recurrent_initializer="orthogonal",
504 bias_initializer="zeros",
505 unit_forget_bias=True,
506 kernel_regularizer=None,
507 recurrent_regularizer=None,
508 bias_regularizer=None,
509 activity_regularizer=None,
510 kernel_constraint=None,
511 recurrent_constraint=None,
512 bias_constraint=None,
513 dropout=0.0,
514 recurrent_dropout=0.0,
515 return_sequences=False,
516 return_state=False,
517 go_backwards=False,
518 stateful=False,
519 time_major=False,
520 unroll=False,
521 **kwargs,
522 ):
523 # return_runtime is a flag for testing, which shows the real backend
524 # implementation chosen by grappler in graph mode.
525 self.return_runtime = kwargs.pop("return_runtime", False)
526 implementation = kwargs.pop("implementation", 2)
527 if implementation == 0:
528 logging.warning(
529 "`implementation=0` has been deprecated, "
530 "and now defaults to `implementation=1`."
531 "Please update your layer call."
532 )
533 if "enable_caching_device" in kwargs:
534 cell_kwargs = {
535 "enable_caching_device": kwargs.pop("enable_caching_device")
536 }
537 else:
538 cell_kwargs = {}
539 cell = LSTMCell(
540 units,
541 activation=activation,
542 recurrent_activation=recurrent_activation,
543 use_bias=use_bias,
544 kernel_initializer=kernel_initializer,
545 recurrent_initializer=recurrent_initializer,
546 unit_forget_bias=unit_forget_bias,
547 bias_initializer=bias_initializer,
548 kernel_regularizer=kernel_regularizer,
549 recurrent_regularizer=recurrent_regularizer,
550 bias_regularizer=bias_regularizer,
551 kernel_constraint=kernel_constraint,
552 recurrent_constraint=recurrent_constraint,
553 bias_constraint=bias_constraint,
554 dropout=dropout,
555 recurrent_dropout=recurrent_dropout,
556 implementation=implementation,
557 dtype=kwargs.get("dtype"),
558 trainable=kwargs.get("trainable", True),
559 name="lstm_cell",
560 **cell_kwargs,
561 )
562 super().__init__(
563 cell,
564 return_sequences=return_sequences,
565 return_state=return_state,
566 go_backwards=go_backwards,
567 stateful=stateful,
568 time_major=time_major,
569 unroll=unroll,
570 **kwargs,
571 )
572 self.activity_regularizer = regularizers.get(activity_regularizer)
573 self.input_spec = [InputSpec(ndim=3)]
574 self.state_spec = [
575 InputSpec(shape=(None, dim)) for dim in (self.units, self.units)
576 ]
577 self._could_use_gpu_kernel = (
578 self.activation in (activations.tanh, tf.tanh)
579 and self.recurrent_activation in (activations.sigmoid, tf.sigmoid)
580 and recurrent_dropout == 0
581 and not unroll
582 and use_bias
583 and tf.compat.v1.executing_eagerly_outside_functions()
584 )
585 if tf.config.list_logical_devices("GPU"):
586 # Only show the message when there is GPU available, user will not
587 # care about the cuDNN if there isn't any GPU.
588 if self._could_use_gpu_kernel:
589 logging.debug(gru_lstm_utils.CUDNN_AVAILABLE_MSG % self.name)
590 else:
591 logging.warning(
592 gru_lstm_utils.CUDNN_NOT_AVAILABLE_MSG % self.name
593 )
595 if gru_lstm_utils.use_new_gru_lstm_impl():
596 self._defun_wrapper = gru_lstm_utils.DefunWrapper(
597 time_major, go_backwards, "lstm"
598 )
600 def call(self, inputs, mask=None, training=None, initial_state=None):
601 # The input should be dense, padded with zeros. If a ragged input is fed
602 # into the layer, it is padded and the row lengths are used for masking.
603 inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
604 is_ragged_input = row_lengths is not None
605 self._validate_args_if_ragged(is_ragged_input, mask)
607 # LSTM does not support constants. Ignore it during process.
608 inputs, initial_state, _ = self._process_inputs(
609 inputs, initial_state, None
610 )
612 if isinstance(mask, list):
613 mask = mask[0]
615 input_shape = backend.int_shape(inputs)
616 timesteps = input_shape[0] if self.time_major else input_shape[1]
618 if not self._could_use_gpu_kernel:
619 # Fall back to use the normal LSTM.
620 kwargs = {"training": training}
621 self._maybe_reset_cell_dropout_mask(self.cell)
623 def step(inputs, states):
624 return self.cell(inputs, states, **kwargs)
626 last_output, outputs, states = backend.rnn(
627 step,
628 inputs,
629 initial_state,
630 constants=None,
631 go_backwards=self.go_backwards,
632 mask=mask,
633 unroll=self.unroll,
634 input_length=row_lengths
635 if row_lengths is not None
636 else timesteps,
637 time_major=self.time_major,
638 zero_output_for_mask=self.zero_output_for_mask,
639 return_all_outputs=self.return_sequences,
640 )
641 runtime = gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_UNKNOWN)
642 else:
643 # Use the new defun approach for backend implementation swap.
644 # Note that different implementations need to have same function
645 # signature, eg, the tensor parameters need to have same shape and
646 # dtypes. Since the cuDNN has an extra set of bias, those bias will
647 # be passed to both normal and cuDNN implementations.
648 self.reset_dropout_mask()
649 dropout_mask = self.get_dropout_mask_for_cell(
650 inputs, training, count=4
651 )
652 if dropout_mask is not None:
653 inputs = inputs * dropout_mask[0]
654 if gru_lstm_utils.use_new_gru_lstm_impl():
655 lstm_kwargs = {
656 "inputs": inputs,
657 "init_h": gru_lstm_utils.read_variable_value(
658 initial_state[0]
659 ),
660 "init_c": gru_lstm_utils.read_variable_value(
661 initial_state[1]
662 ),
663 "kernel": gru_lstm_utils.read_variable_value(
664 self.cell.kernel
665 ),
666 "recurrent_kernel": gru_lstm_utils.read_variable_value(
667 self.cell.recurrent_kernel
668 ),
669 "bias": gru_lstm_utils.read_variable_value(self.cell.bias),
670 "mask": mask,
671 "time_major": self.time_major,
672 "go_backwards": self.go_backwards,
673 "sequence_lengths": row_lengths,
674 "zero_output_for_mask": self.zero_output_for_mask,
675 }
676 (
677 last_output,
678 outputs,
679 new_h,
680 new_c,
681 runtime,
682 ) = self._defun_wrapper.defun_layer(**lstm_kwargs)
683 else:
684 gpu_lstm_kwargs = {
685 "inputs": inputs,
686 "init_h": gru_lstm_utils.read_variable_value(
687 initial_state[0]
688 ),
689 "init_c": gru_lstm_utils.read_variable_value(
690 initial_state[1]
691 ),
692 "kernel": gru_lstm_utils.read_variable_value(
693 self.cell.kernel
694 ),
695 "recurrent_kernel": gru_lstm_utils.read_variable_value(
696 self.cell.recurrent_kernel
697 ),
698 "bias": gru_lstm_utils.read_variable_value(self.cell.bias),
699 "mask": mask,
700 "time_major": self.time_major,
701 "go_backwards": self.go_backwards,
702 "sequence_lengths": row_lengths,
703 "return_sequences": self.return_sequences,
704 }
705 normal_lstm_kwargs = gpu_lstm_kwargs.copy()
706 normal_lstm_kwargs.update(
707 {
708 "zero_output_for_mask": self.zero_output_for_mask,
709 }
710 )
712 if tf.executing_eagerly():
713 device_type = gru_lstm_utils.get_context_device_type()
714 can_use_gpu = (
715 # Either user specified GPU or unspecified but GPU is
716 # available.
717 (
718 device_type == gru_lstm_utils.GPU_DEVICE_NAME
719 or (
720 device_type is None
721 and tf.config.list_logical_devices("GPU")
722 )
723 )
724 and gru_lstm_utils.is_cudnn_supported_inputs(
725 mask, self.time_major, row_lengths
726 )
727 )
728 # Under eager context, check the device placement and prefer
729 # the GPU implementation when GPU is available.
730 if can_use_gpu:
731 last_output, outputs, new_h, new_c, runtime = gpu_lstm(
732 **gpu_lstm_kwargs
733 )
734 else:
735 (
736 last_output,
737 outputs,
738 new_h,
739 new_c,
740 runtime,
741 ) = standard_lstm(**normal_lstm_kwargs)
742 else:
743 (
744 last_output,
745 outputs,
746 new_h,
747 new_c,
748 runtime,
749 ) = lstm_with_backend_selection(**normal_lstm_kwargs)
751 states = [new_h, new_c]
753 if self.stateful:
754 updates = [
755 tf.compat.v1.assign(
756 self_state, tf.cast(state, self_state.dtype)
757 )
758 for self_state, state in zip(self.states, states)
759 ]
760 self.add_update(updates)
762 if self.return_sequences:
763 output = backend.maybe_convert_to_ragged(
764 is_ragged_input,
765 outputs,
766 row_lengths,
767 go_backwards=self.go_backwards,
768 )
769 else:
770 output = last_output
772 if self.return_state:
773 return [output] + list(states)
774 elif self.return_runtime:
775 return output, runtime
776 else:
777 return output
779 @property
780 def units(self):
781 return self.cell.units
783 @property
784 def activation(self):
785 return self.cell.activation
787 @property
788 def recurrent_activation(self):
789 return self.cell.recurrent_activation
791 @property
792 def use_bias(self):
793 return self.cell.use_bias
795 @property
796 def kernel_initializer(self):
797 return self.cell.kernel_initializer
799 @property
800 def recurrent_initializer(self):
801 return self.cell.recurrent_initializer
803 @property
804 def bias_initializer(self):
805 return self.cell.bias_initializer
807 @property
808 def unit_forget_bias(self):
809 return self.cell.unit_forget_bias
811 @property
812 def kernel_regularizer(self):
813 return self.cell.kernel_regularizer
815 @property
816 def recurrent_regularizer(self):
817 return self.cell.recurrent_regularizer
819 @property
820 def bias_regularizer(self):
821 return self.cell.bias_regularizer
823 @property
824 def kernel_constraint(self):
825 return self.cell.kernel_constraint
827 @property
828 def recurrent_constraint(self):
829 return self.cell.recurrent_constraint
831 @property
832 def bias_constraint(self):
833 return self.cell.bias_constraint
835 @property
836 def dropout(self):
837 return self.cell.dropout
839 @property
840 def recurrent_dropout(self):
841 return self.cell.recurrent_dropout
843 @property
844 def implementation(self):
845 return self.cell.implementation
847 def get_config(self):
848 config = {
849 "units": self.units,
850 "activation": activations.serialize(self.activation),
851 "recurrent_activation": activations.serialize(
852 self.recurrent_activation
853 ),
854 "use_bias": self.use_bias,
855 "kernel_initializer": initializers.serialize(
856 self.kernel_initializer
857 ),
858 "recurrent_initializer": initializers.serialize(
859 self.recurrent_initializer
860 ),
861 "bias_initializer": initializers.serialize(self.bias_initializer),
862 "unit_forget_bias": self.unit_forget_bias,
863 "kernel_regularizer": regularizers.serialize(
864 self.kernel_regularizer
865 ),
866 "recurrent_regularizer": regularizers.serialize(
867 self.recurrent_regularizer
868 ),
869 "bias_regularizer": regularizers.serialize(self.bias_regularizer),
870 "activity_regularizer": regularizers.serialize(
871 self.activity_regularizer
872 ),
873 "kernel_constraint": constraints.serialize(self.kernel_constraint),
874 "recurrent_constraint": constraints.serialize(
875 self.recurrent_constraint
876 ),
877 "bias_constraint": constraints.serialize(self.bias_constraint),
878 "dropout": self.dropout,
879 "recurrent_dropout": self.recurrent_dropout,
880 "implementation": self.implementation,
881 }
882 config.update(rnn_utils.config_for_enable_caching_device(self.cell))
883 base_config = super().get_config()
884 del base_config["cell"]
885 return dict(list(base_config.items()) + list(config.items()))
887 @classmethod
888 def from_config(cls, config):
889 if "implementation" in config and config["implementation"] == 0:
890 config["implementation"] = 1
891 return cls(**config)
894def standard_lstm(
895 inputs,
896 init_h,
897 init_c,
898 kernel,
899 recurrent_kernel,
900 bias,
901 mask,
902 time_major,
903 go_backwards,
904 sequence_lengths,
905 zero_output_for_mask,
906 return_sequences,
907):
908 """LSTM with standard kernel implementation.
910 This implementation can be run on all types for hardware.
912 This implementation lifts out all the layer weights and make them function
913 parameters. It has same number of tensor input params as the cuDNN
914 counterpart. The RNN step logic has been simplified, eg dropout and mask is
915 removed since cuDNN implementation does not support that.
917 Note that the first half of the bias tensor should be ignored by this impl.
918 The cuDNN impl need an extra set of input gate bias. In order to make the
919 both function take same shape of parameter, that extra set of bias is also
920 feed
921 here.
923 Args:
924 inputs: input tensor of LSTM layer.
925 init_h: initial state tensor for the cell output.
926 init_c: initial state tensor for the cell hidden state.
927 kernel: weights for cell kernel.
928 recurrent_kernel: weights for cell recurrent kernel.
929 bias: weights for cell kernel bias and recurrent bias. Only recurrent bias
930 is used in this case.
931 mask: Boolean tensor for mask out the steps within sequence.
932 An individual `True` entry indicates that the corresponding timestep
933 should be utilized, while a `False` entry indicates that the
934 corresponding timestep should be ignored.
935 time_major: boolean, whether the inputs are in the format of
936 [time, batch, feature] or [batch, time, feature].
937 go_backwards: Boolean (default False). If True, process the input sequence
938 backwards and return the reversed sequence.
939 sequence_lengths: The lengths of all sequences coming from a variable
940 length input, such as ragged tensors. If the input has a fixed timestep
941 size, this should be None.
942 zero_output_for_mask: Boolean, whether to output zero for masked timestep.
943 return_sequences: Boolean. If True, return the recurrent outputs for all
944 timesteps in the sequence. If False, only return the output for the
945 last timestep (which consumes less memory).
947 Returns:
948 last_output: output tensor for the last timestep, which has shape
949 [batch, units].
950 outputs:
951 - If `return_sequences=True`: output tensor for all timesteps,
952 which has shape [batch, time, units].
953 - Else, a tensor equal to `last_output` with shape [batch, 1, units]
954 state_0: the cell output, which has same shape as init_h.
955 state_1: the cell hidden state, which has same shape as init_c.
956 runtime: constant string tensor which indicate real runtime hardware. This
957 value is for testing purpose and should be used by user.
958 """
959 input_shape = backend.int_shape(inputs)
960 timesteps = input_shape[0] if time_major else input_shape[1]
962 def step(cell_inputs, cell_states):
963 """Step function that will be used by Keras RNN backend."""
964 h_tm1 = cell_states[0] # previous memory state
965 c_tm1 = cell_states[1] # previous carry state
967 z = backend.dot(cell_inputs, kernel)
968 z += backend.dot(h_tm1, recurrent_kernel)
969 z = backend.bias_add(z, bias)
971 z0, z1, z2, z3 = tf.split(z, 4, axis=1)
973 i = tf.sigmoid(z0)
974 f = tf.sigmoid(z1)
975 c = f * c_tm1 + i * tf.tanh(z2)
976 o = tf.sigmoid(z3)
978 h = o * tf.tanh(c)
979 return h, [h, c]
981 last_output, outputs, new_states = backend.rnn(
982 step,
983 inputs,
984 [init_h, init_c],
985 constants=None,
986 unroll=False,
987 time_major=time_major,
988 mask=mask,
989 go_backwards=go_backwards,
990 input_length=(
991 sequence_lengths if sequence_lengths is not None else timesteps
992 ),
993 zero_output_for_mask=zero_output_for_mask,
994 return_all_outputs=return_sequences,
995 )
996 return (
997 last_output,
998 outputs,
999 new_states[0],
1000 new_states[1],
1001 gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_CPU),
1002 )
1005def gpu_lstm(
1006 inputs,
1007 init_h,
1008 init_c,
1009 kernel,
1010 recurrent_kernel,
1011 bias,
1012 mask,
1013 time_major,
1014 go_backwards,
1015 sequence_lengths,
1016 return_sequences,
1017):
1018 """LSTM with either cuDNN or ROCm implementation which is only available for
1019 GPU.
1021 Note that currently only right padded data is supported, or the result will
1022 be polluted by the unmasked data which should be filtered.
1024 Args:
1025 inputs: Input tensor of LSTM layer.
1026 init_h: Initial state tensor for the cell output.
1027 init_c: Initial state tensor for the cell hidden state.
1028 kernel: Weights for cell kernel.
1029 recurrent_kernel: Weights for cell recurrent kernel.
1030 bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
1031 is used in this case.
1032 mask: Boolean tensor for mask out the steps within sequence. An individual
1033 `True` entry indicates that the corresponding timestep should be
1034 utilized, while a `False` entry indicates that the corresponding
1035 timestep should be ignored.
1036 time_major: Boolean, whether the inputs are in the format of [time, batch,
1037 feature] or [batch, time, feature].
1038 go_backwards: Boolean (default False). If True, process the input sequence
1039 backwards and return the reversed sequence.
1040 sequence_lengths: The lengths of all sequences coming from a variable
1041 length input, such as ragged tensors. If the input has a fixed timestep
1042 size, this should be None.
1043 return_sequences: Boolean. If True, return the recurrent outputs for all
1044 timesteps in the sequence. If False, only return the output for the
1045 last timestep, matching the CPU function output format.
1047 Returns:
1048 last_output: Output tensor for the last timestep, which has shape
1049 [batch, units].
1050 outputs:
1051 - If `return_sequences=True`: output tensor for all timesteps,
1052 which has shape [batch, time, units].
1053 - Else, a tensor equal to `last_output` with shape [batch, 1, units]
1054 state_0: The cell output, which has same shape as init_h.
1055 state_1: The cell hidden state, which has same shape as init_c.
1056 runtime: Constant string tensor which indicate real runtime hardware. This
1057 value is for testing purpose and should not be used by user.
1058 """
1059 if mask is not None:
1060 sequence_lengths = gru_lstm_utils.calculate_sequence_by_mask(
1061 mask, time_major
1062 )
1064 if not time_major and sequence_lengths is None:
1065 inputs = tf.transpose(inputs, perm=(1, 0, 2))
1066 seq_axis, batch_axis = (0, 1)
1067 else:
1068 seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
1069 # For init_h and init_c, cuDNN expects one more dim of num_layers before or
1070 # after batch dim for time major or batch major inputs respectively
1071 init_h = tf.expand_dims(init_h, axis=seq_axis)
1072 init_c = tf.expand_dims(init_c, axis=seq_axis)
1074 weights = tf.split(kernel, 4, axis=1)
1075 weights += tf.split(recurrent_kernel, 4, axis=1)
1076 # cuDNN has an extra set of bias for inputs, we disable them (setting to 0),
1077 # so that mathematically it is same as the canonical LSTM implementation.
1078 full_bias = tf.concat((tf.zeros_like(bias), bias), 0)
1080 if tf.sysconfig.get_build_info()["is_rocm_build"]:
1081 # ROCm MIOpen's weight sequence for LSTM is different from both
1082 # canonical and Cudnn format
1083 # MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o]
1084 # i is input gate weights.
1085 # f is forget gate weights.
1086 # o is output gate weights.
1087 # c is cell gate weights.
1088 weights = [weights[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
1089 # full_bias is a tensor of shape (8*n,)
1090 full_bias = tf.split(full_bias, 8, axis=0)
1091 full_bias = [full_bias[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]
1093 params = gru_lstm_utils.canonical_to_params(
1094 weights=weights,
1095 biases=tf.split(full_bias, 8),
1096 shape=tf.constant([-1]),
1097 transpose_weights=True,
1098 )
1100 if sequence_lengths is not None:
1101 if go_backwards:
1102 # Three reversals are required. E.g.,
1103 # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
1104 # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
1105 # output_from_cudnn = [6, 5, 4, 0, 0]
1106 # expected_output = [0, 0, 6, 5 ,4]
1107 inputs = tf.reverse_sequence(
1108 inputs,
1109 sequence_lengths,
1110 seq_axis=seq_axis,
1111 batch_axis=batch_axis,
1112 )
1113 outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(
1114 input=inputs,
1115 input_h=init_h,
1116 input_c=init_c,
1117 params=params,
1118 is_training=True,
1119 rnn_mode="lstm",
1120 sequence_lengths=sequence_lengths,
1121 time_major=time_major,
1122 )
1123 if go_backwards:
1124 outputs = tf.reverse_sequence(
1125 outputs,
1126 sequence_lengths,
1127 seq_axis=seq_axis,
1128 batch_axis=batch_axis,
1129 )
1130 outputs = tf.reverse(outputs, axis=[seq_axis])
1131 else:
1132 # # Fill the array with shape [batch] with value of max timesteps.
1133 # sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
1134 # array_ops.shape(inputs)[0])
1135 if go_backwards:
1136 # Reverse axis 0 since the input is already convert to time major.
1137 inputs = tf.reverse(inputs, axis=[0])
1138 outputs, h, c, _ = tf.raw_ops.CudnnRNN(
1139 input=inputs,
1140 input_h=init_h,
1141 input_c=init_c,
1142 params=params,
1143 is_training=True,
1144 rnn_mode="lstm",
1145 )
1147 last_output = outputs[-1]
1148 if not time_major and sequence_lengths is None and return_sequences:
1149 outputs = tf.transpose(outputs, perm=[1, 0, 2])
1150 h = tf.squeeze(h, axis=seq_axis)
1151 c = tf.squeeze(c, axis=seq_axis)
1153 # In the case of variable length input, the cudnn kernel will fill zeros for
1154 # the output, whereas the default keras behavior is to bring over the
1155 # previous output for t-1, so that in the return_sequence=False case, user
1156 # can quickly get the final effect output instead just 0s at the last
1157 # timestep. In order to mimic the default keras behavior, we copy the final
1158 # h state as the last_output, since it is numerically same as the output.
1159 if sequence_lengths is not None:
1160 last_output = h
1162 # Match CPU return format
1163 if not return_sequences:
1164 outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)
1166 return (
1167 last_output,
1168 outputs,
1169 h,
1170 c,
1171 gru_lstm_utils.runtime(gru_lstm_utils.RUNTIME_GPU),
1172 )
1175def lstm_with_backend_selection(
1176 inputs,
1177 init_h,
1178 init_c,
1179 kernel,
1180 recurrent_kernel,
1181 bias,
1182 mask,
1183 time_major,
1184 go_backwards,
1185 sequence_lengths,
1186 zero_output_for_mask,
1187 return_sequences,
1188):
1189 """Call the LSTM with optimized backend kernel selection.
1191 Under the hood, this function will create two TF function, one with the most
1192 generic kernel and can run on all device condition, and the second one with
1193 cuDNN specific kernel, which can only run on GPU.
1195 The first function will be called with normal_lstm_params, while the second
1196 function is not called, but only registered in the graph. The Grappler will
1197 do the proper graph rewrite and swap the optimized TF function based on the
1198 device placement.
1200 Args:
1201 inputs: Input tensor of LSTM layer.
1202 init_h: Initial state tensor for the cell output.
1203 init_c: Initial state tensor for the cell hidden state.
1204 kernel: Weights for cell kernel.
1205 recurrent_kernel: Weights for cell recurrent kernel.
1206 bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
1207 is used in this case.
1208 mask: Boolean tensor for mask out the steps within sequence.
1209 An individual `True` entry indicates that the corresponding timestep
1210 should be utilized, while a `False` entry indicates that the
1211 corresponding timestep should be ignored.
1212 time_major: Boolean, whether the inputs are in the format of
1213 [time, batch, feature] or [batch, time, feature].
1214 go_backwards: Boolean (default False). If True, process the input sequence
1215 backwards and return the reversed sequence.
1216 sequence_lengths: The lengths of all sequences coming from a variable
1217 length input, such as ragged tensors. If the input has a fixed timestep
1218 size, this should be None.
1219 zero_output_for_mask: Boolean, whether to output zero for masked timestep.
1220 return_sequences: Boolean. If True, return the recurrent outputs for all
1221 timesteps in the sequence. If False, only return the output for the
1222 last timestep (which consumes less memory).
1224 Returns:
1225 List of output tensors, same as standard_lstm.
1226 """
1227 params = {
1228 "inputs": inputs,
1229 "init_h": init_h,
1230 "init_c": init_c,
1231 "kernel": kernel,
1232 "recurrent_kernel": recurrent_kernel,
1233 "bias": bias,
1234 "mask": mask,
1235 "time_major": time_major,
1236 "go_backwards": go_backwards,
1237 "sequence_lengths": sequence_lengths,
1238 "zero_output_for_mask": zero_output_for_mask,
1239 "return_sequences": return_sequences,
1240 }
1242 def gpu_lstm_with_fallback(
1243 inputs,
1244 init_h,
1245 init_c,
1246 kernel,
1247 recurrent_kernel,
1248 bias,
1249 mask,
1250 time_major,
1251 go_backwards,
1252 sequence_lengths,
1253 zero_output_for_mask,
1254 return_sequences,
1255 ):
1256 """Use cuDNN kernel when mask is none or strictly right padded."""
1258 def cudnn_lstm_fn():
1259 return gpu_lstm(
1260 inputs=inputs,
1261 init_h=init_h,
1262 init_c=init_c,
1263 kernel=kernel,
1264 recurrent_kernel=recurrent_kernel,
1265 bias=bias,
1266 mask=mask,
1267 time_major=time_major,
1268 go_backwards=go_backwards,
1269 sequence_lengths=sequence_lengths,
1270 return_sequences=return_sequences,
1271 )
1273 def stardard_lstm_fn():
1274 return standard_lstm(
1275 inputs=inputs,
1276 init_h=init_h,
1277 init_c=init_c,
1278 kernel=kernel,
1279 recurrent_kernel=recurrent_kernel,
1280 bias=bias,
1281 mask=mask,
1282 time_major=time_major,
1283 go_backwards=go_backwards,
1284 sequence_lengths=sequence_lengths,
1285 zero_output_for_mask=zero_output_for_mask,
1286 return_sequences=return_sequences,
1287 )
1289 return tf.__internal__.smart_cond.smart_cond(
1290 gru_lstm_utils.is_cudnn_supported_inputs(
1291 mask, time_major, sequence_lengths
1292 ),
1293 true_fn=cudnn_lstm_fn,
1294 false_fn=stardard_lstm_fn,
1295 )
1297 if gru_lstm_utils.use_new_gru_lstm_impl():
1298 # Chooses the implementation dynamically based on the running device.
1299 (
1300 last_output,
1301 outputs,
1302 new_h,
1303 new_c,
1304 runtime,
1305 ) = tf.__internal__.execute_fn_for_device(
1306 {
1307 gru_lstm_utils.CPU_DEVICE_NAME: lambda: standard_lstm(**params),
1308 gru_lstm_utils.GPU_DEVICE_NAME: lambda: gpu_lstm_with_fallback(
1309 **params
1310 ),
1311 },
1312 lambda: standard_lstm(**params),
1313 )
1314 else:
1315 # Each time a `tf.function` is called, we will give it a unique
1316 # identifiable API name, so that Grappler won't get confused when it
1317 # sees multiple LSTM layers added into same graph, and it will be able
1318 # to pair up the different implementations across them.
1319 api_name = "lstm_" + str(uuid.uuid4())
1320 supportive_attribute = {
1321 "time_major": time_major,
1322 "go_backwards": go_backwards,
1323 }
1324 defun_standard_lstm = gru_lstm_utils.generate_defun_backend(
1325 api_name,
1326 gru_lstm_utils.CPU_DEVICE_NAME,
1327 standard_lstm,
1328 supportive_attribute,
1329 )
1330 defun_gpu_lstm = gru_lstm_utils.generate_defun_backend(
1331 api_name,
1332 gru_lstm_utils.GPU_DEVICE_NAME,
1333 gpu_lstm_with_fallback,
1334 supportive_attribute,
1335 )
1337 # Call the normal LSTM impl and register the cuDNN impl function. The
1338 # grappler will kick in during session execution to optimize the graph.
1339 last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
1340 **params
1341 )
1342 gru_lstm_utils.function_register(defun_gpu_lstm, **params)
1344 return last_output, outputs, new_h, new_c, runtime