Recommending movies: retrieval using a sequential model

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

In this tutorial, we are going to build a sequential retrieval model. Sequential recommendation is a popular model that looks at a sequence of items that users have interacted with previously and then predicts the next item. Here the order of the items within each sequence matters, so we are going to use a recurrent neural network to model the sequential relationship. For more details, please refer to this GRU4Rec paper.

Imports

First let's get our dependencies and imports out of the way.

pip install -q tensorflow-recommenders
pip install -q --upgrade tensorflow-datasets
import os
import pprint
import tempfile

from typing import Dict, Text

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_recommenders as tfrs
2022-12-14 12:39:47.708842: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:39:47.708940: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:39:47.708949: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Preparing the dataset

Next, we need to prepare our dataset. We are going to leverage the data generation utility in this TensorFlow Lite On-device Recommendation reference app.

MovieLens 1M data contains ratings.dat (columns: UserID, MovieID, Rating, Timestamp), and movies.dat (columns: MovieID, Title, Genres). The example generation script download the 1M dataset, takes both files, only keep ratings higher than 2, form user movie interaction timelines, sample activities as labels and 10 previous user activities as the context for prediction.

wget -nc https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/recommendation/ml/data/example_generation_movielens.py
python -m example_generation_movielens  --data_dir=data/raw  --output_dir=data/examples  --min_timeline_length=3  --max_context_length=10  --max_context_movie_genre_length=10  --min_rating=2  --train_data_fraction=0.9  --build_vocabs=False
--2022-12-14 12:39:49--  https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/recommendation/ml/data/example_generation_movielens.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18041 (18K) [text/plain]
Saving to: ‘example_generation_movielens.py’

example_generation_ 100%[===================>]  17.62K  --.-KB/s    in 0.001s  

2022-12-14 12:39:49 (18.6 MB/s) - ‘example_generation_movielens.py’ saved [18041/18041]

2022-12-14 12:39:50.711022: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:39:50.711113: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:39:50.711126: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
I1214 12:39:51.542600 139789676263232 example_generation_movielens.py:460] Downloading and extracting data.
Downloading data from https://files.grouplens.org/datasets/movielens/ml-1m.zip
1826816/5917549 [========>.....................] - ETA: 0s5917549/5917549 [==============================] - 0s 0us/step
I1214 12:39:52.073689 139789676263232 example_generation_movielens.py:406] Reading data to dataframes.
/tmpfs/src/temp/docs/examples/example_generation_movielens.py:132: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
  ratings_df = pd.read_csv(
/tmpfs/src/temp/docs/examples/example_generation_movielens.py:140: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
  movies_df = pd.read_csv(
I1214 12:39:56.795858 139789676263232 example_generation_movielens.py:408] Generating movie rating user timelines.
I1214 12:39:59.942978 139789676263232 example_generation_movielens.py:410] Generating train and test examples.
6040/6040 [==============================] - 72s 12ms/step
93799/93799 [==============================] - 2s 17us/step
I1214 12:41:38.911691 139789676263232 example_generation_movielens.py:473] Generated dataset: {'train_size': 844195, 'test_size': 93799, 'train_file': 'data/examples/train_movielens_1m.tfrecord', 'test_file': 'data/examples/test_movielens_1m.tfrecord'}

Here is a sample of the generated dataset.

0 : {
  features: {
    feature: {
      key  : "context_movie_id"
      value: { int64_list: { value: [ 1124, 2240, 3251, ..., 1268 ] } }
    }
    feature: {
      key  : "context_movie_rating"
      value: { float_list: {value: [ 3.0, 3.0, 4.0, ..., 3.0 ] } }
    }
    feature: {
      key  : "context_movie_year"
      value: { int64_list: { value: [ 1981, 1980, 1985, ..., 1990 ] } }
    }
    feature: {
      key  : "context_movie_genre"
      value: { bytes_list: { value: [ "Drama", "Drama", "Mystery", ..., "UNK" ] } }
    }
    feature: {
      key  : "label_movie_id"
      value: { int64_list: { value: [ 3252 ] }  }
    }
  }
}

You can see that it includes a sequence of context movie IDs, and a label movie ID (next movie), plus context features such as movie year, rating and genre.

In our case we will only be using the sequence of context movie IDs and the label movie ID. You can refer to the Leveraging context features tutorial to learn more about adding additional context features.

train_filename = "./data/examples/train_movielens_1m.tfrecord"
train = tf.data.TFRecordDataset(train_filename)

test_filename = "./data/examples/test_movielens_1m.tfrecord"
test = tf.data.TFRecordDataset(test_filename)

feature_description = {
    'context_movie_id': tf.io.FixedLenFeature([10], tf.int64, default_value=np.repeat(0, 10)),
    'context_movie_rating': tf.io.FixedLenFeature([10], tf.float32, default_value=np.repeat(0, 10)),
    'context_movie_year': tf.io.FixedLenFeature([10], tf.int64, default_value=np.repeat(1980, 10)),
    'context_movie_genre': tf.io.FixedLenFeature([10], tf.string, default_value=np.repeat("Drama", 10)),
    'label_movie_id': tf.io.FixedLenFeature([1], tf.int64, default_value=0),
}

