Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/rnn/nas_cell.py: 23%

77 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implements NAS Cell.""" 

16 

17import tensorflow as tf 

18from typeguard import typechecked 

19 

20from tensorflow_addons.rnn.abstract_rnn_cell import AbstractRNNCell 

21from tensorflow_addons.utils.types import ( 

22 FloatTensorLike, 

23 TensorLike, 

24 Initializer, 

25) 

26from typing import Optional 

27 

28 

29@tf.keras.utils.register_keras_serializable(package="Addons") 

30class NASCell(AbstractRNNCell): 

31 """Neural Architecture Search (NAS) recurrent network cell. 

32 

33 This implements the recurrent cell from the paper: 

34 

35 https://arxiv.org/abs/1611.01578 

36 

37 Barret Zoph and Quoc V. Le. 

38 "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017. 

39 

40 The class uses an optional projection layer. 

41 

42 Example: 

43 

44 >>> inputs = np.random.random([30,23,9]).astype(np.float32) 

45 >>> NASCell = tfa.rnn.NASCell(4) 

46 >>> rnn = tf.keras.layers.RNN(NASCell, return_sequences=True, return_state=True) 

47 >>> outputs, memory_state, carry_state = rnn(inputs) 

48 >>> outputs.shape 

49 TensorShape([30, 23, 4]) 

50 >>> memory_state.shape 

51 TensorShape([30, 4]) 

52 >>> carry_state.shape 

53 TensorShape([30, 4]) 

54 """ 

55 

56 # NAS cell's architecture base. 

57 _NAS_BASE = 8 

58 

59 @typechecked 

60 def __init__( 

61 self, 

62 units: TensorLike, 

63 projection: Optional[FloatTensorLike] = None, 

64 use_bias: bool = False, 

65 kernel_initializer: Initializer = "glorot_uniform", 

66 recurrent_initializer: Initializer = "glorot_uniform", 

67 projection_initializer: Initializer = "glorot_uniform", 

68 bias_initializer: Initializer = "zeros", 

69 **kwargs, 

70 ): 

71 """Initialize the parameters for a NAS cell. 

72 

73 Args: 

74 units: int, The number of units in the NAS cell. 

75 projection: (optional) int, The output dimensionality for the 

76 projection matrices. If None, no projection is performed. 

77 use_bias: (optional) bool, If True then use biases within the cell. 

78 This is False by default. 

79 kernel_initializer: Initializer for kernel weight. 

80 recurrent_initializer: Initializer for recurrent kernel weight. 

81 projection_initializer: Initializer for projection weight, used when 

82 projection is not None. 

83 bias_initializer: Initializer for bias, used when use_bias is True. 

84 **kwargs: Additional keyword arguments. 

85 """ 

86 super().__init__(**kwargs) 

87 self.units = units 

88 self.projection = projection 

89 self.use_bias = use_bias 

90 self.kernel_initializer = kernel_initializer 

91 self.recurrent_initializer = recurrent_initializer 

92 self.projection_initializer = projection_initializer 

93 self.bias_initializer = bias_initializer 

94 

95 if projection is not None: 

96 self._state_size = [units, projection] 

97 self._output_size = projection 

98 else: 

99 self._state_size = [units, units] 

100 self._output_size = units 

101 

102 @property 

103 def state_size(self): 

104 return self._state_size 

105 

106 @property 

107 def output_size(self): 

108 return self._output_size 

109 

110 def build(self, inputs_shape): 

111 input_size = tf.compat.dimension_value( 

112 tf.TensorShape(inputs_shape).with_rank(2)[1] 

113 ) 

114 if input_size is None: 

115 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 

116 

117 # Variables for the NAS cell. `recurrent_kernel` is all matrices 

118 # multiplying the hidden state and `kernel` is all matrices multiplying 

119 # the inputs. 

120 self.recurrent_kernel = self.add_weight( 

121 name="recurrent_kernel", 

122 shape=[self.output_size, self._NAS_BASE * self.units], 

123 initializer=self.recurrent_initializer, 

124 ) 

125 self.kernel = self.add_weight( 

126 name="kernel", 

127 shape=[input_size, self._NAS_BASE * self.units], 

128 initializer=self.kernel_initializer, 

129 ) 

130 

131 if self.use_bias: 

132 self.bias = self.add_weight( 

133 name="bias", 

134 shape=[self._NAS_BASE * self.units], 

135 initializer=self.bias_initializer, 

136 ) 

137 # Projection layer if specified 

138 if self.projection is not None: 

139 self.projection_weights = self.add_weight( 

140 name="projection_weights", 

141 shape=[self.units, self.projection], 

142 initializer=self.projection_initializer, 

143 ) 

144 

145 self.built = True 

146 

147 def call(self, inputs, state): 

148 """Run one step of NAS Cell. 

