Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/constraints.py: 54%
107 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16# pylint: disable=g-classes-have-attributes
17"""Constraints: functions that impose constraints on weight values."""
19from tensorflow.python.framework import tensor_shape
20from tensorflow.python.keras import backend
21from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
22from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import array_ops_stack
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import while_loop
27from tensorflow.python.util.tf_export import keras_export
28from tensorflow.tools.docs import doc_controls
31@keras_export('keras.constraints.Constraint')
32class Constraint:
33 """Base class for weight constraints.
35 A `Constraint` instance works like a stateless function.
36 Users who subclass this
37 class should override the `__call__` method, which takes a single
38 weight parameter and return a projected version of that parameter
39 (e.g. normalized or clipped). Constraints can be used with various Keras
40 layers via the `kernel_constraint` or `bias_constraint` arguments.
42 Here's a simple example of a non-negative weight constraint:
44 >>> class NonNegative(tf.keras.constraints.Constraint):
45 ...
46 ... def __call__(self, w):
47 ... return w * tf.cast(tf.math.greater_equal(w, 0.), w.dtype)
49 >>> weight = tf.constant((-1.0, 1.0))
50 >>> NonNegative()(weight)
51 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 1.], dtype=float32)>
53 >>> tf.keras.layers.Dense(4, kernel_constraint=NonNegative())
54 """
56 def __call__(self, w):
57 """Applies the constraint to the input weight variable.
59 By default, the inputs weight variable is not modified.
60 Users should override this method to implement their own projection
61 function.
63 Args:
64 w: Input weight variable.
66 Returns:
67 Projected variable (by default, returns unmodified inputs).
68 """
69 return w
71 def get_config(self):
72 """Returns a Python dict of the object config.
74 A constraint config is a Python dictionary (JSON-serializable) that can
75 be used to reinstantiate the same object.
77 Returns:
78 Python dict containing the configuration of the constraint object.
79 """
80 return {}
83@keras_export('keras.constraints.MaxNorm', 'keras.constraints.max_norm')
84class MaxNorm(Constraint):
85 """MaxNorm weight constraint.
87 Constrains the weights incident to each hidden unit
88 to have a norm less than or equal to a desired value.
90 Also available via the shortcut function `tf.keras.constraints.max_norm`.
92 Args:
93 max_value: the maximum norm value for the incoming weights.
94 axis: integer, axis along which to calculate weight norms.
95 For instance, in a `Dense` layer the weight matrix
96 has shape `(input_dim, output_dim)`,
97 set `axis` to `0` to constrain each weight vector
98 of length `(input_dim,)`.
99 In a `Conv2D` layer with `data_format="channels_last"`,
100 the weight tensor has shape
101 `(rows, cols, input_depth, output_depth)`,
102 set `axis` to `[0, 1, 2]`
103 to constrain the weights of each filter tensor of size
104 `(rows, cols, input_depth)`.
106 """
108 def __init__(self, max_value=2, axis=0):
109 self.max_value = max_value
110 self.axis = axis
112 @doc_controls.do_not_generate_docs
113 def __call__(self, w):
114 norms = backend.sqrt(
115 math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
116 desired = backend.clip(norms, 0, self.max_value)
117 return w * (desired / (backend.epsilon() + norms))
119 @doc_controls.do_not_generate_docs
120 def get_config(self):
121 return {'max_value': self.max_value, 'axis': self.axis}
124@keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg')
125class NonNeg(Constraint):
126 """Constrains the weights to be non-negative.
128 Also available via the shortcut function `tf.keras.constraints.non_neg`.
129 """
131 def __call__(self, w):
132 return w * math_ops.cast(math_ops.greater_equal(w, 0.), backend.floatx())
135@keras_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm')
136class UnitNorm(Constraint):
137 """Constrains the weights incident to each hidden unit to have unit norm.
139 Also available via the shortcut function `tf.keras.constraints.unit_norm`.
141 Args:
142 axis: integer, axis along which to calculate weight norms.
143 For instance, in a `Dense` layer the weight matrix
144 has shape `(input_dim, output_dim)`,
145 set `axis` to `0` to constrain each weight vector
146 of length `(input_dim,)`.
147 In a `Conv2D` layer with `data_format="channels_last"`,
148 the weight tensor has shape
149 `(rows, cols, input_depth, output_depth)`,
150 set `axis` to `[0, 1, 2]`
151 to constrain the weights of each filter tensor of size
152 `(rows, cols, input_depth)`.
153 """
155 def __init__(self, axis=0):
156 self.axis = axis
158 @doc_controls.do_not_generate_docs
159 def __call__(self, w):
160 return w / (
161 backend.epsilon() + backend.sqrt(
162 math_ops.reduce_sum(
163 math_ops.square(w), axis=self.axis, keepdims=True)))
165 @doc_controls.do_not_generate_docs
166 def get_config(self):
167 return {'axis': self.axis}
170@keras_export('keras.constraints.MinMaxNorm', 'keras.constraints.min_max_norm')
171class MinMaxNorm(Constraint):
172 """MinMaxNorm weight constraint.
174 Constrains the weights incident to each hidden unit
175 to have the norm between a lower bound and an upper bound.
177 Also available via the shortcut function `tf.keras.constraints.min_max_norm`.
179 Args:
180 min_value: the minimum norm for the incoming weights.
181 max_value: the maximum norm for the incoming weights.
182 rate: rate for enforcing the constraint: weights will be
183 rescaled to yield
184 `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.
185 Effectively, this means that rate=1.0 stands for strict
186 enforcement of the constraint, while rate<1.0 means that
187 weights will be rescaled at each step to slowly move
188 towards a value inside the desired interval.
189 axis: integer, axis along which to calculate weight norms.
190 For instance, in a `Dense` layer the weight matrix
191 has shape `(input_dim, output_dim)`,
192 set `axis` to `0` to constrain each weight vector
193 of length `(input_dim,)`.
194 In a `Conv2D` layer with `data_format="channels_last"`,
195 the weight tensor has shape
196 `(rows, cols, input_depth, output_depth)`,
197 set `axis` to `[0, 1, 2]`
198 to constrain the weights of each filter tensor of size
199 `(rows, cols, input_depth)`.
200 """
202 def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0):
203 self.min_value = min_value
204 self.max_value = max_value
205 self.rate = rate
206 self.axis = axis
208 @doc_controls.do_not_generate_docs
209 def __call__(self, w):
210 norms = backend.sqrt(
211 math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
212 desired = (
213 self.rate * backend.clip(norms, self.min_value, self.max_value) +
214 (1 - self.rate) * norms)
215 return w * (desired / (backend.epsilon() + norms))
217 @doc_controls.do_not_generate_docs
218 def get_config(self):
219 return {
220 'min_value': self.min_value,
221 'max_value': self.max_value,
222 'rate': self.rate,
223 'axis': self.axis
224 }
227@keras_export('keras.constraints.RadialConstraint',
228 'keras.constraints.radial_constraint')
229class RadialConstraint(Constraint):
230 """Constrains `Conv2D` kernel weights to be the same for each radius.
232 Also available via the shortcut function
233 `tf.keras.constraints.radial_constraint`.
235 For example, the desired output for the following 4-by-4 kernel:
237 ```
238 kernel = [[v_00, v_01, v_02, v_03],
239 [v_10, v_11, v_12, v_13],
240 [v_20, v_21, v_22, v_23],
241 [v_30, v_31, v_32, v_33]]
242 ```
244 is this::
246 ```
247 kernel = [[v_11, v_11, v_11, v_11],
248 [v_11, v_33, v_33, v_11],
249 [v_11, v_33, v_33, v_11],
250 [v_11, v_11, v_11, v_11]]
251 ```
253 This constraint can be applied to any `Conv2D` layer version, including
254 `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or
255 `"channels_first"` data format. The method assumes the weight tensor is of
256 shape `(rows, cols, input_depth, output_depth)`.
257 """
259 @doc_controls.do_not_generate_docs
260 def __call__(self, w):
261 w_shape = w.shape
262 if w_shape.rank is None or w_shape.rank != 4:
263 raise ValueError(
264 'The weight tensor must be of rank 4, but is of shape: %s' % w_shape)
266 height, width, channels, kernels = w_shape
267 w = backend.reshape(w, (height, width, channels * kernels))
268 # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once
269 # backend.switch is supported.
270 w = backend.map_fn(
271 self._kernel_constraint,
272 backend.stack(array_ops_stack.unstack(w, axis=-1), axis=0))
273 return backend.reshape(
274 backend.stack(array_ops_stack.unstack(w, axis=0), axis=-1),
275 (height, width, channels, kernels))
277 def _kernel_constraint(self, kernel):
278 """Radially constraints a kernel with shape (height, width, channels)."""
279 padding = backend.constant([[1, 1], [1, 1]], dtype='int32')
281 kernel_shape = backend.shape(kernel)[0]
282 start = backend.cast(kernel_shape / 2, 'int32')
284 kernel_new = backend.switch(
285 backend.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
286 lambda: kernel[start - 1:start, start - 1:start],
287 lambda: kernel[start - 1:start, start - 1:start] + backend.zeros( # pylint: disable=g-long-lambda
288 (2, 2), dtype=kernel.dtype))
289 index = backend.switch(
290 backend.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
291 lambda: backend.constant(0, dtype='int32'),
292 lambda: backend.constant(1, dtype='int32'))
293 while_condition = lambda index, *args: backend.less(index, start)
295 def body_fn(i, array):
296 return i + 1, array_ops.pad(
297 array,
298 padding,
299 constant_values=kernel[start + i, start + i])
301 _, kernel_new = while_loop.while_loop(
302 while_condition,
303 body_fn, [index, kernel_new],
304 shape_invariants=[
305 index.get_shape(),
306 tensor_shape.TensorShape([None, None])
307 ])
308 return kernel_new
311# Aliases.
313max_norm = MaxNorm
314non_neg = NonNeg
315unit_norm = UnitNorm
316min_max_norm = MinMaxNorm
317radial_constraint = RadialConstraint
319# Legacy aliases.
320maxnorm = max_norm
321nonneg = non_neg
322unitnorm = unit_norm
325@keras_export('keras.constraints.serialize')
326def serialize(constraint):
327 return serialize_keras_object(constraint)
330@keras_export('keras.constraints.deserialize')
331def deserialize(config, custom_objects=None):
332 return deserialize_keras_object(
333 config,
334 module_objects=globals(),
335 custom_objects=custom_objects,
336 printable_module_name='constraint')
339@keras_export('keras.constraints.get')
340def get(identifier):
341 if identifier is None:
342 return None
343 if isinstance(identifier, dict):
344 return deserialize(identifier)
345 elif isinstance(identifier, str):
346 config = {'class_name': str(identifier), 'config': {}}
347 return deserialize(config)
348 elif callable(identifier):
349 return identifier
350 else:
351 raise ValueError('Could not interpret constraint identifier: ' +
352 str(identifier))