{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "132569e0-e469-49c4-a396-2f5f1c45fd29",
   "metadata": {},
   "source": [
    "# MNIST handwriting dataset example\n",
    "\n",
    "This notebook runs the \"hello world\" for train Tensorflow neural networks/models. It is a commonly used benchmark for ML and computer vision.\n",
    "It includes 70,000 greysale images of handwritten digits (0-9)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7415bbc",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a848f4b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70667776",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install tensorflow keras matplotlib numpy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb015ddc",
   "metadata": {},
   "source": [
    "## Importing data & splitting dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bab554a4-9dde-4168-b154-d7792ec0b246",
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.datasets import mnist\n",
    "\n",
    "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
    "\n",
    "print('training data:')\n",
    "print(train_images.shape)\n",
    "print(len(train_labels))\n",
    "print(train_labels)\n",
    "\n",
    "print('testing data:')\n",
    "print(test_images.shape)\n",
    "print(len(test_labels))\n",
    "print(test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fa23c09",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13ff4e87-4cec-41ca-9455-fd24dd56fd19",
   "metadata": {},
   "outputs": [],
   "source": [
    "import keras\n",
    "from keras import layers\n",
    "\n",
    "model = keras.Sequential(\n",
    "    [\n",
    "        layers.Flatten(input_shape=(28, 28)),\n",
    "        layers.Dense(256, activation=\"relu\"),\n",
    "        layers.Dropout(0.2),                    # prevent overfitting\n",
    "        layers.Dense(10, activation=\"softmax\"), # 10-way classification layer, probability scores for 10 outputs which sum up to 1\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24559d17-6911-4874-a3d7-d9ce1fcccde0",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(\n",
    "    optimizer=\"adam\",\n",
    "    loss=\"sparse_categorical_crossentropy\",\n",
    "    metrics=[\"accuracy\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0ae1394-79e5-4060-bb75-1f29820f8ef1",
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS=10\n",
    "BATCH_SIZE=128\n",
    "history = model.fit(train_images, train_labels, epochs=EPOCHS, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e776c31c",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a660a88-714f-45ff-a1d7-c4bdb9b1073c",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_train_accuracy = history.history['accuracy'][-1]\n",
    "final_train_error_rate = 1 - final_train_accuracy\n",
    "\n",
    "print(f\"final training accuracy (from fit): {final_train_accuracy}\")\n",
    "print(f\"final training error rate (from fit): {final_train_error_rate}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2b850fc-d157-49f0-919f-967c4314eef7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('evaluating the model on new data')\n",
    "test_loss, test_acc = model.evaluate(test_images, test_labels)\n",
    "error_rate = 1 - test_acc\n",
    "print(f\"test_loss: {test_loss}, test_acc: {test_acc}, error_rate: {error_rate}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1544da24",
   "metadata": {},
   "source": [
    "## Predictions & testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f8555f6-0f7c-45b8-9eca-94e2b56c2334",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_digits = test_images[0:10] # test_digits is the first 10 elements (numbers) in the test_images\n",
    "predictions = model.predict(test_digits)\n",
    "print(f\"number of predictions: {len(predictions)}\")\n",
    "\n",
    "print('first test image:')\n",
    "print(f\"probability matrix for all numbers: {predictions[0]}\")   # prints all the possibilities. you will notice that it's the highest at index 7 \n",
    "max_confidence = predictions[0].argmax()    # index with highest confidence\n",
    "print(f\"highest confidence: {predictions[0][max_confidence]}\")\n",
    "print(f\"index with highest confidence: {max_confidence}\")\n",
    "print(f\"actual answer: {test_labels[0]}\")\n",
    "\n",
    "print('second test image:')\n",
    "print(f\"probability matrix for all numbers: {predictions[2]}\")   # prints all the possibilities. you will notice that it's the highest at index 2 \n",
    "max_confidence= predictions[1].argmax()\n",
    "print(f\"highest confidence: {predictions[1][max_confidence]}\")\n",
    "print(f\"index with highest confidence: {max_confidence}\")\n",
    "print(f\"actual answer: {test_labels[1]}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
