Recommend movies for users with TensorFlow Ranking

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

In this tutorial, we build a simple two tower ranking model using the MovieLens 100K dataset with TF-Ranking. We can use this model to rank and recommend movies for a given user according to their predicted user ratings.

Setup

Install and import the TF-Ranking library:

pip install -q tensorflow-ranking
pip install -q --upgrade tensorflow-datasets
from typing import Dict, Tuple

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_ranking as tfr
2024-03-19 11:34:49.704174: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-19 11:34:49.704225: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-19 11:34:49.705795: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Read the data

Prepare to train a model by creating a ratings dataset and movies dataset. Use user_id as the query input feature, movie_title as the document input feature, and user_rating as the label to train the ranking model.

%%capture --no-display
# Ratings data.
ratings = tfds.load('movielens/100k-ratings', split="train")
# Features of all the available movies.
movies = tfds.load('movielens/100k-movies', split="train")

# Select the basic features.
ratings = ratings.map(lambda x: {
    "movie_title": x["movie_title"],
    "user_id": x["user_id"],
    "user_rating": x["user_rating"]
})
2024-03-19 11:34:53.385017: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Build vocabularies to convert all user ids and all movie titles into integer indices for embedding layers:

movies = movies.map(lambda x: x["movie_title"])
users = ratings.map(lambda x: x["user_id"])

user_ids_vocabulary = tf.keras.layers.experimental.preprocessing.StringLookup(
    mask_token=None)
user_ids_vocabulary.adapt(users.batch(1000))

movie_titles_vocabulary = tf.keras.layers.experimental.preprocessing.StringLookup(
    mask_token=None)
movie_titles_vocabulary.adapt(movies.batch(1000))

Group by user_id to form lists for ranking models:

key_func = lambda x: user_ids_vocabulary(x["user_id"])
reduce_func = lambda key, dataset: dataset.batch(100)
ds_train = ratings.group_by_window(
    key_func=key_func, reduce_func=reduce_func, window_size=100)
for x in ds_train.take(1):
  for key, value in x.items():
    print(f"Shape of {key}: {value.shape}")
    print(f"Example values of {key}: {value[:5].numpy()}")
    print()
Shape of movie_title: (100,)
Example values of movie_title: [b'Man Who Would Be King, The (1975)' b'Silence of the Lambs, The (1991)'
 b'Next Karate Kid, The (1994)' b'2001: A Space Odyssey (1968)'
 b'Usual Suspects, The (1995)']

Shape of user_id: (100,)
Example values of user_id: [b'405' b'405' b'405' b'405' b'405']

Shape of user_rating: (100,)
Example values of user_rating: [1. 4. 1. 5. 5.]

Generate batched features and labels:

def _features_and_labels(
    x: Dict[str, tf.Tensor]) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
  labels = x.pop("user_rating")
  return x, labels


ds_train = ds_train.map(_features_and_labels)

