Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/constraints.py: 51%
112 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
17"""Constraints: functions that impose constraints on weight values."""
19import tensorflow.compat.v2 as tf
21from keras.src import backend
22from keras.src.saving.legacy import serialization as legacy_serialization
23from keras.src.saving.serialization_lib import deserialize_keras_object
24from keras.src.saving.serialization_lib import serialize_keras_object
26# isort: off
27from tensorflow.python.util.tf_export import keras_export
28from tensorflow.tools.docs import doc_controls
31@keras_export("keras.constraints.Constraint")
32class Constraint:
33 """Base class for weight constraints.
35 A `Constraint` instance works like a stateless function.
36 Users who subclass this
37 class should override the `__call__` method, which takes a single
38 weight parameter and return a projected version of that parameter
39 (e.g. normalized or clipped). Constraints can be used with various Keras
40 layers via the `kernel_constraint` or `bias_constraint` arguments.
42 Here's a simple example of a non-negative weight constraint:
44 >>> class NonNegative(tf.keras.constraints.Constraint):
45 ...
46 ... def __call__(self, w):
47 ... return w * tf.cast(tf.math.greater_equal(w, 0.), w.dtype)
49 >>> weight = tf.constant((-1.0, 1.0))
50 >>> NonNegative()(weight)
51 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 1.],
52 dtype=float32)>
54 >>> tf.keras.layers.Dense(4, kernel_constraint=NonNegative())
55 """
57 def __call__(self, w):
58 """Applies the constraint to the input weight variable.
60 By default, the inputs weight variable is not modified.
61 Users should override this method to implement their own projection
62 function.
64 Args:
65 w: Input weight variable.
67 Returns:
68 Projected variable (by default, returns unmodified inputs).
69 """
70 return w
72 def get_config(self):
73 """Returns a Python dict of the object config.
75 A constraint config is a Python dictionary (JSON-serializable) that can
76 be used to reinstantiate the same object.
78 Returns:
79 Python dict containing the configuration of the constraint object.
80 """
81 return {}
83 @classmethod
84 def from_config(cls, config):
85 """Instantiates a weight constraint from a configuration dictionary.
87 Example:
89 ```python
90 constraint = UnitNorm()
91 config = constraint.get_config()
92 constraint = UnitNorm.from_config(config)
93 ```
95 Args:
96 config: A Python dictionary, the output of `get_config`.
98 Returns:
99 A `tf.keras.constraints.Constraint` instance.
100 """
101 return cls(**config)
104@keras_export("keras.constraints.MaxNorm", "keras.constraints.max_norm")
105class MaxNorm(Constraint):
106 """MaxNorm weight constraint.
108 Constrains the weights incident to each hidden unit
109 to have a norm less than or equal to a desired value.
111 Also available via the shortcut function `tf.keras.constraints.max_norm`.
113 Args:
114 max_value: the maximum norm value for the incoming weights.
115 axis: integer, axis along which to calculate weight norms.
116 For instance, in a `Dense` layer the weight matrix
117 has shape `(input_dim, output_dim)`,
118 set `axis` to `0` to constrain each weight vector
119 of length `(input_dim,)`.
120 In a `Conv2D` layer with `data_format="channels_last"`,
121 the weight tensor has shape
122 `(rows, cols, input_depth, output_depth)`,
123 set `axis` to `[0, 1, 2]`
124 to constrain the weights of each filter tensor of size
125 `(rows, cols, input_depth)`.
127 """
129 def __init__(self, max_value=2, axis=0):
130 self.max_value = max_value
131 self.axis = axis
133 @doc_controls.do_not_generate_docs
134 def __call__(self, w):
135 norms = backend.sqrt(
136 tf.reduce_sum(tf.square(w), axis=self.axis, keepdims=True)
137 )
138 desired = backend.clip(norms, 0, self.max_value)
139 return w * (desired / (backend.epsilon() + norms))
141 @doc_controls.do_not_generate_docs
142 def get_config(self):
143 return {"max_value": self.max_value, "axis": self.axis}
146@keras_export("keras.constraints.NonNeg", "keras.constraints.non_neg")
147class NonNeg(Constraint):
148 """Constrains the weights to be non-negative.
150 Also available via the shortcut function `tf.keras.constraints.non_neg`.
151 """
153 def __call__(self, w):
154 return w * tf.cast(tf.greater_equal(w, 0.0), backend.floatx())
157@keras_export("keras.constraints.UnitNorm", "keras.constraints.unit_norm")
158class UnitNorm(Constraint):
159 """Constrains the weights incident to each hidden unit to have unit norm.
161 Also available via the shortcut function `tf.keras.constraints.unit_norm`.
163 Args:
164 axis: integer, axis along which to calculate weight norms.
165 For instance, in a `Dense` layer the weight matrix
166 has shape `(input_dim, output_dim)`,
167 set `axis` to `0` to constrain each weight vector
168 of length `(input_dim,)`.
169 In a `Conv2D` layer with `data_format="channels_last"`,
170 the weight tensor has shape
171 `(rows, cols, input_depth, output_depth)`,
172 set `axis` to `[0, 1, 2]`
173 to constrain the weights of each filter tensor of size
174 `(rows, cols, input_depth)`.
175 """
177 def __init__(self, axis=0):
178 self.axis = axis
180 @doc_controls.do_not_generate_docs
181 def __call__(self, w):
182 return w / (
183 backend.epsilon()
184 + backend.sqrt(
185 tf.reduce_sum(tf.square(w), axis=self.axis, keepdims=True)
186 )
187 )
189 @doc_controls.do_not_generate_docs
190 def get_config(self):
191 return {"axis": self.axis}
194@keras_export("keras.constraints.MinMaxNorm", "keras.constraints.min_max_norm")
195class MinMaxNorm(Constraint):
196 """MinMaxNorm weight constraint.
198 Constrains the weights incident to each hidden unit
199 to have the norm between a lower bound and an upper bound.
201 Also available via the shortcut function
202 `tf.keras.constraints.min_max_norm`.
204 Args:
205 min_value: the minimum norm for the incoming weights.
206 max_value: the maximum norm for the incoming weights.
207 rate: rate for enforcing the constraint: weights will be
208 rescaled to yield
209 `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.
210 Effectively, this means that rate=1.0 stands for strict
211 enforcement of the constraint, while rate<1.0 means that
212 weights will be rescaled at each step to slowly move
213 towards a value inside the desired interval.
214 axis: integer, axis along which to calculate weight norms.
215 For instance, in a `Dense` layer the weight matrix
216 has shape `(input_dim, output_dim)`,
217 set `axis` to `0` to constrain each weight vector
218 of length `(input_dim,)`.
219 In a `Conv2D` layer with `data_format="channels_last"`,
220 the weight tensor has shape
221 `(rows, cols, input_depth, output_depth)`,
222 set `axis` to `[0, 1, 2]`
223 to constrain the weights of each filter tensor of size
224 `(rows, cols, input_depth)`.
225 """
227 def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0):
228 self.min_value = min_value
229 self.max_value = max_value
230 self.rate = rate
231 self.axis = axis
233 @doc_controls.do_not_generate_docs
234 def __call__(self, w):
235 norms = backend.sqrt(
236 tf.reduce_sum(tf.square(w), axis=self.axis, keepdims=True)
237 )
238 desired = (
239 self.rate * backend.clip(norms, self.min_value, self.max_value)
240 + (1 - self.rate) * norms
241 )
242 return w * (desired / (backend.epsilon() + norms))
244 @doc_controls.do_not_generate_docs
245 def get_config(self):
246 return {
247 "min_value": self.min_value,
248 "max_value": self.max_value,
249 "rate": self.rate,
250 "axis": self.axis,
251 }
254@keras_export(
255 "keras.constraints.RadialConstraint", "keras.constraints.radial_constraint"
256)
257class RadialConstraint(Constraint):
258 """Constrains `Conv2D` kernel weights to be the same for each radius.
260 Also available via the shortcut function
261 `tf.keras.constraints.radial_constraint`.
263 For example, the desired output for the following 4-by-4 kernel:
265 ```
266 kernel = [[v_00, v_01, v_02, v_03],
267 [v_10, v_11, v_12, v_13],
268 [v_20, v_21, v_22, v_23],
269 [v_30, v_31, v_32, v_33]]
270 ```
272 is this::
274 ```
275 kernel = [[v_11, v_11, v_11, v_11],
276 [v_11, v_33, v_33, v_11],
277 [v_11, v_33, v_33, v_11],
278 [v_11, v_11, v_11, v_11]]
279 ```
281 This constraint can be applied to any `Conv2D` layer version, including
282 `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"`
283 or `"channels_first"` data format. The method assumes the weight tensor is
284 of shape `(rows, cols, input_depth, output_depth)`.
285 """
287 @doc_controls.do_not_generate_docs
288 def __call__(self, w):
289 w_shape = w.shape
290 if w_shape.rank is None or w_shape.rank != 4:
291 raise ValueError(
292 "The weight tensor must have rank 4. "
293 f"Received weight tensor with shape: {w_shape}"
294 )
296 height, width, channels, kernels = w_shape
297 w = backend.reshape(w, (height, width, channels * kernels))
298 # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once
299 # backend.switch is supported.
300 w = backend.map_fn(
301 self._kernel_constraint,
302 backend.stack(tf.unstack(w, axis=-1), axis=0),
303 )
304 return backend.reshape(
305 backend.stack(tf.unstack(w, axis=0), axis=-1),
306 (height, width, channels, kernels),
307 )
309 def _kernel_constraint(self, kernel):
310 """Radially constraints a kernel with shape (height, width,
311 channels)."""
312 padding = backend.constant([[1, 1], [1, 1]], dtype="int32")
314 kernel_shape = backend.shape(kernel)[0]
315 start = backend.cast(kernel_shape / 2, "int32")
317 kernel_new = backend.switch(
318 backend.cast(tf.math.floormod(kernel_shape, 2), "bool"),
319 lambda: kernel[start - 1 : start, start - 1 : start],
320 lambda: kernel[start - 1 : start, start - 1 : start]
321 + backend.zeros((2, 2), dtype=kernel.dtype),
322 )
323 index = backend.switch(
324 backend.cast(tf.math.floormod(kernel_shape, 2), "bool"),
325 lambda: backend.constant(0, dtype="int32"),
326 lambda: backend.constant(1, dtype="int32"),
327 )
328 while_condition = lambda index, *args: backend.less(index, start)
330 def body_fn(i, array):
331 return i + 1, tf.pad(
332 array, padding, constant_values=kernel[start + i, start + i]
333 )
335 _, kernel_new = tf.compat.v1.while_loop(
336 while_condition,
337 body_fn,
338 [index, kernel_new],
339 shape_invariants=[index.get_shape(), tf.TensorShape([None, None])],
340 )
341 return kernel_new
344# Aliases.
346max_norm = MaxNorm
347non_neg = NonNeg
348unit_norm = UnitNorm
349min_max_norm = MinMaxNorm
350radial_constraint = RadialConstraint
352# Legacy aliases.
353maxnorm = max_norm
354nonneg = non_neg
355unitnorm = unit_norm
358@keras_export("keras.constraints.serialize")
359def serialize(constraint, use_legacy_format=False):
360 if use_legacy_format:
361 return legacy_serialization.serialize_keras_object(constraint)
362 return serialize_keras_object(constraint)
365@keras_export("keras.constraints.deserialize")
366def deserialize(config, custom_objects=None, use_legacy_format=False):
367 if use_legacy_format:
368 return legacy_serialization.deserialize_keras_object(
369 config,
370 module_objects=globals(),
371 custom_objects=custom_objects,
372 printable_module_name="constraint",
373 )
374 return deserialize_keras_object(
375 config,
376 module_objects=globals(),
377 custom_objects=custom_objects,
378 printable_module_name="constraint",
379 )
382@keras_export("keras.constraints.get")
383def get(identifier):
384 """Retrieves a Keras constraint function."""
385 if identifier is None:
386 return None
387 if isinstance(identifier, dict):
388 use_legacy_format = "module" not in identifier
389 return deserialize(identifier, use_legacy_format=use_legacy_format)
390 elif isinstance(identifier, str):
391 config = {"class_name": str(identifier), "config": {}}
392 return get(config)
393 elif callable(identifier):
394 return identifier
395 else:
396 raise ValueError(
397 f"Could not interpret constraint function identifier: {identifier}"
398 )