Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/spectral_normalization.py: 23%
47 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
16import tensorflow as tf
17from typeguard import typechecked
20@tf.keras.utils.register_keras_serializable(package="Addons")
21class SpectralNormalization(tf.keras.layers.Wrapper):
22 """Performs spectral normalization on weights.
24 This wrapper controls the Lipschitz constant of the layer by
25 constraining its spectral norm, which can stabilize the training of GANs.
27 See [Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957).
29 Wrap `tf.keras.layers.Conv2D`:
31 >>> x = np.random.rand(1, 10, 10, 1)
32 >>> conv2d = SpectralNormalization(tf.keras.layers.Conv2D(2, 2))
33 >>> y = conv2d(x)
34 >>> y.shape
35 TensorShape([1, 9, 9, 2])
37 Wrap `tf.keras.layers.Dense`:
39 >>> x = np.random.rand(1, 10, 10, 1)
40 >>> dense = SpectralNormalization(tf.keras.layers.Dense(10))
41 >>> y = dense(x)
42 >>> y.shape
43 TensorShape([1, 10, 10, 10])
45 Args:
46 layer: A `tf.keras.layers.Layer` instance that
47 has either `kernel` or `embeddings` attribute.
48 power_iterations: `int`, the number of iterations during normalization.
49 Raises:
50 AssertionError: If not initialized with a `Layer` instance.
51 ValueError: If initialized with negative `power_iterations`.
52 AttributeError: If `layer` does not has `kernel` or `embeddings` attribute.
53 """
55 @typechecked
56 def __init__(self, layer: tf.keras.layers, power_iterations: int = 1, **kwargs):
57 super().__init__(layer, **kwargs)
58 if power_iterations <= 0:
59 raise ValueError(
60 "`power_iterations` should be greater than zero, got "
61 "`power_iterations={}`".format(power_iterations)
62 )
63 self.power_iterations = power_iterations
64 self._initialized = False
66 def build(self, input_shape):
67 """Build `Layer`"""
68 super().build(input_shape)
69 input_shape = tf.TensorShape(input_shape)
70 self.input_spec = tf.keras.layers.InputSpec(shape=[None] + input_shape[1:])
72 if hasattr(self.layer, "kernel"):
73 self.w = self.layer.kernel
74 elif hasattr(self.layer, "embeddings"):
75 self.w = self.layer.embeddings
76 else:
77 raise AttributeError(
78 "{} object has no attribute 'kernel' nor "
79 "'embeddings'".format(type(self.layer).__name__)
80 )
82 self.w_shape = self.w.shape.as_list()
84 self.u = self.add_weight(
85 shape=(1, self.w_shape[-1]),
86 initializer=tf.initializers.TruncatedNormal(stddev=0.02),
87 trainable=False,
88 name="sn_u",
89 dtype=self.w.dtype,
90 )
92 def call(self, inputs, training=None):
93 """Call `Layer`"""
94 if training is None:
95 training = tf.keras.backend.learning_phase()
97 if training:
98 self.normalize_weights()
100 output = self.layer(inputs)
101 return output
103 def compute_output_shape(self, input_shape):
104 return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())
106 def normalize_weights(self):
107 """Generate spectral normalized weights.
109 This method will update the value of `self.w` with the
110 spectral normalized value, so that the layer is ready for `call()`.
111 """
113 w = tf.reshape(self.w, [-1, self.w_shape[-1]])
114 u = self.u
116 with tf.name_scope("spectral_normalize"):
117 for _ in range(self.power_iterations):
118 v = tf.math.l2_normalize(tf.matmul(u, w, transpose_b=True))
119 u = tf.math.l2_normalize(tf.matmul(v, w))
120 u = tf.stop_gradient(u)
121 v = tf.stop_gradient(v)
122 sigma = tf.matmul(tf.matmul(v, w), u, transpose_b=True)
123 self.u.assign(tf.cast(u, self.u.dtype))
124 self.w.assign(
125 tf.cast(tf.reshape(self.w / sigma, self.w_shape), self.w.dtype)
126 )
128 def get_config(self):
129 config = {"power_iterations": self.power_iterations}
130 base_config = super().get_config()
131 return {**base_config, **config}