Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/constraints.py: 54%

107 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=invalid-name 

16# pylint: disable=g-classes-have-attributes 

17"""Constraints: functions that impose constraints on weight values.""" 

18 

19from tensorflow.python.framework import tensor_shape 

20from tensorflow.python.keras import backend 

21from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 

22from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import array_ops_stack 

25from tensorflow.python.ops import math_ops 

26from tensorflow.python.ops import while_loop 

27from tensorflow.python.util.tf_export import keras_export 

28from tensorflow.tools.docs import doc_controls 

29 

30 

31@keras_export('keras.constraints.Constraint') 

32class Constraint: 

33 """Base class for weight constraints. 

34 

35 A `Constraint` instance works like a stateless function. 

36 Users who subclass this 

37 class should override the `__call__` method, which takes a single 

38 weight parameter and return a projected version of that parameter 

39 (e.g. normalized or clipped). Constraints can be used with various Keras 

40 layers via the `kernel_constraint` or `bias_constraint` arguments. 

41 

42 Here's a simple example of a non-negative weight constraint: 

43 

44 >>> class NonNegative(tf.keras.constraints.Constraint): 

45 ... 

46 ... def __call__(self, w): 

47 ... return w * tf.cast(tf.math.greater_equal(w, 0.), w.dtype) 

48 

49 >>> weight = tf.constant((-1.0, 1.0)) 

50 >>> NonNegative()(weight) 

51 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 1.], dtype=float32)> 

52 

53 >>> tf.keras.layers.Dense(4, kernel_constraint=NonNegative()) 

54 """ 

55 

56 def __call__(self, w): 

57 """Applies the constraint to the input weight variable. 

58 

59 By default, the inputs weight variable is not modified. 

60 Users should override this method to implement their own projection 

61 function. 

62 

63 Args: 

64 w: Input weight variable. 

65 

66 Returns: 

67 Projected variable (by default, returns unmodified inputs). 

68 """ 

69 return w 

70 

71 def get_config(self): 

72 """Returns a Python dict of the object config. 

73 

74 A constraint config is a Python dictionary (JSON-serializable) that can 

75 be used to reinstantiate the same object. 

76 

77 Returns: 

78 Python dict containing the configuration of the constraint object. 

79 """ 

80 return {} 

81 

82 

83@keras_export('keras.constraints.MaxNorm', 'keras.constraints.max_norm') 

84class MaxNorm(Constraint): 

85 """MaxNorm weight constraint. 

86 

87 Constrains the weights incident to each hidden unit 

88 to have a norm less than or equal to a desired value. 

89 

90 Also available via the shortcut function `tf.keras.constraints.max_norm`. 

91 

92 Args: 

93 max_value: the maximum norm value for the incoming weights. 

94 axis: integer, axis along which to calculate weight norms. 

95 For instance, in a `Dense` layer the weight matrix 

96 has shape `(input_dim, output_dim)`, 

97 set `axis` to `0` to constrain each weight vector 

98 of length `(input_dim,)`. 

99 In a `Conv2D` layer with `data_format="channels_last"`, 

100 the weight tensor has shape 

101 `(rows, cols, input_depth, output_depth)`, 

102 set `axis` to `[0, 1, 2]` 

103 to constrain the weights of each filter tensor of size 

104 `(rows, cols, input_depth)`. 

105 

106 """ 

107 

108 def __init__(self, max_value=2, axis=0): 

109 self.max_value = max_value 

110 self.axis = axis 

111 

112 @doc_controls.do_not_generate_docs 

113 def __call__(self, w): 

114 norms = backend.sqrt( 

115 math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) 

116 desired = backend.clip(norms, 0, self.max_value) 

117 return w * (desired / (backend.epsilon() + norms)) 

118 

119 @doc_controls.do_not_generate_docs 

120 def get_config(self): 

121 return {'max_value': self.max_value, 'axis': self.axis} 

122 

123 

124@keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg') 

125class NonNeg(Constraint): 

126 """Constrains the weights to be non-negative. 

127 

128 Also available via the shortcut function `tf.keras.constraints.non_neg`. 

129 """ 

130 

131 def __call__(self, w): 

132 return w * math_ops.cast(math_ops.greater_equal(w, 0.), backend.floatx()) 

