import torch
from torch import nn
from torchvision import datasets    # there are different libraries/datasets such as TorchText, TorchVision and TorchAudio
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt


# Define device
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print(f"Using MPS device. \n")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device.\n")
else:
    device = torch.device("cpu")
    print(f"MPS not available, using CPU. \n")


# download training data from open datasets
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),   # images are orginally stored as PIL images or NumPy arrays, so this converts it to Tensors
)

# dowload test data from open datasets
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)


#----------------------
# define model
#----------------------
class NeuralNetwork(nn.Module):
    def __init__(self):
        """
        init() is the machinery
        forward() is the conveyor belt

        - Flatten(): unrolls the 28x28 matrix into a 784 array
        - Linear(input, output): applies linear transform (y=mx+b) to create neurons
        - ReLU(): activation function used to help model learn non-linear patterns
        """
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),  # expands 768 inputs into 512 neurons
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)  # there are 10 different clothing items 0-9
        )

    def forward(self, x):
        """Where the magic happens
        - flatten image
        - pass thru stack (Linear -> ReLu -> Linear -> ReLU -> Linear)

        Args:
            x: images

        Returns:
            logits: raw scores for each of the 10 classes
        """
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


#-----------------
# load & test
#-----------------
model = NeuralNetwork().to(device)
# weights_only=True is for security
# blocks execution of complex Python functions or classes that may be hidden in the .pth file
model.load_state_dict(torch.load("01_mnist_model.pth", weights_only=True))

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
#    predicted, actual = classes[pred[0].argmax(0)], classes[y]
#    print(f'Predicted: "{predicted}", Actual: "{actual}"')
    probabilities = torch.nn.functional.softmax(pred, dim=1) # dim=1 because tensor is [batch, classes]
    conf, predicted_idx = torch.max(probabilities, dim=1)
    predicted = classes[predicted_idx]
    actual = classes[y]
    confidence = conf.item() * 100
    print(f'Predicted: "{predicted}" ({confidence:.2f}%), Actual: "{actual}"')


#-----------------
# plotting
#-----------------
plt.figure(figsize=(6,3))

plt.subplot(1,2,1)
plt.imshow(x.cpu().squeeze(), cmap="grey")
plt.title(f"Actual: {actual}")
plt.axis("off")

plt.subplot(1,2,2)
plt.bar(range(10), probabilities[0].cpu().numpy())
plt.xticks(range(10), classes, rotation=90)
plt.title(f"Predicted: {predicted}\n{conf.item()*100:.2f}%")
plt.ylim([0,1]) # keep the scale 0-100%

plt.tight_layout()
plt.show()
