##### Copyright 2020 The TensorFlow Authors.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Multi-task recommenders

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/recommenders/examples/multitask"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/recommenders/blob/main/docs/examples/multitask.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/recommenders/blob/main/docs/examples/multitask.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/recommenders/docs/examples/multitask.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

In the [basic retrieval tutorial](basic_retrieval) we built a retrieval system using movie watches as positive interaction signals.

In many applications, however, there are multiple rich sources of feedback to draw upon. For example, an e-commerce site may record user visits to product pages (abundant, but relatively low signal), image clicks, adding to cart, and, finally, purchases. It may even record post-purchase signals such as reviews and returns.

Integrating all these different forms of feedback is critical to building systems that users love to use, and that do not optimize for any one metric at the expense of overall performance.

In addition, building a joint model for multiple tasks may produce better results than building a number of task-specific models. This is especially true where some data is abundant (for example, clicks), and some data is sparse (purchases, returns, manual reviews). In those scenarios, a joint model may be able to use representations learned from the abundant task to improve its predictions on the sparse task via a phenomenon known as [transfer learning](https://en.wikipedia.org/wiki/Transfer_learning). For example, [this paper](https://openreview.net/pdf?id=SJxPVcSonN) shows that a model predicting explicit user ratings from sparse user surveys can be substantially improved by adding an auxiliary task that uses abundant click log data.

In this tutorial, we are going to build a multi-objective recommender for Movielens, using both implicit (movie watches) and explicit signals (ratings).

## Imports


Let's first get our imports out of the way.


In [2]:
!pip install -q tensorflow-recommenders
!pip install -q --upgrade tensorflow-datasets

In [3]:
import os
import pprint
import tempfile

from typing import Dict, Text

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

2022-12-14 12:23:34.727681: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:23:34.727787: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [4]:
import tensorflow_recommenders as tfrs

## Preparing the dataset

We're going to use the Movielens 100K dataset.

In [5]:
ratings = tfds.load('movielens/100k-ratings', split="train")
movies = tfds.load('movielens/100k-movies', split="train")

# Select the basic features.
ratings = ratings.map(lambda x: {
    "movie_title": x["movie_title"],
    "user_id": x["user_id"],
    "user_rating": x["user_rating"],
})
movies = movies.map(lambda x: x["movie_title"])

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


And repeat our preparations for building vocabularies and splitting the data into a train and a test set:

In [6]:
# Randomly shuffle data and split between train and test.
tf.random.set_seed(42)
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)

train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)

movie_titles = movies.batch(1_000)
user_ids = ratings.batch(1_000_000).map(lambda x: x["user_id"])

unique_movie_titles = np.unique(np.concatenate(list(movie_titles)))
unique_user_ids = np.unique(np.concatenate(list(user_ids)))

## A multi-task model

There are two critical parts to multi-task recommenders:

1. They optimize for two or more objectives, and so have two or more losses.
2. They share variables between the tasks, allowing for transfer learning.

In this tutorial, we will define our models as before, but instead of having  a single task, we will have two tasks: one that predicts ratings, and one that predicts movie watches.

The user and movie models are as before:

```python
user_model = tf.keras.Sequential([
  tf.keras.layers.StringLookup(
      vocabulary=unique_user_ids, mask_token=None),
  # We add 1 to account for the unknown token.
  tf.keras.layers.Embedding(len(unique_user_ids) + 1, embedding_dimension)
])

movie_model = tf.keras.Sequential([
  tf.keras.layers.StringLookup(
      vocabulary=unique_movie_titles, mask_token=None),
  tf.keras.layers.Embedding(len(unique_movie_titles) + 1, embedding_dimension)
])
```

However, now we will have two tasks. The first is the rating task:

```python
tfrs.tasks.Ranking(
    loss=tf.keras.losses.MeanSquaredError(),
    metrics=[tf.keras.metrics.RootMeanSquaredError()],
)
```

Its goal is to predict the ratings as accurately as possible.

The second is the retrieval task:

```python
tfrs.tasks.Retrieval(
    metrics=tfrs.metrics.FactorizedTopK(
        candidates=movies.batch(128)
    )
)
```

As before, this task's goal is to predict which movies the user will or will not watch.

### Putting it together

