# nested_learning_minimal.py
# License: MIT
# 
# Minimal PyTorch reference implementation for Nested Learning (NL).
#
# This module demonstrates the core concepts:
# - DeepL2GD optimizer: applies a Gram-matrix projection before gradient update
# - CMSSequential: Continuum Memory System with multi-timescale learning
# - NLTrainer: trains on sequential tasks to show continual learning properties
#
# The key insight is that components (layers, optimizers) update at different
# frequencies ("clocks"), enabling both plasticity and memory retention.

from __future__ import annotations
import math
from dataclasses import dataclass
from typing import List, Dict, Tuple
import torch
import torch.nn as nn


class NLLinear(nn.Linear):
    """
    A Linear layer that caches its last input for the Gram-matrix computation.
    
    The cached input is used by the DeepL2GD optimizer to compute the projection
    P = I - alpha * x @ x^T (or batched: I - alpha * (1/B) * X^T @ X).
    """
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias=bias)
        # Register buffer to avoid checkpointing issues; not saved as a parameter
        self.register_buffer("_last_input", None, persistent=False)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Cache input (detached to avoid gradient accumulation issues)
        self._last_input = x.detach()
        return super().forward(x)


def replace_linears_with_nl(module: nn.Module) -> nn.Module:
    """
    Recursively replace all nn.Linear layers in a module with NLLinear.
    
    This enables input caching for every linear layer, allowing DeepL2GD
    to apply the Gram-matrix projection.
    """
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
            # Create NLLinear with same dimensions and bias setting
            nl = NLLinear(child.in_features, child.out_features, bias=(child.bias is not None))
            # Copy weights from the original layer
            nl.weight.data.copy_(child.weight.data)
            if child.bias is not None:
                nl.bias.data.copy_(child.bias.data)
            setattr(module, name, nl)
        else:
            # Recurse into nested modules
            replace_linears_with_nl(child)
    return module


class DeepL2GD(torch.optim.Optimizer):
    """
    Deep L2 Gradient Descent optimizer.
    
    This optimizer implements the update rule from the Nested Learning paper:
        W_{t+1} = W_t * (I - alpha * x_t @ x_t^T) - eta * dL/dW_t
    
    It wraps a base optimizer (e.g., AdamW) and applies an additional
    Gram-matrix projection step that "forgets" directions aligned with the
    recent input, while preserving orthogonal directions.
    
    Args:
        params: Parameters to optimize
        base_optim_ctor: Constructor for the base optimizer (e.g., torch.optim.AdamW)
        alpha: Strength of the Gram projection (typically 1e-3)
        **base_kwargs: Arguments passed to the base optimizer
    """
    def __init__(self, params, base_optim_ctor, alpha: float = 1e-3, **base_kwargs):
        if alpha < 0:
            raise ValueError("alpha must be >= 0")
        self.alpha = alpha
        # Instantiate the base optimizer with provided parameters and kwargs
        self._base = base_optim_ctor(params, **base_kwargs)
        super().__init__(self._base.param_groups, dict(alpha=alpha))
    
    @torch.no_grad()
    def step(self, closure=None):
        """
        Perform one optimization step:
        1. Apply Gram-matrix projection to weight matrices
        2. Call the base optimizer's step (e.g., AdamW)
        """
        for group in self._base.param_groups:
            for p in group["params"]:
                # Skip parameters without gradients or non-trainable params
                if not p.requires_grad or p.grad is None:
                    continue
                
                # Retrieve the module that owns this parameter (set by attach_param_owners)
                owner = getattr(p, "_nl_owner", None)
                if owner is None:
                    continue
                
                # Only apply projection to weight matrices in NLLinear layers
                if isinstance(owner, NLLinear) and owner._last_input is not None and p.ndim == 2:
                    X = owner._last_input
                    
                    # Skip if input is empty
                    if X.numel() == 0:
                        continue
                    
                    # Reshape to (batch_size, input_dim) if needed
                    if X.dim() > 2:
                        X = X.reshape(-1, X.shape[-1])
                    
                    B = X.shape[0]
                    # Compute Gram matrix: G = (1/B) * X^T @ X
                    gram = (X.transpose(0, 1) @ X) / max(1, B)
                    
                    # Apply projection: W *= (I - alpha * G)
                    # Using addmm_ for efficiency: W = 1.0*W + (-alpha)*W@G
                    p.addmm_(mat1=p, mat2=gram, alpha=-self.alpha, beta=1.0)
        
        # Call the base optimizer's step (e.g., AdamW update)
        return self._base.step(closure=closure)
    
    def zero_grad(self, set_to_none: bool = False):
        """Forward zero_grad to the base optimizer."""
        return self._base.zero_grad(set_to_none=set_to_none)
    
    def __getattr__(self, name):
        """Forward attribute access to the base optimizer for compatibility."""
        if name in ("_base", "alpha", "param_groups", "state", "defaults"):
            return super().__getattribute__(name)
        return getattr(self._base, name)


