# nl_transformer_tiny.py
# License: MIT
#
# Tiny Transformer wired with Nested Learning clocks.
#
# This module demonstrates how to apply multi-timescale (multi-clock)
# Nested Learning to a Transformer. Different components update at
# different frequencies:
#   - Q/K/V projections: every step (fast, frequent updates)
#   - Output projection (O): every 8 steps (medium)
#   - FFN: every 64 steps (slow, infrequent updates)
#
# The intuition: sequence-local operations (attention heads) should adapt
# quickly to new token patterns, while knowledge-heavy paths (FFN) should
# update slowly to preserve learned feature distributions.

import math
import torch
import torch.nn as nn


class NLLinear(nn.Linear):
    """
    Linear layer with input caching for Nested Learning.
    
    Stores the input for use in Gram-matrix projection computation.
    """
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias=bias)
        self.register_buffer("_last_input", None, persistent=False)
    
    def forward(self, x):
        self._last_input = x.detach()
        return super().forward(x)


def attach_param_owners(module: nn.Module):
    """
    Attach module references to parameters for Gram-matrix lookups.
    
    Enables the optimizer to find the cached input (_last_input)
    associated with each parameter.
    """
    for mod in module.modules():
        for p in mod.parameters(recurse=False):
            setattr(p, "_nl_owner", mod)


class MHA(nn.Module):
    """
    Multi-Head Attention using NLLinear layers.
    
    Applies Q, K, V projections independently, then combines with
    scaled dot-product attention, and projects output via O.
    All linear transformations use NLLinear for input caching.
    
    Args:
        d_model: Model dimension (must be divisible by n_heads)
        n_heads: Number of attention heads
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        # Four linear projections (all NLLinear for caching)
        self.q = NLLinear(d_model, d_model)
        self.k = NLLinear(d_model, d_model)
        self.v = NLLinear(d_model, d_model)
        self.o = NLLinear(d_model, d_model)
    
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        Returns:
            output: (batch, seq_len, d_model)
        """
        B, T, C = x.shape
        H = self.n_heads
        Ch = C // H
        
        def split(h):
            """Reshape (B, T, C) to (B, H, T, Ch) for multi-head processing."""
            return h.view(B, T, H, Ch).transpose(1, 2)
        
        # Project and split into heads
        Q = split(self.q(x))  # (B, H, T, Ch)
        K = split(self.k(x))  # (B, H, T, Ch)
        V = split(self.v(x))  # (B, H, T, Ch)
        
        # Scaled dot-product attention
        # scores: (B, H, T, T)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(Ch)
        att = scores.softmax(dim=-1)
        
        # Combine values with attention weights
        ctx = att @ V  # (B, H, T, Ch)
        
        # Merge heads: (B, H, T, Ch) -> (B, T, C)
        ctx = ctx.transpose(1, 2).contiguous().view(B, T, C)
        
        # Output projection
        return self.o(ctx)


class Block(nn.Module):
    """
    Transformer block: Layer Norm -> Multi-Head Attention -> Residual,
                     Layer Norm -> FFN -> Residual.
    
    All linear layers use NLLinear for Nested Learning.
    """
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        # Attention sub-block
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MHA(d_model, n_heads)
        
        # FFN sub-block
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            NLLinear(d_model, d_ff),
            nn.GELU(),
            NLLinear(d_ff, d_model)
        )
    
    def forward(self, x):
        """Apply attention and FFN with residual connections."""
        # Attention residual
        x = x + self.attn(self.ln1(x))
        # FFN residual
        x = x + self.ff(self.ln2(x))
        return x


