{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a36f8457",
   "metadata": {},
   "source": [
    "# Gemma 3 Fine-tuning using LoRA\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5f36f85",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b1f2b50",
   "metadata": {},
   "source": [
    "### Install packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71e8d7c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install jax jaxlib keras keras-hub ipywidgets tensorflow libtpu tpu-info py-cpuinfo psutil\n",
    "#!pip install jax jaxlib keras keras-hub ipywidgets py-cpuinfo psutil   # if GPU only, you can leave out some TensorFlow and TPU modules"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e224a4e4",
   "metadata": {},
   "source": [
    "### Colab server specs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d55ad088",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cpuinfo\n",
    "import psutil\n",
    "\n",
    "# get system cpu details\n",
    "print(\"----- CPU details -----\")\n",
    "print(f\"Model:        {cpuinfo.get_cpu_info().get('brand_raw', 'N/A')}\")\n",
    "print(f\"Core count:   {psutil.cpu_count(logical=False)}\")\n",
    "print(f\"Thread count: {psutil.cpu_count(logical=True)}\")\n",
    "\n",
    "# Get system memory details\n",
    "memory_info = psutil.virtual_memory()\n",
    "print(\"\\n----- Memory details -----\")\n",
    "print(f\"Total Memory:     {memory_info.total / (1024**3):.2f} GB\")\n",
    "print(f\"Available Memory: {memory_info.available / (1024**3):.2f} GB\")\n",
    "#print(f\"Used Memory: {memory_info.used / (1024**3):.2f} GB\")\n",
    "#print(f\"Free Memory: {memory_info.free / (1024**3):.2f} GB\")\n",
    "#print(f\"Memory Usage Percentage: {memory_info.percent}%\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c9cfd73",
   "metadata": {},
   "source": [
    "#### Accelerator (GPU/TPU) specs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00771927",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"\\n----- Accelerator details -----\")\n",
    "!nvidia-smi 2> /dev/null\n",
    "!tpu-info 2> /dev/null"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9a25615",
   "metadata": {},
   "source": [
    "### Getting secrets/env var setup\n",
    "Current workaround as the user data/secrets is not yet accessible through the extension"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2548881b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ipywidgets as widgets\n",
    "\n",
    "widgets.FileUpload(\n",
    "    accept='.env',  # .env file with 1 environment variable per line\n",
    "    multiple=False  # True to accept multiple files upload else False\n",
    ")\n",
    "\n",
    "uploader = widgets.FileUpload()\n",
    "display(uploader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "281a9d35",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# dictionary key of uploaded file wil be the uploaded filename\n",
    "# which we are extracting here\n",
    "filename_dict_key = list(uploader.value)[0]\n",
    "\n",
    "env_content_string = uploader.value[filename_dict_key]['content'].decode('utf-8')\n",
    "\n",
    "for line in env_content_string.splitlines():\n",
    "    if '=' in line and line.strip() and not line.strip().startswith('#'):\n",
    "        key, value = line.split('=', 1)\n",
    "        key = key.strip()\n",
    "        value = value.strip().strip(\"'\").strip('\"')\n",
    "        os.environ[key] = value\n",
    "\n",
    "for key, value in os.environ.items():\n",
    "    if key.startswith('KAGGLE_'):\n",
    "        value_first_char = value[0]\n",
    "        value_last_char = value[-1]\n",
    "        masked_chars = 'x' * (len(value) - 2)\n",
    "        masked_value = value_first_char + masked_chars + value_last_char\n",
    "        print(f\"{key}: {masked_value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ff21cbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\" # jax, torch, or tensorflow\n",
    "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.00\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20821fc9",
   "metadata": {},
   "source": [
    "## Download Gemma3 model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "661cd20a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import keras\n",
    "import keras_hub\n",
    "\n",
    "gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(\"gemma3_instruct_1b\")    # instruction-tuned open model\n",
    "gemma_lm.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6638d7f",
   "metadata": {},
   "source": [
    "### Before fine-tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "857629ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n",
    "\n",
    "prompt = template.format(\n",
    "    instruction=\"What should I do on a trip to Spain?\",\n",
    "    response=\"\",\n",
    ")\n",
    "sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)\n",
    "gemma_lm.compile(sampler=sampler)\n",
    "print(gemma_lm.generate(prompt, max_length=512))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5becd895",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = template.format(\n",
    "    instruction=\"Explain the process of photosynthesis in a way that a child could understand.\",\n",
    "    response=\"\",\n",
    ")\n",
    "print(gemma_lm.generate(prompt, max_length=256))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14acebf8",
   "metadata": {},
   "source": [
    "## LoRA Fine-tuning"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22d993f5",
   "metadata": {},
   "source": [
    "### Preparing training data\n",
    "Downloading a large dataset (15k) of prompts and concise responses.\n",
    "\n",
    "Will limit the training examples to 1000 just to reduce to training time, but definitely trying increasing this value to get better results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77c1df40",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e012541",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "prompts = []\n",
    "responses = []\n",
    "line_count = 0\n",
    "\n",
    "with open(\"databricks-dolly-15k.jsonl\") as file:\n",
    "    for line in file:\n",
    "        if line_count >= 1000: \n",
    "            break  # Limit the training examples, to reduce execution time.\n",
    "\n",
    "        examples = json.loads(line)\n",
    "        # Filter out examples with context, to keep it simple.\n",
    "        if examples[\"context\"]:\n",
    "            continue\n",
    "        # Format data into prompts and response lists.\n",
    "        prompts.append(examples[\"instruction\"])\n",
    "        responses.append(examples[\"response\"])\n",
    "\n",
    "        line_count += 1\n",
    "\n",
    "data = {\n",
    "    \"prompts\": prompts,\n",
    "    \"responses\": responses\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "faba975a",
   "metadata": {},
   "source": [
    "### Configuring LoRA tuning\n",
    "Some parameters you can tune to affect the results:\n",
    "- `rank`: Configures the number of trainable parameters. You typically start small (i.e. 4) for better effieciency but gradually increase for subsequent tunings to try to find the best bang for your buck.\n",
    "- `learning_rate`: How big of a step the model takes towards the solution (towards min. loss). Too big of a step and your loss might get worse, too small and your model will take forever to train.\n",
    "- `weight_decay`: A penalty added towards large weights. Used to encourage smaller, distributed weights.\n",
    "- `epochs`: How many times the learning algorithm will work through the dataset. More epochs generally leads to better learning (but will also take longer proportional to the number of epochs). Too many epochs can lead to overfitting (where model is just remember the training data rather than generalizing).\n",
    "- `batch_size`: The number of data samples to load at once. Training/tuning is a memory-intensive process, so you want to load small samples from your dataset at a time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33e2d7bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Enable LoRA for the model and set the LoRA rank to 4.\n",
    "# if you wish to set change the ranks, you will have to redownload the model from keras-hub\n",
    "gemma_lm.backbone.enable_lora(rank=4)\n",
    "gemma_lm.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3623c2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Limit the input sequence length to 256 (to control memory usage).\n",
    "gemma_lm.preprocessor.sequence_length = 256\n",
    "# Use AdamW (a common optimizer for transformer models).\n",
    "optimizer = keras.optimizers.AdamW(\n",
    "    learning_rate=5e-5,\n",
    "    weight_decay=0.01,\n",
    ")\n",
    "# Exclude layernorm and bias terms from decay.\n",
    "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n",
    "\n",
    "gemma_lm.compile(\n",
    "    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    optimizer=optimizer,\n",
    "    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a958bc93",
   "metadata": {},
   "source": [
    "#### Run fine-tune process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e905d7dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "gemma_lm.fit(data, epochs=1, batch_size=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4831475b",
   "metadata": {},
   "source": [
    "### Post fine-tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fff7f65b",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n",
    "\n",
    "prompt = template.format(\n",
    "    instruction=\"What should I do on a trip to Spain?\",\n",
    "    response=\"\",\n",
    ")\n",
    "sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)\n",
    "gemma_lm.compile(sampler=sampler)\n",
    "print(gemma_lm.generate(prompt, max_length=512))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ad38a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = template.format(\n",
    "    instruction=\"Explain the process of photosynthesis in a way that a child could understand.\",\n",
    "    response=\"\",\n",
    ")\n",
    "print(gemma_lm.generate(prompt, max_length=256))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94d292fa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
