When is a deep network trainable?

  • Why is the initialisation of a network so important?
  • Deeper networks seem to work generally better, provided they can be trained

An abstraction of neural networks

Consider a randomly initialised neural network.

For each layer \(l\)

\[h^l_i = \sum_j W^l_{ij} y^l_j + b_i^l\] \[ y^{l+1}_i = \phi\left(h^l_i\right) \]

where \(W^{l}_{ij} \sim N\left(0, \frac{\sigma_w^2}{N_l}\right)\) and \(b^l_i \sim N(0, \sigma^2_b)\).

The signal input is \(y^0_i = x_i.\) The numbers \(\sigma_w^2\) and \(\sigma_b^2\) are constants for the whole network.

The mean-field approximation

Each \(h_i\) is assumed to be approximately Gaussian.

  • This is consistent, as each \(h_i\) is a sum of many random variables in a wide network
  • The Gaussians are centered at zero and have a layer-specific variance \(q^l\), \(\left<h^l_i h^l_j\right> = q^l \delta_{ij}\)
  • All activations in a layer are statistically the same type of variable
  • Idea: Ignore the microscopic layer details, instead, understand how the parameters of the statistical distribution changes from layer to layer
  • Each layer is statistically described by a very small set of numbers

Recursion relations

Find mathematical relations between the layer parameters.

  • \[q^l = \frac{1}{N_{l}} \sum_{i = 1}^{N_l} \left(h^l_i\right)^2\]
  • Plugging in the definitions \[q^l = \left< \left(h^l_i\right)^2 \right> = \left< W^l_i \cdot \phi\left(h^{l-1}\right) \right> + \left<\left(b^l_i\right)^2\right>\]
  • \[= \frac{\sigma^2_w}{N_{l-1}} \sum_{i - 1}^{N_{l-1}} \left<\left(\phi\left(h_i^{l-1}\right)\right)^2 \right> + \sigma_b^2\]

Recursion relations (cont.)

But, all the activations in the previous layers are also identically Gaussian distributed

\[\frac{1}{N_{l-1}} \sum_{i=1}^{N_{l-1}}\left<\left(\phi\left(h_i^{l-1}\right)\right)^2\right> = \int_{-\infty}^{\infty} \frac{dz}{\sqrt{2\pi}} \mathcal{e}^{-\frac{z^2}{2}} \phi^2\left(\sqrt{q^{l-1}}z\right)\]

  • with \(Dz = \frac{dz}{\sqrt{2\pi}} \mathcal{e}^{-\frac{z^2}{2}}\) we get a recursion relation \[q^l = \sigma_w^2 \int Dz \phi^2\left(\sqrt{q^{l-1}}z\right) + \sigma_b^2\]
  • \[q^l = f(q^{l-1}| \sigma^2_w, \sigma^2_b)\]
  • Details depend in the activation function \(\phi\)

Fixed point Analysis

At a fixed point \(q^*\), the variance maps to itself

\[q^* = f\left(q^* | \sigma^2_w, \sigma^2_b\right)\]

The stability of a fixed point depends on the slope of \(f\) at \(q^*\)

Fixed point Analysis (cont.)

Fixed point Analysis (cont.)

Fixed point Analysis (cont.)

Fixed point Analysis (cont.)

Stable Networks?

Depending on the values of \(\sigma^2_w\) and \(\sigma^2_b\), the network has nonzero fixpoint for \(q\), it is always stable for \(\tanh\) activation.

More than just variance

Covariance measures how similar signals are. We would expect that similar inputs have similar outputs.

\[q^l_{ab} = \sum_{i=1}^{N_{l-1}} h^l_i(x_a^0)h^l_i(x_b^0)\]

  • Correlation, which is a normalised covariance, also has a simple recursion \[c_{12} = \frac{\sigma_w^2}{q^*}\int Dz_1 Dz_2 \phi(u_1)\phi(u_2) + \sigma_b^2\] \[u_1 = \sqrt{q^*} z_1\mbox{ and } u_2 = \sqrt{q}\left[c^{l-1}_{12} z_1 + \sqrt{1-(c^{l-1}_{12})^2}z_2\right]\]

Phase Transition!

One can show that \(c_{12} = 1\) is always a fixed point, but it is not stables as

\[\chi_1 = \frac{\partial c^l_{12}}{\partial c^{l-1}_{12}} = \sigma_w^2 \int D z_1 Dz_2 \phi'(u_1)\phi'(u_2)\]

  • If stable, \(c^*_{12} = 1\), and similar inputs have similar outputs
  • If unstable, \(c^*_{12} < 1 \rightarrow 0\), and similar inputs have more and more dissimilar outputs -> chaos and butterfly effect

Correlation Map

Chaos Transition

  • They also show that deeper architectures are better at disentangling complicated manifolds

Phase Diagram

Two length scales

\[\left|q^l - q^*\right|\sim \mathrm{e}^{-l/\xi_q} \hspace{1cm} \left|c^l - c^*\right|\sim \mathrm{e}^{-l/\xi_c}\]

  • \(\xi_1\) length scale of how deep a single signal can penetrate the network \[\xi_q^{-1} = - \log \left[ \chi_1 + \sigma_w^2 \int Dz \phi''(\sqrt{q*}z)\phi(\sqrt{q*}z) \right] \]
  • \[\xi_c^{-1} = - \log \left[ \sigma_w^2 \int Dz \phi'(\sqrt{q*}z)\phi'(\sqrt{q*}z) \right] = (-\log \chi_1)^{-1} \]

Length scales diverge

For tanh activation, \(\xi_q\) stays always finite, but can become large close to a transition. At a transition, \(\xi_c\) always diverges

Interlude: Critical Opalescence

  • Known from nature
  • Transparent liquids/gases become milky at special points in temperature/pressure space.
  • Explanation: Close to transition, droplets of many length scales coexist that can scatter all sorts of visible light
  • Critical Opalescence video

Dropout destroys the phase transition

Even small amounts of dropouts make \(c = 1\) fixed point unstable

  • Signal to noise ratio gets worse and worse

Backpropagation Goldilocks

Apply similar reasoning to backpropagation and find

\[q_{aa}^l = q_{aa}^L \mathrm{e}^{-\frac{L-l}{\xi_\nabla}}, \hspace{1cm} \xi^{-1}_\nabla = -\log \chi_1\]

  • In ordered phase, gradient dies off \(0 < \xi_{\nabla} < C\)
  • In chaotic phase, \(\xi_\nabla < 0\)
  • At criticality \(\xi_\nabla \sim \infty\)

Backprop Experiments

Covariance hypothesis

Observation: Covariance between gradients can be shown to still follow the \(\xi_c\) lengthscale

  • Hypothesis: \(\xi_c\) determines whether a network can learn or not
  • Ratio \(\frac{\xi_c}{L}\) decides

Experiments

The colour indicates the training accuracy on MNIST, red is good

With dropout

Summary

  • Mean field theory is an approximation to study qualitative behaviour of neural networks
  • Ordered and chaotic phases exists, depending on initialisation parameters \(\sigma_w^2\) and \(\sigma_b^2\)
  • Phase transitions are accompanied by a diverging length scale \(\xi_c\)
  • Training is possible when \(\xi_c \geq L\)