class TinyTransformer(nn.Module):
    """
    A minimal Transformer for language modeling.
    
    Args:
        vocab: Vocabulary size
        d_model: Model dimension (default 128)
        n_heads: Number of attention heads (default 4)
        d_ff: FFN hidden dimension (default 256)
        n_layers: Number of Transformer blocks (default 2)
    """
    def __init__(self, vocab, d_model=128, n_heads=4, d_ff=256, n_layers=2):
        super().__init__()
        
        # Token and position embeddings
        self.tok = nn.Embedding(vocab, d_model)
        self.pos = nn.Embedding(512, d_model)  # Support up to seq_len=512
        
        # Stack of Transformer blocks
        self.blocks = nn.ModuleList([
            Block(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        
        # Final layer norm and output projection
        self.ln = nn.LayerNorm(d_model)
        self.head = NLLinear(d_model, vocab)
    
    def forward(self, idx):
        """
        Args:
            idx: (batch, seq_len) token indices
        Returns:
            logits: (batch, seq_len, vocab) prediction logits
        """
        B, T = idx.shape
        
        # Embed tokens and add positional encodings
        x = self.tok(idx) + self.pos(torch.arange(T, device=idx.device))[None, :, :]
        
        # Apply Transformer blocks
        for b in self.blocks:
            x = b(x)
        
        # Final layer norm and output projection
        x = self.ln(x)
        return self.head(x)


class DeepL2GD(torch.optim.Optimizer):
    """
    Deep L2 Gradient Descent optimizer (same as in minimal implementation).
    
    Applies Gram-matrix projection before base optimizer update.
    """
    def __init__(self, params, base_optim_ctor, alpha: float = 1e-3, **base_kwargs):
        if alpha < 0:
            raise ValueError("alpha must be >= 0")
        self.alpha = alpha
        self._base = base_optim_ctor(params, **base_kwargs)
        super().__init__(self._base.param_groups, dict(alpha=alpha))
    
    @torch.no_grad()
    def step(self, closure=None):
        """
        Gram-matrix projection followed by base optimizer step.
        """
        for g in self._base.param_groups:
            for p in g["params"]:
                if p.grad is None or not p.requires_grad:
                    continue
                
                # Find the owning module for cached input
                owner = getattr(p, "_nl_owner", None)
                if owner is None or not hasattr(owner, "_last_input"):
                    continue
                
                X = owner._last_input
                if X is None or p.ndim != 2:
                    continue
                
                # Reshape to 2D if needed
                if X.dim() > 2:
                    X = X.reshape(-1, X.shape[-1])
                
                B = X.shape[0]
                # Gram matrix
                gram = (X.transpose(0, 1) @ X) / max(1, B)
                
                # Apply projection: W *= (I - alpha * G)
                p.addmm_(mat1=p, mat2=gram, alpha=-self.alpha, beta=1.0)
        
        return self._base.step(closure=closure)
    
    def zero_grad(self, set_to_none=False):
        return self._base.zero_grad(set_to_none=set_to_none)
    
    def __getattr__(self, name):
        if name in ("_base", "alpha", "param_groups", "state", "defaults"):
            return super().__getattribute__(name)
        return getattr(self._base, name)


def wire_clocks(model: TinyTransformer, lr=3e-4, alpha=5e-4):
    """
    Assign different update periods ("clocks") to different components.
    
    This demonstrates multi-timescale Nested Learning:
    - Attention Q/K/V: C=1 (every step, fastest)
    - Attention O: C=8 (every 8 steps, medium)
    - FFN layers: C=64 (every 64 steps, slowest)
    
    Args:
        model: TinyTransformer model
        lr: Learning rate for all optimizers
        alpha: Gram-projection strength
    
    Returns:
        periods: Dict mapping module to its update period
        sched: Dict mapping period C to DeepL2GD optimizer for that period
    """
    periods = {}
    
    # Assign periods to each component
    for b in model.blocks:
        # Attention heads: update frequently
        periods[b.attn.q] = 1
        periods[b.attn.k] = 1
        periods[b.attn.v] = 1
        
        # Attention output: medium frequency
        periods[b.attn.o] = 8
        
        # FFN layers: slow updates (preserve knowledge)
        for layer in b.ff:
            if isinstance(layer, NLLinear):
                periods[layer] = 64
    
    # Output head: quick updates
    periods[model.head] = 1
    
    # Organize by period: sched[C] = list of parameters with period C
    sched = {}
    for mod, C in periods.items():
        attach_param_owners(mod)
        sched.setdefault(C, []).extend(list(mod.parameters()))
    
    # Create one DeepL2GD optimizer per unique period
    for C in list(sched.keys()):
        sched[C] = DeepL2GD(
            sched[C],
            base_optim_ctor=torch.optim.AdamW,
            lr=lr,
            weight_decay=0.01,
            alpha=alpha
        )
    
    return periods, sched


def step_with_clocks(loss, periods, sched, global_step):
    """
    Perform a training step with multi-clock updates.
    
    Args:
        loss: Loss tensor to backprop
        periods: Dict mapping module to period (for reference)
        sched: Dict mapping period C to DeepL2GD optimizer
        global_step: Current training step counter
    
    Returns:
        global_step + 1 (updated counter)
    """
    # Backward pass (computes gradients for all parameters)
    loss.backward()
    
    # Check each optimizer's period and apply if due
    for C, opt in sched.items():
        # Apply update if (global_step + 1) is divisible by C
        if (global_step + 1) % C != 0:
            continue
        
        # Optimizer step and zero gradients
        opt.step()
        opt.zero_grad(set_to_none=True)
    
    return global_step + 1


def make_synth_bigrams(T=128, N=5000, vocab=32, seed=0):
    """
    Generate synthetic bigram data for language modeling.
    
    Creates sequences where each token's probability depends on the
    previous token via a learned bigram matrix. Good for testing
    continual learning with distribution shifts.
    
    Args:
        T: Sequence length
        N: Number of sequences
        vocab: Vocabulary size
        seed: Random seed
    
    Returns:
        (seqs_A, seqs_B): Two different bigram regimes
    """
    g = torch.Generator().manual_seed(seed)
    
    # Create two transition matrices (bigram probabilities)
    A = torch.rand(vocab, vocab, generator=g)
    A = A / A.sum(dim=1, keepdim=True)
    
    B = torch.rand(vocab, vocab, generator=g)
    B = B / B.sum(dim=1, keepdim=True)
    
    def sample(M, n):
        """Sample n sequences from bigram matrix M."""
        seq = torch.zeros(n, T, dtype=torch.long)
        for i in range(n):
            # Start with a random token
            t = torch.randint(0, vocab, (1,), generator=g).item()
            seq[i, 0] = t
            
            # Generate rest of sequence by sampling from transition matrix
            for j in range(1, T):
                t = torch.multinomial(M[t], 1).item()
                seq[i, j] = t
        return seq
    
    return sample(A, N), sample(B, N)


def tiny_train_demo(device="cpu"):
    """
    Demonstration: train a TinyTransformer on two bigram regimes.
    
    Shows that Nested Learning with clocks enables the model to
    continually adapt to new sequences while retaining some knowledge
    of the previous regime.
    
    Args:
        device: 'cpu' or 'cuda'
    
    Returns:
        Trained model
    """
    # Create model
    vocab = 32
    model = TinyTransformer(vocab).to(device)
    
    # Wire clocks: assign different update frequencies
    periods, sched = wire_clocks(model)
    
    # Generate bigram datasets
    X_A, X_B = make_synth_bigrams(N=256, T=64, vocab=vocab)
    
    # Loss function
    ce = nn.CrossEntropyLoss()
    gs = 0  # Global step counter
    
    def run_on(X, steps=200, gs=0):
        """Train for `steps` iterations on dataset X."""
        for _ in range(steps):
            # Sample a random batch
            idx = torch.randint(0, X.size(0), (32,))
            xb = X[idx, :].to(device)
            
            # Forward: predict next token given all previous tokens
            logits = model(xb[:, :-1])
            loss = ce(logits.reshape(-1, vocab), xb[:, 1:].reshape(-1))
            
            # Backward and multi-clock step
            model.zero_grad(set_to_none=True)
            gs = step_with_clocks(loss, periods, sched, gs)
        
        return gs
    
    # Sequential task learning
    print("Training on Task A (100 steps)...")
    gs = run_on(X_A, steps=100, gs=gs)
    
    print("Training on Task B (100 steps)...")
    gs = run_on(X_B, steps=100, gs=gs)
    
    print(f"Training complete. Total steps: {gs}")
    return model


if __name__ == "__main__":
    # Run the demonstration
    tiny_train_demo(device="cpu")
