Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/spectral_normalization.py: 23%

47 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15 

16import tensorflow as tf 

17from typeguard import typechecked 

18 

19 

20@tf.keras.utils.register_keras_serializable(package="Addons") 

21class SpectralNormalization(tf.keras.layers.Wrapper): 

22 """Performs spectral normalization on weights. 

23 

24 This wrapper controls the Lipschitz constant of the layer by 

25 constraining its spectral norm, which can stabilize the training of GANs. 

26 

27 See [Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957). 

28 

29 Wrap `tf.keras.layers.Conv2D`: 

30 

31 >>> x = np.random.rand(1, 10, 10, 1) 

32 >>> conv2d = SpectralNormalization(tf.keras.layers.Conv2D(2, 2)) 

33 >>> y = conv2d(x) 

34 >>> y.shape 

35 TensorShape([1, 9, 9, 2]) 

36 

37 Wrap `tf.keras.layers.Dense`: 

38 

39 >>> x = np.random.rand(1, 10, 10, 1) 

40 >>> dense = SpectralNormalization(tf.keras.layers.Dense(10)) 

41 >>> y = dense(x) 

42 >>> y.shape 

43 TensorShape([1, 10, 10, 10]) 

44 

45 Args: 

46 layer: A `tf.keras.layers.Layer` instance that 

47 has either `kernel` or `embeddings` attribute. 

48 power_iterations: `int`, the number of iterations during normalization. 

49 Raises: 

50 AssertionError: If not initialized with a `Layer` instance. 

51 ValueError: If initialized with negative `power_iterations`. 

52 AttributeError: If `layer` does not has `kernel` or `embeddings` attribute. 

53 """ 

54 

55 @typechecked 

56 def __init__(self, layer: tf.keras.layers, power_iterations: int = 1, **kwargs): 

57 super().__init__(layer, **kwargs) 

58 if power_iterations <= 0: 

59 raise ValueError( 

60 "`power_iterations` should be greater than zero, got " 

61 "`power_iterations={}`".format(power_iterations) 

62 ) 

63 self.power_iterations = power_iterations 

64 self._initialized = False 

65 

66 def build(self, input_shape): 

67 """Build `Layer`""" 

68 super().build(input_shape) 

69 input_shape = tf.TensorShape(input_shape) 

70 self.input_spec = tf.keras.layers.InputSpec(shape=[None] + input_shape[1:]) 

71 

72 if hasattr(self.layer, "kernel"): 

73 self.w = self.layer.kernel 

74 elif hasattr(self.layer, "embeddings"): 

75 self.w = self.layer.embeddings 

76 else: 

77 raise AttributeError( 

78 "{} object has no attribute 'kernel' nor " 

79 "'embeddings'".format(type(self.layer).__name__) 

80 ) 

81 

82 self.w_shape = self.w.shape.as_list() 

83 

84 self.u = self.add_weight( 

85 shape=(1, self.w_shape[-1]), 

86 initializer=tf.initializers.TruncatedNormal(stddev=0.02), 

87 trainable=False, 

88 name="sn_u", 

89 dtype=self.w.dtype, 

90 ) 

91 

92 def call(self, inputs, training=None): 

93 """Call `Layer`""" 

94 if training is None: 

95 training = tf.keras.backend.learning_phase() 

96 

97 if training: 

98 self.normalize_weights() 

99 

100 output = self.layer(inputs) 

101 return output 

102 

103 def compute_output_shape(self, input_shape): 

104 return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list()) 

105 

106 def normalize_weights(self): 

107 """Generate spectral normalized weights. 

108 

109 This method will update the value of `self.w` with the 

110 spectral normalized value, so that the layer is ready for `call()`. 

111 """ 

112 

113 w = tf.reshape(self.w, [-1, self.w_shape[-1]]) 

114 u = self.u 

115 

116 with tf.name_scope("spectral_normalize"): 

117 for _ in range(self.power_iterations): 

118 v = tf.math.l2_normalize(tf.matmul(u, w, transpose_b=True)) 

119 u = tf.math.l2_normalize(tf.matmul(v, w)) 

120 u = tf.stop_gradient(u) 

121 v = tf.stop_gradient(v) 

122 sigma = tf.matmul(tf.matmul(v, w), u, transpose_b=True) 

123 self.u.assign(tf.cast(u, self.u.dtype)) 

124 self.w.assign( 

125 tf.cast(tf.reshape(self.w / sigma, self.w_shape), self.w.dtype) 

126 ) 

127 

128 def get_config(self): 

129 config = {"power_iterations": self.power_iterations} 

130 base_config = super().get_config() 

131 return {**base_config, **config}