ds_train = ds_train.apply(
    tf.data.experimental.dense_to_ragged_batch(batch_size=32))
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_12750/4021484596.py:10: dense_to_ragged_batch (from tensorflow.python.data.experimental.ops.batching) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.ragged_batch` instead.

The user_id and movie_title tensors generated in ds_train are of shape [32, None], where the second dimension is 100 in most cases except for the batches when less than 100 items grouped in lists. A model working on ragged tensors is thus used.

for x, label in ds_train.take(1):
  for key, value in x.items():
    print(f"Shape of {key}: {value.shape}")
    print(f"Example values of {key}: {value[:3, :3].numpy()}")
    print()
  print(f"Shape of label: {label.shape}")
  print(f"Example values of label: {label[:3, :3].numpy()}")
Shape of movie_title: (32, None)
Example values of movie_title: [[b'Man Who Would Be King, The (1975)'
  b'Silence of the Lambs, The (1991)' b'Next Karate Kid, The (1994)']
 [b'Flower of My Secret, The (Flor de mi secreto, La) (1995)'
  b'Little Princess, The (1939)' b'Time to Kill, A (1996)']
 [b'Kundun (1997)' b'Scream (1996)' b'Power 98 (1995)']]

Shape of user_id: (32, None)
Example values of user_id: [[b'405' b'405' b'405']
 [b'655' b'655' b'655']
 [b'13' b'13' b'13']]

Shape of label: (32, None)
Example values of label: [[1. 4. 1.]
 [3. 3. 3.]
 [5. 1. 1.]]

Define a model

Define a ranking model by inheriting from tf.keras.Model and implementing the call method:

class MovieLensRankingModel(tf.keras.Model):

  def __init__(self, user_vocab, movie_vocab):
    super().__init__()

    # Set up user and movie vocabulary and embedding.
    self.user_vocab = user_vocab
    self.movie_vocab = movie_vocab
    self.user_embed = tf.keras.layers.Embedding(user_vocab.vocabulary_size(),
                                                64)
    self.movie_embed = tf.keras.layers.Embedding(movie_vocab.vocabulary_size(),
                                                 64)

  def call(self, features: Dict[str, tf.Tensor]) -> tf.Tensor:
    # Define how the ranking scores are computed: 
    # Take the dot-product of the user embeddings with the movie embeddings.

    user_embeddings = self.user_embed(self.user_vocab(features["user_id"]))
    movie_embeddings = self.movie_embed(
        self.movie_vocab(features["movie_title"]))

    return tf.reduce_sum(user_embeddings * movie_embeddings, axis=2)

Create the model, and then compile it with ranking tfr.keras.losses and tfr.keras.metrics, which are the core of the TF-Ranking package.

This example uses a ranking-specific softmax loss, which is a listwise loss introduced to promote all relevant items in the ranking list with better chances on top of the irrelevant ones. In contrast to the softmax loss in the multi-class classification problem, where only one class is positive and the rest are negative, the TF-Ranking library supports multiple relevant documents in a query list and non-binary relevance labels.

For ranking metrics, this example uses in specific Normalized Discounted Cumulative Gain (NDCG) and Mean Reciprocal Rank (MRR), which calculate the user utility of a ranked query list with position discounts. For more details about ranking metrics, review evaluation measures offline metrics.

# Create the ranking model, trained with a ranking loss and evaluated with
# ranking metrics.
model = MovieLensRankingModel(user_ids_vocabulary, movie_titles_vocabulary)
optimizer = tf.keras.optimizers.Adagrad(0.5)
loss = tfr.keras.losses.get(
    loss=tfr.keras.losses.RankingLossKey.SOFTMAX_LOSS, ragged=True)
eval_metrics = [
    tfr.keras.metrics.get(key="ndcg", name="metric/ndcg", ragged=True),
    tfr.keras.metrics.get(key="mrr", name="metric/mrr", ragged=True)
]
model.compile(optimizer=optimizer, loss=loss, metrics=eval_metrics)

Train and evaluate the model

Train the model with model.fit.

model.fit(ds_train, epochs=3)
Epoch 1/3
48/48 [==============================] - 7s 56ms/step - loss: 998.7637 - metric/ndcg: 0.8213 - metric/mrr: 1.0000
Epoch 2/3
48/48 [==============================] - 4s 53ms/step - loss: 997.1824 - metric/ndcg: 0.9161 - metric/mrr: 1.0000
Epoch 3/3
48/48 [==============================] - 4s 53ms/step - loss: 994.8384 - metric/ndcg: 0.9383 - metric/mrr: 1.0000
<keras.src.callbacks.History at 0x7f666424d700>

Generate predictions and evaluate.

# Get movie title candidate list.
for movie_titles in movies.batch(2000):
  break

# Generate the input for user 42.
inputs = {
    "user_id":
        tf.expand_dims(tf.repeat("42", repeats=movie_titles.shape[0]), axis=0),
    "movie_title":
        tf.expand_dims(movie_titles, axis=0)
}

# Get movie recommendations for user 42.
scores = model(inputs)
titles = tfr.utils.sort_by_scores(scores,
                                  [tf.expand_dims(movie_titles, axis=0)])[0]
print(f"Top 5 recommendations for user 42: {titles[0, :5]}")
Top 5 recommendations for user 42: [b'Star Wars (1977)' b'Liar Liar (1997)' b'Toy Story (1995)'
 b'Raiders of the Lost Ark (1981)' b'Sound of Music, The (1965)']