149 

150 Args: 

151 inputs: input Tensor, 2D, batch x num_units. 

152 state: This must be a list of state Tensors, both `2-D`, with column 

153 sizes `c_state` and `m_state`. 

154 

155 Returns: 

156 A tuple containing: 

157 - A `2-D, [batch x output_dim]`, Tensor representing the output of 

158 the NAS Cell after reading `inputs` when previous state was 

159 `state`. 

160 Here output_dim is: 

161 projection if projection was set, units otherwise. 

162 - Tensor(s) representing the new state of NAS Cell after reading 

163 `inputs` when the previous state was `state`. Same type and 

164 shape(s) as `state`. 

165 

166 Raises: 

167 ValueError: If input size cannot be inferred from inputs via 

168 static shape inference. 

169 """ 

170 sigmoid = tf.math.sigmoid 

171 tanh = tf.math.tanh 

172 relu = tf.nn.relu 

173 

174 c_prev, m_prev = state 

175 

176 m_matrix = tf.matmul(m_prev, self.recurrent_kernel) 

177 inputs_matrix = tf.matmul(inputs, self.kernel) 

178 

179 if self.use_bias: 

180 m_matrix = tf.nn.bias_add(m_matrix, self.bias) 

181 

182 # The NAS cell branches into 8 different splits for both the hidden 

183 # state and the input 

184 m_matrix_splits = tf.split( 

185 axis=1, num_or_size_splits=self._NAS_BASE, value=m_matrix 

186 ) 

187 inputs_matrix_splits = tf.split( 

188 axis=1, num_or_size_splits=self._NAS_BASE, value=inputs_matrix 

189 ) 

190 

191 # First layer 

192 layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) 

193 layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1]) 

194 layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2]) 

195 layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3]) 

196 layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4]) 

197 layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5]) 

198 layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6]) 

199 layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7]) 

200 

201 # Second layer 

202 l2_0 = tanh(layer1_0 * layer1_1) 

203 l2_1 = tanh(layer1_2 + layer1_3) 

204 l2_2 = tanh(layer1_4 * layer1_5) 

205 l2_3 = sigmoid(layer1_6 + layer1_7) 

206 

207 # Inject the cell 

208 l2_0 = tanh(l2_0 + c_prev) 

209 

210 # Third layer 

211 l3_0_pre = l2_0 * l2_1 

212 new_c = l3_0_pre # create new cell 

213 l3_0 = l3_0_pre 

214 l3_1 = tanh(l2_2 + l2_3) 

215 

216 # Final layer 

217 new_m = tanh(l3_0 * l3_1) 

218 

219 # Projection layer if specified 

220 if self.projection is not None: 

221 new_m = tf.matmul(new_m, self.projection_weights) 

222 

223 return new_m, [new_c, new_m] 

224 

225 def get_config(self): 

226 config = { 

227 "units": self.units, 

228 "projection": self.projection, 

229 "use_bias": self.use_bias, 

230 "kernel_initializer": self.kernel_initializer, 

231 "recurrent_initializer": self.recurrent_initializer, 

232 "bias_initializer": self.bias_initializer, 

233 "projection_initializer": self.projection_initializer, 

234 } 

235 base_config = super().get_config() 

236 return {**base_config, **config}