15 November 2020

The Problem

  • Forecasts for multiple time steps is hard
  • Many methods lack explainability
  • Complex interaction of known and unknown inputs
  • Traditional explainable methods often do not work with time-order
  • Forecasts are very important (\($\))

The Data Model

The Temporal Fusion Transformer (TFT) - Novelties

  • Static covariate encoders
  • Gating mechanisms to allow for less complex models
  • A sequence-to-sequence layer to process known and observed inputs
  • A temporal self-attention mechanism
  • Explainability gives:
  • Globally important variables
  • Persistent temporal patterns
  • Significant events

Multi-horizon Forecasting - The problem structure

A data set has \(I\) unique entities (e.g. stores in a chain) and \(t \in [0,T_i]\) timesteps

  • \(\mathbf{s}_i \in \mathbb{R}^{m_s}\) are static covariates
  • \(\chi_{i,t} \in \mathbb{R}^{m_{\chi}}\) are time-dependent entity-dependent inputs
  • \(\chi_{i,t} = [z_{i,t}^T, x_{i,t}^T]^T\) where \(z_{i,t} \in \mathbb{R}^{m_z}\) are unknown beforehand and \(x_{i,t} \in \mathbb{R}^{mx}\) are known
  • \(y_{i,t} \in \mathbb{R}^{}\) are targets

Forcasting quantiles

\[\hat{y}_i(q,t,\tau) = f_q\left(\tau, y_{i, t-k:t}, \mathbb{z}_{i,t-k:t}, \mathbf{x}_{i,t-k:(t+\tau)},\mathbf{s}_i\right)\]

where

  • \(q\) is the quantile of the prediction
  • \(\tau \in \{1,\dots, \tau_{max}\}\) is how many steps ahead we want to predict
  • \(t\) is current time
  • Note how \(\mathbf{x}\) and \(\mathbf{z}\) have different lookback settings.

Model Architecture

Major Components

Static Variable encoder

Four contexts are produced they feed into

  • Variable selection
  • Local processing of temporal features (two contexts)
  • Enriching temporal features with static information
  • (practically anywhere)

GRN

Gating mechanisms for adaptive depth/complexity

Without context, \(c=0\) Unit input and output have same dimension \(d_{\mbox{model}}\) which is shared across whole TFT

Variable selection networks

Variable selection networks

Variable selection networks (cont.)

  • Inputs transformed to \(d_{model}\) dimensions: Linearly for continuous variables, entity embedding for categoricals
  • Each feature has own non-linear processing step (tilde in pic missing) with weights shared across time
  • \(\nu_{X,t}\) is itself interesting for interpretation
  • Noisy variables can already be silenced early on

Local Encoder

Local patterns can be trained by a sequence-to-sequence encoder

  • Uses two LSTMs, but these could be replaced by other models
  • The new state of the model is a "sentence" of the form \[\phi(t,n) \in \{\phi(t,-k), \dots, \phi(t,\tau_{max}))\}\]
  • No clear distinction between past and future at this point

Multihead attention

Attention introduces additional dimensions \(d_{\mbox{attn}}\) and \(d_{\mbox{value}}\)

\[\mbox{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V}) = A\left(\mathbf{Q},\mathbf{K}\right)\mathbf{V}\]

  • Per head \[H_h = \mbox{Attention}\left(\mathbf{Q}\mathbf{W}_Q^{(h)},\mathbf{K}\mathbf{W}_K^{(h)},\mathbf{V}\mathbf{W}_V^{(h)}\right)\]
  • \[\mbox{Multihead}(\mathbf{Q},\mathbf{K},\mathbf{V}) = \left[ H_1,\dots,H_{m_H}\right] \mathbf{W}_H\] With \(\mathbf{W}_Q,\mathbf{W}_K \in \mathbb{R}^{d_{model}\times d_{attn}}\), \(\mathbf{W}_V \in \mathbb{R}^{d_{model}\times d_{V}}\) and \(\mathbf{W}_H \in \mathbb{R}^{(m_Hd_V)\times d_{model}}\)

Interpretable

Multi-head attention is less interpretable, so use shared value matrix and average attention

