{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2022 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-04-20T11:08:54.061808Z", "iopub.status.busy": "2024-04-20T11:08:54.061554Z", "iopub.status.idle": "2024-04-20T11:08:54.065665Z", "shell.execute_reply": "2024-04-20T11:08:54.065098Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "36EdAGhThQov" }, "source": [ "# Learning to Rank with Decision Forests\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", "
\n" ] }, { "cell_type": "markdown", "metadata": { "id": "kvvDY0LVhuaW" }, "source": [ "Welcome to the **Learning to Rank Colab** for **TensorFlow Decision Forests** (**TF-DF**).\n", "In this colab, you will learn how to use **TF-DF** for ranking.\n", "\n", "This colab assumes you are familiar with the concepts presented the [Beginner colab](beginner_colab.ipynb), notably about the installation about TF-DF.\n", "\n", "In this colab, you will:\n", "\n", "1. Learn what a ranking model is.\n", "1. Train a Gradient Boosted Trees models on the LETOR3 dataset.\n", "1. Evaluate the quality of this model." ] }, { "cell_type": "markdown", "metadata": { "id": "jK9tCTcwqq4k" }, "source": [ "## Installing TensorFlow Decision Forests\n", "\n", "Install TF-DF by running the following cell." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:08:54.069169Z", "iopub.status.busy": "2024-04-20T11:08:54.068765Z", "iopub.status.idle": "2024-04-20T11:08:57.618544Z", "shell.execute_reply": "2024-04-20T11:08:57.617595Z" }, "id": "Pa1Pf37RhEYN" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow_decision_forests\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading tensorflow_decision_forests-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.0 kB)\r\n", "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.26.4)\r\n", "Requirement already satisfied: pandas in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.2.2)\r\n", "Requirement already satisfied: tensorflow~=2.16.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.16.1)\r\n", "Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.16.0)\r\n", "Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.4.0)\r\n", "Requirement already satisfied: wheel in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (0.41.2)\r\n", "Collecting wurlitzer (from tensorflow_decision_forests)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading wurlitzer-3.0.3-py3-none-any.whl.metadata (1.9 kB)\r\n", "Requirement already satisfied: tf-keras~=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.16.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (24.3.25)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (0.5.4)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (0.2.0)\r\n", "Requirement already satisfied: h5py>=3.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (3.11.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (18.1.1)\r\n", "Requirement already satisfied: ml-dtypes~=0.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (0.3.2)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (24.0)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (3.20.3)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (2.31.0)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (69.5.1)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (2.4.0)\r\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (4.11.0)\r\n", "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (1.16.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (1.63.0rc2)\r\n", "Requirement already satisfied: tensorboard<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (2.16.2)\r\n", "Requirement already satisfied: keras>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (3.2.1)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.16.1->tensorflow_decision_forests) (0.36.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: python-dateutil>=2.8.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2.9.0.post0)\r\n", "Requirement already satisfied: pytz>=2020.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2024.1)\r\n", "Requirement already satisfied: tzdata>=2022.7 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2024.1)\r\n", "Requirement already satisfied: rich in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (13.7.1)\r\n", "Requirement already satisfied: namex in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (0.0.8)\r\n", "Requirement already satisfied: optree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (0.11.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow~=2.16.1->tensorflow_decision_forests) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow~=2.16.1->tensorflow_decision_forests) (3.7)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow~=2.16.1->tensorflow_decision_forests) (2.2.1)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow~=2.16.1->tensorflow_decision_forests) (2024.2.2)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (3.6)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (3.0.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (7.1.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (2.1.5)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: markdown-it-py>=2.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (3.0.0)\r\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (2.17.2)\r\n", "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow~=2.16.1->tensorflow_decision_forests) (3.18.1)\r\n", "Requirement already satisfied: mdurl~=0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.0.0->tensorflow~=2.16.1->tensorflow_decision_forests) (0.1.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading tensorflow_decision_forests-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.5 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading wurlitzer-3.0.3-py3-none-any.whl (7.3 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: wurlitzer, tensorflow_decision_forests\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed tensorflow_decision_forests-1.9.0 wurlitzer-3.0.3\r\n" ] } ], "source": [ "!pip install tensorflow_decision_forests\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vZGda2dOe-hH" }, "source": [ "[Wurlitzer](https://pypi.org/project/wurlitzer/) is needed to display the detailed training logs in Colabs (when using `verbose=2` in the model constructor)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:08:57.623108Z", "iopub.status.busy": "2024-04-20T11:08:57.622824Z", "iopub.status.idle": "2024-04-20T11:08:59.602638Z", "shell.execute_reply": "2024-04-20T11:08:59.601816Z" }, "id": "lk26uBSCe8Du" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: wurlitzer in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (3.0.3)\r\n" ] } ], "source": [ "!pip install wurlitzer" ] }, { "cell_type": "markdown", "metadata": { "id": "3oinwbhXlggd" }, "source": [ "## Importing libraries" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:08:59.607082Z", "iopub.status.busy": "2024-04-20T11:08:59.606785Z", "iopub.status.idle": "2024-04-20T11:09:02.035993Z", "shell.execute_reply": "2024-04-20T11:09:02.035306Z" }, "id": "52W45tmDjD64" }, "outputs": [], "source": [ "import os\n", "# Keep using Keras 2\n", "os.environ['TF_USE_LEGACY_KERAS'] = '1'\n", "\n", "import tensorflow_decision_forests as tfdf\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "import tf_keras\n", "import math" ] }, { "cell_type": "markdown", "metadata": { "id": "0LPPwWxYxtDM" }, "source": [ "The hidden code cell limits the output height in colab.\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-04-20T11:09:02.040722Z", "iopub.status.busy": "2024-04-20T11:09:02.039932Z", "iopub.status.idle": "2024-04-20T11:09:02.044296Z", "shell.execute_reply": "2024-04-20T11:09:02.043692Z" }, "id": "2AhqJz3VmQM-" }, "outputs": [], "source": [ "#@title\n", "\n", "from IPython.core.magic import register_line_magic\n", "from IPython.display import Javascript\n", "from IPython.display import display as ipy_display\n", "\n", "# Some of the model training logs can cover the full\n", "# screen if not compressed to a smaller viewport.\n", "# This magic allows setting a max height for a cell.\n", "@register_line_magic\n", "def set_cell_height(size):\n", " ipy_display(\n", " Javascript(\"google.colab.output.setIframeHeight(0, true, {maxHeight: \" +\n", " str(size) + \"})\"))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:02.047413Z", "iopub.status.busy": "2024-04-20T11:09:02.046987Z", "iopub.status.idle": "2024-04-20T11:09:02.050488Z", "shell.execute_reply": "2024-04-20T11:09:02.049863Z" }, "id": "8gVQ-txtjFU4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found TensorFlow Decision Forests v1.9.0\n" ] } ], "source": [ "# Check the version of TensorFlow Decision Forests\n", "print(\"Found TensorFlow Decision Forests v\" + tfdf.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "S54mR6i9jkhp" }, "source": [ "## What is a ranking model?\n", "\n", "The goal of a [ranking](https://en.wikipedia.org/wiki/Learning_to_rank) model is to **correctly order** items. For example, ranking can be used to select the best *documents* to retrieve following a user *query*.\n", "\n", "A common way to represent a Ranking dataset is with a \"relevance\" score: The order of the elements is defined by their relevance: Items of greater relevance should be before lower relevance items. The cost of a mistake is defined by the difference between the relevance of the predicted item with the relevance of the correct item. For example, misordering two items with respective relevance 3 and 4 is not as bad as misordering two items with respective relevance 1 and 5.\n", "\n", "TF-DF expects ranking datasets to be presented in a \"flat\" format.\n", "A dataset of queries and corresponding documents might look like this:\n", "\n", "query | document_id | feature_1 | feature_2 | relevance\n", "----- | ----------- | --------- | --------- | ---------------\n", "cat | 1 | 0.1 | blue | 4\n", "cat | 2 | 0.5 | green | 1\n", "cat | 3 | 0.2 | red | 2\n", "dog | 4 | NA | red | 0\n", "dog | 5 | 0.2 | red | 0\n", "dog | 6 | 0.6 | green | 1\n", "\n", "\n", "The *relevance/label* is a floating point numerical value between 0 and 5\n", "(generally between 0 and 4) where 0 means \"completely unrelated\", 4 means \"very\n", "relevant\" and 5 means \"same as the query\".\n", "\n", "In this example, Document 1 is very relevant to the query \"cat\", while document 2 is only \"related\" to cats. There are no documents is really talking about \"dog\" (the highest relevance is 1 for the document 6). However, the dog query is still expecting to return document 6 (since this is the document that talks the \"most\" about dogs).\n", "\n", "Interestingly, decision forests are often good rankers, and many\n", "state-of-the-art ranking models are decision forests." ] }, { "cell_type": "markdown", "metadata": { "id": "QvcRIhFF0BoN" }, "source": [ "## Let's train a Ranking model\n", "\n", "In this example, use a sample of the\n", "[LETOR3](https://www.microsoft.com/en-us/research/project/letor-learning-rank-information-retrieval/#!letor-3-0)\n", "dataset. More precisely, we want to download the `OHSUMED.zip` from [the LETOR3 repo](https://onedrive.live.com/?authkey=%21ACnoZZSZVfHPJd0&id=8FEADC23D838BDA8%21107&cid=8FEADC23D838BDA8). This dataset is stored in the\n", "libsvm format, so we will need to convert it to csv." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:02.053782Z", "iopub.status.busy": "2024-04-20T11:09:02.053297Z", "iopub.status.idle": "2024-04-20T11:09:11.825968Z", "shell.execute_reply": "2024-04-20T11:09:11.825197Z" }, "id": "axD6x1ZivHCS" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://download.microsoft.com/download/E/7/E/E7EABEF1-4C7B-4E31-ACE5-73927950ED5E/Letor.zip\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 8192/61824018 [..............................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 98304/61824018 [..............................] - ETA: 44s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 278528/61824018 [..............................] - ETA: 29s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 729088/61824018 [..............................] - ETA: 15s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 1753088/61824018 [..............................] - ETA: 8s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 2105344/61824018 [>.............................] - ETA: 12s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 4202496/61824018 [=>............................] - ETA: 8s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 6299648/61824018 [==>...........................] - ETA: 7s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 8396800/61824018 [===>..........................] - ETA: 7s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10493952/61824018 [====>.........................] - ETA: 6s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "12591104/61824018 [=====>........................] - ETA: 6s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "14688256/61824018 [======>.......................] - ETA: 5s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "16785408/61824018 [=======>......................] - ETA: 5s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "18882560/61824018 [========>.....................] - ETA: 5s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "20979712/61824018 [=========>....................] - ETA: 4s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "23076864/61824018 [==========>...................] - ETA: 4s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "25174016/61824018 [===========>..................] - ETA: 4s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "27271168/61824018 [============>.................] - ETA: 4s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "29368320/61824018 [=============>................] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "31465472/61824018 [==============>...............] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "32702464/61824018 [==============>...............] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "33562624/61824018 [===============>..............] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "35659776/61824018 [================>.............] - ETA: 3s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "37756928/61824018 [=================>............] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "39854080/61824018 [==================>...........] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "41418752/61824018 [===================>..........] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "41951232/61824018 [===================>..........] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "44048384/61824018 [====================>.........] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "45809664/61824018 [=====================>........] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "46145536/61824018 [=====================>........] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "48242688/61824018 [======================>.......] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "49905664/61824018 [=======================>......] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "50339840/61824018 [=======================>......] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "52436992/61824018 [========================>.....] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "54050816/61824018 [=========================>....] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "54534144/61824018 [=========================>....] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "56631296/61824018 [==========================>...] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "58236928/61824018 [===========================>..] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "58728448/61824018 [===========================>..] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "60825600/61824018 [============================>.] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "61824018/61824018 [==============================] - 7s 0us/step\n" ] } ], "source": [ "archive_path = tf_keras.utils.get_file(\"letor.zip\",\n", " \"https://download.microsoft.com/download/E/7/E/E7EABEF1-4C7B-4E31-ACE5-73927950ED5E/Letor.zip\",\n", " extract=True)\n", "\n", "# Path to a ranking ataset using libsvm format.\n", "raw_dataset_path = os.path.join(os.path.dirname(archive_path),\"OHSUMED/Data/Fold1/trainingset.txt\")" ] }, { "cell_type": "markdown", "metadata": { "id": "f5UB-Kwn6JKK" }, "source": [ "Here are the first lines of the dataset:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:11.830056Z", "iopub.status.busy": "2024-04-20T11:09:11.829797Z", "iopub.status.idle": "2024-04-20T11:09:11.955416Z", "shell.execute_reply": "2024-04-20T11:09:11.954656Z" }, "id": "pDjamyHv6K7B" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2 qid:1 1:3.00000000 2:2.07944154 3:0.27272727 4:0.26103413 5:37.33056511 6:11.43124125 7:37.29975005 8:1.13865735 9:15.52428944 10:8.83129655 11:12.00000000 12:5.37527841 13:0.08759124 14:0.08649364 15:28.30306459 16:9.34002375 17:24.80878473 18:0.39309068 19:57.41651698 20:3.29489291 21:25.02310000 22:3.21979940 23:-3.87098000 24:-3.90273000 25:-3.87512000 #docid = 40626\r", "\r\n", "0 qid:1 1:3.00000000 2:2.07944154 3:0.42857143 4:0.40059418 5:37.33056511 6:11.43124125 7:37.29975005 8:1.81447983 9:17.45499227 10:11.61793065 11:10.00000000 12:5.19295685 13:0.08547009 14:0.08453711 15:28.30306459 16:9.34002375 17:24.80878473 18:0.34920457 19:43.24062605 20:2.65472417 21:23.49030000 22:3.15658757 23:-3.96838000 24:-4.00865000 25:-3.98670000 #docid = 11852\r", "\r\n", "2 qid:1 1:0.00000000 2:0.00000000 3:0.00000000 4:0.00000000 5:37.33056511 6:11.43124125 7:37.29975005 8:0.00000000 9:0.00000000 10:0.00000000 11:8.00000000 12:4.38202663 13:0.07692308 14:0.07601813 15:28.30306459 16:9.34002375 17:24.80878473 18:0.24031887 19:25.81698944 20:1.55134225 21:15.86500000 22:2.76411543 23:-4.28166000 24:-4.33313000 25:-4.44161000 #docid = 12693\r", "\r\n", "2 qid:1 1:4.00000000 2:2.77258872 3:0.33333333 4:0.32017083 5:37.33056511 6:11.43124125 7:37.29975005 8:1.26080803 9:17.97524177 10:8.86378153 11:3.00000000 12:1.79175947 13:0.03409091 14:0.03377241 15:28.30306459 16:9.34002375 17:24.80878473 18:0.11149640 19:10.09242586 20:0.64975836 21:14.27780000 22:2.65870588 23:-4.77772000 24:-4.73563000 25:-4.86759000 #docid = 12694\r", "\r\n", "0 qid:1 1:0.00000000 2:0.00000000 3:0.00000000 4:0.00000000 5:37.33056511 6:11.43124125 7:37.29975005 8:0.00000000 9:0.00000000 10:0.00000000 11:6.00000000 12:3.87120101 13:0.04761905 14:0.04736907 15:28.30306459 16:9.34002375 17:24.80878473 18:0.18210403 19:23.54629629 20:1.62139253 21:15.27640000 22:2.72630915 23:-4.43073000 24:-4.45985000 25:-4.57053000 #docid = 15450\r", "\r\n", "1 qid:1 1:1.00000000 2:0.69314718 3:0.14285714 4:0.13353139 5:37.33056511 6:11.43124125 7:37.29975005 8:0.62835774 9:6.12170704 10:4.15689134 11:10.00000000 12:4.43081680 13:0.08333333 14:0.08191707 15:28.30306459 16:9.34002375 17:24.80878473 18:0.32796715 19:43.13226482 20:2.12249256 21:16.33990000 22:2.79360997 23:-4.75652000 24:-4.66814000 25:-4.82965000 #docid = 17665\r", "\r\n", "0 qid:1 1:0.00000000 2:0.00000000 3:0.00000000 4:0.00000000 5:37.33056511 6:11.43124125 7:37.29975005 8:0.00000000 9:0.00000000 10:0.00000000 11:3.00000000 12:2.07944154 13:0.05357143 14:0.05309873 15:28.30306459 16:9.34002375 17:24.80878473 18:0.25524876 19:14.92698785 20:2.59607100 21:16.00510000 22:2.77290742 23:-4.54349000 24:-4.52334000 25:-4.69865000 #docid = 18432\r", "\r\n", "0 qid:1 1:1.00000000 2:0.69314718 3:0.50000000 4:0.40546511 5:37.33056511 6:11.43124125 7:37.29975005 8:1.07964603 9:3.88727484 10:3.22375250 11:14.00000000 12:5.19295685 13:0.10000000 14:0.09779062 15:28.30306459 16:9.34002375 17:24.80878473 18:0.41939944 19:65.74099134 20:2.81316484 21:20.37810000 22:3.01446079 23:-4.25087000 24:-4.18235000 25:-4.21824000 #docid = 18540\r", "\r\n", "0 qid:1 1:3.00000000 2:2.07944154 3:0.60000000 4:0.54696467 5:37.33056511 6:11.43124125 7:37.29975005 8:2.13084866 9:15.65986863 10:10.90521468 11:9.00000000 12:4.68213123 13:0.13043478 14:0.12827973 15:28.30306459 16:9.34002375 17:24.80878473 18:0.44756051 19:33.23097043 20:2.69791902 21:21.26510000 22:3.05706723 23:-4.18472000 24:-4.18399000 25:-4.03491000 #docid = 44695\r", "\r\n", "0 qid:1 1:0.00000000 2:0.00000000 3:0.00000000 4:0.00000000 5:37.33056511 6:11.43124125 7:37.29975005 8:0.00000000 9:0.00000000 10:0.00000000 11:16.00000000 12:6.10479323 13:0.11510791 14:0.11289797 15:28.30306459 16:9.34002375 17:24.80878473 18:0.37130060 19:55.42216381 20:2.33077909 21:19.21490000 22:2.95568602 23:-4.09988000 24:-4.15679000 25:-4.17349000 #docid = 18541\r", "\r\n" ] } ], "source": [ "!head {raw_dataset_path}" ] }, { "cell_type": "markdown", "metadata": { "id": "rcManr98ZGID" }, "source": [ "The first step is to convert this dataset to the \"flat\" format mentioned above." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:11.959628Z", "iopub.status.busy": "2024-04-20T11:09:11.958952Z", "iopub.status.idle": "2024-04-20T11:09:12.096376Z", "shell.execute_reply": "2024-04-20T11:09:12.095777Z" }, "id": "mkiM9HJox-e8" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
relevancegroupf_1f_2f_3f_4f_5f_6f_7f_8...f_16f_17f_18f_19f_20f_21f_22f_23f_24f_25
02g_13.02.0794420.2727270.26103437.33056511.43124137.299751.138657...9.34002424.8087850.39309157.4165173.29489325.02313.219799-3.87098-3.90273-3.87512
10g_13.02.0794420.4285710.40059437.33056511.43124137.299751.814480...9.34002424.8087850.34920543.2406262.65472423.49033.156588-3.96838-4.00865-3.98670
22g_10.00.0000000.0000000.00000037.33056511.43124137.299750.000000...9.34002424.8087850.24031925.8169891.55134215.86502.764115-4.28166-4.33313-4.44161
\n", "

3 rows × 27 columns

\n", "
" ], "text/plain": [ " relevance group f_1 f_2 f_3 f_4 f_5 f_6 \\\n", "0 2 g_1 3.0 2.079442 0.272727 0.261034 37.330565 11.431241 \n", "1 0 g_1 3.0 2.079442 0.428571 0.400594 37.330565 11.431241 \n", "2 2 g_1 0.0 0.000000 0.000000 0.000000 37.330565 11.431241 \n", "\n", " f_7 f_8 ... f_16 f_17 f_18 f_19 \\\n", "0 37.29975 1.138657 ... 9.340024 24.808785 0.393091 57.416517 \n", "1 37.29975 1.814480 ... 9.340024 24.808785 0.349205 43.240626 \n", "2 37.29975 0.000000 ... 9.340024 24.808785 0.240319 25.816989 \n", "\n", " f_20 f_21 f_22 f_23 f_24 f_25 \n", "0 3.294893 25.0231 3.219799 -3.87098 -3.90273 -3.87512 \n", "1 2.654724 23.4903 3.156588 -3.96838 -4.00865 -3.98670 \n", "2 1.551342 15.8650 2.764115 -4.28166 -4.33313 -4.44161 \n", "\n", "[3 rows x 27 columns]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def convert_libsvm_to_csv(src_path, dst_path):\n", " \"\"\"Converts a libsvm ranking dataset into a flat csv file.\n", " \n", " Note: This code is specific to the LETOR3 dataset.\n", " \"\"\"\n", " dst_handle = open(dst_path, \"w\")\n", " first_line = True\n", " for src_line in open(src_path,\"r\"):\n", " # Note: The last 3 items are comments.\n", " items = src_line.split(\" \")[:-3]\n", " relevance = items[0]\n", " group = items[1].split(\":\")[1]\n", " features = [ item.split(\":\") for item in items[2:]]\n", "\n", " if first_line:\n", " # Csv header\n", " dst_handle.write(\"relevance,group,\" + \",\".join([\"f_\" + feature[0] for feature in features]) + \"\\n\")\n", " first_line = False\n", " dst_handle.write(relevance + \",g_\" + group + \",\" + (\",\".join([feature[1] for feature in features])) + \"\\n\")\n", " dst_handle.close()\n", "\n", "# Convert the dataset.\n", "csv_dataset_path=\"/tmp/ohsumed.csv\"\n", "convert_libsvm_to_csv(raw_dataset_path, csv_dataset_path)\n", "\n", "# Load a dataset into a Pandas Dataframe.\n", "dataset_df = pd.read_csv(csv_dataset_path)\n", "\n", "# Display the first 3 examples.\n", "dataset_df.head(3)" ] }, { "cell_type": "markdown", "metadata": { "id": "jdelXEgw6bsq" }, "source": [ "In this dataset, each row represents a pair of query/document (called \"group\"). The \"relevance\" tells how much the query matches the document.\n", "\n", "The features of the query and the document are merged together in \"f1-25\". The exact definition of the features is not known, but it would be omething like:\n", "\n", "- Number of words in queries\n", "- Number of common words between the query and the document\n", "- Cosinus similarity between an embedding of the query and an embedding of the document.\n", "- ...\n", "\n", "Let's convert the Pandas Dataframe into a TensorFlow Dataset:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:12.099664Z", "iopub.status.busy": "2024-04-20T11:09:12.099433Z", "iopub.status.idle": "2024-04-20T11:09:14.480135Z", "shell.execute_reply": "2024-04-20T11:09:14.479406Z" }, "id": "5QMbBkCEXxu_" }, "outputs": [], "source": [ "dataset_ds = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label=\"relevance\", task=tfdf.keras.Task.RANKING)" ] }, { "cell_type": "markdown", "metadata": { "id": "sOPVcfs_7xxf" }, "source": [ "Let's configure and train our Ranking model." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:14.484063Z", "iopub.status.busy": "2024-04-20T11:09:14.483802Z", "iopub.status.idle": "2024-04-20T11:09:23.351465Z", "shell.execute_reply": "2024-04-20T11:09:23.350733Z" }, "id": "Ba1gb75SX1rr" }, "outputs": [ { "data": { "application/javascript": [ "google.colab.output.setIframeHeight(0, true, {maxHeight: 400})" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Use /tmpfs/tmp/tmpzqjjgty3 as temporary training directory\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Reading training dataset...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[WARNING 24-04-20 11:09:14.5069 UTC gradient_boosted_trees.cc:1840] \"goss_alpha\" set but \"sampling_method\" not equal to \"GOSS\".\n", "[WARNING 24-04-20 11:09:14.5069 UTC gradient_boosted_trees.cc:1851] \"goss_beta\" set but \"sampling_method\" not equal to \"GOSS\".\n", "[WARNING 24-04-20 11:09:14.5069 UTC gradient_boosted_trees.cc:1865] \"selective_gradient_boosting_ratio\" set but \"sampling_method\" not equal to \"SELGB\".\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training dataset read in 0:00:03.986733. Found 9219 examples.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 0:00:00.757738\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Compiling model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[INFO 24-04-20 11:09:19.2736 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmpzqjjgty3/model/ with prefix fa7585ffd7c24e56\n", "[INFO 24-04-20 11:09:19.2748 UTC quick_scorer_extended.cc:911] The binary was compiled without AVX2 support, but your CPU supports it. Enable it for faster model inference.\n", "[INFO 24-04-20 11:09:19.2749 UTC abstract_model.cc:1344] Engine \"GradientBoostedTreesQuickScorerExtended\" built\n", "[INFO 24-04-20 11:09:19.2749 UTC kernel.cc:1061] Use fast generic engine\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model compiled.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%set_cell_height 400\n", "\n", "model = tfdf.keras.GradientBoostedTreesModel(\n", " task=tfdf.keras.Task.RANKING,\n", " ranking_group=\"group\",\n", " num_trees=50)\n", "\n", "model.fit(dataset_ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "kz9kdege8T_y" }, "source": [ "We can now look at the quality of the model on the validation dataset. By default, TF-DF trains ranking models to optimize the [NDCG](https://en.wikipedia.org/wiki/Discounted_cumulative_gain). The NDCG is a value between 0 and 1, where 1 is the perfect score. For this reason, -NDCG is the model loss." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:23.354714Z", "iopub.status.busy": "2024-04-20T11:09:23.354431Z", "iopub.status.idle": "2024-04-20T11:09:24.273912Z", "shell.execute_reply": "2024-04-20T11:09:24.273200Z" }, "id": "lt5qysPs8zex" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "logs = model.make_inspector().training_logs()\n", "\n", "plt.figure(figsize=(12, 4))\n", "\n", "plt.subplot(1, 2, 1)\n", "plt.plot([log.num_trees for log in logs], [log.evaluation.ndcg for log in logs])\n", "plt.xlabel(\"Number of trees\")\n", "plt.ylabel(\"NDCG (validation)\")\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.plot([log.num_trees for log in logs], [log.evaluation.loss for log in logs])\n", "plt.xlabel(\"Number of trees\")\n", "plt.ylabel(\"Loss (validation)\")\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "eq1E_26Y8rtQ" }, "source": [ "As for all TF-DF models, you can also look at the model report (Note: The model report also contains the training logs):" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:24.277872Z", "iopub.status.busy": "2024-04-20T11:09:24.277256Z", "iopub.status.idle": "2024-04-20T11:09:24.288341Z", "shell.execute_reply": "2024-04-20T11:09:24.287764Z" }, "id": "L4N1R8fM4jFh" }, "outputs": [ { "data": { "application/javascript": [ "google.colab.output.setIframeHeight(0, true, {maxHeight: 400})" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"gradient_boosted_trees_model\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 1 (1.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 1 (1.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Type: \"GRADIENT_BOOSTED_TREES\"\n", "Task: RANKING\n", "Label: \"__LABEL\"\n", "Rank group: \"group\"\n", "\n", "Input Features (25):\n", "\tf_1\n", "\tf_10\n", "\tf_11\n", "\tf_12\n", "\tf_13\n", "\tf_14\n", "\tf_15\n", "\tf_16\n", "\tf_17\n", "\tf_18\n", "\tf_19\n", "\tf_2\n", "\tf_20\n", "\tf_21\n", "\tf_22\n", "\tf_23\n", "\tf_24\n", "\tf_25\n", "\tf_3\n", "\tf_4\n", "\tf_5\n", "\tf_6\n", "\tf_7\n", "\tf_8\n", "\tf_9\n", "\n", "No weights\n", "\n", "Variable Importance: INV_MEAN_MIN_DEPTH:\n", " 1. \"f_9\" 0.326164 ################\n", " 2. \"f_3\" 0.318071 ###############\n", " 3. \"f_8\" 0.308922 #############\n", " 4. \"f_4\" 0.271175 #########\n", " 5. \"f_19\" 0.221570 ###\n", " 6. \"f_10\" 0.215666 ##\n", " 7. \"f_11\" 0.206509 #\n", " 8. \"f_22\" 0.204742 #\n", " 9. \"f_25\" 0.204497 #\n", " 10. \"f_23\" 0.203238 \n", " 11. \"f_21\" 0.200830 \n", " 12. \"f_24\" 0.200445 \n", " 13. \"f_12\" 0.198840 \n", " 14. \"f_18\" 0.197676 \n", " 15. \"f_20\" 0.196634 \n", " 16. \"f_6\" 0.196085 \n", " 17. \"f_16\" 0.196061 \n", " 18. \"f_2\" 0.195683 \n", " 19. \"f_5\" 0.195683 \n", " 20. \"f_13\" 0.195559 \n", " 21. \"f_17\" 0.195559 \n", "\n", "Variable Importance: NUM_AS_ROOT:\n", " 1. \"f_3\" 4.000000 ################\n", " 2. \"f_4\" 4.000000 ################\n", " 3. \"f_8\" 3.000000 ##########\n", " 4. \"f_9\" 1.000000 \n", "\n", "Variable Importance: NUM_NODES:\n", " 1. \"f_8\" 25.000000 ################\n", " 2. \"f_19\" 18.000000 ###########\n", " 3. \"f_10\" 15.000000 #########\n", " 4. \"f_9\" 14.000000 ########\n", " 5. \"f_3\" 13.000000 ########\n", " 6. \"f_23\" 7.000000 ####\n", " 7. \"f_24\" 6.000000 ###\n", " 8. \"f_11\" 5.000000 ##\n", " 9. \"f_21\" 5.000000 ##\n", " 10. \"f_25\" 5.000000 ##\n", " 11. \"f_4\" 5.000000 ##\n", " 12. \"f_22\" 4.000000 ##\n", " 13. \"f_12\" 3.000000 #\n", " 14. \"f_20\" 3.000000 #\n", " 15. \"f_16\" 2.000000 \n", " 16. \"f_6\" 2.000000 \n", " 17. \"f_13\" 1.000000 \n", " 18. \"f_17\" 1.000000 \n", " 19. \"f_18\" 1.000000 \n", " 20. \"f_2\" 1.000000 \n", " 21. \"f_5\" 1.000000 \n", "\n", "Variable Importance: SUM_SCORE:\n", " 1. \"f_8\" 10779.340861 ################\n", " 2. \"f_9\" 8831.772410 #############\n", " 3. \"f_3\" 4526.101184 ######\n", " 4. \"f_4\" 4360.245403 ######\n", " 5. \"f_19\" 2325.288894 ###\n", " 6. \"f_10\" 1881.848369 ##\n", " 7. \"f_21\" 1674.980191 ##\n", " 8. \"f_11\" 1127.632256 #\n", " 9. \"f_23\" 1021.834252 #\n", " 10. \"f_24\" 914.851512 #\n", " 11. \"f_22\" 885.619576 #\n", " 12. \"f_25\" 748.665007 #\n", " 13. \"f_20\" 310.610858 \n", " 14. \"f_16\" 298.972842 \n", " 15. \"f_6\" 212.376573 \n", " 16. \"f_12\" 130.725240 \n", " 17. \"f_2\" 112.124991 \n", " 18. \"f_18\" 86.341193 \n", " 19. \"f_5\" 65.103908 \n", " 20. \"f_13\" 57.966947 \n", " 21. \"f_17\" 21.930388 \n", "\n", "\n", "\n", "Loss: LAMBDA_MART_NDCG5\n", "Validation loss value: -0.438692\n", "Number of trees per iteration: 1\n", "Node format: NOT_SET\n", "Number of trees: 12\n", "Total number of nodes: 286\n", "\n", "Number of nodes by tree:\n", "Count: 12 Average: 23.8333 StdDev: 3.50793\n", "Min: 17 Max: 29 Ignored: 0\n", "----------------------------------------------\n", "[ 17, 18) 1 8.33% 8.33% ###\n", "[ 18, 19) 0 0.00% 8.33%\n", "[ 19, 20) 1 8.33% 16.67% ###\n", "[ 20, 21) 0 0.00% 16.67%\n", "[ 21, 22) 2 16.67% 33.33% #######\n", "[ 22, 23) 0 0.00% 33.33%\n", "[ 23, 24) 1 8.33% 41.67% ###\n", "[ 24, 25) 0 0.00% 41.67%\n", "[ 25, 26) 3 25.00% 66.67% ##########\n", "[ 26, 27) 0 0.00% 66.67%\n", "[ 27, 28) 3 25.00% 91.67% ##########\n", "[ 28, 29) 0 0.00% 91.67%\n", "[ 29, 29] 1 8.33% 100.00% ###\n", "\n", "Depth by leafs:\n", "Count: 149 Average: 4.14094 StdDev: 1.08696\n", "Min: 1 Max: 5 Ignored: 0\n", "----------------------------------------------\n", "[ 1, 2) 2 1.34% 1.34%\n", "[ 2, 3) 18 12.08% 13.42% ##\n", "[ 3, 4) 13 8.72% 22.15% ##\n", "[ 4, 5) 40 26.85% 48.99% #####\n", "[ 5, 5] 76 51.01% 100.00% ##########\n", "\n", "Number of training obs by leaf:\n", "Count: 149 Average: 673.691 StdDev: 2015.44\n", "Min: 5 Max: 8211 Ignored: 0\n", "----------------------------------------------\n", "[ 5, 415) 127 85.23% 85.23% ##########\n", "[ 415, 825) 6 4.03% 89.26%\n", "[ 825, 1236) 2 1.34% 90.60%\n", "[ 1236, 1646) 0 0.00% 90.60%\n", "[ 1646, 2056) 0 0.00% 90.60%\n", "[ 2056, 2467) 1 0.67% 91.28%\n", "[ 2467, 2877) 0 0.00% 91.28%\n", "[ 2877, 3287) 0 0.00% 91.28%\n", "[ 3287, 3698) 1 0.67% 91.95%\n", "[ 3698, 4108) 0 0.00% 91.95%\n", "[ 4108, 4518) 0 0.00% 91.95%\n", "[ 4518, 4929) 1 0.67% 92.62%\n", "[ 4929, 5339) 0 0.00% 92.62%\n", "[ 5339, 5749) 0 0.00% 92.62%\n", "[ 5749, 6160) 1 0.67% 93.29%\n", "[ 6160, 6570) 0 0.00% 93.29%\n", "[ 6570, 6980) 0 0.00% 93.29%\n", "[ 6980, 7391) 0 0.00% 93.29%\n", "[ 7391, 7801) 8 5.37% 98.66% #\n", "[ 7801, 8211] 2 1.34% 100.00%\n", "\n", "Attribute in nodes:\n", "\t25 : f_8 [NUMERICAL]\n", "\t18 : f_19 [NUMERICAL]\n", "\t15 : f_10 [NUMERICAL]\n", "\t14 : f_9 [NUMERICAL]\n", "\t13 : f_3 [NUMERICAL]\n", "\t7 : f_23 [NUMERICAL]\n", "\t6 : f_24 [NUMERICAL]\n", "\t5 : f_4 [NUMERICAL]\n", "\t5 : f_25 [NUMERICAL]\n", "\t5 : f_21 [NUMERICAL]\n", "\t5 : f_11 [NUMERICAL]\n", "\t4 : f_22 [NUMERICAL]\n", "\t3 : f_20 [NUMERICAL]\n", "\t3 : f_12 [NUMERICAL]\n", "\t2 : f_6 [NUMERICAL]\n", "\t2 : f_16 [NUMERICAL]\n", "\t1 : f_5 [NUMERICAL]\n", "\t1 : f_2 [NUMERICAL]\n", "\t1 : f_18 [NUMERICAL]\n", "\t1 : f_17 [NUMERICAL]\n", "\t1 : f_13 [NUMERICAL]\n", "\n", "Attribute in nodes with depth <= 0:\n", "\t4 : f_4 [NUMERICAL]\n", "\t4 : f_3 [NUMERICAL]\n", "\t3 : f_8 [NUMERICAL]\n", "\t1 : f_9 [NUMERICAL]\n", "\n", "Attribute in nodes with depth <= 1:\n", "\t11 : f_9 [NUMERICAL]\n", "\t9 : f_8 [NUMERICAL]\n", "\t4 : f_4 [NUMERICAL]\n", "\t4 : f_3 [NUMERICAL]\n", "\t1 : f_25 [NUMERICAL]\n", "\t1 : f_24 [NUMERICAL]\n", "\t1 : f_23 [NUMERICAL]\n", "\t1 : f_22 [NUMERICAL]\n", "\t1 : f_19 [NUMERICAL]\n", "\t1 : f_11 [NUMERICAL]\n", "\n", "Attribute in nodes with depth <= 2:\n", "\t15 : f_8 [NUMERICAL]\n", "\t12 : f_9 [NUMERICAL]\n", "\t11 : f_3 [NUMERICAL]\n", "\t6 : f_19 [NUMERICAL]\n", "\t5 : f_4 [NUMERICAL]\n", "\t2 : f_25 [NUMERICAL]\n", "\t2 : f_11 [NUMERICAL]\n", "\t2 : f_10 [NUMERICAL]\n", "\t1 : f_24 [NUMERICAL]\n", "\t1 : f_23 [NUMERICAL]\n", "\t1 : f_22 [NUMERICAL]\n", "\t1 : f_18 [NUMERICAL]\n", "\t1 : f_17 [NUMERICAL]\n", "\n", "Attribute in nodes with depth <= 3:\n", "\t22 : f_8 [NUMERICAL]\n", "\t13 : f_9 [NUMERICAL]\n", "\t11 : f_3 [NUMERICAL]\n", "\t10 : f_19 [NUMERICAL]\n", "\t9 : f_10 [NUMERICAL]\n", "\t5 : f_4 [NUMERICAL]\n", "\t5 : f_23 [NUMERICAL]\n", "\t5 : f_11 [NUMERICAL]\n", "\t4 : f_25 [NUMERICAL]\n", "\t4 : f_22 [NUMERICAL]\n", "\t4 : f_21 [NUMERICAL]\n", "\t3 : f_24 [NUMERICAL]\n", "\t2 : f_12 [NUMERICAL]\n", "\t1 : f_18 [NUMERICAL]\n", "\t1 : f_17 [NUMERICAL]\n", "\n", "Attribute in nodes with depth <= 5:\n", "\t25 : f_8 [NUMERICAL]\n", "\t18 : f_19 [NUMERICAL]\n", "\t15 : f_10 [NUMERICAL]\n", "\t14 : f_9 [NUMERICAL]\n", "\t13 : f_3 [NUMERICAL]\n", "\t7 : f_23 [NUMERICAL]\n", "\t6 : f_24 [NUMERICAL]\n", "\t5 : f_4 [NUMERICAL]\n", "\t5 : f_25 [NUMERICAL]\n", "\t5 : f_21 [NUMERICAL]\n", "\t5 : f_11 [NUMERICAL]\n", "\t4 : f_22 [NUMERICAL]\n", "\t3 : f_20 [NUMERICAL]\n", "\t3 : f_12 [NUMERICAL]\n", "\t2 : f_6 [NUMERICAL]\n", "\t2 : f_16 [NUMERICAL]\n", "\t1 : f_5 [NUMERICAL]\n", "\t1 : f_2 [NUMERICAL]\n", "\t1 : f_18 [NUMERICAL]\n", "\t1 : f_17 [NUMERICAL]\n", "\t1 : f_13 [NUMERICAL]\n", "\n", "Condition type in nodes:\n", "\t137 : HigherCondition\n", "Condition type in nodes with depth <= 0:\n", "\t12 : HigherCondition\n", "Condition type in nodes with depth <= 1:\n", "\t34 : HigherCondition\n", "Condition type in nodes with depth <= 2:\n", "\t60 : HigherCondition\n", "Condition type in nodes with depth <= 3:\n", "\t99 : HigherCondition\n", "Condition type in nodes with depth <= 5:\n", "\t137 : HigherCondition\n", "\n", "Training logs:\n", "Number of iteration to final model: 12\n", "\tIter:1 train-loss:-0.346669 valid-loss:-0.262935 train-NDCG@5:0.346669 valid-NDCG@5:0.262935\n", "\tIter:2 train-loss:-0.412635 valid-loss:-0.335301 train-NDCG@5:0.412635 valid-NDCG@5:0.335301\n", "\tIter:3 train-loss:-0.468270 valid-loss:-0.341295 train-NDCG@5:0.468270 valid-NDCG@5:0.341295\n", "\tIter:4 train-loss:-0.481511 valid-loss:-0.301897 train-NDCG@5:0.481511 valid-NDCG@5:0.301897\n", "\tIter:5 train-loss:-0.473165 valid-loss:-0.394670 train-NDCG@5:0.473165 valid-NDCG@5:0.394670\n", "\tIter:6 train-loss:-0.496260 valid-loss:-0.415201 train-NDCG@5:0.496260 valid-NDCG@5:0.415201\n", "\tIter:16 train-loss:-0.526791 valid-loss:-0.380900 train-NDCG@5:0.526791 valid-NDCG@5:0.380900\n", "\tIter:26 train-loss:-0.560398 valid-loss:-0.367496 train-NDCG@5:0.560398 valid-NDCG@5:0.367496\n", "\tIter:36 train-loss:-0.584252 valid-loss:-0.341845 train-NDCG@5:0.584252 valid-NDCG@5:0.341845\n", "\n" ] } ], "source": [ "%set_cell_height 400\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "gCWwJwkw9aHB" }, "source": [ "And if you are curious, you can also plot the model:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:24.292907Z", "iopub.status.busy": "2024-04-20T11:09:24.292662Z", "iopub.status.idle": "2024-04-20T11:09:24.299797Z", "shell.execute_reply": "2024-04-20T11:09:24.299130Z" }, "id": "8wdUhz9X9cbI" }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "
\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "DjViB8SK1BrY" }, "source": [ "## Predicting with a ranking model\n", "\n", "For an incoming query, we can use our ranking model to predict the relevance of\n", "a stack of documents. In practice this means that for each query, we must come\n", "up with a set of documents that may or may not be relevant to the query. We call\n", "these documents our **candidate documents**. For each pair\n", "query/candidate document, we can compute the same features used during training.\n", "This is our **serving dataset**.\n", "\n", "Going back to the example from the beginning of this tutorial, the serving\n", "dataset might look like this:\n", "\n", "query | document_id | feature_1 | feature_2\n", "------ | ----------- | --------- | --------- \n", "fish | 32 | 0.3 | blue \n", "fish | 33 | 1.0 | green \n", "fish | 34 | 0.4 | blue \n", "fish | 35 | NA | brown \n", "\n", "Observe that *relevance* is not part of the serving dataset, since this is what\n", "the model is trying to predict.\n", "\n", "The serving dataset is fed to the TF-DF model and assigns a relevance score to\n", "each document.\n", "\n", "query | document_id | feature_1 | feature_2 | relevance\n", "------ | ----------- | --------- | --------- | ---------\n", "fish | 32 | 0.3 | blue | 0.325\n", "fish | 33 | 1.0 | green | 0.125\n", "fish | 34 | 0.4 | blue | 0.155\n", "fish | 35 | NA | brown | 0.593\n", "\n", "This means that the document with document_id 35 is predicted to be most\n", "relevant for query \"fish\"." ] }, { "cell_type": "markdown", "metadata": { "id": "WXBPhvfr1BrY" }, "source": [ "Let's try to do this with our real model." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:24.303079Z", "iopub.status.busy": "2024-04-20T11:09:24.302449Z", "iopub.status.idle": "2024-04-20T11:09:24.434750Z", "shell.execute_reply": "2024-04-20T11:09:24.433967Z" }, "id": "cArmmZZ81BrY" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
relevancegroupf_1f_2f_3f_4f_5f_6f_7f_8...f_16f_17f_18f_19f_20f_21f_22f_23f_24f_25
02g_13.02.0794420.2727270.26103437.33056511.43124137.299751.138657...9.34002424.8087850.39309157.4165173.29489325.02313.219799-3.87098-3.90273-3.87512
10g_13.02.0794420.4285710.40059437.33056511.43124137.299751.814480...9.34002424.8087850.34920543.2406262.65472423.49033.156588-3.96838-4.00865-3.98670
22g_10.00.0000000.0000000.00000037.33056511.43124137.299750.000000...9.34002424.8087850.24031925.8169891.55134215.86502.764115-4.28166-4.33313-4.44161
\n", "

3 rows × 27 columns

\n", "
" ], "text/plain": [ " relevance group f_1 f_2 f_3 f_4 f_5 f_6 \\\n", "0 2 g_1 3.0 2.079442 0.272727 0.261034 37.330565 11.431241 \n", "1 0 g_1 3.0 2.079442 0.428571 0.400594 37.330565 11.431241 \n", "2 2 g_1 0.0 0.000000 0.000000 0.000000 37.330565 11.431241 \n", "\n", " f_7 f_8 ... f_16 f_17 f_18 f_19 \\\n", "0 37.29975 1.138657 ... 9.340024 24.808785 0.393091 57.416517 \n", "1 37.29975 1.814480 ... 9.340024 24.808785 0.349205 43.240626 \n", "2 37.29975 0.000000 ... 9.340024 24.808785 0.240319 25.816989 \n", "\n", " f_20 f_21 f_22 f_23 f_24 f_25 \n", "0 3.294893 25.0231 3.219799 -3.87098 -3.90273 -3.87512 \n", "1 2.654724 23.4903 3.156588 -3.96838 -4.00865 -3.98670 \n", "2 1.551342 15.8650 2.764115 -4.28166 -4.33313 -4.44161 \n", "\n", "[3 rows x 27 columns]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Path to a test dataset using libsvm format.\n", "test_dataset_path = os.path.join(os.path.dirname(archive_path),\"OHSUMED/Data/Fold1/testset.txt\")\n", "# Convert the dataset.\n", "csv_test_dataset_path=\"/tmp/ohsumed_test.csv\"\n", "convert_libsvm_to_csv(raw_dataset_path, csv_test_dataset_path)\n", "\n", "# Load a dataset into a Pandas Dataframe.\n", "test_dataset_df = pd.read_csv(csv_test_dataset_path)\n", "\n", "# Display the first 3 examples.\n", "test_dataset_df.head(3)" ] }, { "cell_type": "markdown", "metadata": { "id": "SXty9G-P1BrY" }, "source": [ "Suppose our query is \"g_5\" and the test dataset already contains the candidate\n", "documents for this query." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:24.438041Z", "iopub.status.busy": "2024-04-20T11:09:24.437513Z", "iopub.status.idle": "2024-04-20T11:09:24.482092Z", "shell.execute_reply": "2024-04-20T11:09:24.481442Z" }, "id": "IBfLDHig1BrY" }, "outputs": [], "source": [ "# Filter by \"g_5\"\n", "serving_dataset_df = test_dataset_df[test_dataset_df['group'] == 'g_5']\n", "# Remove the columns for group and relevance, not needed for predictions.\n", "serving_dataset_df = serving_dataset_df.drop(['relevance', 'group'], axis=1)\n", "# Convert to a Tensorflow dataset\n", "serving_dataset_ds = tfdf.keras.pd_dataframe_to_tf_dataset(serving_dataset_df, task=tfdf.keras.Task.RANKING)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:24.485124Z", "iopub.status.busy": "2024-04-20T11:09:24.484901Z", "iopub.status.idle": "2024-04-20T11:09:24.673274Z", "shell.execute_reply": "2024-04-20T11:09:24.672361Z" }, "id": "_EzlVuk01BrY" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 176ms/step\n" ] } ], "source": [ "# Run predictions with on all candidate documents\n", "predictions = model.predict(serving_dataset_ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "LRqRnruM1BrY" }, "source": [ "We can use add the predictions to the dataframe and use them to find the \n", "documents with the highest scores. " ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2024-04-20T11:09:24.676926Z", "iopub.status.busy": "2024-04-20T11:09:24.676472Z", "iopub.status.idle": "2024-04-20T11:09:24.697665Z", "shell.execute_reply": "2024-04-20T11:09:24.696816Z" }, "id": "IrkLkIWG1BrY" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
f_1f_2f_3f_4f_5f_6f_7f_8f_9f_10...f_17f_18f_19f_20f_21f_22f_23f_24f_25prediction_score
6422.01.3862940.6666670.57536429.4471178.43511629.4480212.20713512.29217010.101899...21.2087150.52384577.8521487.65910130.26603.410025-3.03908-3.19282-2.871120.965342
6853.02.0794420.7500000.66943129.4471178.43511629.4480213.06016421.79565717.652746...21.2087150.79368139.6232718.51380133.98303.525860-2.84235-2.81360-2.599200.893874
6464.02.7725890.2857140.27597129.4471178.43511629.4480211.42106324.55033814.727974...21.2087150.60296384.8681087.76793131.02683.434851-3.19269-3.31166-3.149010.258856
6844.02.4849070.3333330.31423629.4471178.43511629.4480211.73030429.29974415.114793...21.2087150.69289971.2796488.14880436.56453.599078-2.16625-2.43823-1.946580.258856
6403.02.0794420.4285710.40059429.4471178.43511629.4480212.10736121.79565715.999891...21.2087150.0000000.0000000.00000030.64223.422378-3.20997-2.59768-2.597680.258856
\n", "

5 rows × 26 columns

\n", "
" ], "text/plain": [ " f_1 f_2 f_3 f_4 f_5 f_6 f_7 \\\n", "642 2.0 1.386294 0.666667 0.575364 29.447117 8.435116 29.448021 \n", "685 3.0 2.079442 0.750000 0.669431 29.447117 8.435116 29.448021 \n", "646 4.0 2.772589 0.285714 0.275971 29.447117 8.435116 29.448021 \n", "684 4.0 2.484907 0.333333 0.314236 29.447117 8.435116 29.448021 \n", "640 3.0 2.079442 0.428571 0.400594 29.447117 8.435116 29.448021 \n", "\n", " f_8 f_9 f_10 ... f_17 f_18 f_19 \\\n", "642 2.207135 12.292170 10.101899 ... 21.208715 0.523845 77.852148 \n", "685 3.060164 21.795657 17.652746 ... 21.208715 0.793681 39.623271 \n", "646 1.421063 24.550338 14.727974 ... 21.208715 0.602963 84.868108 \n", "684 1.730304 29.299744 15.114793 ... 21.208715 0.692899 71.279648 \n", "640 2.107361 21.795657 15.999891 ... 21.208715 0.000000 0.000000 \n", "\n", " f_20 f_21 f_22 f_23 f_24 f_25 prediction_score \n", "642 7.659101 30.2660 3.410025 -3.03908 -3.19282 -2.87112 0.965342 \n", "685 8.513801 33.9830 3.525860 -2.84235 -2.81360 -2.59920 0.893874 \n", "646 7.767931 31.0268 3.434851 -3.19269 -3.31166 -3.14901 0.258856 \n", "684 8.148804 36.5645 3.599078 -2.16625 -2.43823 -1.94658 0.258856 \n", "640 0.000000 30.6422 3.422378 -3.20997 -2.59768 -2.59768 0.258856 \n", "\n", "[5 rows x 26 columns]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "serving_dataset_df['prediction_score'] = predictions\n", "serving_dataset_df.sort_values(by=['prediction_score'], ascending=False).head()" ] } ], "metadata": { "colab": { "name": "ranking_colab.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }