Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/noisy_dense.py: 28%

71 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16import tensorflow as tf 

17from tensorflow.keras import ( 

18 activations, 

19 initializers, 

20 regularizers, 

21 constraints, 

22) 

23from tensorflow.keras import backend as K 

24from tensorflow.keras.layers import InputSpec 

25from typeguard import typechecked 

26 

27from tensorflow_addons.utils import types 

28 

29 

30def _scaled_noise(size, dtype): 

31 x = tf.random.normal(shape=size, dtype=dtype) 

32 return tf.sign(x) * tf.sqrt(tf.abs(x)) 

33 

34 

35@tf.keras.utils.register_keras_serializable(package="Addons") 

36class NoisyDense(tf.keras.layers.Dense): 

37 r"""Noisy dense layer that injects random noise to the weights of dense layer. 

38 

39 Noisy dense layers are fully connected layers whose weights and biases are 

40 augmented by factorised Gaussian noise. The factorised Gaussian noise is 

41 controlled through gradient descent by a second weights layer. 

42 

43 A `NoisyDense` layer implements the operation: 

44 $$ 

45 \mathrm{NoisyDense}(x) = 

46 \mathrm{activation}(\mathrm{dot}(x, \mu + (\sigma \cdot \epsilon)) 

47 + \mathrm{bias}) 

48 $$ 

49 where $\mu$ is the standard weights layer, $\epsilon$ is the factorised 

50 Gaussian noise, and $\sigma$ is a second weights layer which controls 

51 $\epsilon$. 

52 

53 Note: bias only added if `use_bias` is `True`. 

54 

55 Example: 

56 

57 >>> # Create a `Sequential` model and add a NoisyDense 

58 >>> # layer as the first layer. 

59 >>> model = tf.keras.models.Sequential() 

60 >>> model.add(tf.keras.Input(shape=(16,))) 

61 >>> model.add(NoisyDense(32, activation='relu')) 

62 >>> # Now the model will take as input arrays of shape (None, 16) 

63 >>> # and output arrays of shape (None, 32). 

64 >>> # Note that after the first layer, you don't need to specify 

65 >>> # the size of the input anymore: 

66 >>> model.add(NoisyDense(32)) 

67 >>> model.output_shape 

68 (None, 32) 

69 

70 There are implemented both variants: 

71 1. Independent Gaussian noise 

72 2. Factorised Gaussian noise. 

73 We can choose between that by 'use_factorised' parameter. 

74 

75 Args: 

76 units: Positive integer, dimensionality of the output space. 

77 sigma: A float between 0-1 used as a standard deviation figure and is 

78 applied to the gaussian noise layer (`sigma_kernel` and `sigma_bias`). (uses only if use_factorised=True) 

79 use_factorised: Boolean, whether the layer uses independent or factorised Gaussian noise 

80 activation: Activation function to use. 

81 If you don't specify anything, no activation is applied 

82 (ie. "linear" activation: `a(x) = x`). 

83 use_bias: Boolean, whether the layer uses a bias vector. 

84 kernel_regularizer: Regularizer function applied to 

85 the `kernel` weights matrix. 

86 bias_regularizer: Regularizer function applied to the bias vector. 

87 activity_regularizer: Regularizer function applied to 

88 the output of the layer (its "activation"). 

89 kernel_constraint: Constraint function applied to 

90 the `kernel` weights matrix. 

91 bias_constraint: Constraint function applied to the bias vector. 

92 

93 Input shape: 

94 N-D tensor with shape: `(batch_size, ..., input_dim)`. 

95 The most common situation would be 

96 a 2D input with shape `(batch_size, input_dim)`. 

97 

98 Output shape: 

99 N-D tensor with shape: `(batch_size, ..., units)`. 

100 For instance, for a 2D input with shape `(batch_size, input_dim)`, 

101 the output would have shape `(batch_size, units)`. 

102 

103 References: 

104 - [Noisy Networks for Explanation](https://arxiv.org/pdf/1706.10295.pdf) 

105 """ 

106 

107 @typechecked 

108 def __init__( 

109 self, 

110 units: int, 

111 sigma: float = 0.5, 

112 use_factorised: bool = True, 

113 activation: types.Activation = None, 

114 use_bias: bool = True, 

115 kernel_regularizer: types.Regularizer = None, 

116 bias_regularizer: types.Regularizer = None, 

117 activity_regularizer: types.Regularizer = None, 

118 kernel_constraint: types.Constraint = None, 

119 bias_constraint: types.Constraint = None, 

120 **kwargs, 

121 ): 

122 super().__init__( 

123 units=units, 

124 activation=activation, 

125 use_bias=use_bias, 

126 kernel_regularizer=kernel_regularizer, 

127 bias_regularizer=bias_regularizer, 

128 activity_regularizer=activity_regularizer, 

129 kernel_constraint=kernel_constraint, 

130 bias_constraint=bias_constraint, 

131 **kwargs, 

132 ) 

133 delattr(self, "kernel_initializer") 

134 delattr(self, "bias_initializer") 

135 self.sigma = sigma 

136 self.use_factorised = use_factorised 

137 

138 def build(self, input_shape): 

139 # Make sure dtype is correct 

140 dtype = tf.dtypes.as_dtype(self.dtype or K.floatx()) 

141 if not (dtype.is_floating or dtype.is_complex): 

142 raise TypeError( 

143 "Unable to build `Dense` layer with non-floating point " 

144 "dtype %s" % (dtype,) 

145 ) 

146 

147 input_shape = tf.TensorShape(input_shape) 

