Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/layers/stochastic_depth.py: 38%
24 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1import tensorflow as tf
2from typeguard import typechecked
5@tf.keras.utils.register_keras_serializable(package="Addons")
6class StochasticDepth(tf.keras.layers.Layer):
7 """Stochastic Depth layer.
9 Implements Stochastic Depth as described in
10 [Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382), to randomly drop residual branches
11 in residual architectures.
13 Usage:
14 Residual architectures with fixed depth, use residual branches that are merged back into the main network
15 by adding the residual branch back to the input:
17 >>> input = np.ones((1, 3, 3, 1), dtype = np.float32)
18 >>> residual = tf.keras.layers.Conv2D(1, 1)(input)
19 >>> output = tf.keras.layers.Add()([input, residual])
20 >>> output.shape
21 TensorShape([1, 3, 3, 1])
23 StochasticDepth acts as a drop-in replacement for the addition:
25 >>> input = np.ones((1, 3, 3, 1), dtype = np.float32)
26 >>> residual = tf.keras.layers.Conv2D(1, 1)(input)
27 >>> output = tfa.layers.StochasticDepth()([input, residual])
28 >>> output.shape
29 TensorShape([1, 3, 3, 1])
31 At train time, StochasticDepth returns:
33 $$
34 x[0] + b_l * x[1],
35 $$
37 where $b_l$ is a random Bernoulli variable with probability $P(b_l = 1) = p_l$
39 At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability ($p_l$):
41 $$
42 x[0] + p_l * x[1]
43 $$
45 Args:
46 survival_probability: float, the probability of the residual branch being kept.
48 Call Args:
49 inputs: List of `[shortcut, residual]` where `shortcut`, and `residual` are tensors of equal shape.
51 Output shape:
52 Equal to the shape of inputs `shortcut`, and `residual`
53 """
55 @typechecked
56 def __init__(self, survival_probability: float = 0.5, **kwargs):
57 super().__init__(**kwargs)
59 self.survival_probability = survival_probability
61 def call(self, x, training=None):
62 if not isinstance(x, list) or len(x) != 2:
63 raise ValueError("input must be a list of length 2.")
65 shortcut, residual = x
67 # Random bernoulli variable indicating whether the branch should be kept or not or not
68 b_l = tf.keras.backend.random_bernoulli(
69 [], p=self.survival_probability, dtype=self._compute_dtype_object
70 )
72 def _call_train():
73 return shortcut + b_l * residual
75 def _call_test():
76 return shortcut + self.survival_probability * residual
78 return tf.keras.backend.in_train_phase(
79 _call_train, _call_test, training=training
80 )
82 def compute_output_shape(self, input_shape):
83 return input_shape[0]
85 def get_config(self):
86 base_config = super().get_config()
88 config = {"survival_probability": self.survival_probability}
90 return {**base_config, **config}