Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/rnn/peephole_lstm_cell.py: 29%
24 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements PeepholeLSTM Cell."""
17import tensorflow as tf
20@tf.keras.utils.register_keras_serializable(package="Addons")
21class PeepholeLSTMCell(tf.keras.layers.LSTMCell):
22 """Equivalent to `tf.keras.layers.LSTMCell` class but adds peephole connections.
24 Peephole connections allow the gates to utilize the previous internal state as
25 well as the previous hidden state (which is what LSTMCell is limited to).
26 This allows PeepholeLSTMCell to better learn precise timings over LSTMCell.
28 From [Gers et al., 2002](
29 http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf):
31 "We find that LSTM augmented by 'peephole connections' from its internal
32 cells to its multiplicative gates can learn the fine distinction between
33 sequences of spikes spaced either 50 or 49 time steps apart without the help
34 of any short training exemplars."
36 The peephole implementation is based on:
38 [Sak et al., 2014](https://research.google.com/pubs/archive/43905.pdf)
40 Example:
42 >>> inputs = np.random.random([30,23,9]).astype(np.float32)
43 >>> LSTMCell = tfa.rnn.PeepholeLSTMCell(4)
44 >>> rnn = tf.keras.layers.RNN(LSTMCell, return_sequences=True, return_state=True)
45 >>> outputs, memory_state, carry_state = rnn(inputs)
46 >>> outputs.shape
47 TensorShape([30, 23, 4])
48 >>> memory_state.shape
49 TensorShape([30, 4])
50 >>> carry_state.shape
51 TensorShape([30, 4])
52 """
54 def build(self, input_shape):
55 super().build(input_shape)
56 # The following are the weight matrices for the peephole connections. These
57 # are multiplied with the previous internal state during the computation of
58 # carry and output.
59 self.input_gate_peephole_weights = self.add_weight(
60 shape=(self.units,),
61 name="input_gate_peephole_weights",
62 initializer=self.kernel_initializer,
63 )
64 self.forget_gate_peephole_weights = self.add_weight(
65 shape=(self.units,),
66 name="forget_gate_peephole_weights",
67 initializer=self.kernel_initializer,
68 )
69 self.output_gate_peephole_weights = self.add_weight(
70 shape=(self.units,),
71 name="output_gate_peephole_weights",
72 initializer=self.kernel_initializer,
73 )
75 def _compute_carry_and_output(self, x, h_tm1, c_tm1):
76 x_i, x_f, x_c, x_o = x
77 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
78 i = self.recurrent_activation(
79 x_i
80 + tf.keras.backend.dot(h_tm1_i, self.recurrent_kernel[:, : self.units])
81 + self.input_gate_peephole_weights * c_tm1
82 )
83 f = self.recurrent_activation(
84 x_f
85 + tf.keras.backend.dot(
86 h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2]
87 )
88 + self.forget_gate_peephole_weights * c_tm1
89 )
90 c = f * c_tm1 + i * self.activation(
91 x_c
92 + tf.keras.backend.dot(
93 h_tm1_c, self.recurrent_kernel[:, self.units * 2 : self.units * 3]
94 )
95 )
96 o = self.recurrent_activation(
97 x_o
98 + tf.keras.backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :])
99 + self.output_gate_peephole_weights * c
100 )
101 return c, o
103 def _compute_carry_and_output_fused(self, z, c_tm1):
104 z0, z1, z2, z3 = z
105 i = self.recurrent_activation(z0 + self.input_gate_peephole_weights * c_tm1)
106 f = self.recurrent_activation(z1 + self.forget_gate_peephole_weights * c_tm1)
107 c = f * c_tm1 + i * self.activation(z2)
108 o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c)
109 return c, o