Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/rnn/peephole_lstm_cell.py: 29%

24 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implements PeepholeLSTM Cell.""" 

16 

17import tensorflow as tf 

18 

19 

20@tf.keras.utils.register_keras_serializable(package="Addons") 

21class PeepholeLSTMCell(tf.keras.layers.LSTMCell): 

22 """Equivalent to `tf.keras.layers.LSTMCell` class but adds peephole connections. 

23 

24 Peephole connections allow the gates to utilize the previous internal state as 

25 well as the previous hidden state (which is what LSTMCell is limited to). 

26 This allows PeepholeLSTMCell to better learn precise timings over LSTMCell. 

27 

28 From [Gers et al., 2002]( 

29 http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf): 

30 

31 "We find that LSTM augmented by 'peephole connections' from its internal 

32 cells to its multiplicative gates can learn the fine distinction between 

33 sequences of spikes spaced either 50 or 49 time steps apart without the help 

34 of any short training exemplars." 

35 

36 The peephole implementation is based on: 

37 

38 [Sak et al., 2014](https://research.google.com/pubs/archive/43905.pdf) 

39 

40 Example: 

41 

42 >>> inputs = np.random.random([30,23,9]).astype(np.float32) 

43 >>> LSTMCell = tfa.rnn.PeepholeLSTMCell(4) 

44 >>> rnn = tf.keras.layers.RNN(LSTMCell, return_sequences=True, return_state=True) 

45 >>> outputs, memory_state, carry_state = rnn(inputs) 

46 >>> outputs.shape 

47 TensorShape([30, 23, 4]) 

48 >>> memory_state.shape 

49 TensorShape([30, 4]) 

50 >>> carry_state.shape 

51 TensorShape([30, 4]) 

52 """ 

53 

54 def build(self, input_shape): 

55 super().build(input_shape) 

56 # The following are the weight matrices for the peephole connections. These 

57 # are multiplied with the previous internal state during the computation of 

58 # carry and output. 

59 self.input_gate_peephole_weights = self.add_weight( 

60 shape=(self.units,), 

61 name="input_gate_peephole_weights", 

62 initializer=self.kernel_initializer, 

63 ) 

64 self.forget_gate_peephole_weights = self.add_weight( 

65 shape=(self.units,), 

66 name="forget_gate_peephole_weights", 

67 initializer=self.kernel_initializer, 

68 ) 

69 self.output_gate_peephole_weights = self.add_weight( 

70 shape=(self.units,), 

71 name="output_gate_peephole_weights", 

72 initializer=self.kernel_initializer, 

73 ) 

74 

75 def _compute_carry_and_output(self, x, h_tm1, c_tm1): 

76 x_i, x_f, x_c, x_o = x 

77 h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 

78 i = self.recurrent_activation( 

79 x_i 

80 + tf.keras.backend.dot(h_tm1_i, self.recurrent_kernel[:, : self.units]) 

81 + self.input_gate_peephole_weights * c_tm1 

82 ) 

83 f = self.recurrent_activation( 

84 x_f 

85 + tf.keras.backend.dot( 

86 h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2] 

87 ) 

88 + self.forget_gate_peephole_weights * c_tm1 

89 ) 

90 c = f * c_tm1 + i * self.activation( 

91 x_c 

92 + tf.keras.backend.dot( 

93 h_tm1_c, self.recurrent_kernel[:, self.units * 2 : self.units * 3] 

94 ) 

95 ) 

96 o = self.recurrent_activation( 

97 x_o 

98 + tf.keras.backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :]) 

99 + self.output_gate_peephole_weights * c 

100 ) 

101 return c, o 

102 

103 def _compute_carry_and_output_fused(self, z, c_tm1): 

104 z0, z1, z2, z3 = z 

105 i = self.recurrent_activation(z0 + self.input_gate_peephole_weights * c_tm1) 

106 f = self.recurrent_activation(z1 + self.forget_gate_peephole_weights * c_tm1) 

107 c = f * c_tm1 + i * self.activation(z2) 

108 o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c) 

109 return c, o