Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/base_cudnn_rnn.py: 26%

69 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base class for recurrent layers backed by cuDNN.""" 

16 

17 

18import tensorflow.compat.v2 as tf 

19 

20from keras.src import backend 

21from keras.src.engine.input_spec import InputSpec 

22from keras.src.layers.rnn.base_rnn import RNN 

23 

24 

25class _CuDNNRNN(RNN): 

26 """Private base class for CuDNNGRU and CuDNNLSTM layers. 

27 

28 Args: 

29 return_sequences: Boolean. Whether to return the last output 

30 in the output sequence, or the full sequence. 

31 return_state: Boolean. Whether to return the last state 

32 in addition to the output. 

33 go_backwards: Boolean (default False). 

34 If True, process the input sequence backwards and return the 

35 reversed sequence. 

36 stateful: Boolean (default False). If True, the last state 

37 for each sample at index i in a batch will be used as initial 

38 state for the sample of index i in the following batch. 

39 time_major: Boolean (default False). If true, the inputs and outputs will 

40 be in shape `(timesteps, batch, ...)`, whereas in the False case, it 

41 will be `(batch, timesteps, ...)`. 

42 """ 

43 

44 def __init__( 

45 self, 

46 return_sequences=False, 

47 return_state=False, 

48 go_backwards=False, 

49 stateful=False, 

50 time_major=False, 

51 **kwargs 

52 ): 

53 # We invoke the base layer's initializer directly here because we do not 

54 # want to create RNN cell instance. 

55 super(RNN, self).__init__(**kwargs) 

56 self.return_sequences = return_sequences 

57 self.return_state = return_state 

58 self.go_backwards = go_backwards 

59 self.stateful = stateful 

60 self.time_major = time_major 

61 self.supports_masking = False 

62 self.input_spec = [InputSpec(ndim=3)] 

63 if hasattr(self.cell.state_size, "__len__"): 

64 state_size = self.cell.state_size 

65 else: 

66 state_size = [self.cell.state_size] 

67 self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size] 

68 self.constants_spec = None 

69 self._states = None 

70 self._num_constants = 0 

71 self._vector_shape = tf.constant([-1]) 

72 

73 def call(self, inputs, mask=None, training=None, initial_state=None): 

74 if isinstance(mask, list): 

75 mask = mask[0] 

76 if mask is not None: 

77 raise ValueError("Masking is not supported for CuDNN RNNs.") 

78 

79 # input shape: `(samples, time (padded with zeros), input_dim)` 

80 # note that the .build() method of subclasses MUST define 

81 # self.input_spec and self.state_spec with complete input shapes. 

82 if isinstance(inputs, list): 

83 initial_state = inputs[1:] 

84 inputs = inputs[0] 

85 elif initial_state is not None: 

86 pass 

87 elif self.stateful: 

88 initial_state = self.states 

89 else: 

90 initial_state = self.get_initial_state(inputs) 

91 

92 if len(initial_state) != len(self.states): 

93 raise ValueError( 

94 "Layer has " 

95 + str(len(self.states)) 

96 + " states but was passed " 

97 + str(len(initial_state)) 

98 + " initial states." 

99 ) 

100 

101 if self.go_backwards: 

102 # Reverse time axis. 

103 inputs = backend.reverse(inputs, 1) 

104 output, states = self._process_batch(inputs, initial_state) 

105 

106 if self.stateful: 

107 updates = [ 

108 tf.compat.v1.assign(self_state, state) 

109 for self_state, state in zip(self.states, states) 

110 ] 

111 self.add_update(updates) 

112 

113 if self.return_state: 

114 return [output] + states 

115 else: 

116 return output 

117 

118 def get_config(self): 

119 config = { 

120 "return_sequences": self.return_sequences, 

121 "return_state": self.return_state, 

122 "go_backwards": self.go_backwards, 

123 "stateful": self.stateful, 

124 "time_major": self.time_major, 

125 } 

126 base_config = super(RNN, self).get_config() 

127 return dict(list(base_config.items()) + list(config.items())) 

128 

129 @classmethod 

130 def from_config(cls, config): 

131 return cls(**config) 

132 

133 @property 

134 def trainable_weights(self): 

135 if self.trainable and self.built: 

136 return [self.kernel, self.recurrent_kernel, self.bias] 

137 return [] 

138 

139 @property 

140 def non_trainable_weights(self): 

141 if not self.trainable and self.built: 

142 return [self.kernel, self.recurrent_kernel, self.bias] 

143 return [] 

144 

145 @property 

146 def losses(self): 

147 return super(RNN, self).losses 

148 

149 def get_losses_for(self, inputs=None): 

150 return super(RNN, self).get_losses_for(inputs=inputs) 

151