We put it all together in a model class.

The new component here is that - since we have two tasks and two losses - we need to decide on how important each loss is. We can do this by giving each of the losses a weight, and treating these weights as hyperparameters. If we assign a large loss weight to the rating task, our model is going to focus on predicting ratings (but still use some information from the retrieval task); if we assign a large loss weight to the retrieval task, it will focus on retrieval instead.

In [7]:
class MovielensModel(tfrs.models.Model):

  def __init__(self, rating_weight: float, retrieval_weight: float) -> None:
    # We take the loss weights in the constructor: this allows us to instantiate
    # several model objects with different loss weights.

    super().__init__()

    embedding_dimension = 32

    # User and movie models.
    self.movie_model: tf.keras.layers.Layer = tf.keras.Sequential([
      tf.keras.layers.StringLookup(
        vocabulary=unique_movie_titles, mask_token=None),
      tf.keras.layers.Embedding(len(unique_movie_titles) + 1, embedding_dimension)
    ])
    self.user_model: tf.keras.layers.Layer = tf.keras.Sequential([
      tf.keras.layers.StringLookup(
        vocabulary=unique_user_ids, mask_token=None),
      tf.keras.layers.Embedding(len(unique_user_ids) + 1, embedding_dimension)
    ])

    # A small model to take in user and movie embeddings and predict ratings.
    # We can make this as complicated as we want as long as we output a scalar
    # as our prediction.
    self.rating_model = tf.keras.Sequential([
        tf.keras.layers.Dense(256, activation="relu"),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dense(1),
    ])

    # The tasks.
    self.rating_task: tf.keras.layers.Layer = tfrs.tasks.Ranking(
        loss=tf.keras.losses.MeanSquaredError(),
        metrics=[tf.keras.metrics.RootMeanSquaredError()],
    )
    self.retrieval_task: tf.keras.layers.Layer = tfrs.tasks.Retrieval(
        metrics=tfrs.metrics.FactorizedTopK(
            candidates=movies.batch(128).map(self.movie_model)
        )
    )

    # The loss weights.
    self.rating_weight = rating_weight
    self.retrieval_weight = retrieval_weight

  def call(self, features: Dict[Text, tf.Tensor]) -> tf.Tensor:
    # We pick out the user features and pass them into the user model.
    user_embeddings = self.user_model(features["user_id"])
    # And pick out the movie features and pass them into the movie model.
    movie_embeddings = self.movie_model(features["movie_title"])
    
    return (
        user_embeddings,
        movie_embeddings,
        # We apply the multi-layered rating model to a concatentation of
        # user and movie embeddings.
        self.rating_model(
            tf.concat([user_embeddings, movie_embeddings], axis=1)
        ),
    )

  def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:

    ratings = features.pop("user_rating")

    user_embeddings, movie_embeddings, rating_predictions = self(features)

    # We compute the loss for each task.
    rating_loss = self.rating_task(
        labels=ratings,
        predictions=rating_predictions,
    )
    retrieval_loss = self.retrieval_task(user_embeddings, movie_embeddings)

    # And combine them using the loss weights.
    return (self.rating_weight * rating_loss
            + self.retrieval_weight * retrieval_loss)

### Rating-specialized model

Depending on the weights we assign, the model will encode a different balance of the tasks. Let's start with a model that only considers ratings.

In [8]:
model = MovielensModel(rating_weight=1.0, retrieval_weight=0.0)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

In [9]:
cached_train = train.shuffle(100_000).batch(8192).cache()
cached_test = test.batch(4096).cache()

In [10]:
model.fit(cached_train, epochs=3)
metrics = model.evaluate(cached_test, return_dict=True)

print(f"Retrieval top-100 accuracy: {metrics['factorized_top_k/top_100_categorical_accuracy']:.3f}.")
print(f"Ranking RMSE: {metrics['root_mean_squared_error']:.3f}.")

