Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/constraints.py: 51%

112 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16 

17"""Constraints: functions that impose constraints on weight values.""" 

18 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src import backend 

22from keras.src.saving.legacy import serialization as legacy_serialization 

23from keras.src.saving.serialization_lib import deserialize_keras_object 

24from keras.src.saving.serialization_lib import serialize_keras_object 

25 

26# isort: off 

27from tensorflow.python.util.tf_export import keras_export 

28from tensorflow.tools.docs import doc_controls 

29 

30 

31@keras_export("keras.constraints.Constraint") 

32class Constraint: 

33 """Base class for weight constraints. 

34 

35 A `Constraint` instance works like a stateless function. 

36 Users who subclass this 

37 class should override the `__call__` method, which takes a single 

38 weight parameter and return a projected version of that parameter 

39 (e.g. normalized or clipped). Constraints can be used with various Keras 

40 layers via the `kernel_constraint` or `bias_constraint` arguments. 

41 

42 Here's a simple example of a non-negative weight constraint: 

43 

44 >>> class NonNegative(tf.keras.constraints.Constraint): 

45 ... 

46 ... def __call__(self, w): 

47 ... return w * tf.cast(tf.math.greater_equal(w, 0.), w.dtype) 

48 

49 >>> weight = tf.constant((-1.0, 1.0)) 

50 >>> NonNegative()(weight) 

51 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 1.], 

52 dtype=float32)> 

53 

54 >>> tf.keras.layers.Dense(4, kernel_constraint=NonNegative()) 

55 """ 

56 

57 def __call__(self, w): 

58 """Applies the constraint to the input weight variable. 

59 

60 By default, the inputs weight variable is not modified. 

61 Users should override this method to implement their own projection 

62 function. 

63 

64 Args: 

65 w: Input weight variable. 

66 

67 Returns: 

68 Projected variable (by default, returns unmodified inputs). 

69 """ 

70 return w 

71 

72 def get_config(self): 

73 """Returns a Python dict of the object config. 

74 

75 A constraint config is a Python dictionary (JSON-serializable) that can 

76 be used to reinstantiate the same object. 

77 

78 Returns: 

79 Python dict containing the configuration of the constraint object. 

80 """ 

81 return {} 

82 

83 @classmethod 

84 def from_config(cls, config): 

85 """Instantiates a weight constraint from a configuration dictionary. 

86 

87 Example: 

88 

89 ```python 

90 constraint = UnitNorm() 

91 config = constraint.get_config() 

92 constraint = UnitNorm.from_config(config) 

93 ``` 

94 

95 Args: 

96 config: A Python dictionary, the output of `get_config`. 

97 

98 Returns: 

99 A `tf.keras.constraints.Constraint` instance. 

100 """ 

101 return cls(**config) 

102 

103 

104@keras_export("keras.constraints.MaxNorm", "keras.constraints.max_norm") 

105class MaxNorm(Constraint): 

106 """MaxNorm weight constraint. 

107 

108 Constrains the weights incident to each hidden unit 

109 to have a norm less than or equal to a desired value. 

110 

111 Also available via the shortcut function `tf.keras.constraints.max_norm`. 

112 

113 Args: 

114 max_value: the maximum norm value for the incoming weights. 

115 axis: integer, axis along which to calculate weight norms. 

116 For instance, in a `Dense` layer the weight matrix 

117 has shape `(input_dim, output_dim)`, 

118 set `axis` to `0` to constrain each weight vector 

119 of length `(input_dim,)`. 

120 In a `Conv2D` layer with `data_format="channels_last"`, 

121 the weight tensor has shape 

122 `(rows, cols, input_depth, output_depth)`, 

123 set `axis` to `[0, 1, 2]` 

124 to constrain the weights of each filter tensor of size 

125 `(rows, cols, input_depth)`. 

126 

127 """ 

128 

129 def __init__(self, max_value=2, axis=0): 

130 self.max_value = max_value 

131 self.axis = axis 

132 

133 @doc_controls.do_not_generate_docs 

134 def __call__(self, w): 

135 norms = backend.sqrt( 

136 tf.reduce_sum(tf.square(w), axis=self.axis, keepdims=True) 

137 ) 

138 desired = backend.clip(norms, 0, self.max_value) 

139 return w * (desired / (backend.epsilon() + norms)) 

140 

141 @doc_controls.do_not_generate_docs 

142 def get_config(self): 

143 return {"max_value": self.max_value, "axis": self.axis} 

144 

145 

146@keras_export("keras.constraints.NonNeg", "keras.constraints.non_neg") 

147class NonNeg(Constraint): 

148 """Constrains the weights to be non-negative. 

149 

150 Also available via the shortcut function `tf.keras.constraints.non_neg`. 

151 """ 

