Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/stochastic_depth.py: 38%

24 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1import tensorflow as tf 

2from typeguard import typechecked 

3 

4 

5@tf.keras.utils.register_keras_serializable(package="Addons") 

6class StochasticDepth(tf.keras.layers.Layer): 

7 """Stochastic Depth layer. 

8 

9 Implements Stochastic Depth as described in 

10 [Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382), to randomly drop residual branches 

11 in residual architectures. 

12 

13 Usage: 

14 Residual architectures with fixed depth, use residual branches that are merged back into the main network 

15 by adding the residual branch back to the input: 

16 

17 >>> input = np.ones((1, 3, 3, 1), dtype = np.float32) 

18 >>> residual = tf.keras.layers.Conv2D(1, 1)(input) 

19 >>> output = tf.keras.layers.Add()([input, residual]) 

20 >>> output.shape 

21 TensorShape([1, 3, 3, 1]) 

22 

23 StochasticDepth acts as a drop-in replacement for the addition: 

24 

25 >>> input = np.ones((1, 3, 3, 1), dtype = np.float32) 

26 >>> residual = tf.keras.layers.Conv2D(1, 1)(input) 

27 >>> output = tfa.layers.StochasticDepth()([input, residual]) 

28 >>> output.shape 

29 TensorShape([1, 3, 3, 1]) 

30 

31 At train time, StochasticDepth returns: 

32 

33 $$ 

34 x[0] + b_l * x[1], 

35 $$ 

36 

37 where $b_l$ is a random Bernoulli variable with probability $P(b_l = 1) = p_l$ 

38 

39 At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability ($p_l$): 

40 

41 $$ 

42 x[0] + p_l * x[1] 

43 $$ 

44 

45 Args: 

46 survival_probability: float, the probability of the residual branch being kept. 

47 

48 Call Args: 

49 inputs: List of `[shortcut, residual]` where `shortcut`, and `residual` are tensors of equal shape. 

50 

51 Output shape: 

52 Equal to the shape of inputs `shortcut`, and `residual` 

53 """ 

54 

55 @typechecked 

56 def __init__(self, survival_probability: float = 0.5, **kwargs): 

57 super().__init__(**kwargs) 

58 

59 self.survival_probability = survival_probability 

60 

61 def call(self, x, training=None): 

62 if not isinstance(x, list) or len(x) != 2: 

63 raise ValueError("input must be a list of length 2.") 

64 

65 shortcut, residual = x 

66 

67 # Random bernoulli variable indicating whether the branch should be kept or not or not 

68 b_l = tf.keras.backend.random_bernoulli( 

69 [], p=self.survival_probability, dtype=self._compute_dtype_object 

70 ) 

71 

72 def _call_train(): 

73 return shortcut + b_l * residual 

74 

75 def _call_test(): 

76 return shortcut + self.survival_probability * residual 

77 

78 return tf.keras.backend.in_train_phase( 

79 _call_train, _call_test, training=training 

80 ) 

81 

82 def compute_output_shape(self, input_shape): 

83 return input_shape[0] 

84 

85 def get_config(self): 

86 base_config = super().get_config() 

87 

88 config = {"survival_probability": self.survival_probability} 

89 

90 return {**base_config, **config}