Epoch 1/3


 1/10 [==>...........................] - ETA: 41s - root_mean_squared_error: 3.6920 - factorized_top_k/top_1_categorical_accuracy: 3.6621e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0021 - factorized_top_k/top_10_categorical_accuracy: 0.0054 - factorized_top_k/top_50_categorical_accuracy: 0.0297 - factorized_top_k/top_100_categorical_accuracy: 0.0596 - loss: 13.6308 - regularization_loss: 0.0000e+00 - total_loss: 13.6308

 2/10 [=====>........................] - ETA: 2s - root_mean_squared_error: 3.2088 - factorized_top_k/top_1_categorical_accuracy: 3.0518e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0023 - factorized_top_k/top_10_categorical_accuracy: 0.0062 - factorized_top_k/top_50_categorical_accuracy: 0.0303 - factorized_top_k/top_100_categorical_accuracy: 0.0585 - loss: 10.2962 - regularization_loss: 0.0000e+00 - total_loss: 10.2962 



















Epoch 2/3


 1/10 [==>...........................] - ETA: 2s - root_mean_squared_error: 1.1171 - factorized_top_k/top_1_categorical_accuracy: 3.6621e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0021 - factorized_top_k/top_10_categorical_accuracy: 0.0051 - factorized_top_k/top_50_categorical_accuracy: 0.0297 - factorized_top_k/top_100_categorical_accuracy: 0.0596 - loss: 1.2478 - regularization_loss: 0.0000e+00 - total_loss: 1.2478

 2/10 [=====>........................] - ETA: 2s - root_mean_squared_error: 1.1216 - factorized_top_k/top_1_categorical_accuracy: 3.0518e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0022 - factorized_top_k/top_10_categorical_accuracy: 0.0059 - factorized_top_k/top_50_categorical_accuracy: 0.0301 - factorized_top_k/top_100_categorical_accuracy: 0.0586 - loss: 1.2581 - regularization_loss: 0.0000e+00 - total_loss: 1.2581



















Epoch 3/3


 1/10 [==>...........................] - ETA: 2s - root_mean_squared_error: 1.1127 - factorized_top_k/top_1_categorical_accuracy: 2.4414e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0021 - factorized_top_k/top_10_categorical_accuracy: 0.0054 - factorized_top_k/top_50_categorical_accuracy: 0.0303 - factorized_top_k/top_100_categorical_accuracy: 0.0594 - loss: 1.2382 - regularization_loss: 0.0000e+00 - total_loss: 1.2382

 2/10 [=====>........................] - ETA: 2s - root_mean_squared_error: 1.1171 - factorized_top_k/top_1_categorical_accuracy: 2.4414e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0020 - factorized_top_k/top_10_categorical_accuracy: 0.0059 - factorized_top_k/top_50_categorical_accuracy: 0.0302 - factorized_top_k/top_100_categorical_accuracy: 0.0587 - loss: 1.2480 - regularization_loss: 0.0000e+00 - total_loss: 1.2480



















1/5 [=====>........................] - ETA: 7s - root_mean_squared_error: 1.1162 - factorized_top_k/top_1_categorical_accuracy: 4.8828e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0032 - factorized_top_k/top_10_categorical_accuracy: 0.0054 - factorized_top_k/top_50_categorical_accuracy: 0.0254 - factorized_top_k/top_100_categorical_accuracy: 0.0605 - loss: 1.2459 - regularization_loss: 0.0000e+00 - total_loss: 1.2459











Retrieval top-100 accuracy: 0.060.
Ranking RMSE: 1.113.


The model does OK on predicting ratings (with an RMSE of around 1.11), but performs poorly at predicting which movies will be watched or not: its accuracy at 100 is almost 4 times worse than a model trained solely to predict watches.

### Retrieval-specialized model

Let's now try a model that focuses on retrieval only.

In [11]:
model = MovielensModel(rating_weight=0.0, retrieval_weight=1.0)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

In [12]:
model.fit(cached_train, epochs=3)
metrics = model.evaluate(cached_test, return_dict=True)

print(f"Retrieval top-100 accuracy: {metrics['factorized_top_k/top_100_categorical_accuracy']:.3f}.")
print(f"Ranking RMSE: {metrics['root_mean_squared_error']:.3f}.")