148 self.last_dim = tf.compat.dimension_value(input_shape[-1]) 

149 sqrt_dim = self.last_dim ** (1 / 2) 

150 if self.last_dim is None: 

151 raise ValueError( 

152 "The last dimension of the inputs to `Dense` " 

153 "should be defined. Found `None`." 

154 ) 

155 self.input_spec = InputSpec(min_ndim=2, axes={-1: self.last_dim}) 

156 

157 # use factorising Gaussian variables 

158 if self.use_factorised: 

159 mu_init = 1.0 / sqrt_dim 

160 sigma_init = self.sigma / sqrt_dim 

161 # use independent Gaussian variables 

162 else: 

163 mu_init = (3.0 / self.last_dim) ** (1 / 2) 

164 sigma_init = 0.017 

165 

166 sigma_init = initializers.Constant(value=sigma_init) 

167 mu_init = initializers.RandomUniform(minval=-mu_init, maxval=mu_init) 

168 

169 # Learnable parameters 

170 self.sigma_kernel = self.add_weight( 

171 "sigma_kernel", 

172 shape=[self.last_dim, self.units], 

173 initializer=sigma_init, 

174 regularizer=self.kernel_regularizer, 

175 constraint=self.kernel_constraint, 

176 dtype=self.dtype, 

177 trainable=True, 

178 ) 

179 

180 self.mu_kernel = self.add_weight( 

181 "mu_kernel", 

182 shape=[self.last_dim, self.units], 

183 initializer=mu_init, 

184 regularizer=self.kernel_regularizer, 

185 constraint=self.kernel_constraint, 

186 dtype=self.dtype, 

187 trainable=True, 

188 ) 

189 

190 self.eps_kernel = self.add_weight( 

191 "eps_kernel", 

192 shape=[self.last_dim, self.units], 

193 initializer=initializers.Zeros(), 

194 regularizer=None, 

195 constraint=None, 

196 dtype=self.dtype, 

197 trainable=False, 

198 ) 

199 

200 if self.use_bias: 

201 self.sigma_bias = self.add_weight( 

202 "sigma_bias", 

203 shape=[ 

204 self.units, 

205 ], 

206 initializer=sigma_init, 

207 regularizer=self.bias_regularizer, 

208 constraint=self.bias_constraint, 

209 dtype=self.dtype, 

210 trainable=True, 

211 ) 

212 

213 self.mu_bias = self.add_weight( 

214 "mu_bias", 

215 shape=[ 

216 self.units, 

217 ], 

218 initializer=mu_init, 

219 regularizer=self.bias_regularizer, 

220 constraint=self.bias_constraint, 

221 dtype=self.dtype, 

222 trainable=True, 

223 ) 

224 

225 self.eps_bias = self.add_weight( 

226 "eps_bias", 

227 shape=[ 

228 self.units, 

229 ], 

230 initializer=initializers.Zeros(), 

231 regularizer=None, 

232 constraint=None, 

233 dtype=self.dtype, 

234 trainable=False, 

235 ) 

236 else: 

237 self.sigma_bias = None 

238 self.mu_bias = None 

239 self.eps_bias = None 

240 self.reset_noise() 

241 self.built = True 

242 

243 @property 

244 def kernel(self): 

245 return self.mu_kernel + (self.sigma_kernel * self.eps_kernel) 

246 

247 @property 

248 def bias(self): 

249 if self.use_bias: 

250 return self.mu_bias + (self.sigma_bias * self.eps_bias) 

251 

252 def reset_noise(self): 

253 """Create the factorised Gaussian noise.""" 

254 

255 if self.use_factorised: 

256 # Generate random noise 

257 in_eps = _scaled_noise([self.last_dim, 1], dtype=self.dtype) 

258 out_eps = _scaled_noise([1, self.units], dtype=self.dtype) 

259 

260 # Scale the random noise 

261 self.eps_kernel.assign(tf.matmul(in_eps, out_eps)) 

262 self.eps_bias.assign(out_eps[0]) 

263 else: 

264 # generate independent variables 

265 self.eps_kernel.assign( 

266 tf.random.normal(shape=[self.last_dim, self.units], dtype=self.dtype) 

267 ) 

268 self.eps_bias.assign( 

269 tf.random.normal( 

270 shape=[ 

271 self.units, 

272 ], 

273 dtype=self.dtype, 

274 ) 

275 ) 

276 

277 def remove_noise(self): 

278 """Remove the factorised Gaussian noise.""" 

279 

280 self.eps_kernel.assign(tf.zeros([self.last_dim, self.units], dtype=self.dtype)) 

281 self.eps_bias.assign(tf.zeros([self.units], dtype=self.dtype)) 

282 

283 def call(self, inputs): 

284 # TODO(WindQAQ): Replace this with `dense()` once public. 

285 return super().call(inputs) 

286 

287 def get_config(self): 

288 # TODO(WindQAQ): Get rid of this hacky way. 

289 config = super(tf.keras.layers.Dense, self).get_config() 

290 config.update( 

291 { 

292 "units": self.units, 

293 "sigma": self.sigma, 

294 "use_factorised": self.use_factorised, 

295 "activation": activations.serialize(self.activation), 

296 "use_bias": self.use_bias, 

297 "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), 

298 "bias_regularizer": regularizers.serialize(self.bias_regularizer), 

299 "activity_regularizer": regularizers.serialize( 

300 self.activity_regularizer 

301 ), 

302 "kernel_constraint": constraints.serialize(self.kernel_constraint), 

303 "bias_constraint": constraints.serialize(self.bias_constraint), 

304 } 

305 ) 

306 return config