152 

153 def __call__(self, w): 

154 return w * tf.cast(tf.greater_equal(w, 0.0), backend.floatx()) 

155 

156 

157@keras_export("keras.constraints.UnitNorm", "keras.constraints.unit_norm") 

158class UnitNorm(Constraint): 

159 """Constrains the weights incident to each hidden unit to have unit norm. 

160 

161 Also available via the shortcut function `tf.keras.constraints.unit_norm`. 

162 

163 Args: 

164 axis: integer, axis along which to calculate weight norms. 

165 For instance, in a `Dense` layer the weight matrix 

166 has shape `(input_dim, output_dim)`, 

167 set `axis` to `0` to constrain each weight vector 

168 of length `(input_dim,)`. 

169 In a `Conv2D` layer with `data_format="channels_last"`, 

170 the weight tensor has shape 

171 `(rows, cols, input_depth, output_depth)`, 

172 set `axis` to `[0, 1, 2]` 

173 to constrain the weights of each filter tensor of size 

174 `(rows, cols, input_depth)`. 

175 """ 

176 

177 def __init__(self, axis=0): 

178 self.axis = axis 

179 

180 @doc_controls.do_not_generate_docs 

181 def __call__(self, w): 

182 return w / ( 

183 backend.epsilon() 

184 + backend.sqrt( 

185 tf.reduce_sum(tf.square(w), axis=self.axis, keepdims=True) 

186 ) 

187 ) 

188 

189 @doc_controls.do_not_generate_docs 

190 def get_config(self): 

191 return {"axis": self.axis} 

192 

193 

194@keras_export("keras.constraints.MinMaxNorm", "keras.constraints.min_max_norm") 

195class MinMaxNorm(Constraint): 

196 """MinMaxNorm weight constraint. 

197 

198 Constrains the weights incident to each hidden unit 

199 to have the norm between a lower bound and an upper bound. 

200 

201 Also available via the shortcut function 

202 `tf.keras.constraints.min_max_norm`. 

203 

204 Args: 

205 min_value: the minimum norm for the incoming weights. 

206 max_value: the maximum norm for the incoming weights. 

207 rate: rate for enforcing the constraint: weights will be 

208 rescaled to yield 

209 `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`. 

210 Effectively, this means that rate=1.0 stands for strict 

211 enforcement of the constraint, while rate<1.0 means that 

212 weights will be rescaled at each step to slowly move 

213 towards a value inside the desired interval. 

214 axis: integer, axis along which to calculate weight norms. 

215 For instance, in a `Dense` layer the weight matrix 

216 has shape `(input_dim, output_dim)`, 

217 set `axis` to `0` to constrain each weight vector 

218 of length `(input_dim,)`. 

219 In a `Conv2D` layer with `data_format="channels_last"`, 

220 the weight tensor has shape 

221 `(rows, cols, input_depth, output_depth)`, 

222 set `axis` to `[0, 1, 2]` 

223 to constrain the weights of each filter tensor of size 

224 `(rows, cols, input_depth)`. 

225 """ 

226 

227 def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0): 

228 self.min_value = min_value 

229 self.max_value = max_value 

230 self.rate = rate 

231 self.axis = axis 

232 

233 @doc_controls.do_not_generate_docs 

234 def __call__(self, w): 

235 norms = backend.sqrt( 

236 tf.reduce_sum(tf.square(w), axis=self.axis, keepdims=True) 

237 ) 

238 desired = ( 

239 self.rate * backend.clip(norms, self.min_value, self.max_value) 

240 + (1 - self.rate) * norms 

241 ) 

242 return w * (desired / (backend.epsilon() + norms)) 

243 

244 @doc_controls.do_not_generate_docs 

245 def get_config(self): 

246 return { 

247 "min_value": self.min_value, 

248 "max_value": self.max_value, 

249 "rate": self.rate, 

250 "axis": self.axis, 

251 } 

252 

253 

254@keras_export( 

255 "keras.constraints.RadialConstraint", "keras.constraints.radial_constraint" 

256) 

257class RadialConstraint(Constraint): 

258 """Constrains `Conv2D` kernel weights to be the same for each radius. 

259 

260 Also available via the shortcut function 

261 `tf.keras.constraints.radial_constraint`. 

262 

263 For example, the desired output for the following 4-by-4 kernel: 

264 

265 ``` 

266 kernel = [[v_00, v_01, v_02, v_03], 

267 [v_10, v_11, v_12, v_13], 

268 [v_20, v_21, v_22, v_23], 

269 [v_30, v_31, v_32, v_33]] 

270 ``` 

271 

272 is this:: 

273 

274 ``` 

275 kernel = [[v_11, v_11, v_11, v_11], 

276 [v_11, v_33, v_33, v_11], 

277 [v_11, v_33, v_33, v_11], 

278 [v_11, v_11, v_11, v_11]] 

279 ``` 

280 

281 This constraint can be applied to any `Conv2D` layer version, including 

282 `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` 

283 or `"channels_first"` data format. The method assumes the weight tensor is 

284 of shape `(rows, cols, input_depth, output_depth)`. 

285 """ 