Epoch 1/3


 1/10 [==>...........................] - ETA: 14s - root_mean_squared_error: 3.6977 - factorized_top_k/top_1_categorical_accuracy: 4.8828e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0020 - factorized_top_k/top_10_categorical_accuracy: 0.0048 - factorized_top_k/top_50_categorical_accuracy: 0.0299 - factorized_top_k/top_100_categorical_accuracy: 0.0602 - loss: 73818.1250 - regularization_loss: 0.0000e+00 - total_loss: 73818.1250

 2/10 [=====>........................] - ETA: 2s - root_mean_squared_error: 3.7078 - factorized_top_k/top_1_categorical_accuracy: 7.3242e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0044 - factorized_top_k/top_10_categorical_accuracy: 0.0096 - factorized_top_k/top_50_categorical_accuracy: 0.0409 - factorized_top_k/top_100_categorical_accuracy: 0.0742 - loss: 73819.0781 - regularization_loss: 0.0000e+00 - total_loss: 73819.0781 



















Epoch 2/3


 1/10 [==>...........................] - ETA: 2s - root_mean_squared_error: 3.6878 - factorized_top_k/top_1_categorical_accuracy: 2.4414e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0055 - factorized_top_k/top_10_categorical_accuracy: 0.0140 - factorized_top_k/top_50_categorical_accuracy: 0.1040 - factorized_top_k/top_100_categorical_accuracy: 0.2114 - loss: 71794.9609 - regularization_loss: 0.0000e+00 - total_loss: 71794.9609

 2/10 [=====>........................] - ETA: 2s - root_mean_squared_error: 3.6989 - factorized_top_k/top_1_categorical_accuracy: 6.1035e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0084 - factorized_top_k/top_10_categorical_accuracy: 0.0181 - factorized_top_k/top_50_categorical_accuracy: 0.1132 - factorized_top_k/top_100_categorical_accuracy: 0.2273 - loss: 71635.9883 - regularization_loss: 0.0000e+00 - total_loss: 71635.9883



















Epoch 3/3


 1/10 [==>...........................] - ETA: 2s - root_mean_squared_error: 3.6859 - factorized_top_k/top_1_categorical_accuracy: 6.1035e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0137 - factorized_top_k/top_10_categorical_accuracy: 0.0288 - factorized_top_k/top_50_categorical_accuracy: 0.1566 - factorized_top_k/top_100_categorical_accuracy: 0.2805 - loss: 70024.9688 - regularization_loss: 0.0000e+00 - total_loss: 70024.9688

 2/10 [=====>........................] - ETA: 2s - root_mean_squared_error: 3.6969 - factorized_top_k/top_1_categorical_accuracy: 0.0016 - factorized_top_k/top_5_categorical_accuracy: 0.0161 - factorized_top_k/top_10_categorical_accuracy: 0.0332 - factorized_top_k/top_50_categorical_accuracy: 0.1632 - factorized_top_k/top_100_categorical_accuracy: 0.2882 - loss: 69953.2734 - regularization_loss: 0.0000e+00 - total_loss: 69953.2734    



















1/5 [=====>........................] - ETA: 2s - root_mean_squared_error: 3.6898 - factorized_top_k/top_1_categorical_accuracy: 0.0012 - factorized_top_k/top_5_categorical_accuracy: 0.0081 - factorized_top_k/top_10_categorical_accuracy: 0.0195 - factorized_top_k/top_50_categorical_accuracy: 0.1238 - factorized_top_k/top_100_categorical_accuracy: 0.2363 - loss: 32480.8496 - regularization_loss: 0.0000e+00 - total_loss: 32480.8496











Retrieval top-100 accuracy: 0.233.
Ranking RMSE: 3.688.


We get the opposite result: a model that does well on retrieval, but poorly on predicting ratings.

### Joint model

Let's now train a model that assigns positive weights to both tasks.

In [13]:
model = MovielensModel(rating_weight=1.0, retrieval_weight=1.0)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

In [14]:
model.fit(cached_train, epochs=3)
metrics = model.evaluate(cached_test, return_dict=True)

print(f"Retrieval top-100 accuracy: {metrics['factorized_top_k/top_100_categorical_accuracy']:.3f}.")
print(f"Ranking RMSE: {metrics['root_mean_squared_error']:.3f}.")

