{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "RYmPh1qB_KO2" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:49.501002Z", "iopub.status.busy": "2024-07-30T11:22:49.500314Z", "iopub.status.idle": "2024-07-30T11:22:49.505059Z", "shell.execute_reply": "2024-07-30T11:22:49.504300Z" }, "id": "oMRm3czy9tLh" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "ooXoR4kx_YL9" }, "source": [ "# TF Lattice Aggregate Function Models" ] }, { "cell_type": "markdown", "metadata": { "id": "BR6XNYEXEgSU" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "-ZfQWUmfEsyZ" }, "source": [ "## Overview\n", "\n", "TFL Premade Aggregate Function Models are quick and easy ways to build TFL `keras.Model` instances for learning complex aggregation functions. This guide outlines the steps needed to construct a TFL Premade Aggregate Function Model and train/test it." ] }, { "cell_type": "markdown", "metadata": { "id": "L0lgWoB6Gmk1" }, "source": [ "## Setup\n", "\n", "Installing TF Lattice package:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:49.509218Z", "iopub.status.busy": "2024-07-30T11:22:49.508678Z", "iopub.status.idle": "2024-07-30T11:22:51.405228Z", "shell.execute_reply": "2024-07-30T11:22:51.404242Z" }, "id": "ivwKrEdLGphZ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tensorflow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.17.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tf-keras in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.17.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow-lattice\r\n", " Using cached tensorflow_lattice-2.1.0-py2.py3-none-any.whl.metadata (1.8 kB)\r\n", "Requirement already satisfied: pydot in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (3.0.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting graphviz\r\n", " Using cached graphviz-0.20.3-py3-none-any.whl.metadata (12 kB)\r\n", "Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (2.1.0)\r\n", "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=24.3.25 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (24.3.25)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (0.6.0)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (0.2.0)\r\n", "Requirement already satisfied: h5py>=3.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (3.11.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (18.1.1)\r\n", "Requirement already satisfied: ml-dtypes<0.5.0,>=0.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (0.4.0)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (24.1)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (3.20.3)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (2.32.3)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (72.1.0)\r\n", "Requirement already satisfied: six>=1.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (1.16.0)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (2.4.0)\r\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (4.12.2)\r\n", "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (1.16.0)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (1.65.1)\r\n", "Requirement already satisfied: tensorboard<2.18,>=2.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (2.17.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: keras>=3.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (3.4.1)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (0.37.1)\r\n", "Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (1.26.4)\r\n", "Requirement already satisfied: matplotlib in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-lattice) (3.9.1)\r\n", "Requirement already satisfied: pandas in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-lattice) (2.2.2)\r\n", "Requirement already satisfied: scikit-learn in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-lattice) (1.5.1)\r\n", "Requirement already satisfied: pyparsing>=3.0.9 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pydot) (3.1.2)\r\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow) (0.43.0)\r\n", "Requirement already satisfied: rich in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.2.0->tensorflow) (13.7.1)\r\n", "Requirement already satisfied: namex in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.2.0->tensorflow) (0.0.8)\r\n", "Requirement already satisfied: optree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.2.0->tensorflow) (0.12.1)\r\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow) (3.7)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow) (2.2.2)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow) (2024.7.4)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.18,>=2.17->tensorflow) (3.6)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.18,>=2.17->tensorflow) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.18,>=2.17->tensorflow) (3.0.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: contourpy>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tensorflow-lattice) (1.2.1)\r\n", "Requirement already satisfied: cycler>=0.10 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tensorflow-lattice) (0.12.1)\r\n", "Requirement already satisfied: fonttools>=4.22.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tensorflow-lattice) (4.53.1)\r\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tensorflow-lattice) (1.4.5)\r\n", "Requirement already satisfied: pillow>=8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tensorflow-lattice) (10.4.0)\r\n", "Requirement already satisfied: python-dateutil>=2.7 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tensorflow-lattice) (2.9.0.post0)\r\n", "Requirement already satisfied: importlib-resources>=3.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from matplotlib->tensorflow-lattice) (6.4.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pytz>=2020.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow-lattice) (2024.1)\r\n", "Requirement already satisfied: tzdata>=2022.7 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow-lattice) (2024.1)\r\n", "Requirement already satisfied: scipy>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from scikit-learn->tensorflow-lattice) (1.13.1)\r\n", "Requirement already satisfied: joblib>=1.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from scikit-learn->tensorflow-lattice) (1.4.2)\r\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from scikit-learn->tensorflow-lattice) (3.5.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=3.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-resources>=3.2.0->matplotlib->tensorflow-lattice) (3.19.2)\r\n", "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.18,>=2.17->tensorflow) (8.2.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.18,>=2.17->tensorflow) (2.1.5)\r\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.2.0->tensorflow) (3.0.0)\r\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.2.0->tensorflow) (2.18.0)\r\n", "Requirement already satisfied: mdurl~=0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.2.0->tensorflow) (0.1.2)\r\n", "Using cached tensorflow_lattice-2.1.0-py2.py3-none-any.whl (216 kB)\r\n", "Using cached graphviz-0.20.3-py3-none-any.whl (47 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: graphviz, tensorflow-lattice\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed graphviz-0.20.3 tensorflow-lattice-2.1.0\r\n" ] } ], "source": [ "#@test {\"skip\": true}\n", "!pip install -U tensorflow tf-keras tensorflow-lattice pydot graphviz" ] }, { "cell_type": "markdown", "metadata": { "id": "VQsRKS4wGrMu" }, "source": [ "Importing required packages:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:51.409762Z", "iopub.status.busy": "2024-07-30T11:22:51.409407Z", "iopub.status.idle": "2024-07-30T11:22:54.375068Z", "shell.execute_reply": "2024-07-30T11:22:54.374170Z" }, "id": "j41-kd4MGtDS" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-07-30 11:22:51.711202: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-07-30 11:22:51.735645: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-07-30 11:22:51.742800: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import tensorflow as tf\n", "\n", "import collections\n", "import logging\n", "import numpy as np\n", "import pandas as pd\n", "import sys\n", "import tensorflow_lattice as tfl\n", "logging.disable(sys.maxsize)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:54.379451Z", "iopub.status.busy": "2024-07-30T11:22:54.378950Z", "iopub.status.idle": "2024-07-30T11:22:54.383085Z", "shell.execute_reply": "2024-07-30T11:22:54.382432Z" }, "id": "HlJH1SMx3Vul" }, "outputs": [], "source": [ "# Use Keras 2.\n", "version_fn = getattr(tf.keras, \"version\", None)\n", "if version_fn and version_fn().startswith(\"3.\"):\n", " import tf_keras as keras\n", "else:\n", " keras = tf.keras" ] }, { "cell_type": "markdown", "metadata": { "id": "ZHPohKjBIFG5" }, "source": [ "Downloading the Puzzles dataset:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:54.386944Z", "iopub.status.busy": "2024-07-30T11:22:54.386266Z", "iopub.status.idle": "2024-07-30T11:22:54.586459Z", "shell.execute_reply": "2024-07-30T11:22:54.585785Z" }, "id": "VjYHpw2dSfHH" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
puzzle_namestar_ratingword_countis_amazonincludes_photonum_helpfulSales12-18MonthsAgo
0NightHawks510810681
1NightHawks53110581
2NightHawks55110481
3NightHawks58810481
4NightHawks43710281
\n", "
" ], "text/plain": [ " puzzle_name star_rating word_count is_amazon includes_photo \\\n", "0 NightHawks 5 108 1 0 \n", "1 NightHawks 5 31 1 0 \n", "2 NightHawks 5 51 1 0 \n", "3 NightHawks 5 88 1 0 \n", "4 NightHawks 4 37 1 0 \n", "\n", " num_helpful Sales12-18MonthsAgo \n", "0 6 81 \n", "1 5 81 \n", "2 4 81 \n", "3 4 81 \n", "4 2 81 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_dataframe = pd.read_csv(\n", " 'https://raw.githubusercontent.com/wbakst/puzzles_data/master/train.csv')\n", "train_dataframe.head()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:54.590058Z", "iopub.status.busy": "2024-07-30T11:22:54.589816Z", "iopub.status.idle": "2024-07-30T11:22:54.788977Z", "shell.execute_reply": "2024-07-30T11:22:54.788306Z" }, "id": "UOsgu3eIEur6" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
puzzle_namestar_ratingword_countis_amazonincludes_photonum_helpfulSalesLastSixMonths
0NightHawks413800040
1NightHawks511500040
2NightHawks512700040
3NightHawks510810640
4NightHawks53110540
\n", "
" ], "text/plain": [ " puzzle_name star_rating word_count is_amazon includes_photo \\\n", "0 NightHawks 4 138 0 0 \n", "1 NightHawks 5 115 0 0 \n", "2 NightHawks 5 127 0 0 \n", "3 NightHawks 5 108 1 0 \n", "4 NightHawks 5 31 1 0 \n", "\n", " num_helpful SalesLastSixMonths \n", "0 0 40 \n", "1 0 40 \n", "2 0 40 \n", "3 6 40 \n", "4 5 40 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_dataframe = pd.read_csv(\n", " 'https://raw.githubusercontent.com/wbakst/puzzles_data/master/test.csv')\n", "test_dataframe.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "XG7MPCyzVr22" }, "source": [ "Extract and convert features and labels" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:54.793035Z", "iopub.status.busy": "2024-07-30T11:22:54.792430Z", "iopub.status.idle": "2024-07-30T11:22:54.796341Z", "shell.execute_reply": "2024-07-30T11:22:54.795708Z" }, "id": "bYdJicq5bBuz" }, "outputs": [], "source": [ "# Features:\n", "# - star_rating rating out of 5 stars (1-5)\n", "# - word_count number of words in the review\n", "# - is_amazon 1 = reviewed on amazon; 0 = reviewed on artifact website\n", "# - includes_photo if the review includes a photo of the puzzle\n", "# - num_helpful number of people that found this review helpful\n", "# - num_reviews total number of reviews for this puzzle (we construct)\n", "#\n", "# This ordering of feature names will be the exact same order that we construct\n", "# our model to expect.\n", "feature_names = [\n", " 'star_rating', 'word_count', 'is_amazon', 'includes_photo', 'num_helpful',\n", " 'num_reviews'\n", "]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:54.799946Z", "iopub.status.busy": "2024-07-30T11:22:54.799406Z", "iopub.status.idle": "2024-07-30T11:22:54.809672Z", "shell.execute_reply": "2024-07-30T11:22:54.808999Z" }, "id": "kx0ZX2HR-4qb" }, "outputs": [], "source": [ "def extract_features(dataframe, label_name):\n", " # First we extract flattened features.\n", " flattened_features = {\n", " feature_name: dataframe[feature_name].values.astype(float)\n", " for feature_name in feature_names[:-1]\n", " }\n", "\n", " # Construct mapping from puzzle name to feature.\n", " star_rating = collections.defaultdict(list)\n", " word_count = collections.defaultdict(list)\n", " is_amazon = collections.defaultdict(list)\n", " includes_photo = collections.defaultdict(list)\n", " num_helpful = collections.defaultdict(list)\n", " labels = {}\n", "\n", " # Extract each review.\n", " for i in range(len(dataframe)):\n", " row = dataframe.iloc[i]\n", " puzzle_name = row['puzzle_name']\n", " star_rating[puzzle_name].append(float(row['star_rating']))\n", " word_count[puzzle_name].append(float(row['word_count']))\n", " is_amazon[puzzle_name].append(float(row['is_amazon']))\n", " includes_photo[puzzle_name].append(float(row['includes_photo']))\n", " num_helpful[puzzle_name].append(float(row['num_helpful']))\n", " labels[puzzle_name] = float(row[label_name])\n", "\n", " # Organize data into list of list of features.\n", " names = list(star_rating.keys())\n", " star_rating = [star_rating[name] for name in names]\n", " word_count = [word_count[name] for name in names]\n", " is_amazon = [is_amazon[name] for name in names]\n", " includes_photo = [includes_photo[name] for name in names]\n", " num_helpful = [num_helpful[name] for name in names]\n", " num_reviews = [[len(ratings)] * len(ratings) for ratings in star_rating]\n", " labels = [labels[name] for name in names]\n", "\n", " # Flatten num_reviews\n", " flattened_features['num_reviews'] = [len(reviews) for reviews in num_reviews]\n", "\n", " # Convert data into ragged tensors.\n", " star_rating = tf.ragged.constant(star_rating)\n", " word_count = tf.ragged.constant(word_count)\n", " is_amazon = tf.ragged.constant(is_amazon)\n", " includes_photo = tf.ragged.constant(includes_photo)\n", " num_helpful = tf.ragged.constant(num_helpful)\n", " num_reviews = tf.ragged.constant(num_reviews)\n", " labels = tf.constant(labels)\n", "\n", " # Now we can return our extracted data.\n", " return (star_rating, word_count, is_amazon, includes_photo, num_helpful,\n", " num_reviews), labels, flattened_features" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:54.813182Z", "iopub.status.busy": "2024-07-30T11:22:54.812628Z", "iopub.status.idle": "2024-07-30T11:22:55.131115Z", "shell.execute_reply": "2024-07-30T11:22:55.130248Z" }, "id": "Nd6j_J5CbNiz" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-07-30 11:22:54.958021: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n" ] } ], "source": [ "train_xs, train_ys, flattened_features = extract_features(train_dataframe, 'Sales12-18MonthsAgo')\n", "test_xs, test_ys, _ = extract_features(test_dataframe, 'SalesLastSixMonths')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:55.135201Z", "iopub.status.busy": "2024-07-30T11:22:55.134615Z", "iopub.status.idle": "2024-07-30T11:22:55.139627Z", "shell.execute_reply": "2024-07-30T11:22:55.138811Z" }, "id": "KfHHhCRsHejl" }, "outputs": [], "source": [ "# Let's define our label minimum and maximum.\n", "min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))\n", "min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))" ] }, { "cell_type": "markdown", "metadata": { "id": "9TwqlRirIhAq" }, "source": [ "Setting the default values used for training in this guide:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:55.143412Z", "iopub.status.busy": "2024-07-30T11:22:55.142905Z", "iopub.status.idle": "2024-07-30T11:22:55.146714Z", "shell.execute_reply": "2024-07-30T11:22:55.146043Z" }, "id": "GckmXFzRIhdD" }, "outputs": [], "source": [ "LEARNING_RATE = 0.1\n", "BATCH_SIZE = 128\n", "NUM_EPOCHS = 500\n", "MIDDLE_DIM = 3\n", "MIDDLE_LATTICE_SIZE = 2\n", "MIDDLE_KEYPOINTS = 16\n", "OUTPUT_KEYPOINTS = 8" ] }, { "cell_type": "markdown", "metadata": { "id": "TpDKon4oIh2W" }, "source": [ "## Feature Configs\n", "\n", "Feature calibration and per-feature configurations are set using [tfl.configs.FeatureConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/FeatureConfig). Feature configurations include monotonicity constraints, per-feature regularization (see [tfl.configs.RegularizerConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/RegularizerConfig)), and lattice sizes for lattice models.\n", "\n", "Note that we must fully specify the feature config for any feature that we want our model to recognize. Otherwise the model will have no way of knowing that such a feature exists. For aggregation models, these features will automaticaly be considered and properly handled as ragged." ] }, { "cell_type": "markdown", "metadata": { "id": "_IMwcDh7Xs5n" }, "source": [ "### Compute Quantiles\n", "\n", "Although the default setting for `pwl_calibration_input_keypoints` in `tfl.configs.FeatureConfig` is 'quantiles', for premade models we have to manually define the input keypoints. To do so, we first define our own helper function for computing quantiles." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:55.150555Z", "iopub.status.busy": "2024-07-30T11:22:55.150055Z", "iopub.status.idle": "2024-07-30T11:22:55.155739Z", "shell.execute_reply": "2024-07-30T11:22:55.155106Z" }, "id": "l0uYl9ZpXtW1" }, "outputs": [], "source": [ "def compute_quantiles(features,\n", " num_keypoints=10,\n", " clip_min=None,\n", " clip_max=None,\n", " missing_value=None):\n", " # Clip min and max if desired.\n", " if clip_min is not None:\n", " features = np.maximum(features, clip_min)\n", " features = np.append(features, clip_min)\n", " if clip_max is not None:\n", " features = np.minimum(features, clip_max)\n", " features = np.append(features, clip_max)\n", " # Make features unique.\n", " unique_features = np.unique(features)\n", " # Remove missing values if specified.\n", " if missing_value is not None:\n", " unique_features = np.delete(unique_features,\n", " np.where(unique_features == missing_value))\n", " # Compute and return quantiles over unique non-missing feature values.\n", " return np.quantile(\n", " unique_features,\n", " np.linspace(0., 1., num=num_keypoints),\n", " interpolation='nearest').astype(float)" ] }, { "cell_type": "markdown", "metadata": { "id": "9oYZdVeWEhf2" }, "source": [ "### Defining Our Feature Configs\n", "\n", "Now that we can compute our quantiles, we define a feature config for each feature that we want our model to take as input." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:55.159411Z", "iopub.status.busy": "2024-07-30T11:22:55.158911Z", "iopub.status.idle": "2024-07-30T11:22:55.167552Z", "shell.execute_reply": "2024-07-30T11:22:55.166888Z" }, "id": "rEYlSXhTEmoh" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_67252/285458577.py:8: DeprecationWarning: the `interpolation=` argument to quantile was renamed to `method=`, which has additional options.\n", "Users of the modes 'nearest', 'lower', 'higher', or 'midpoint' are encouraged to review the method they used. (Deprecated NumPy 1.22)\n", " pwl_calibration_input_keypoints=compute_quantiles(\n", "/tmpfs/tmp/ipykernel_67252/285458577.py:16: DeprecationWarning: the `interpolation=` argument to quantile was renamed to `method=`, which has additional options.\n", "Users of the modes 'nearest', 'lower', 'higher', or 'midpoint' are encouraged to review the method they used. (Deprecated NumPy 1.22)\n", " pwl_calibration_input_keypoints=compute_quantiles(\n", "/tmpfs/tmp/ipykernel_67252/285458577.py:34: DeprecationWarning: the `interpolation=` argument to quantile was renamed to `method=`, which has additional options.\n", "Users of the modes 'nearest', 'lower', 'higher', or 'midpoint' are encouraged to review the method they used. (Deprecated NumPy 1.22)\n", " pwl_calibration_input_keypoints=compute_quantiles(\n", "/tmpfs/tmp/ipykernel_67252/285458577.py:47: DeprecationWarning: the `interpolation=` argument to quantile was renamed to `method=`, which has additional options.\n", "Users of the modes 'nearest', 'lower', 'higher', or 'midpoint' are encouraged to review the method they used. (Deprecated NumPy 1.22)\n", " pwl_calibration_input_keypoints=compute_quantiles(\n" ] } ], "source": [ "# Feature configs are used to specify how each feature is calibrated and used.\n", "feature_configs = [\n", " tfl.configs.FeatureConfig(\n", " name='star_rating',\n", " lattice_size=2,\n", " monotonicity='increasing',\n", " pwl_calibration_num_keypoints=5,\n", " pwl_calibration_input_keypoints=compute_quantiles(\n", " flattened_features['star_rating'], num_keypoints=5),\n", " ),\n", " tfl.configs.FeatureConfig(\n", " name='word_count',\n", " lattice_size=2,\n", " monotonicity='increasing',\n", " pwl_calibration_num_keypoints=5,\n", " pwl_calibration_input_keypoints=compute_quantiles(\n", " flattened_features['word_count'], num_keypoints=5),\n", " ),\n", " tfl.configs.FeatureConfig(\n", " name='is_amazon',\n", " lattice_size=2,\n", " num_buckets=2,\n", " ),\n", " tfl.configs.FeatureConfig(\n", " name='includes_photo',\n", " lattice_size=2,\n", " num_buckets=2,\n", " ),\n", " tfl.configs.FeatureConfig(\n", " name='num_helpful',\n", " lattice_size=2,\n", " monotonicity='increasing',\n", " pwl_calibration_num_keypoints=5,\n", " pwl_calibration_input_keypoints=compute_quantiles(\n", " flattened_features['num_helpful'], num_keypoints=5),\n", " # Larger num_helpful indicating more trust in star_rating.\n", " reflects_trust_in=[\n", " tfl.configs.TrustConfig(\n", " feature_name=\"star_rating\", trust_type=\"trapezoid\"),\n", " ],\n", " ),\n", " tfl.configs.FeatureConfig(\n", " name='num_reviews',\n", " lattice_size=2,\n", " monotonicity='increasing',\n", " pwl_calibration_num_keypoints=5,\n", " pwl_calibration_input_keypoints=compute_quantiles(\n", " flattened_features['num_reviews'], num_keypoints=5),\n", " )\n", "]" ] }, { "cell_type": "markdown", "metadata": { "id": "9zoPJRBvPdcH" }, "source": [ "## Aggregate Function Model\n", "\n", "To construct a TFL premade model, first construct a model configuration from [tfl.configs](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs). An aggregate function model is constructed using the [tfl.configs.AggregateFunctionConfig](https://www.tensorflow.org/lattice/api_docs/python/tfl/configs/AggregateFunctionConfig). It applies piecewise-linear and categorical calibration, followed by a lattice model on each dimension of the ragged input. It then applies an aggregation layer over the output for each dimension. This is then followed by an optional output piecewise-linear calibration." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:55.171373Z", "iopub.status.busy": "2024-07-30T11:22:55.170835Z", "iopub.status.idle": "2024-07-30T11:22:59.989231Z", "shell.execute_reply": "2024-07-30T11:22:59.988393Z" }, "id": "l_4J7EjSPiP3" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Model config defines the model structure for the aggregate function model.\n", "aggregate_function_model_config = tfl.configs.AggregateFunctionConfig(\n", " feature_configs=feature_configs,\n", " middle_dimension=MIDDLE_DIM,\n", " middle_lattice_size=MIDDLE_LATTICE_SIZE,\n", " middle_calibration=True,\n", " middle_calibration_num_keypoints=MIDDLE_KEYPOINTS,\n", " middle_monotonicity='increasing',\n", " output_min=min_label,\n", " output_max=max_label,\n", " output_calibration=True,\n", " output_calibration_num_keypoints=OUTPUT_KEYPOINTS,\n", " output_initialization=np.linspace(\n", " min_label, max_label, num=OUTPUT_KEYPOINTS))\n", "# An AggregateFunction premade model constructed from the given model config.\n", "aggregate_function_model = tfl.premade.AggregateFunction(\n", " aggregate_function_model_config)\n", "# Let's plot our model.\n", "keras.utils.plot_model(\n", " aggregate_function_model, show_layer_names=False, rankdir='LR')" ] }, { "cell_type": "markdown", "metadata": { "id": "4F7AwiXgWhe2" }, "source": [ "The output of each Aggregation layer is the averaged output of a calibrated lattice over the ragged inputs. Here is the model used inside the first Aggregation layer:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:22:59.993926Z", "iopub.status.busy": "2024-07-30T11:22:59.993307Z", "iopub.status.idle": "2024-07-30T11:23:00.084438Z", "shell.execute_reply": "2024-07-30T11:23:00.083602Z" }, "id": "UM7XF6UIWo4T" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "aggregation_layers = [\n", " layer for layer in aggregate_function_model.layers\n", " if isinstance(layer, tfl.layers.Aggregation)\n", "]\n", "keras.utils.plot_model(\n", " aggregation_layers[0].model, show_layer_names=False, rankdir='LR')" ] }, { "cell_type": "markdown", "metadata": { "id": "0ohYOftgTZhq" }, "source": [ "Now, as with any other [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model), we compile and fit the model to our data." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:23:00.089044Z", "iopub.status.busy": "2024-07-30T11:23:00.088264Z", "iopub.status.idle": "2024-07-30T11:23:21.286871Z", "shell.execute_reply": "2024-07-30T11:23:21.286174Z" }, "id": "uB9di3-lTfMy" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "aggregate_function_model.compile(\n", " loss='mae',\n", " optimizer=keras.optimizers.Adam(LEARNING_RATE))\n", "aggregate_function_model.fit(\n", " train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "pwZtGDR-Tzur" }, "source": [ "After training our model, we can evaluate it on our test set." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2024-07-30T11:23:21.290783Z", "iopub.status.busy": "2024-07-30T11:23:21.290195Z", "iopub.status.idle": "2024-07-30T11:23:22.850932Z", "shell.execute_reply": "2024-07-30T11:23:22.850256Z" }, "id": "RWj1YfubT0NE" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test Set Evaluation...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/7 [===>..........................] - ETA: 9s - loss: 110.6181" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/7 [==============================] - 2s 3ms/step - loss: 50.3869\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "50.38688659667969\n" ] } ], "source": [ "print('Test Set Evaluation...')\n", "print(aggregate_function_model.evaluate(test_xs, test_ys))" ] } ], "metadata": { "colab": { "name": "aggregate_function_models.ipynb", "private_outputs": true, "provenance": [ { "file_id": "1ohMV9lhzSWZq3aH27fBAZ1Oj3wy19PI0", "timestamp": 1588637142053 } ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }