Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/normalization/spectral_normalization.py: 27%
45 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The Keras Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16import tensorflow.compat.v2 as tf
18from keras.src.initializers import TruncatedNormal
19from keras.src.layers.rnn import Wrapper
21# isort: off
22from tensorflow.python.util.tf_export import keras_export
25# Adapted from TF-Addons implementation
26@keras_export("keras.layers.SpectralNormalization", v1=[])
27class SpectralNormalization(Wrapper):
28 """Performs spectral normalization on the weights of a target layer.
30 This wrapper controls the Lipschitz constant of the weights of a layer by
31 constraining their spectral norm, which can stabilize the training of GANs.
33 Args:
34 layer: A `keras.layers.Layer` instance that
35 has either a `kernel` (e.g. `Conv2D`, `Dense`...)
36 or an `embeddings` attribute (`Embedding` layer).
37 power_iterations: int, the number of iterations during normalization.
39 Examples:
41 Wrap `keras.layers.Conv2D`:
42 >>> x = np.random.rand(1, 10, 10, 1)
43 >>> conv2d = SpectralNormalization(tf.keras.layers.Conv2D(2, 2))
44 >>> y = conv2d(x)
45 >>> y.shape
46 TensorShape([1, 9, 9, 2])
48 Wrap `keras.layers.Dense`:
49 >>> x = np.random.rand(1, 10, 10, 1)
50 >>> dense = SpectralNormalization(tf.keras.layers.Dense(10))
51 >>> y = dense(x)
52 >>> y.shape
53 TensorShape([1, 10, 10, 10])
55 Reference:
57 - [Spectral Normalization for GAN](https://arxiv.org/abs/1802.05957).
58 """
60 def __init__(self, layer, power_iterations=1, **kwargs):
61 super().__init__(layer, **kwargs)
62 if power_iterations <= 0:
63 raise ValueError(
64 "`power_iterations` should be greater than zero. Received: "
65 f"`power_iterations={power_iterations}`"
66 )
67 self.power_iterations = power_iterations
69 def build(self, input_shape):
70 super().build(input_shape)
71 input_shape = tf.TensorShape(input_shape)
72 self.input_spec = tf.keras.layers.InputSpec(
73 shape=[None] + input_shape[1:]
74 )
76 if hasattr(self.layer, "kernel"):
77 self.kernel = self.layer.kernel
78 elif hasattr(self.layer, "embeddings"):
79 self.kernel = self.layer.embeddings
80 else:
81 raise ValueError(
82 f"{type(self.layer).__name__} object has no attribute 'kernel' "
83 "nor 'embeddings'"
84 )
86 self.kernel_shape = self.kernel.shape.as_list()
88 self.vector_u = self.add_weight(
89 shape=(1, self.kernel_shape[-1]),
90 initializer=TruncatedNormal(stddev=0.02),
91 trainable=False,
92 name="vector_u",
93 dtype=self.kernel.dtype,
94 )
96 def call(self, inputs, training=False):
97 if training:
98 self.normalize_weights()
100 output = self.layer(inputs)
101 return output
103 def compute_output_shape(self, input_shape):
104 return tf.TensorShape(
105 self.layer.compute_output_shape(input_shape).as_list()
106 )
108 def normalize_weights(self):
109 """Generate spectral normalized weights.
111 This method will update the value of `self.kernel` with the
112 spectral normalized value, so that the layer is ready for `call()`.
113 """
115 weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
116 vector_u = self.vector_u
118 # check for zeroes weights
119 if not tf.reduce_all(tf.equal(weights, 0.0)):
120 for _ in range(self.power_iterations):
121 vector_v = tf.math.l2_normalize(
122 tf.matmul(vector_u, weights, transpose_b=True)
123 )
124 vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
125 vector_u = tf.stop_gradient(vector_u)
126 vector_v = tf.stop_gradient(vector_v)
127 sigma = tf.matmul(
128 tf.matmul(vector_v, weights), vector_u, transpose_b=True
129 )
130 self.vector_u.assign(tf.cast(vector_u, self.vector_u.dtype))
131 self.kernel.assign(
132 tf.cast(
133 tf.reshape(self.kernel / sigma, self.kernel_shape),
134 self.kernel.dtype,
135 )
136 )
138 def get_config(self):
139 config = {"power_iterations": self.power_iterations}
140 base_config = super().get_config()
141 return {**base_config, **config}