Epoch 1/3


 1/10 [==>...........................] - ETA: 13s - root_mean_squared_error: 3.7058 - factorized_top_k/top_1_categorical_accuracy: 3.6621e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0020 - factorized_top_k/top_10_categorical_accuracy: 0.0054 - factorized_top_k/top_50_categorical_accuracy: 0.0277 - factorized_top_k/top_100_categorical_accuracy: 0.0558 - loss: 73831.4688 - regularization_loss: 0.0000e+00 - total_loss: 73831.4688

 2/10 [=====>........................] - ETA: 2s - root_mean_squared_error: 3.3079 - factorized_top_k/top_1_categorical_accuracy: 6.1035e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0034 - factorized_top_k/top_10_categorical_accuracy: 0.0082 - factorized_top_k/top_50_categorical_accuracy: 0.0374 - factorized_top_k/top_100_categorical_accuracy: 0.0695 - loss: 73829.6445 - regularization_loss: 0.0000e+00 - total_loss: 73829.6445 



















Epoch 2/3


 1/10 [==>...........................] - ETA: 2s - root_mean_squared_error: 2.3682 - factorized_top_k/top_1_categorical_accuracy: 7.3242e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0083 - factorized_top_k/top_10_categorical_accuracy: 0.0161 - factorized_top_k/top_50_categorical_accuracy: 0.1001 - factorized_top_k/top_100_categorical_accuracy: 0.2144 - loss: 71739.6562 - regularization_loss: 0.0000e+00 - total_loss: 71739.6562

 2/10 [=====>........................] - ETA: 2s - root_mean_squared_error: 1.8633 - factorized_top_k/top_1_categorical_accuracy: 0.0010 - factorized_top_k/top_5_categorical_accuracy: 0.0090 - factorized_top_k/top_10_categorical_accuracy: 0.0193 - factorized_top_k/top_50_categorical_accuracy: 0.1142 - factorized_top_k/top_100_categorical_accuracy: 0.2304 - loss: 71578.6289 - regularization_loss: 0.0000e+00 - total_loss: 71578.6289    



















Epoch 3/3


 1/10 [==>...........................] - ETA: 2s - root_mean_squared_error: 1.2133 - factorized_top_k/top_1_categorical_accuracy: 6.1035e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0149 - factorized_top_k/top_10_categorical_accuracy: 0.0319 - factorized_top_k/top_50_categorical_accuracy: 0.1605 - factorized_top_k/top_100_categorical_accuracy: 0.2759 - loss: 70038.7500 - regularization_loss: 0.0000e+00 - total_loss: 70038.7500

 2/10 [=====>........................] - ETA: 2s - root_mean_squared_error: 1.2771 - factorized_top_k/top_1_categorical_accuracy: 0.0015 - factorized_top_k/top_5_categorical_accuracy: 0.0177 - factorized_top_k/top_10_categorical_accuracy: 0.0347 - factorized_top_k/top_50_categorical_accuracy: 0.1646 - factorized_top_k/top_100_categorical_accuracy: 0.2874 - loss: 69957.9883 - regularization_loss: 0.0000e+00 - total_loss: 69957.9883    



















1/5 [=====>........................] - ETA: 2s - root_mean_squared_error: 1.1232 - factorized_top_k/top_1_categorical_accuracy: 7.3242e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0076 - factorized_top_k/top_10_categorical_accuracy: 0.0210 - factorized_top_k/top_50_categorical_accuracy: 0.1272 - factorized_top_k/top_100_categorical_accuracy: 0.2361 - loss: 32467.7871 - regularization_loss: 0.0000e+00 - total_loss: 32467.7871











Retrieval top-100 accuracy: 0.235.
Ranking RMSE: 1.110.


The result is a model that performs roughly as well on both tasks as each specialized model.

### Making prediction

We can use the trained multitask model to get trained user and movie embeddings, as well as the predicted rating:

In [15]:
trained_movie_embeddings, trained_user_embeddings, predicted_rating = model({
      "user_id": np.array(["42"]),
      "movie_title": np.array(["Dances with Wolves (1990)"])
  })
print("Predicted rating:")
print(predicted_rating)

Predicted rating:
tf.Tensor([[4.604047]], shape=(1, 1), dtype=float32)


While the results here do not show a clear accuracy benefit from a joint model in this case, multi-task learning is in general an extremely useful tool. We can expect better results when we can transfer knowledge from a data-abundant task (such as clicks) to a closely related data-sparse task (such as purchases).