def attach_param_owners(module: nn.Module):
    """
    Attach a reference from each parameter to its owning module.
    
    This is needed by DeepL2GD to find the cached input (_last_input)
    for the Gram-matrix computation.
    """
    for mod in module.modules():
        for p in mod.parameters(recurse=False):
            setattr(p, "_nl_owner", mod)


class MLPBlock(nn.Module):
    """
    A simple feedforward block: Linear -> GELU -> Dropout -> Linear.
    
    Used as a building block in CMSSequential for a multi-level system.
    """
    def __init__(self, d_in, d_hidden, d_out, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            NLLinear(d_in, d_hidden), 
            nn.GELU(),
            nn.Dropout(dropout),
            NLLinear(d_hidden, d_out)
        )
    
    def forward(self, x):
        return self.net(x)


class CMSSequential(nn.Module):
    """
    Continuum Memory System: a sequential stack of blocks with per-level update schedules.
    
    Each block has its own optimizer and update period (clock).
    At each training step, gradients are accumulated per block, and only
    blocks whose period divides the global step counter actually update.
    
    Args:
        dims: List of (in_dim, out_dim) tuples for each block
        update_periods: List of update periods C for each block (e.g., [1, 16, 128])
        lr: Learning rate for all optimizers
        alpha: Gram-projection strength for all DeepL2GD instances
    """
    def __init__(self, dims: List[Tuple[int, int]], update_periods: List[int], lr=1e-3, alpha=1e-3):
        super().__init__()
        assert len(dims) == len(update_periods), "dims and update_periods must have same length"
        
        # Create blocks
        blocks = []
        for din, dout in dims:
            # Hidden dimension is roughly geometric mean of in/out for expressive power
            d_hid = max(64, int(math.sqrt(din * dout)))
            blocks.append(MLPBlock(din, d_hid, dout))
        
        self.blocks = nn.ModuleList(blocks)
        self.periods = update_periods
        self.global_step = 0
        
        # Create one optimizer per block
        self.optims: List[DeepL2GD] = []
        # Gradient accumulators: stores sum of gradients until update is due
        self._accum: List[Dict[nn.Parameter, torch.Tensor]] = []
        
        for blk in self.blocks:
            # Attach parameter owners for Gram-matrix computation
            attach_param_owners(blk)
            # Create a DeepL2GD optimizer for this block
            opt = DeepL2GD(
                blk.parameters(),
                base_optim_ctor=torch.optim.AdamW,
                lr=lr,
                weight_decay=0.01,
                alpha=alpha
            )
            self.optims.append(opt)
            self._accum.append({})
    
    def forward(self, x):
        """Forward pass through all blocks sequentially."""
        for blk in self.blocks:
            x = blk(x)
        return x
    
    def zero_grad_all(self):
        """Zero gradients for all block optimizers."""
        for opt in self.optims:
            opt.zero_grad(set_to_none=True)
    
    @torch.no_grad()
    def _accumulate_block_grads(self, bidx: int):
        """
        Accumulate gradients for block `bidx` until an update is due.
        
        Stores sum of gradients in self._accum[bidx][param].
        """
        for p in self.blocks[bidx].parameters():
            if p.grad is None:
                continue
            # Initialize or accumulate
            buf = self._accum[bidx].get(p)
            g = p.grad.detach()
            if buf is None:
                self._accum[bidx][p] = g.clone()
            else:
                buf.add_(g)
    
    @torch.no_grad()
    def _apply_block_step_if_due(self, bidx: int) -> bool:
        """
        Check if block `bidx` is due for an update.
        
        If (global_step + 1) % period == 0:
            - Average accumulated gradients by the period
            - Assign to parameters
            - Call optimizer.step()
            - Clear gradients
        
        Returns:
            True if step was applied, False otherwise.
        """
        C = self.periods[bidx]
        # Check if this block's update is due
        if (self.global_step + 1) % C != 0:
            # Not due yet; clear gradients to prepare for next accumulation
            for p in self.blocks[bidx].parameters():
                if p.grad is not None:
                    p.grad = None
            return False
        
        # Update is due: apply accumulated gradients
        for p in self.blocks[bidx].parameters():
            buf = self._accum[bidx].pop(p, None)
            if buf is None:
                continue
            # Average the accumulated gradient over the period
            p.grad = buf.div(C)
        
        # Perform optimizer step (DeepL2GD + base optimizer)
        self.optims[bidx].step()
        
        # Clear gradients after step
        for p in self.blocks[bidx].parameters():
            p.grad = None
        
        return True
    
    def scheduled_step(self):
        """
        Called once per training iteration.
        
        For each block:
        1. Accumulate its gradients
        2. Apply step if due (based on period)
        Then increment global step counter.
        """
        for bidx, _ in enumerate(self.blocks):
            self._accumulate_block_grads(bidx)
            self._apply_block_step_if_due(bidx)
        self.global_step += 1


# ============================================================================
# Continual Learning Demo: Sequential Task Learning
# ============================================================================