286 

287 @doc_controls.do_not_generate_docs 

288 def __call__(self, w): 

289 w_shape = w.shape 

290 if w_shape.rank is None or w_shape.rank != 4: 

291 raise ValueError( 

292 "The weight tensor must have rank 4. " 

293 f"Received weight tensor with shape: {w_shape}" 

294 ) 

295 

296 height, width, channels, kernels = w_shape 

297 w = backend.reshape(w, (height, width, channels * kernels)) 

298 # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once 

299 # backend.switch is supported. 

300 w = backend.map_fn( 

301 self._kernel_constraint, 

302 backend.stack(tf.unstack(w, axis=-1), axis=0), 

303 ) 

304 return backend.reshape( 

305 backend.stack(tf.unstack(w, axis=0), axis=-1), 

306 (height, width, channels, kernels), 

307 ) 

308 

309 def _kernel_constraint(self, kernel): 

310 """Radially constraints a kernel with shape (height, width, 

311 channels).""" 

312 padding = backend.constant([[1, 1], [1, 1]], dtype="int32") 

313 

314 kernel_shape = backend.shape(kernel)[0] 

315 start = backend.cast(kernel_shape / 2, "int32") 

316 

317 kernel_new = backend.switch( 

318 backend.cast(tf.math.floormod(kernel_shape, 2), "bool"), 

319 lambda: kernel[start - 1 : start, start - 1 : start], 

320 lambda: kernel[start - 1 : start, start - 1 : start] 

321 + backend.zeros((2, 2), dtype=kernel.dtype), 

322 ) 

323 index = backend.switch( 

324 backend.cast(tf.math.floormod(kernel_shape, 2), "bool"), 

325 lambda: backend.constant(0, dtype="int32"), 

326 lambda: backend.constant(1, dtype="int32"), 

327 ) 

328 while_condition = lambda index, *args: backend.less(index, start) 

329 

330 def body_fn(i, array): 

331 return i + 1, tf.pad( 

332 array, padding, constant_values=kernel[start + i, start + i] 

333 ) 

334 

335 _, kernel_new = tf.compat.v1.while_loop( 

336 while_condition, 

337 body_fn, 

338 [index, kernel_new], 

339 shape_invariants=[index.get_shape(), tf.TensorShape([None, None])], 

340 ) 

341 return kernel_new 

342 

343 

344# Aliases. 

345 

346max_norm = MaxNorm 

347non_neg = NonNeg 

348unit_norm = UnitNorm 

349min_max_norm = MinMaxNorm 

350radial_constraint = RadialConstraint 

351 

352# Legacy aliases. 

353maxnorm = max_norm 

354nonneg = non_neg 

355unitnorm = unit_norm 

356 

357 

358@keras_export("keras.constraints.serialize") 

359def serialize(constraint, use_legacy_format=False): 

360 if use_legacy_format: 

361 return legacy_serialization.serialize_keras_object(constraint) 

362 return serialize_keras_object(constraint) 

363 

364 

365@keras_export("keras.constraints.deserialize") 

366def deserialize(config, custom_objects=None, use_legacy_format=False): 

367 if use_legacy_format: 

368 return legacy_serialization.deserialize_keras_object( 

369 config, 

370 module_objects=globals(), 

371 custom_objects=custom_objects, 

372 printable_module_name="constraint", 

373 ) 

374 return deserialize_keras_object( 

375 config, 

376 module_objects=globals(), 

377 custom_objects=custom_objects, 

378 printable_module_name="constraint", 

379 ) 

380 

381 

382@keras_export("keras.constraints.get") 

383def get(identifier): 

384 """Retrieves a Keras constraint function.""" 

385 if identifier is None: 

386 return None 

387 if isinstance(identifier, dict): 

388 use_legacy_format = "module" not in identifier 

389 return deserialize(identifier, use_legacy_format=use_legacy_format) 

390 elif isinstance(identifier, str): 

391 config = {"class_name": str(identifier), "config": {}} 

392 return get(config) 

393 elif callable(identifier): 

394 return identifier 

395 else: 

396 raise ValueError( 

397 f"Could not interpret constraint function identifier: {identifier}" 

398 ) 

399