133 

134 

135@keras_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm') 

136class UnitNorm(Constraint): 

137 """Constrains the weights incident to each hidden unit to have unit norm. 

138 

139 Also available via the shortcut function `tf.keras.constraints.unit_norm`. 

140 

141 Args: 

142 axis: integer, axis along which to calculate weight norms. 

143 For instance, in a `Dense` layer the weight matrix 

144 has shape `(input_dim, output_dim)`, 

145 set `axis` to `0` to constrain each weight vector 

146 of length `(input_dim,)`. 

147 In a `Conv2D` layer with `data_format="channels_last"`, 

148 the weight tensor has shape 

149 `(rows, cols, input_depth, output_depth)`, 

150 set `axis` to `[0, 1, 2]` 

151 to constrain the weights of each filter tensor of size 

152 `(rows, cols, input_depth)`. 

153 """ 

154 

155 def __init__(self, axis=0): 

156 self.axis = axis 

157 

158 @doc_controls.do_not_generate_docs 

159 def __call__(self, w): 

160 return w / ( 

161 backend.epsilon() + backend.sqrt( 

162 math_ops.reduce_sum( 

163 math_ops.square(w), axis=self.axis, keepdims=True))) 

164 

165 @doc_controls.do_not_generate_docs 

166 def get_config(self): 

167 return {'axis': self.axis} 

168 

169 

170@keras_export('keras.constraints.MinMaxNorm', 'keras.constraints.min_max_norm') 

171class MinMaxNorm(Constraint): 

172 """MinMaxNorm weight constraint. 

173 

174 Constrains the weights incident to each hidden unit 

175 to have the norm between a lower bound and an upper bound. 

176 

177 Also available via the shortcut function `tf.keras.constraints.min_max_norm`. 

178 

179 Args: 

180 min_value: the minimum norm for the incoming weights. 

181 max_value: the maximum norm for the incoming weights. 

182 rate: rate for enforcing the constraint: weights will be 

183 rescaled to yield 

184 `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`. 

185 Effectively, this means that rate=1.0 stands for strict 

186 enforcement of the constraint, while rate<1.0 means that 

187 weights will be rescaled at each step to slowly move 

188 towards a value inside the desired interval. 

189 axis: integer, axis along which to calculate weight norms. 

190 For instance, in a `Dense` layer the weight matrix 

191 has shape `(input_dim, output_dim)`, 

192 set `axis` to `0` to constrain each weight vector 

193 of length `(input_dim,)`. 

194 In a `Conv2D` layer with `data_format="channels_last"`, 

195 the weight tensor has shape 

196 `(rows, cols, input_depth, output_depth)`, 

197 set `axis` to `[0, 1, 2]` 

198 to constrain the weights of each filter tensor of size 

199 `(rows, cols, input_depth)`. 

200 """ 

201 

202 def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0): 

203 self.min_value = min_value 

204 self.max_value = max_value 

205 self.rate = rate 

206 self.axis = axis 

207 

208 @doc_controls.do_not_generate_docs 

209 def __call__(self, w): 

210 norms = backend.sqrt( 

211 math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) 

212 desired = ( 

213 self.rate * backend.clip(norms, self.min_value, self.max_value) + 

214 (1 - self.rate) * norms) 

215 return w * (desired / (backend.epsilon() + norms)) 

216 

217 @doc_controls.do_not_generate_docs 

218 def get_config(self): 

219 return { 

220 'min_value': self.min_value, 

221 'max_value': self.max_value, 

222 'rate': self.rate, 

223 'axis': self.axis 

224 } 

225 

226 

227@keras_export('keras.constraints.RadialConstraint', 

228 'keras.constraints.radial_constraint') 

229class RadialConstraint(Constraint): 

