# deep_l2gd_ema.py
# License: MIT
#
# DeepL2GD with EMA (Exponential Moving Average) smoothing.
#
# This variant applies exponential moving average to the Gram matrix
# before projection. This is useful when minibatches are noisy or small,
# as it averages out stochastic fluctuations in the gradient directions.
#
# Update rule:
#   G_t = beta * G_{t-1} + (1 - beta) * gram_t       (EMA update)
#   W_t = W_{t-1} * (I - alpha * G_t) - eta * dL/dW  (projection + base step)

import torch


class DeepL2GDEMA(torch.optim.Optimizer):
    """
    Deep L2 Gradient Descent with EMA-smoothed Gram matrix.
    
    The EMA version maintains an exponential moving average of recent
    Gram matrices, which smooths out noise from small or biased minibatches.
    This can improve stability in continual learning scenarios.
    
    Args:
        params: Parameters to optimize
        base_optim_ctor: Constructor for base optimizer (e.g., torch.optim.AdamW)
        alpha: Strength of the projection (typically 1e-3)
        beta: EMA decay rate, controls memory of past Gram matrices
              Higher beta = more weight on historical Gram matrices
              (typically 0.01 to 0.1)
        **base_kwargs: Arguments passed to the base optimizer
    
    Example:
        >>> optim = DeepL2GDEMA(
        ...     model.parameters(),
        ...     base_optim_ctor=torch.optim.AdamW,
        ...     alpha=1e-3,
        ...     beta=0.05,
        ...     lr=1e-3
        ... )
    """
    
    def __init__(self, params, base_optim_ctor, alpha: float = 1e-3, beta: float = 0.05, **base_kwargs):
        # Validate hyperparameters
        if alpha < 0 or not (0.0 < beta <= 1.0):
            raise ValueError("alpha >= 0, 0 < beta <= 1 required")
        
        self.alpha = alpha
        self.beta = beta
        
        # Instantiate the base optimizer
        self._base = base_optim_ctor(params, **base_kwargs)
        super().__init__(self._base.param_groups, dict(alpha=alpha, beta=beta))
        
        # Dictionary to store EMA Gram matrices per parameter
        # self._ema[param] = G_t (current EMA estimate)
        self._ema = {}
    
    @torch.no_grad()
    def step(self, closure=None):
        """
        Perform one optimization step:
        1. Compute current batch's Gram matrix
        2. Update EMA: G_ema = beta * G_ema + (1 - beta) * G_batch
        3. Apply projection using EMA Gram matrix
        4. Call base optimizer's step
        """
        for group in self._base.param_groups:
            for p in group["params"]:
                # Skip if no gradient or parameter is frozen
                if not p.requires_grad or p.grad is None:
                    continue
                
                # Retrieve parameter owner for input caching
                owner = getattr(p, "_nl_owner", None)
                if owner is None or not hasattr(owner, "_last_input"):
                    continue
                
                X = owner._last_input
                # Only apply to weight matrices (2D tensors)
                if X is None or p.ndim != 2:
                    continue
                
                # Reshape to 2D if needed (e.g., from conv layers)
                if X.dim() > 2:
                    X = X.reshape(-1, X.shape[-1])
                
                B = X.shape[0]
                # Compute current batch's Gram matrix: G = (1/B) * X^T @ X
                gram = (X.transpose(0, 1) @ X) / max(1, B)
                
                # Update or initialize EMA Gram matrix
                buf = self._ema.get(p)
                if buf is None:
                    # First step: initialize with current Gram
                    self._ema[p] = gram.clone()
                else:
                    # EMA update: new_G = (1 - beta) * G_batch + beta * G_old
                    # Rearranged: G *= (1 - beta); G += beta * gram
                    buf.mul_(1.0 - self.beta).add_(self.beta * gram)
                
                # Apply projection using smoothed (EMA) Gram matrix
                # W = W * (I - alpha * G_ema)
                p.addmm_(mat1=p, mat2=self._ema[p], alpha=-self.alpha, beta=1.0)
        
        # Call base optimizer's step (e.g., AdamW)
        return self._base.step(closure=closure)
    
    def zero_grad(self, set_to_none=False):
        """Forward zero_grad to the base optimizer."""
        return self._base.zero_grad(set_to_none=set_to_none)
    
    def __getattr__(self, name):
        """Forward attribute access to base optimizer for compatibility."""
        if name in ("_base", "alpha", "beta", "_ema", "param_groups", "state", "defaults"):
            return super().__getattribute__(name)
        return getattr(self._base, name)