\[\tilde{H} = \left(\frac{1}{m_H} \sum_{1}^{m_H} A\left(\mathbf{Q},\mathbf{K}\right)\right) \mathbf{V}\mathbf{W}_{V}\]

  • I don't understand why \(m_H\) attention matrices are less interpretable

Here: self-attention

  • \(\mathbf{Q} = \mathbf{K} = \mathbf{V} = \mathbf{\Phi}\) with the additional choices \(d_V = d_{attn} = d_{model}/m_H\)
  • Decoder masking is applied
  • Not sure whether this is necessary, as only known future state flows into the layer

Prediction intervals

Last step is position-wise feed-forward with shared weights

  • At each point different quantiles are predicted, \(q = 0.1, 0.5, 0.9\)
  • Forecast only the future

The Loss

\[L\left(\Omega, \mathbf{W}\right) = \sum_{y_t \in \Omega} \sum_{q \in \mathcal{Q}}\sum_{\tau=1}^{\tau_{max}} \frac{QL\left(y_t, \hat{y}(q,t-\tau, \tau),q\right)}{M\tau_{max}}\]

\[QL(y,\hat{y},q) = q\left(y-\hat{y}\right)_++(1-q)\left(\hat{y}-y\right)_+\]

for \(M\) samples

  • \(\hat{y} = y_q\) minimises \(\mathcal{E}_{\sim y}\left[QL(y,\hat{y}, q)\right]\)
  • Testloss only on \(q=0.5,0.9\) \[\mbox{q-Risk} = \frac{2\sum_{y_t \in \tilde{\Omega}}\sum_{\tau=1}^{\tau_{max}}QL(y_t,\hat{y}(q,t-\tau,\tau),q)}{\sum_{y_t\in\tilde{\Omega}}\sum_{\tau = 1}^{\tau_{max}}|y_t|}\]

Data sets

  • Electricity: Benchmarking, 1 week hourly customer data to forecast next 24 hours
  • Traffic: Benchmarking, 1 week data to forecast 24hr occupancy rate of SF freeways
  • Retail: Full complexity, sales 30-day forecast per product/store, 90 days of past data
  • Volatility: Small and noisy data, daily volatility values of 31 stocks, 5 day forecast using 252 day information

Training Procedure

  • 3-way train-validation-test split
  • Hyperparameter random searched in 60 iterations
  • Hyperparameters: \(d_{model}\), dropout rate, minibatch size, learning rate, max gradient norm, num. heads
  • Single GPU training, Electricity took 6 hours

Hyperparameter results

Benchmark results

Difficult dataset results

Ablation results

Interpretability

  • Individual Variable Importance

  • Persistent Temporal Patterns

  • Identifying interesting regime changes

Variable Importance

Extracted from the Variable Selection step by sampling \(\nu_i\) and recording quantiles

Temporal Patterns

Patterns are established by sampling and recording the quantiles of attention layer

1 Step ahead and multi-horizon forecast

1 Step ahead and multi-horizon forecast

Patterns vary between data sets

Identifying Regimes

Calculate the average attention per entity position and forecast horizon \[\bar{\alpha}(n,\tau) = \frac{1}{T}\sum_{t=1}^T \alpha(t,n,\tau)\] These form a distribution over the positions because \(\sum_n \bar{\alpha}(n,\tau) =1\).

  • Use a distance metric for distributions \[\kappa(\mathbf{p},\mathbf{q}) = \sqrt{1-\rho(\mathbf{p},\mathbf{q})}\] where \(\rho(\mathbf{p},\mathbf{q}) = \sum_i \sqrt{p_iq_i}\) is the Bhattacharya coefficient.

Identifying Regimes (cont.)

Calculate the "distance" to the long term attention average, averaged over the forecast window

\[dist(t) = \frac{1}{\tau_{max}}\sum_{\tau=1}^{\tau_{max}}\kappa(\bar{\mathbf{\alpha}}(\tau),\alpha(t,\tau))\]

  • It is interesting, as it does not involve the directly observed values at a point in time, but only the forecast. It is probably a lagging measure.

Volatility

Conclusion - TFTs

  • Complex model containing seq2eq and self-attention
  • Has ability to adjust its complecity via Gated residual units
  • Can learn short- and long-term trends
  • Its interpretable-ish
  • Very competitive for forecasting