{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "_dEaVsqSgNyQ" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:34:21.595264Z", "iopub.status.busy": "2024-03-19T11:34:21.594985Z", "iopub.status.idle": "2024-03-19T11:34:21.599146Z", "shell.execute_reply": "2024-03-19T11:34:21.598488Z" }, "id": "4FyfuZX-gTKS" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "sT8AyHRMNh41" }, "source": [ "# Recommend movies for users with TensorFlow Ranking\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "8f-reQ11gbLB" }, "source": [ "In this tutorial, we build a simple two tower ranking model using the [MovieLens 100K dataset](https://grouplens.org/datasets/movielens/100k/) with TF-Ranking. We can use this model to rank and recommend movies for a given user according to their predicted user ratings." ] }, { "cell_type": "markdown", "metadata": { "id": "qA00wBE2Ntdm" }, "source": [ "## Setup\n", "\n", "Install and import the TF-Ranking library:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:34:21.602793Z", "iopub.status.busy": "2024-03-19T11:34:21.602283Z", "iopub.status.idle": "2024-03-19T11:34:49.220904Z", "shell.execute_reply": "2024-03-19T11:34:49.219735Z" }, "id": "6yzAaM85Z12D" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n", "tf-keras 2.16.0 requires tensorflow<2.17,>=2.16, but you have tensorflow 2.15.1 which is incompatible.\u001b[0m\u001b[31m\r\n", "\u001b[0m" ] } ], "source": [ "!pip install -q tensorflow-ranking\n", "!pip install -q --upgrade tensorflow-datasets" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:34:49.225468Z", "iopub.status.busy": "2024-03-19T11:34:49.225153Z", "iopub.status.idle": "2024-03-19T11:34:52.504946Z", "shell.execute_reply": "2024-03-19T11:34:52.503816Z" }, "id": "n3oYt3R6Nr9l" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-03-19 11:34:49.704174: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-03-19 11:34:49.704225: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-03-19 11:34:49.705795: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "from typing import Dict, Tuple\n", "\n", "import tensorflow as tf\n", "\n", "import tensorflow_datasets as tfds\n", "import tensorflow_ranking as tfr" ] }, { "cell_type": "markdown", "metadata": { "id": "zCxQ1CZcO2wh" }, "source": [ "## Read the data" ] }, { "cell_type": "markdown", "metadata": { "id": "A0sY6-Rtt_Co" }, "source": [ "Prepare to train a model by creating a ratings dataset and movies dataset. Use `user_id` as the query input feature, `movie_title` as the document input feature, and `user_rating` as the label to train the ranking model." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:34:52.509622Z", "iopub.status.busy": "2024-03-19T11:34:52.509144Z", "iopub.status.idle": "2024-03-19T11:34:54.215772Z", "shell.execute_reply": "2024-03-19T11:34:54.214821Z" }, "id": "M-mxBYjdO5m7" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-03-19 11:34:53.385017: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n" ] } ], "source": [ "%%capture --no-display\n", "# Ratings data.\n", "ratings = tfds.load('movielens/100k-ratings', split=\"train\")\n", "# Features of all the available movies.\n", "movies = tfds.load('movielens/100k-movies', split=\"train\")\n", "\n", "# Select the basic features.\n", "ratings = ratings.map(lambda x: {\n", " \"movie_title\": x[\"movie_title\"],\n", " \"user_id\": x[\"user_id\"],\n", " \"user_rating\": x[\"user_rating\"]\n", "})" ] }, { "cell_type": "markdown", "metadata": { "id": "5W0HSfmSNCWm" }, "source": [ "Build vocabularies to convert all user ids and all movie titles into integer indices for embedding layers:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:34:54.220528Z", "iopub.status.busy": "2024-03-19T11:34:54.219873Z", "iopub.status.idle": "2024-03-19T11:34:56.813110Z", "shell.execute_reply": "2024-03-19T11:34:56.812268Z" }, "id": "9I1VTEjHzpfX" }, "outputs": [], "source": [ "movies = movies.map(lambda x: x[\"movie_title\"])\n", "users = ratings.map(lambda x: x[\"user_id\"])\n", "\n", "user_ids_vocabulary = tf.keras.layers.experimental.preprocessing.StringLookup(\n", " mask_token=None)\n", "user_ids_vocabulary.adapt(users.batch(1000))\n", "\n", "movie_titles_vocabulary = tf.keras.layers.experimental.preprocessing.StringLookup(\n", " mask_token=None)\n", "movie_titles_vocabulary.adapt(movies.batch(1000))" ] }, { "cell_type": "markdown", "metadata": { "id": "zMsmoqWTOTKo" }, "source": [ "Group by `user_id` to form lists for ranking models:\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:34:56.817807Z", "iopub.status.busy": "2024-03-19T11:34:56.817133Z", "iopub.status.idle": "2024-03-19T11:34:56.854424Z", "shell.execute_reply": "2024-03-19T11:34:56.853687Z" }, "id": "lXY7kX7nOSwH" }, "outputs": [], "source": [ "key_func = lambda x: user_ids_vocabulary(x[\"user_id\"])\n", "reduce_func = lambda key, dataset: dataset.batch(100)\n", "ds_train = ratings.group_by_window(\n", " key_func=key_func, reduce_func=reduce_func, window_size=100)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:34:56.858574Z", "iopub.status.busy": "2024-03-19T11:34:56.857931Z", "iopub.status.idle": "2024-03-19T11:34:57.326088Z", "shell.execute_reply": "2024-03-19T11:34:57.325304Z" }, "id": "57r87tdQlkcT" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape of movie_title: (100,)\n", "Example values of movie_title: [b'Man Who Would Be King, The (1975)' b'Silence of the Lambs, The (1991)'\n", " b'Next Karate Kid, The (1994)' b'2001: A Space Odyssey (1968)'\n", " b'Usual Suspects, The (1995)']\n", "\n", "Shape of user_id: (100,)\n", "Example values of user_id: [b'405' b'405' b'405' b'405' b'405']\n", "\n", "Shape of user_rating: (100,)\n", "Example values of user_rating: [1. 4. 1. 5. 5.]\n", "\n" ] } ], "source": [ "for x in ds_train.take(1):\n", " for key, value in x.items():\n", " print(f\"Shape of {key}: {value.shape}\")\n", " print(f\"Example values of {key}: {value[:5].numpy()}\")\n", " print()" ] }, { "cell_type": "markdown", "metadata": { "id": "YcZJf2qxOeWU" }, "source": [ "Generate batched features and labels:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:34:57.330162Z", "iopub.status.busy": "2024-03-19T11:34:57.329450Z", "iopub.status.idle": "2024-03-19T11:34:57.381947Z", "shell.execute_reply": "2024-03-19T11:34:57.381178Z" }, "id": "ctq2RTOqOfAo" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_12750/4021484596.py:10: dense_to_ragged_batch (from tensorflow.python.data.experimental.ops.batching) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use `tf.data.Dataset.ragged_batch` instead.\n" ] } ], "source": [ "def _features_and_labels(\n", " x: Dict[str, tf.Tensor]) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:\n", " labels = x.pop(\"user_rating\")\n", " return x, labels\n", "\n", "\n", "ds_train = ds_train.map(_features_and_labels)\n", "\n", "ds_train = ds_train.apply(\n", " tf.data.experimental.dense_to_ragged_batch(batch_size=32))" ] }, { "cell_type": "markdown", "metadata": { "id": "RJUU3mv-_VdQ" }, "source": [ "The `user_id` and `movie_title` tensors generated in `ds_train` are of shape `[32, None]`, where the second dimension is 100 in most cases except for the batches when less than 100 items grouped in lists. A model working on ragged tensors is thus used." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:34:57.385828Z", "iopub.status.busy": "2024-03-19T11:34:57.385533Z", "iopub.status.idle": "2024-03-19T11:34:58.410928Z", "shell.execute_reply": "2024-03-19T11:34:58.409898Z" }, "id": "GTquqk1GkIfd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape of movie_title: (32, None)\n", "Example values of movie_title: [[b'Man Who Would Be King, The (1975)'\n", " b'Silence of the Lambs, The (1991)' b'Next Karate Kid, The (1994)']\n", " [b'Flower of My Secret, The (Flor de mi secreto, La) (1995)'\n", " b'Little Princess, The (1939)' b'Time to Kill, A (1996)']\n", " [b'Kundun (1997)' b'Scream (1996)' b'Power 98 (1995)']]\n", "\n", "Shape of user_id: (32, None)\n", "Example values of user_id: [[b'405' b'405' b'405']\n", " [b'655' b'655' b'655']\n", " [b'13' b'13' b'13']]\n", "\n", "Shape of label: (32, None)\n", "Example values of label: [[1. 4. 1.]\n", " [3. 3. 3.]\n", " [5. 1. 1.]]\n" ] } ], "source": [ "for x, label in ds_train.take(1):\n", " for key, value in x.items():\n", " print(f\"Shape of {key}: {value.shape}\")\n", " print(f\"Example values of {key}: {value[:3, :3].numpy()}\")\n", " print()\n", " print(f\"Shape of label: {label.shape}\")\n", " print(f\"Example values of label: {label[:3, :3].numpy()}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Lrch6rVBOB9Q" }, "source": [ "## Define a model\n", "\n", "Define a ranking model by inheriting from `tf.keras.Model` and implementing the `call` method:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:34:58.415106Z", "iopub.status.busy": "2024-03-19T11:34:58.414364Z", "iopub.status.idle": "2024-03-19T11:34:58.420571Z", "shell.execute_reply": "2024-03-19T11:34:58.419802Z" }, "id": "e5dNbDZwOIHR" }, "outputs": [], "source": [ "class MovieLensRankingModel(tf.keras.Model):\n", "\n", " def __init__(self, user_vocab, movie_vocab):\n", " super().__init__()\n", "\n", " # Set up user and movie vocabulary and embedding.\n", " self.user_vocab = user_vocab\n", " self.movie_vocab = movie_vocab\n", " self.user_embed = tf.keras.layers.Embedding(user_vocab.vocabulary_size(),\n", " 64)\n", " self.movie_embed = tf.keras.layers.Embedding(movie_vocab.vocabulary_size(),\n", " 64)\n", "\n", " def call(self, features: Dict[str, tf.Tensor]) -> tf.Tensor:\n", " # Define how the ranking scores are computed: \n", " # Take the dot-product of the user embeddings with the movie embeddings.\n", "\n", " user_embeddings = self.user_embed(self.user_vocab(features[\"user_id\"]))\n", " movie_embeddings = self.movie_embed(\n", " self.movie_vocab(features[\"movie_title\"]))\n", "\n", " return tf.reduce_sum(user_embeddings * movie_embeddings, axis=2)" ] }, { "cell_type": "markdown", "metadata": { "id": "BMV0HpzmJGWk" }, "source": [ "Create the model, and then compile it with ranking `tfr.keras.losses` and `tfr.keras.metrics`, which are the core of the TF-Ranking package. \n", "\n", "This example uses a ranking-specific **softmax loss**, which is a listwise loss introduced to promote all relevant items in the ranking list with better chances on top of the irrelevant ones. In contrast to the softmax loss in the multi-class classification problem, where only one class is positive and the rest are negative, the TF-Ranking library supports multiple relevant documents in a query list and non-binary relevance labels.\n", "\n", "For ranking metrics, this example uses in specific **Normalized Discounted Cumulative Gain (NDCG)** and **Mean Reciprocal Rank (MRR)**, which calculate the user utility of a ranked query list with position discounts. For more details about ranking metrics, review evaluation measures [offline metrics](https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Offline_metrics)." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:34:58.424051Z", "iopub.status.busy": "2024-03-19T11:34:58.423563Z", "iopub.status.idle": "2024-03-19T11:34:58.457080Z", "shell.execute_reply": "2024-03-19T11:34:58.456368Z" }, "id": "H2tQDhqkOKf1" }, "outputs": [], "source": [ "# Create the ranking model, trained with a ranking loss and evaluated with\n", "# ranking metrics.\n", "model = MovieLensRankingModel(user_ids_vocabulary, movie_titles_vocabulary)\n", "optimizer = tf.keras.optimizers.Adagrad(0.5)\n", "loss = tfr.keras.losses.get(\n", " loss=tfr.keras.losses.RankingLossKey.SOFTMAX_LOSS, ragged=True)\n", "eval_metrics = [\n", " tfr.keras.metrics.get(key=\"ndcg\", name=\"metric/ndcg\", ragged=True),\n", " tfr.keras.metrics.get(key=\"mrr\", name=\"metric/mrr\", ragged=True)\n", "]\n", "model.compile(optimizer=optimizer, loss=loss, metrics=eval_metrics)" ] }, { "cell_type": "markdown", "metadata": { "id": "NeBnBFMfVLzP" }, "source": [ "## Train and evaluate the model\n", "\n", "Train the model with `model.fit`." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:34:58.461501Z", "iopub.status.busy": "2024-03-19T11:34:58.460904Z", "iopub.status.idle": "2024-03-19T11:35:13.416826Z", "shell.execute_reply": "2024-03-19T11:35:13.415859Z" }, "id": "bzGm7WqSVNyP" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/Unknown - 5s 5s/step - loss: 1477.8344 - metric/ndcg: 0.7807 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 2/Unknown - 5s 273ms/step - loss: 1527.9211 - metric/ndcg: 0.7945 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 3/Unknown - 5s 222ms/step - loss: 1564.8604 - metric/ndcg: 0.8016 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 4/Unknown - 5s 200ms/step - loss: 1584.4081 - metric/ndcg: 0.8058 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5/Unknown - 5s 184ms/step - loss: 1575.7542 - metric/ndcg: 0.8057 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 6/Unknown - 6s 176ms/step - loss: 1579.3936 - metric/ndcg: 0.8049 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 7/Unknown - 6s 167ms/step - loss: 1584.7850 - metric/ndcg: 0.8059 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 8/Unknown - 6s 161ms/step - loss: 1593.7064 - metric/ndcg: 0.8073 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9/Unknown - 6s 152ms/step - loss: 1590.5203 - metric/ndcg: 0.8075 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 10/Unknown - 6s 148ms/step - loss: 1597.3547 - metric/ndcg: 0.8104 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 11/Unknown - 6s 142ms/step - loss: 1599.4286 - metric/ndcg: 0.8108 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 12/Unknown - 6s 140ms/step - loss: 1604.5621 - metric/ndcg: 0.8123 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 13/Unknown - 6s 141ms/step - loss: 1605.0614 - metric/ndcg: 0.8127 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 14/Unknown - 6s 140ms/step - loss: 1605.7787 - metric/ndcg: 0.8136 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 15/Unknown - 7s 143ms/step - loss: 1606.2090 - metric/ndcg: 0.8142 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 16/Unknown - 7s 140ms/step - loss: 1606.6212 - metric/ndcg: 0.8144 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 17/Unknown - 7s 137ms/step - loss: 1603.9541 - metric/ndcg: 0.8141 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 18/Unknown - 7s 134ms/step - loss: 1601.2360 - metric/ndcg: 0.8149 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 19/Unknown - 7s 131ms/step - loss: 1579.0735 - metric/ndcg: 0.8142 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 30/Unknown - 7s 83ms/step - loss: 1231.8003 - metric/ndcg: 0.8279 - metric/mrr: 1.0000 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 39/Unknown - 7s 65ms/step - loss: 1152.8252 - metric/ndcg: 0.8254 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "48/48 [==============================] - 7s 56ms/step - loss: 998.7637 - metric/ndcg: 0.8213 - metric/mrr: 1.0000\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/48 [..............................] - ETA: 52s - loss: 1476.4519 - metric/ndcg: 0.9138 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 2/48 [>.............................] - ETA: 8s - loss: 1526.5396 - metric/ndcg: 0.9113 - metric/mrr: 1.0000 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 3/48 [>.............................] - ETA: 7s - loss: 1563.4860 - metric/ndcg: 0.9130 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 4/48 [=>............................] - ETA: 6s - loss: 1583.0378 - metric/ndcg: 0.9161 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5/48 [==>...........................] - ETA: 6s - loss: 1574.3459 - metric/ndcg: 0.9128 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 6/48 [==>...........................] - ETA: 6s - loss: 1577.8793 - metric/ndcg: 0.9145 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 7/48 [===>..........................] - ETA: 5s - loss: 1583.1188 - metric/ndcg: 0.9151 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 8/48 [====>.........................] - ETA: 5s - loss: 1591.9490 - metric/ndcg: 0.9154 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9/48 [====>.........................] - ETA: 5s - loss: 1588.6942 - metric/ndcg: 0.9136 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/48 [=====>........................] - ETA: 5s - loss: 1595.4871 - metric/ndcg: 0.9148 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "11/48 [=====>........................] - ETA: 5s - loss: 1597.5161 - metric/ndcg: 0.9142 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "12/48 [======>.......................] - ETA: 4s - loss: 1602.5718 - metric/ndcg: 0.9158 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "13/48 [=======>......................] - ETA: 4s - loss: 1603.0027 - metric/ndcg: 0.9161 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "14/48 [=======>......................] - ETA: 4s - loss: 1603.6213 - metric/ndcg: 0.9167 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "15/48 [========>.....................] - ETA: 4s - loss: 1604.0188 - metric/ndcg: 0.9169 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "16/48 [=========>....................] - ETA: 4s - loss: 1604.3984 - metric/ndcg: 0.9169 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "17/48 [=========>....................] - ETA: 4s - loss: 1601.7072 - metric/ndcg: 0.9167 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "18/48 [==========>...................] - ETA: 4s - loss: 1598.9357 - metric/ndcg: 0.9160 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "19/48 [==========>...................] - ETA: 3s - loss: 1576.7738 - metric/ndcg: 0.9151 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "30/48 [=================>............] - ETA: 1s - loss: 1229.7483 - metric/ndcg: 0.9149 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "38/48 [======================>.......] - ETA: 0s - loss: 1167.9546 - metric/ndcg: 0.9173 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "48/48 [==============================] - 4s 53ms/step - loss: 997.1824 - metric/ndcg: 0.9161 - metric/mrr: 1.0000\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/48 [..............................] - ETA: 51s - loss: 1470.2378 - metric/ndcg: 0.9325 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 2/48 [>.............................] - ETA: 10s - loss: 1521.0977 - metric/ndcg: 0.9251 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 3/48 [>.............................] - ETA: 8s - loss: 1558.6129 - metric/ndcg: 0.9254 - metric/mrr: 1.0000 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 4/48 [=>............................] - ETA: 7s - loss: 1578.3770 - metric/ndcg: 0.9288 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 5/48 [==>...........................] - ETA: 7s - loss: 1569.9446 - metric/ndcg: 0.9263 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 6/48 [==>...........................] - ETA: 6s - loss: 1573.5231 - metric/ndcg: 0.9279 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 7/48 [===>..........................] - ETA: 6s - loss: 1578.7963 - metric/ndcg: 0.9286 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 8/48 [====>.........................] - ETA: 6s - loss: 1587.7528 - metric/ndcg: 0.9300 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 9/48 [====>.........................] - ETA: 5s - loss: 1584.5347 - metric/ndcg: 0.9284 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/48 [=====>........................] - ETA: 5s - loss: 1591.4133 - metric/ndcg: 0.9295 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "11/48 [=====>........................] - ETA: 5s - loss: 1593.4933 - metric/ndcg: 0.9291 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "12/48 [======>.......................] - ETA: 5s - loss: 1598.6293 - metric/ndcg: 0.9303 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "13/48 [=======>......................] - ETA: 4s - loss: 1599.0730 - metric/ndcg: 0.9308 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "14/48 [=======>......................] - ETA: 4s - loss: 1599.6893 - metric/ndcg: 0.9315 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "15/48 [========>.....................] - ETA: 4s - loss: 1600.1823 - metric/ndcg: 0.9321 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "16/48 [=========>....................] - ETA: 4s - loss: 1600.5929 - metric/ndcg: 0.9325 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "17/48 [=========>....................] - ETA: 4s - loss: 1597.9266 - metric/ndcg: 0.9320 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "18/48 [==========>...................] - ETA: 4s - loss: 1595.1353 - metric/ndcg: 0.9318 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "19/48 [==========>...................] - ETA: 3s - loss: 1573.0365 - metric/ndcg: 0.9319 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "30/48 [=================>............] - ETA: 1s - loss: 1226.9624 - metric/ndcg: 0.9334 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "39/48 [=======================>......] - ETA: 0s - loss: 1148.3666 - metric/ndcg: 0.9364 - metric/mrr: 1.0000" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "48/48 [==============================] - 4s 53ms/step - loss: 994.8384 - metric/ndcg: 0.9383 - metric/mrr: 1.0000\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(ds_train, epochs=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "V5uuSRXZoOKW" }, "source": [ "Generate predictions and evaluate." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-03-19T11:35:13.421260Z", "iopub.status.busy": "2024-03-19T11:35:13.420612Z", "iopub.status.idle": "2024-03-19T11:35:13.505324Z", "shell.execute_reply": "2024-03-19T11:35:13.504340Z" }, "id": "6Hryvj3cPnvK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Top 5 recommendations for user 42: [b'Star Wars (1977)' b'Liar Liar (1997)' b'Toy Story (1995)'\n", " b'Raiders of the Lost Ark (1981)' b'Sound of Music, The (1965)']\n" ] } ], "source": [ "# Get movie title candidate list.\n", "for movie_titles in movies.batch(2000):\n", " break\n", "\n", "# Generate the input for user 42.\n", "inputs = {\n", " \"user_id\":\n", " tf.expand_dims(tf.repeat(\"42\", repeats=movie_titles.shape[0]), axis=0),\n", " \"movie_title\":\n", " tf.expand_dims(movie_titles, axis=0)\n", "}\n", "\n", "# Get movie recommendations for user 42.\n", "scores = model(inputs)\n", "titles = tfr.utils.sort_by_scores(scores,\n", " [tf.expand_dims(movie_titles, axis=0)])[0]\n", "print(f\"Top 5 recommendations for user 42: {titles[0, :5]}\")" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "quickstart.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }