# nl_conv.py
# License: MIT
#
# Nested Learning applied to Conv2d layers.
#
# Conv2d layers are more complex than Dense layers because:
# - Input shape: (batch, in_channels, height, width)
# - Weight shape: (out_channels, in_channels, kernel_h, kernel_w)
#
# This module applies a channel-covariance approximation:
# Instead of computing a full Gram matrix for each spatial location,
# we average the kernel responses across spatial locations and apply
# a channel-level projection.

import torch
import torch.nn as nn


class NLConv2d(nn.Conv2d):
    """
    Conv2d layer with input caching for Nested Learning.
    
    Caches the input for use in computing a channel-level covariance matrix,
    which is then used to apply the Gram-matrix projection.
    """
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Cache buffer for the input tensor (batch, channels, height, width)
        self.register_buffer("_last_input", None, persistent=False)
    
    def forward(self, x):
        """Forward pass and cache the input."""
        self._last_input = x.detach()
        return super().forward(x)


def attach_param_owners(module: nn.Module):
    """
    Attach a reference from each parameter to its owning module.
    
    Used to find the cached input when computing the projection.
    This is especially important for Conv2d layers where we need
    to locate the input tensor for channel-covariance computation.
    """
    for mod in module.modules():
        for p in mod.parameters(recurse=False):
            setattr(p, "_nl_owner", mod)


@torch.no_grad()
def conv_channel_projection_step(conv: NLConv2d, alpha: float):
    """
    Apply channel-covariance-based projection to a Conv2d layer.
    
    Approach:
    1. Extract the cached input: x has shape (batch, channels, height, width)
    2. Reshape to (batch*height*width, channels) and compute channel Gram matrix
    3. Average the kernel weights across spatial dimensions (treating each
       output channel separately)
    4. Apply channel-level projection: W_channel = W_channel @ (I - alpha * G_channel)
    5. Distribute the update back across all spatial kernel locations
    
    Args:
        conv: NLConv2d layer to update
        alpha: Projection strength (typically 1e-3 to 1e-4)
    """
    # Skip if not an NLConv2d or no cached input
    if not isinstance(conv, NLConv2d):
        return
    
    x = conv._last_input
    if x is None:
        return
    
    # Extract dimensions
    B, C, H, W = x.shape  # batch, in_channels, height, width
    
    # Reshape input to (B*H*W, C) for channel-level statistics
    Xc = x.permute(0, 2, 3, 1).reshape(-1, C)  # (B*H*W, C)
    
    # Compute channel-level Gram matrix: G = (1/N) * X^T @ X
    gram = (Xc.transpose(0, 1) @ Xc) / max(1, Xc.shape[0])
    
    # Extract weight and reshape for channel processing
    W = conv.weight.data  # (out_channels, in_channels, kernel_h, kernel_w)
    out_c, in_c, kH, kW = W.shape
    
    # Average weights across spatial kernel dimensions
    # This gives us a (out_channels, in_channels) matrix
    Wc = W.reshape(out_c, in_c, -1).mean(dim=-1)  # (out_channels, in_channels)
    
    # Create identity matrix for projection
    I = torch.eye(in_c, device=W.device, dtype=W.dtype)
    
    # Apply channel-level projection: Wc = Wc @ (I - alpha * gram)
    Wc = Wc @ (I - alpha * gram)
    
    # Reshape weight back to kernel form for scaling
    W_new = W.reshape(out_c, in_c, -1)
    
    # Compute scaling factors: how much to scale each kernel location
    # to maintain the updated average
    scale = Wc.unsqueeze(-1) / (W_new.mean(dim=-1, keepdim=True) + 1e-12)
    
    # Apply scaling to all kernel positions
    W_new = W_new * scale
    
    # Write back the updated weights
    conv.weight.copy_(W_new.view_as(conv.weight))


# Example usage (not run in this standalone module):
# ================================================
# model = nn.Sequential(
#     NLConv2d(3, 32, kernel_size=3, padding=1),
#     nn.ReLU(),
#     NLConv2d(32, 64, kernel_size=3, padding=1),
#     nn.MaxPool2d(2),
#     nn.Flatten(),
#     nn.Linear(64 * 16 * 16, 10)
# )
# attach_param_owners(model)
#
# # During training:
# for batch in dataloader:
#     x = model(batch)
#     loss = criterion(x, labels)
#     loss.backward()
#
#     # Apply projection to Conv2d layers at their respective schedules
#     conv_channel_projection_step(model[0], alpha=1e-3)  # every step
#     conv_channel_projection_step(model[2], alpha=1e-3)  # every step
#
#     optimizer.step()