230 """Constrains `Conv2D` kernel weights to be the same for each radius. 

231 

232 Also available via the shortcut function 

233 `tf.keras.constraints.radial_constraint`. 

234 

235 For example, the desired output for the following 4-by-4 kernel: 

236 

237 ``` 

238 kernel = [[v_00, v_01, v_02, v_03], 

239 [v_10, v_11, v_12, v_13], 

240 [v_20, v_21, v_22, v_23], 

241 [v_30, v_31, v_32, v_33]] 

242 ``` 

243 

244 is this:: 

245 

246 ``` 

247 kernel = [[v_11, v_11, v_11, v_11], 

248 [v_11, v_33, v_33, v_11], 

249 [v_11, v_33, v_33, v_11], 

250 [v_11, v_11, v_11, v_11]] 

251 ``` 

252 

253 This constraint can be applied to any `Conv2D` layer version, including 

254 `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or 

255 `"channels_first"` data format. The method assumes the weight tensor is of 

256 shape `(rows, cols, input_depth, output_depth)`. 

257 """ 

258 

259 @doc_controls.do_not_generate_docs 

260 def __call__(self, w): 

261 w_shape = w.shape 

262 if w_shape.rank is None or w_shape.rank != 4: 

263 raise ValueError( 

264 'The weight tensor must be of rank 4, but is of shape: %s' % w_shape) 

265 

266 height, width, channels, kernels = w_shape 

267 w = backend.reshape(w, (height, width, channels * kernels)) 

268 # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once 

269 # backend.switch is supported. 

270 w = backend.map_fn( 

271 self._kernel_constraint, 

272 backend.stack(array_ops_stack.unstack(w, axis=-1), axis=0)) 

273 return backend.reshape( 

274 backend.stack(array_ops_stack.unstack(w, axis=0), axis=-1), 

275 (height, width, channels, kernels)) 

276 

277 def _kernel_constraint(self, kernel): 

278 """Radially constraints a kernel with shape (height, width, channels).""" 

279 padding = backend.constant([[1, 1], [1, 1]], dtype='int32') 

280 

281 kernel_shape = backend.shape(kernel)[0] 

282 start = backend.cast(kernel_shape / 2, 'int32') 

283 

284 kernel_new = backend.switch( 

285 backend.cast(math_ops.floormod(kernel_shape, 2), 'bool'), 

286 lambda: kernel[start - 1:start, start - 1:start], 

287 lambda: kernel[start - 1:start, start - 1:start] + backend.zeros( # pylint: disable=g-long-lambda 

288 (2, 2), dtype=kernel.dtype)) 

289 index = backend.switch( 

290 backend.cast(math_ops.floormod(kernel_shape, 2), 'bool'), 

291 lambda: backend.constant(0, dtype='int32'), 

292 lambda: backend.constant(1, dtype='int32')) 

293 while_condition = lambda index, *args: backend.less(index, start) 

294 

295 def body_fn(i, array): 

296 return i + 1, array_ops.pad( 

297 array, 

298 padding, 

299 constant_values=kernel[start + i, start + i]) 

300 

301 _, kernel_new = while_loop.while_loop( 

302 while_condition, 

303 body_fn, [index, kernel_new], 

304 shape_invariants=[ 

305 index.get_shape(), 

306 tensor_shape.TensorShape([None, None]) 

307 ]) 

308 return kernel_new 

309 

310 

311# Aliases. 

312 

313max_norm = MaxNorm 

314non_neg = NonNeg 

315unit_norm = UnitNorm 

316min_max_norm = MinMaxNorm 

317radial_constraint = RadialConstraint 

318 

319# Legacy aliases. 

320maxnorm = max_norm 

321nonneg = non_neg 

322unitnorm = unit_norm 

323 

324 

325@keras_export('keras.constraints.serialize') 

326def serialize(constraint): 

327 return serialize_keras_object(constraint) 

328 

329 

330@keras_export('keras.constraints.deserialize') 

331def deserialize(config, custom_objects=None): 

332 return deserialize_keras_object( 

333 config, 

334 module_objects=globals(), 

335 custom_objects=custom_objects, 

336 printable_module_name='constraint') 

337 

338 

339@keras_export('keras.constraints.get') 

340def get(identifier): 

341 if identifier is None: 

342 return None 

343 if isinstance(identifier, dict): 

344 return deserialize(identifier) 

345 elif isinstance(identifier, str): 

346 config = {'class_name': str(identifier), 'config': {}} 

347 return deserialize(config) 

348 elif callable(identifier): 

349 return identifier 

350 else: 

351 raise ValueError('Could not interpret constraint identifier: ' + 

352 str(identifier))