def _parse_function(example_proto):
  return tf.io.parse_single_example(example_proto, feature_description)

train_ds = train.map(_parse_function).map(lambda x: {
    "context_movie_id": tf.strings.as_string(x["context_movie_id"]),
    "label_movie_id": tf.strings.as_string(x["label_movie_id"])
})

test_ds = test.map(_parse_function).map(lambda x: {
    "context_movie_id": tf.strings.as_string(x["context_movie_id"]),
    "label_movie_id": tf.strings.as_string(x["label_movie_id"])
})

for x in train_ds.take(1).as_numpy_iterator():
  pprint.pprint(x)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
{'context_movie_id': array([b'908', b'1086', b'1252', b'2871', b'3551', b'593', b'247', b'608',
       b'1358', b'866'], dtype=object),
 'label_movie_id': array([b'190'], dtype=object)}

Now our train/test datasets include only a sequence of historical movie IDs and a label of next movie ID. Note that we use [10] as the shape of the features during tf.Example parsing because we specify 10 as the length of context features in the example generateion step.

We need one more thing before we can start building the model - the vocabulary for our movie IDs.

movies = tfds.load("movielens/1m-movies", split='train')
movies = movies.map(lambda x: x["movie_id"])
movie_ids = movies.batch(1_000)
unique_movie_ids = np.unique(np.concatenate(list(movie_ids)))

Implementing a sequential model

In our basic retrieval tutorial, we use one query tower for the user, and the candidate tow for the candidate movie. However, the two-tower architecture is generalizble and not limited to pair. You can also use it to do item-to-item recommendation as we note in the basic retrieval tutorial.

Here we are still going to use the two-tower architecture. Specificially, we use the query tower with a Gated Recurrent Unit (GRU) layer to encode the sequence of historical movies, and keep the same candidate tower for the candidate movie.

embedding_dimension = 32

query_model = tf.keras.Sequential([
    tf.keras.layers.StringLookup(
      vocabulary=unique_movie_ids, mask_token=None),
    tf.keras.layers.Embedding(len(unique_movie_ids) + 1, embedding_dimension), 
    tf.keras.layers.GRU(embedding_dimension),
])

candidate_model = tf.keras.Sequential([
  tf.keras.layers.StringLookup(
      vocabulary=unique_movie_ids, mask_token=None),
  tf.keras.layers.Embedding(len(unique_movie_ids) + 1, embedding_dimension)
])

The metrics, task and full model are defined similar to the basic retrieval model.

metrics = tfrs.metrics.FactorizedTopK(
  candidates=movies.batch(128).map(candidate_model)
)

task = tfrs.tasks.Retrieval(
  metrics=metrics
)

class Model(tfrs.Model):

    def __init__(self, query_model, candidate_model):
        super().__init__()
        self._query_model = query_model
        self._candidate_model = candidate_model

        self._task = task

    def compute_loss(self, features, training=False):
        watch_history = features["context_movie_id"]
        watch_next_label = features["label_movie_id"]

        query_embedding = self._query_model(watch_history)       
        candidate_embedding = self._candidate_model(watch_next_label)

        return self._task(query_embedding, candidate_embedding, compute_metrics=not training)

Fitting and evaluating

We can now compile, train and evaluate our sequential retrieval model.

model = Model(query_model, candidate_model)
model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))
cached_train = train_ds.shuffle(10_000).batch(12800).cache()
cached_test = test_ds.batch(2560).cache()
model.fit(cached_train, epochs=3)
Epoch 1/3
67/67 [==============================] - 18s 220ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 108359.6299 - regularization_loss: 0.0000e+00 - total_loss: 108359.6299
Epoch 2/3
67/67 [==============================] - 3s 38ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 101734.1007 - regularization_loss: 0.0000e+00 - total_loss: 101734.1007
Epoch 3/3
67/67 [==============================] - 3s 38ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 99763.0675 - regularization_loss: 0.0000e+00 - total_loss: 99763.0675
<keras.callbacks.History at 0x7fe89c26edf0>
model.evaluate(cached_test, return_dict=True)
37/37 [==============================] - 10s 221ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0144 - factorized_top_k/top_5_categorical_accuracy: 0.0775 - factorized_top_k/top_10_categorical_accuracy: 0.1347 - factorized_top_k/top_50_categorical_accuracy: 0.3712 - factorized_top_k/top_100_categorical_accuracy: 0.5030 - loss: 15530.8248 - regularization_loss: 0.0000e+00 - total_loss: 15530.8248
{'factorized_top_k/top_1_categorical_accuracy': 0.014403138309717178,
 'factorized_top_k/top_5_categorical_accuracy': 0.07749549299478531,
 'factorized_top_k/top_10_categorical_accuracy': 0.13472424447536469,
 'factorized_top_k/top_50_categorical_accuracy': 0.37120863795280457,
 'factorized_top_k/top_100_categorical_accuracy': 0.5029690861701965,
 'loss': 9413.7470703125,
 'regularization_loss': 0,
 'total_loss': 9413.7470703125}

This concludes the sequential retrieval tutorial.