{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "X80i_girFR2o" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2022-12-14T12:16:59.353356Z", "iopub.status.busy": "2022-12-14T12:16:59.352837Z", "iopub.status.idle": "2022-12-14T12:16:59.356595Z", "shell.execute_reply": "2022-12-14T12:16:59.356067Z" }, "id": "bB8gHCR3FVC0" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "kCeYA79m1DEX" }, "source": [ "# Recommending movies: ranking\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "gf2jMHkZQYB5" }, "source": [ "Real-world recommender systems are often composed of two stages:\n", "\n", "1. The retrieval stage is responsible for selecting an initial set of hundreds of candidates from all possible candidates. The main objective of this model is to efficiently weed out all candidates that the user is not interested in. Because the retrieval model may be dealing with millions of candidates, it has to be computationally efficient.\n", "2. The ranking stage takes the outputs of the retrieval model and fine-tunes them to select the best possible handful of recommendations. Its task is to narrow down the set of items the user may be interested in to a shortlist of likely candidates.\n", "\n", "We're going to focus on the second stage, ranking. If you are interested in the retrieval stage, have a look at our [retrieval](basic_retrieval) tutorial.\n", "\n", "In this tutorial, we're going to:\n", "\n", "1. Get our data and split it into a training and test set.\n", "2. Implement a ranking model.\n", "3. Fit and evaluate it.\n", "\n", "\n", "## Imports\n", "\n", "\n", "Let's first get our imports out of the way." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:16:59.360049Z", "iopub.status.busy": "2022-12-14T12:16:59.359723Z", "iopub.status.idle": "2022-12-14T12:17:02.748600Z", "shell.execute_reply": "2022-12-14T12:17:02.747520Z" }, "id": "9gG3jLOGbaUv" }, "outputs": [], "source": [ "!pip install -q tensorflow-recommenders\n", "!pip install -q --upgrade tensorflow-datasets" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:02.753172Z", "iopub.status.busy": "2022-12-14T12:17:02.752481Z", "iopub.status.idle": "2022-12-14T12:17:05.122484Z", "shell.execute_reply": "2022-12-14T12:17:05.121813Z" }, "id": "SZGYDaF-m5wZ" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 12:17:03.715935: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 12:17:03.716032: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2022-12-14 12:17:03.716042: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import os\n", "import pprint\n", "import tempfile\n", "\n", "from typing import Dict, Text\n", "\n", "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:05.126358Z", "iopub.status.busy": "2022-12-14T12:17:05.125942Z", "iopub.status.idle": "2022-12-14T12:17:05.137364Z", "shell.execute_reply": "2022-12-14T12:17:05.136712Z" }, "id": "BxQ_hy7xPH3N" }, "outputs": [], "source": [ "import tensorflow_recommenders as tfrs" ] }, { "cell_type": "markdown", "metadata": { "id": "5PAqjR4a1RR4" }, "source": [ "## Preparing the dataset\n", "\n", "We're going to use the same data as the [retrieval](basic_retrieval) tutorial. This time, we're also going to keep the ratings: these are the objectives we are trying to predict." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:05.140718Z", "iopub.status.busy": "2022-12-14T12:17:05.140152Z", "iopub.status.idle": "2022-12-14T12:17:09.222468Z", "shell.execute_reply": "2022-12-14T12:17:09.221811Z" }, "id": "aaQhqcLGP0jL" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.\n", "Instructions for updating:\n", "Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.\n", "Instructions for updating:\n", "Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089\n" ] } ], "source": [ "ratings = tfds.load(\"movielens/100k-ratings\", split=\"train\")\n", "\n", "ratings = ratings.map(lambda x: {\n", " \"movie_title\": x[\"movie_title\"],\n", " \"user_id\": x[\"user_id\"],\n", " \"user_rating\": x[\"user_rating\"]\n", "})" ] }, { "cell_type": "markdown", "metadata": { "id": "Iu4XSa_G1nyN" }, "source": [ "As before, we'll split the data by putting 80% of the ratings in the train set, and 20% in the test set." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:09.225910Z", "iopub.status.busy": "2022-12-14T12:17:09.225557Z", "iopub.status.idle": "2022-12-14T12:17:09.239263Z", "shell.execute_reply": "2022-12-14T12:17:09.238653Z" }, "id": "rS0eDfkjnjJL" }, "outputs": [], "source": [ "tf.random.set_seed(42)\n", "shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)\n", "\n", "train = shuffled.take(80_000)\n", "test = shuffled.skip(80_000).take(20_000)" ] }, { "cell_type": "markdown", "metadata": { "id": "gVi1HJfR9D7H" }, "source": [ "Let's also figure out unique user ids and movie titles present in the data. \n", "\n", "This is important because we need to be able to map the raw values of our categorical features to embedding vectors in our models. To do that, we need a vocabulary that maps a raw feature value to an integer in a contiguous range: this allows us to look up the corresponding embeddings in our embedding tables." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:09.242577Z", "iopub.status.busy": "2022-12-14T12:17:09.242176Z", "iopub.status.idle": "2022-12-14T12:17:15.787770Z", "shell.execute_reply": "2022-12-14T12:17:15.787008Z" }, "id": "MKROCiPo_5LJ" }, "outputs": [], "source": [ "movie_titles = ratings.batch(1_000_000).map(lambda x: x[\"movie_title\"])\n", "user_ids = ratings.batch(1_000_000).map(lambda x: x[\"user_id\"])\n", "\n", "unique_movie_titles = np.unique(np.concatenate(list(movie_titles)))\n", "unique_user_ids = np.unique(np.concatenate(list(user_ids)))" ] }, { "cell_type": "markdown", "metadata": { "id": "4-Vj9nHb48pn" }, "source": [ "## Implementing a model" ] }, { "cell_type": "markdown", "metadata": { "id": "eCi-seR86qqa" }, "source": [ "### Architecture\n", "\n", "Ranking models do not face the same efficiency constraints as retrieval models do, and so we have a little bit more freedom in our choice of architectures.\n", "\n", "A model composed of multiple stacked dense layers is a relatively common architecture for ranking tasks. We can implement it as follows:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:15.792411Z", "iopub.status.busy": "2022-12-14T12:17:15.791782Z", "iopub.status.idle": "2022-12-14T12:17:15.798089Z", "shell.execute_reply": "2022-12-14T12:17:15.797522Z" }, "id": "fAk0y0Yf1eGh" }, "outputs": [], "source": [ "class RankingModel(tf.keras.Model):\n", "\n", " def __init__(self):\n", " super().__init__()\n", " embedding_dimension = 32\n", "\n", " # Compute embeddings for users.\n", " self.user_embeddings = tf.keras.Sequential([\n", " tf.keras.layers.StringLookup(\n", " vocabulary=unique_user_ids, mask_token=None),\n", " tf.keras.layers.Embedding(len(unique_user_ids) + 1, embedding_dimension)\n", " ])\n", "\n", " # Compute embeddings for movies.\n", " self.movie_embeddings = tf.keras.Sequential([\n", " tf.keras.layers.StringLookup(\n", " vocabulary=unique_movie_titles, mask_token=None),\n", " tf.keras.layers.Embedding(len(unique_movie_titles) + 1, embedding_dimension)\n", " ])\n", "\n", " # Compute predictions.\n", " self.ratings = tf.keras.Sequential([\n", " # Learn multiple dense layers.\n", " tf.keras.layers.Dense(256, activation=\"relu\"),\n", " tf.keras.layers.Dense(64, activation=\"relu\"),\n", " # Make rating predictions in the final layer.\n", " tf.keras.layers.Dense(1)\n", " ])\n", " \n", " def call(self, inputs):\n", "\n", " user_id, movie_title = inputs\n", "\n", " user_embedding = self.user_embeddings(user_id)\n", " movie_embedding = self.movie_embeddings(movie_title)\n", "\n", " return self.ratings(tf.concat([user_embedding, movie_embedding], axis=1))" ] }, { "cell_type": "markdown", "metadata": { "id": "g76wZt-s2WmS" }, "source": [ "This model takes user ids and movie titles, and outputs a predicted rating:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:15.801487Z", "iopub.status.busy": "2022-12-14T12:17:15.800914Z", "iopub.status.idle": "2022-12-14T12:17:16.216373Z", "shell.execute_reply": "2022-12-14T12:17:16.215745Z" }, "id": "YVxiAsRE2I8J" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs=['42']. Consider rewriting this model with the Functional API.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs=['42']. Consider rewriting this model with the Functional API.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs=[\"One Flew Over the Cuckoo's Nest (1975)\"]. Consider rewriting this model with the Functional API.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs=[\"One Flew Over the Cuckoo's Nest (1975)\"]. Consider rewriting this model with the Functional API.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "RankingModel()(([\"42\"], [\"One Flew Over the Cuckoo's Nest (1975)\"]))" ] }, { "cell_type": "markdown", "metadata": { "id": "nCaCqJsXSkCo" }, "source": [ "### Loss and metrics\n", "\n", "The next component is the loss used to train our model. TFRS has several loss layers and tasks to make this easy.\n", "\n", "In this instance, we'll make use of the `Ranking` task object: a convenience wrapper that bundles together the loss function and metric computation. \n", "\n", "We'll use it together with the `MeanSquaredError` Keras loss in order to predict the ratings." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:16.220332Z", "iopub.status.busy": "2022-12-14T12:17:16.219705Z", "iopub.status.idle": "2022-12-14T12:17:16.229607Z", "shell.execute_reply": "2022-12-14T12:17:16.228933Z" }, "id": "tJ61Iz2QTBw3" }, "outputs": [], "source": [ "task = tfrs.tasks.Ranking(\n", " loss = tf.keras.losses.MeanSquaredError(),\n", " metrics=[tf.keras.metrics.RootMeanSquaredError()]\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "9-3xFC-1cbz0" }, "source": [ "The task itself is a Keras layer that takes true and predicted as arguments, and returns the computed loss. We'll use that to implement the model's training loop." ] }, { "cell_type": "markdown", "metadata": { "id": "FZUFeSlWRHGx" }, "source": [ "### The full model\n", "\n", "We can now put it all together into a model. TFRS exposes a base model class (`tfrs.models.Model`) which streamlines bulding models: all we need to do is to set up the components in the `__init__` method, and implement the `compute_loss` method, taking in the raw features and returning a loss value.\n", "\n", "The base model will then take care of creating the appropriate training loop to fit our model." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:16.233213Z", "iopub.status.busy": "2022-12-14T12:17:16.232649Z", "iopub.status.idle": "2022-12-14T12:17:16.238007Z", "shell.execute_reply": "2022-12-14T12:17:16.237455Z" }, "id": "8n7c5CHFp0ow" }, "outputs": [], "source": [ "class MovielensModel(tfrs.models.Model):\n", "\n", " def __init__(self):\n", " super().__init__()\n", " self.ranking_model: tf.keras.Model = RankingModel()\n", " self.task: tf.keras.layers.Layer = tfrs.tasks.Ranking(\n", " loss = tf.keras.losses.MeanSquaredError(),\n", " metrics=[tf.keras.metrics.RootMeanSquaredError()]\n", " )\n", "\n", " def call(self, features: Dict[str, tf.Tensor]) -> tf.Tensor:\n", " return self.ranking_model(\n", " (features[\"user_id\"], features[\"movie_title\"]))\n", "\n", " def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:\n", " labels = features.pop(\"user_rating\")\n", " \n", " rating_predictions = self(features)\n", "\n", " # The task computes the loss and the metrics.\n", " return self.task(labels=labels, predictions=rating_predictions)" ] }, { "cell_type": "markdown", "metadata": { "id": "yDN_LJGlnRGo" }, "source": [ "## Fitting and evaluating\n", "\n", "After defining the model, we can use standard Keras fitting and evaluation routines to fit and evaluate the model.\n", "\n", "Let's first instantiate the model." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:16.241123Z", "iopub.status.busy": "2022-12-14T12:17:16.240693Z", "iopub.status.idle": "2022-12-14T12:17:16.273429Z", "shell.execute_reply": "2022-12-14T12:17:16.272854Z" }, "id": "aW63YaqP2wCf" }, "outputs": [], "source": [ "model = MovielensModel()\n", "model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))" ] }, { "cell_type": "markdown", "metadata": { "id": "Nma0vc2XdN5g" }, "source": [ "Then shuffle, batch, and cache the training and evaluation data." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:16.276458Z", "iopub.status.busy": "2022-12-14T12:17:16.276247Z", "iopub.status.idle": "2022-12-14T12:17:16.285651Z", "shell.execute_reply": "2022-12-14T12:17:16.285015Z" }, "id": "53QJwY1gUnfv" }, "outputs": [], "source": [ "cached_train = train.shuffle(100_000).batch(8192).cache()\n", "cached_test = test.batch(4096).cache()" ] }, { "cell_type": "markdown", "metadata": { "id": "u8mHTxKAdTJO" }, "source": [ "Then train the model:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:16.288813Z", "iopub.status.busy": "2022-12-14T12:17:16.288387Z", "iopub.status.idle": "2022-12-14T12:17:20.892207Z", "shell.execute_reply": "2022-12-14T12:17:20.891544Z" }, "id": "ZxPntlT8EFOZ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/10 [==>...........................] - ETA: 26s - root_mean_squared_error: 3.6917 - loss: 13.6284 - regularization_loss: 0.0000e+00 - total_loss: 13.6284" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 2/10 [=====>........................] - ETA: 1s - root_mean_squared_error: 3.2406 - loss: 10.5015 - regularization_loss: 0.0000e+00 - total_loss: 10.5015 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 3/10 [========>.....................] - ETA: 1s - root_mean_squared_error: 2.8550 - loss: 8.1512 - regularization_loss: 0.0000e+00 - total_loss: 8.1512 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 4/10 [===========>..................] - ETA: 1s - root_mean_squared_error: 2.8543 - loss: 8.1470 - regularization_loss: 0.0000e+00 - total_loss: 8.1470" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5/10 [==============>...............] - ETA: 1s - root_mean_squared_error: 2.6761 - loss: 7.1615 - regularization_loss: 0.0000e+00 - total_loss: 7.1615" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 6/10 [=================>............] - ETA: 0s - root_mean_squared_error: 2.4891 - loss: 6.1955 - regularization_loss: 0.0000e+00 - total_loss: 6.1955" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 7/10 [====================>.........] - ETA: 0s - root_mean_squared_error: 2.3489 - loss: 5.5174 - regularization_loss: 0.0000e+00 - total_loss: 5.5174" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/10 [==============================] - ETA: 0s - root_mean_squared_error: 2.0902 - loss: 4.2994 - regularization_loss: 0.0000e+00 - total_loss: 4.2994" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/10 [==============================] - 4s 166ms/step - root_mean_squared_error: 2.0902 - loss: 4.0368 - regularization_loss: 0.0000e+00 - total_loss: 4.0368\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/10 [==>...........................] - ETA: 0s - root_mean_squared_error: 1.1879 - loss: 1.4112 - regularization_loss: 0.0000e+00 - total_loss: 1.4112" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/10 [==============================] - 0s 4ms/step - root_mean_squared_error: 1.1613 - loss: 1.3426 - regularization_loss: 0.0000e+00 - total_loss: 1.3426\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/10 [==>...........................] - ETA: 0s - root_mean_squared_error: 1.1158 - loss: 1.2451 - regularization_loss: 0.0000e+00 - total_loss: 1.2451" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/10 [==============================] - 0s 4ms/step - root_mean_squared_error: 1.1140 - loss: 1.2414 - regularization_loss: 0.0000e+00 - total_loss: 1.2414\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(cached_train, epochs=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "YsluR8audV9W" }, "source": [ "As the model trains, the loss is falling and the RMSE metric is improving." ] }, { "cell_type": "markdown", "metadata": { "id": "7Gxp5RLFcv64" }, "source": [ "Finally, we can evaluate our model on the test set:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:20.895936Z", "iopub.status.busy": "2022-12-14T12:17:20.895377Z", "iopub.status.idle": "2022-12-14T12:17:22.471892Z", "shell.execute_reply": "2022-12-14T12:17:22.471070Z" }, "id": "W-zu6HLODNeI" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/5 [=====>........................] - ETA: 6s - root_mean_squared_error: 1.1041 - loss: 1.2191 - regularization_loss: 0.0000e+00 - total_loss: 1.2191" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/5 [==============================] - 2s 9ms/step - root_mean_squared_error: 1.1009 - loss: 1.2072 - regularization_loss: 0.0000e+00 - total_loss: 1.2072\n" ] }, { "data": { "text/plain": [ "{'root_mean_squared_error': 1.100862741470337,\n", " 'loss': 1.1866925954818726,\n", " 'regularization_loss': 0,\n", " 'total_loss': 1.1866925954818726}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(cached_test, return_dict=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "JKZyP9A1dxit" }, "source": [ "The lower the RMSE metric, the more accurate our model is at predicting ratings." ] }, { "cell_type": "markdown", "metadata": { "id": "hcK4WKmKTE3A" }, "source": [ "## Testing the ranking model\n", "\n", "Now we can test the ranking model by computing predictions for a set of movies and then rank these movies based on the predictions:\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:22.475717Z", "iopub.status.busy": "2022-12-14T12:17:22.475181Z", "iopub.status.idle": "2022-12-14T12:17:22.498033Z", "shell.execute_reply": "2022-12-14T12:17:22.497237Z" }, "id": "6oB5DzrsTTrA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ratings:\n", "Dances with Wolves (1990): [[3.539769]]\n", "M*A*S*H (1970): [[3.5356772]]\n", "Speed (1994): [[3.4501984]]\n" ] } ], "source": [ "test_ratings = {}\n", "test_movie_titles = [\"M*A*S*H (1970)\", \"Dances with Wolves (1990)\", \"Speed (1994)\"]\n", "for movie_title in test_movie_titles:\n", " test_ratings[movie_title] = model({\n", " \"user_id\": np.array([\"42\"]),\n", " \"movie_title\": np.array([movie_title])\n", " })\n", "\n", "print(\"Ratings:\")\n", "for title, score in sorted(test_ratings.items(), key=lambda x: x[1], reverse=True):\n", " print(f\"{title}: {score}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "hfedFnhBZiGw" }, "source": [ "## Exporting for serving\n", "\n", "The model can be easily exported for serving:\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:22.501457Z", "iopub.status.busy": "2022-12-14T12:17:22.501048Z", "iopub.status.idle": "2022-12-14T12:17:23.897316Z", "shell.execute_reply": "2022-12-14T12:17:23.896417Z" }, "id": "qjLDKn5VZqm8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((IndexedSlicesSpec(TensorShape([None, 32]), tf.float32, tf.int64, tf.int32, TensorShape([None])), , 140544556312496), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((IndexedSlicesSpec(TensorShape([None, 32]), tf.float32, tf.int64, tf.int32, TensorShape([None])), , 140544556312496), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((IndexedSlicesSpec(TensorShape([None, 32]), tf.float32, tf.int64, tf.int32, TensorShape([None])), , 140544555465184), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((IndexedSlicesSpec(TensorShape([None, 32]), tf.float32, tf.int64, tf.int32, TensorShape([None])), , 140544555465184), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 256), dtype=tf.float32, name='gradient'), , 140544555465504), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 256), dtype=tf.float32, name='gradient'), , 140544555465504), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), , 140544555464064), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), , 140544555464064), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 64), dtype=tf.float32, name='gradient'), , 140544555331632), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 64), dtype=tf.float32, name='gradient'), , 140544555331632), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), , 140544555331952), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), , 140544555331952), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), , 140544555334032), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), , 140544555334032), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), , 140544555334272), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), , 140544555334272), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((IndexedSlicesSpec(TensorShape([None, 32]), tf.float32, tf.int64, tf.int32, TensorShape([None])), , 140544556312496), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((IndexedSlicesSpec(TensorShape([None, 32]), tf.float32, tf.int64, tf.int32, TensorShape([None])), , 140544556312496), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((IndexedSlicesSpec(TensorShape([None, 32]), tf.float32, tf.int64, tf.int32, TensorShape([None])), , 140544555465184), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((IndexedSlicesSpec(TensorShape([None, 32]), tf.float32, tf.int64, tf.int32, TensorShape([None])), , 140544555465184), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 256), dtype=tf.float32, name='gradient'), , 140544555465504), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 256), dtype=tf.float32, name='gradient'), , 140544555465504), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), , 140544555464064), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256,), dtype=tf.float32, name='gradient'), , 140544555464064), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 64), dtype=tf.float32, name='gradient'), , 140544555331632), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(256, 64), dtype=tf.float32, name='gradient'), , 140544555331632), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), , 140544555331952), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64,), dtype=tf.float32, name='gradient'), , 140544555331952), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), , 140544555334032), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(64, 1), dtype=tf.float32, name='gradient'), , 140544555334032), {}).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), , 140544555334272), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), , 140544555334272), {}).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as ranking_1_layer_call_fn, ranking_1_layer_call_and_return_conditional_losses, _update_step_xla while saving (showing 3 of 3). These functions will not be directly callable after loading.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: export/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: export/assets\n" ] } ], "source": [ "tf.saved_model.save(model, \"export\")" ] }, { "cell_type": "markdown", "metadata": { "id": "sia3ezFPZy1v" }, "source": [ "We can now load it back and perform predictions:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:23.902073Z", "iopub.status.busy": "2022-12-14T12:17:23.901486Z", "iopub.status.idle": "2022-12-14T12:17:24.354853Z", "shell.execute_reply": "2022-12-14T12:17:24.353999Z" }, "id": "owetAuj0Z1ny" }, "outputs": [ { "data": { "text/plain": [ "array([[3.4501984]], dtype=float32)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loaded = tf.saved_model.load(\"export\")\n", "\n", "loaded({\"user_id\": np.array([\"42\"]), \"movie_title\": [\"Speed (1994)\"]}).numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "Qwj7cOOwjGzO" }, "source": [ "## Convert the model to TensorFLow Lite\n", "\n", "Although TensorFlow Recommenders is primarily designed to perform server-side recommendations, you can still convert the trained ranking model to TensorFLow Lite and run it on-device (for better user privacy privacy and lower latency).\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:24.358735Z", "iopub.status.busy": "2022-12-14T12:17:24.358065Z", "iopub.status.idle": "2022-12-14T12:17:25.204455Z", "shell.execute_reply": "2022-12-14T12:17:25.203619Z" }, "id": "YEDKKtuQjDb3" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-14 12:17:24.837136: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.\n", "2022-12-14 12:17:24.837175: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.\n" ] }, { "data": { "text/plain": [ "544480" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "converter = tf.lite.TFLiteConverter.from_saved_model(\"export\")\n", "tflite_model = converter.convert()\n", "open(\"converted_model.tflite\", \"wb\").write(tflite_model)" ] }, { "cell_type": "markdown", "metadata": { "id": "XLbzmVMLjuIy" }, "source": [ "Once the model is converted, you can run it like regular TensorFlow Lite models. Please check out [TensorFlow Lite documentation](https://www.tensorflow.org/lite/guide/inference) to learn more." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2022-12-14T12:17:25.208080Z", "iopub.status.busy": "2022-12-14T12:17:25.207405Z", "iopub.status.idle": "2022-12-14T12:17:25.216032Z", "shell.execute_reply": "2022-12-14T12:17:25.215257Z" }, "id": "RASq6xAvjEiH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[3.450199]]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO: Created TensorFlow Lite XNNPACK delegate for CPU.\n" ] } ], "source": [ "interpreter = tf.lite.Interpreter(model_path=\"converted_model.tflite\")\n", "interpreter.allocate_tensors()\n", "\n", "# Get input and output tensors.\n", "input_details = interpreter.get_input_details()\n", "output_details = interpreter.get_output_details()\n", "\n", "# Test the model.\n", "if input_details[0][\"name\"] == \"serving_default_movie_title:0\":\n", " interpreter.set_tensor(input_details[0][\"index\"], np.array([\"Speed (1994)\"]))\n", " interpreter.set_tensor(input_details[1][\"index\"], np.array([\"42\"]))\n", "else:\n", " interpreter.set_tensor(input_details[0][\"index\"], np.array([\"42\"]))\n", " interpreter.set_tensor(input_details[1][\"index\"], np.array([\"Speed (1994)\"]))\n", "\n", "interpreter.invoke()\n", "\n", "rating = interpreter.get_tensor(output_details[0]['index'])\n", "print(rating)" ] }, { "cell_type": "markdown", "metadata": { "id": "efApI0Ii6srB" }, "source": [ "## Next steps\n", "\n", "The model above gives us a decent start towards building a ranking system.\n", "\n", "Of course, making a practical ranking system requires much more effort.\n", "\n", "In most cases, a ranking model can be substantially improved by using more features rather than just user and candidate identifiers. To see how to do that, have a look at the [side features](featurization) tutorial.\n", "\n", "A careful understanding of the objectives worth optimizing is also necessary. To get started on building a recommender that optimizes multiple objectives, have a look at our [multitask](multitask) tutorial." ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "basic_ranking.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 0 }