Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/normalization/spectral_normalization.py: 27%

45 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The Keras Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16import tensorflow.compat.v2 as tf 

17 

18from keras.src.initializers import TruncatedNormal 

19from keras.src.layers.rnn import Wrapper 

20 

21# isort: off 

22from tensorflow.python.util.tf_export import keras_export 

23 

24 

25# Adapted from TF-Addons implementation 

26@keras_export("keras.layers.SpectralNormalization", v1=[]) 

27class SpectralNormalization(Wrapper): 

28 """Performs spectral normalization on the weights of a target layer. 

29 

30 This wrapper controls the Lipschitz constant of the weights of a layer by 

31 constraining their spectral norm, which can stabilize the training of GANs. 

32 

33 Args: 

34 layer: A `keras.layers.Layer` instance that 

35 has either a `kernel` (e.g. `Conv2D`, `Dense`...) 

36 or an `embeddings` attribute (`Embedding` layer). 

37 power_iterations: int, the number of iterations during normalization. 

38 

39 Examples: 

40 

41 Wrap `keras.layers.Conv2D`: 

42 >>> x = np.random.rand(1, 10, 10, 1) 

43 >>> conv2d = SpectralNormalization(tf.keras.layers.Conv2D(2, 2)) 

44 >>> y = conv2d(x) 

45 >>> y.shape 

46 TensorShape([1, 9, 9, 2]) 

47 

48 Wrap `keras.layers.Dense`: 

49 >>> x = np.random.rand(1, 10, 10, 1) 

50 >>> dense = SpectralNormalization(tf.keras.layers.Dense(10)) 

51 >>> y = dense(x) 

52 >>> y.shape 

53 TensorShape([1, 10, 10, 10]) 

54 

55 Reference: 

56 

57 - [Spectral Normalization for GAN](https://arxiv.org/abs/1802.05957). 

58 """ 

59 

60 def __init__(self, layer, power_iterations=1, **kwargs): 

61 super().__init__(layer, **kwargs) 

62 if power_iterations <= 0: 

63 raise ValueError( 

64 "`power_iterations` should be greater than zero. Received: " 

65 f"`power_iterations={power_iterations}`" 

66 ) 

67 self.power_iterations = power_iterations 

68 

69 def build(self, input_shape): 

70 super().build(input_shape) 

71 input_shape = tf.TensorShape(input_shape) 

72 self.input_spec = tf.keras.layers.InputSpec( 

73 shape=[None] + input_shape[1:] 

74 ) 

75 

76 if hasattr(self.layer, "kernel"): 

77 self.kernel = self.layer.kernel 

78 elif hasattr(self.layer, "embeddings"): 

79 self.kernel = self.layer.embeddings 

80 else: 

81 raise ValueError( 

82 f"{type(self.layer).__name__} object has no attribute 'kernel' " 

83 "nor 'embeddings'" 

84 ) 

85 

86 self.kernel_shape = self.kernel.shape.as_list() 

87 

88 self.vector_u = self.add_weight( 

89 shape=(1, self.kernel_shape[-1]), 

90 initializer=TruncatedNormal(stddev=0.02), 

91 trainable=False, 

92 name="vector_u", 

93 dtype=self.kernel.dtype, 

94 ) 

95 

96 def call(self, inputs, training=False): 

97 if training: 

98 self.normalize_weights() 

99 

100 output = self.layer(inputs) 

101 return output 

102 

103 def compute_output_shape(self, input_shape): 

104 return tf.TensorShape( 

105 self.layer.compute_output_shape(input_shape).as_list() 

106 ) 

107 

108 def normalize_weights(self): 

109 """Generate spectral normalized weights. 

110 

111 This method will update the value of `self.kernel` with the 

112 spectral normalized value, so that the layer is ready for `call()`. 

113 """ 

114 

115 weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]]) 

116 vector_u = self.vector_u 

117 

118 # check for zeroes weights 

119 if not tf.reduce_all(tf.equal(weights, 0.0)): 

120 for _ in range(self.power_iterations): 

121 vector_v = tf.math.l2_normalize( 

122 tf.matmul(vector_u, weights, transpose_b=True) 

123 ) 

124 vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights)) 

125 vector_u = tf.stop_gradient(vector_u) 

126 vector_v = tf.stop_gradient(vector_v) 

127 sigma = tf.matmul( 

128 tf.matmul(vector_v, weights), vector_u, transpose_b=True 

129 ) 

130 self.vector_u.assign(tf.cast(vector_u, self.vector_u.dtype)) 

131 self.kernel.assign( 

132 tf.cast( 

133 tf.reshape(self.kernel / sigma, self.kernel_shape), 

134 self.kernel.dtype, 

135 ) 

136 ) 

137 

138 def get_config(self): 

139 config = {"power_iterations": self.power_iterations} 

140 base_config = super().get_config() 

141 return {**base_config, **config} 

142