@dataclass
class CLTask:
    """Represents a linear classification task: y = argmax(x @ W + b)."""
    W: torch.Tensor  # Shape (d_in, d_out)
    b: torch.Tensor  # Shape (d_out,)


def make_linear_task(d_in=32, d_out=10, seed=0) -> CLTask:
    """Generate a random linear classification task."""
    g = torch.Generator().manual_seed(seed)
    W = torch.randn(d_in, d_out, generator=g) / math.sqrt(d_in)
    b = torch.randn(d_out, generator=g) * 0.05
    return CLTask(W, b)


def sample_data(task: CLTask, n=2048, noise=0.1, seed=123):
    """
    Sample data from a linear task.
    
    Returns:
        x: shape (n, d_in), input features with additive noise
        y: shape (n,), class labels
    """
    g = torch.Generator().manual_seed(seed)
    x = torch.randn(n, task.W.shape[0], generator=g)
    logits = x @ task.W + task.b
    y = logits.argmax(dim=-1)
    # Add noise to inputs for realism
    x = x + noise * torch.randn(x.shape, generator=g)
    return x, y


class NLTrainer:
    """
    Trains a CMSSequential model + head on sequential tasks.
    
    Demonstrates continual learning: after training on Task A then Task B,
    we measure accuracy on both to see retention and new-task learning.
    """
    def __init__(self, model: CMSSequential, d_in: int, d_out: int, 
                 device="cuda" if torch.cuda.is_available() else "cpu"):
        # Replace Linear layers and move to device
        self.model = replace_linears_with_nl(model).to(device)
        self.device = device
        
        # Classification head: maps model output to class logits
        self.head = NLLinear(d_in, d_out).to(device)
        attach_param_owners(self.head)
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss()
        
        # Optimizer for the head (separate from model's block optimizers)
        self.head_optim = DeepL2GD(
            self.head.parameters(),
            base_optim_ctor=torch.optim.AdamW,
            lr=3e-3,
            weight_decay=0.01,
            alpha=model.optims[0].alpha
        )
    
    def _batchify(self, X: torch.Tensor, y: torch.Tensor, bs=128):
        """Generator that yields random batches of data."""
        idx = torch.randperm(X.size(0))
        for s in range(0, X.size(0), bs):
            sel = idx[s:s+bs]
            yield X[sel].to(self.device), y[sel].to(self.device)
    
    def train_on_task(self, X: torch.Tensor, y: torch.Tensor, epochs=2, bs=128):
        """
        Train on a single task for `epochs` epochs.
        
        Args:
            X: Input features (n, d_in)
            y: Labels (n,)
            epochs: Number of passes over the dataset
            bs: Batch size
        """
        for _ in range(epochs):
            for xb, yb in self._batchify(X, y, bs):
                # Zero gradients for all components
                self.model.zero_grad_all()
                self.head_optim.zero_grad(set_to_none=True)
                
                # Forward pass
                out = self.head(self.model(xb))
                loss = self.criterion(out, yb)
                
                # Backward pass
                loss.backward()
                
                # Update steps
                self.model.scheduled_step()  # Applies per-block updates based on period
                self.head_optim.step()
    
    @torch.no_grad()
    def eval_acc(self, X, y, bs=256):
        """Compute accuracy on a dataset."""
        correct = 0
        total = 0
        for xb, yb in self._batchify(X, y, bs):
            logits = self.head(self.model(xb))
            pred = logits.argmax(dim=-1)
            correct += (pred == yb).sum().item()
            total += yb.numel()
        return correct / max(1, total)


if __name__ == "__main__":
    # ========== Continual Learning Experiment ==========
    # Train on Task A, then Task B, and measure plasticity/retention.
    
    # Generate two linear tasks
    d_in, d_out = 64, 10
    taskA = make_linear_task(d_in, d_out, seed=0)
    taskB = make_linear_task(d_in, d_out, seed=1)
    
    # Sample data
    XA, yA = sample_data(taskA, n=6000, noise=0.05, seed=5)
    XB, yB = sample_data(taskB, n=6000, noise=0.05, seed=6)
    
    # Create CMSSequential model with 3 blocks and periods [1, 16, 128]
    dims = [(d_in, d_in), (d_in, d_in), (d_in, d_in)]
    periods = [1, 16, 128]
    model = CMSSequential(dims, update_periods=periods, lr=2e-3, alpha=5e-4)
    
    # Create trainer
    trainer = NLTrainer(model, d_in=d_in, d_out=d_out)
    
    # Train on Task A
    trainer.train_on_task(XA, yA, epochs=3, bs=128)
    accA_before = trainer.eval_acc(XA, yA)
    accB_before = trainer.eval_acc(XB, yB)
    
    # Train on Task B
    trainer.train_on_task(XB, yB, epochs=3, bs=128)
    accB_after = trainer.eval_acc(XB, yB)
    accA_after = trainer.eval_acc(XA, yA)
    
    # Print results
    print(f"Task A acc before B: {accA_before:.3f}  | after B: {accA_after:.3f}")
    print(f"Task B acc before B: {accB_before:.3f}  | after B: {accB_after:.3f}")
