목록별 순위

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

에서 기본 순위 튜토리얼 , 우리는 사용자 / 영화 쌍에 대한 등급을 예측할 수있는 모델을 훈련했다. 모델은 예측된 등급의 평균 제곱 오차를 최소화하도록 훈련되었습니다.

그러나 개별 영화에 대한 모델의 예측을 최적화하는 것이 반드시 순위 모델을 훈련하는 가장 좋은 방법은 아닙니다. 높은 정확도로 점수를 예측하기 위해 순위 모델이 필요하지 않습니다. 대신 사용자의 기본 설정 순서와 일치하는 항목의 정렬된 목록을 생성하는 모델의 기능에 더 관심이 있습니다.

개별 쿼리/항목 쌍에 대한 모델의 예측을 최적화하는 대신 전체 목록의 모델 순위를 최적화할 수 있습니다. 이 방법은 순위 listwise이라고합니다.

이 튜토리얼에서는 TensorFlow Recommenders를 사용하여 목록별 순위 모델을 빌드합니다. 이렇게하려면, 우리가 제공하는 손실과 통계 순위를 사용하게됩니다 TensorFlow 순위 에 초점을 맞추고하는 TensorFlow 패키지 순위에 학습 .

예선

TensorFlow 순위가 런타임 환경에서 사용할 수없는 경우 사용하여 설치할 수 있습니다 pip :

pip install -q tensorflow-recommenders
pip install -q --upgrade tensorflow-datasets
pip install -q tensorflow-ranking

그런 다음 필요한 모든 패키지를 가져올 수 있습니다.

import pprint

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py:119: PkgResourcesDeprecationWarning: 0.18ubuntu0.18.04.1 is an invalid version and will not be supported in a future release
  PkgResourcesDeprecationWarning,
import tensorflow_ranking as tfr
import tensorflow_recommenders as tfrs
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_addons/utils/ensure_tf_install.py:67: UserWarning: Tensorflow Addons supports using Python ops for all Tensorflow versions above or equal to 2.4.0 and strictly below 2.7.0 (nightly versions are not supported). 
 The versions of TensorFlow you are currently using is 2.7.0 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons
  UserWarning,

MovieLens 100K 데이터 세트를 계속 사용할 것입니다. 이전과 마찬가지로 데이터 세트를 로드하고 이 자습서의 사용자 ID, 영화 제목 및 사용자 평가 기능만 유지합니다. 우리는 또한 우리의 어휘를 준비하기 위해 약간의 하우스 키핑을 합니다.

ratings = tfds.load("movielens/100k-ratings", split="train")
movies = tfds.load("movielens/100k-movies", split="train")

ratings = ratings.map(lambda x: {
    "movie_title": x["movie_title"],
    "user_id": x["user_id"],
    "user_rating": x["user_rating"],
})
movies = movies.map(lambda x: x["movie_title"])

unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))
unique_user_ids = np.unique(np.concatenate(list(ratings.batch(1_000).map(
    lambda x: x["user_id"]))))

데이터 전처리

그러나 목록 최적화를 위해 MovieLens 데이터 세트를 직접 사용할 수는 없습니다. 목록별 최적화를 수행하려면 각 사용자가 평가한 영화 목록에 액세스할 수 있어야 하지만 MovieLens 100K 데이터 세트의 각 예에는 단일 영화의 평가만 포함되어 있습니다.

이 문제를 해결하기 위해 각 예제에 사용자 ID와 해당 사용자가 평가한 영화 목록이 포함되도록 데이터 세트를 변환합니다. 목록에 있는 일부 영화는 다른 영화보다 순위가 더 높습니다. 우리 모델의 목표는 이 순서와 일치하는 예측을 하는 것입니다.

이를 위해, 우리는 사용 tfrs.examples.movielens.movielens_to_listwise 도우미 함수를. MovieLens 100K 데이터 세트를 사용하고 위에서 설명한 목록 예제가 포함된 데이터 세트를 생성합니다. 구현 세부 사항은에서 찾을 수 있습니다 소스 코드 .

tf.random.set_seed(42)

# Split between train and tests sets, as before.
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)

train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)

# We sample 50 lists for each user for the training data. For each list we
# sample 5 movies from the movies the user rated.
train = tfrs.examples.movielens.sample_listwise(
    train,
    num_list_per_user=50,
    num_examples_per_list=5,
    seed=42
)
test = tfrs.examples.movielens.sample_listwise(
    test,
    num_list_per_user=1,
    num_examples_per_list=5,
    seed=42
)

훈련 데이터에서 예제를 검사할 수 있습니다. 예제에는 사용자 ID, 10개의 영화 ID 목록 및 사용자의 등급이 포함됩니다.

for example in train.take(1):
  pprint.pprint(example)
{'movie_title': <tf.Tensor: shape=(5,), dtype=string, numpy=
array([b'Postman, The (1997)', b'Liar Liar (1997)', b'Contact (1997)',
       b'Welcome To Sarajevo (1997)',
       b'I Know What You Did Last Summer (1997)'], dtype=object)>,
 'user_id': <tf.Tensor: shape=(), dtype=string, numpy=b'681'>,
 'user_rating': <tf.Tensor: shape=(5,), dtype=float32, numpy=array([4., 5., 1., 4., 1.], dtype=float32)>}

모델 정의

우리는 세 가지 다른 손실로 동일한 모델을 훈련할 것입니다:

  • 평균 제곱 오차,
  • 쌍으로 된 힌지 손실,
  • 목록별 ListMLE 손실.

이 세 가지 손실은 pointwise, pairwise 및 listwise 최적화에 해당합니다.

우리가 사용하는 모델을 평가하기 위해 정규화를 누적 이득 (NDCG)를 할인 . NDCG는 각 후보자의 실제 평점을 가중치로 합산하여 예측 순위를 측정합니다. 모델별로 순위가 낮은 영화의 등급은 더 할인됩니다. 결과적으로 높은 평가를 받은 영화를 상위에 배치하는 좋은 모델은 높은 NDCG 결과를 얻을 수 있습니다. 이 측정항목은 각 후보자의 순위를 고려하므로 목록별 측정항목입니다.

class RankingModel(tfrs.Model):

  def __init__(self, loss):
    super().__init__()
    embedding_dimension = 32

    # Compute embeddings for users.
    self.user_embeddings = tf.keras.Sequential([
      tf.keras.layers.StringLookup(
        vocabulary=unique_user_ids),
      tf.keras.layers.Embedding(len(unique_user_ids) + 2, embedding_dimension)
    ])

    # Compute embeddings for movies.
    self.movie_embeddings = tf.keras.Sequential([
      tf.keras.layers.StringLookup(
        vocabulary=unique_movie_titles),
      tf.keras.layers.Embedding(len(unique_movie_titles) + 2, embedding_dimension)
    ])

    # Compute predictions.
    self.score_model = tf.keras.Sequential([
      # Learn multiple dense layers.
      tf.keras.layers.Dense(256, activation="relu"),
      tf.keras.layers.Dense(64, activation="relu"),
      # Make rating predictions in the final layer.
      tf.keras.layers.Dense(1)
    ])

    self.task = tfrs.tasks.Ranking(
      loss=loss,
      metrics=[
        tfr.keras.metrics.NDCGMetric(name="ndcg_metric"),
        tf.keras.metrics.RootMeanSquaredError()
      ]
    )

  def call(self, features):
    # We first convert the id features into embeddings.
    # User embeddings are a [batch_size, embedding_dim] tensor.
    user_embeddings = self.user_embeddings(features["user_id"])

    # Movie embeddings are a [batch_size, num_movies_in_list, embedding_dim]
    # tensor.
    movie_embeddings = self.movie_embeddings(features["movie_title"])

    # We want to concatenate user embeddings with movie emebeddings to pass
    # them into the ranking model. To do so, we need to reshape the user
    # embeddings to match the shape of movie embeddings.
    list_length = features["movie_title"].shape[1]
    user_embedding_repeated = tf.repeat(
        tf.expand_dims(user_embeddings, 1), [list_length], axis=1)

    # Once reshaped, we concatenate and pass into the dense layers to generate
    # predictions.
    concatenated_embeddings = tf.concat(
        [user_embedding_repeated, movie_embeddings], 2)

    return self.score_model(concatenated_embeddings)

  def compute_loss(self, features, training=False):
    labels = features.pop("user_rating")

    scores = self(features)

    return self.task(
        labels=labels,
        predictions=tf.squeeze(scores, axis=-1),
    )

모델 훈련

이제 세 가지 모델을 각각 훈련할 수 있습니다.

epochs = 30

cached_train = train.shuffle(100_000).batch(8192).cache()
cached_test = test.batch(4096).cache()

평균 제곱 오차 모델

이 모델에서 모델과 매우 유사 기본 순위 튜토리얼 . 실제 등급과 예측 등급 간의 평균 제곱 오차를 최소화하도록 모델을 훈련합니다. 따라서 이 손실은 각 영화에 대해 개별적으로 계산되며 교육은 포인트별로 수행됩니다.

mse_model = RankingModel(tf.keras.losses.MeanSquaredError())
mse_model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))
mse_model.fit(cached_train, epochs=epochs, verbose=False)
<keras.callbacks.History at 0x7f64791a5d10>

쌍별 힌지 손실 모델

쌍별 힌지 손실을 최소화함으로써 모델은 높은 평가 항목과 낮은 평점 항목에 대한 모델 예측 간의 차이를 최대화하려고 시도합니다. 차이가 클수록 모델 손실이 낮아집니다. 그러나 차이가 충분히 크면 손실이 0이 되어 모델이 이 특정 쌍을 더 이상 최적화하지 못하고 잘못 순위가 지정된 다른 쌍에 집중할 수 있습니다.

이 손실은 개별 영화에 대해 계산되는 것이 아니라 영화 쌍에 대해 계산됩니다. 따라서 이 손실을 사용한 훈련은 pairwise입니다.

hinge_model = RankingModel(tfr.keras.losses.PairwiseHingeLoss())
hinge_model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))
hinge_model.fit(cached_train, epochs=epochs, verbose=False)
<keras.callbacks.History at 0x7f647914f190>

목록별 모델

ListMLE TensorFlow 순위 표현의 손실은 최대 우도 추정을 나열합니다. ListMLE 손실을 계산하기 위해 먼저 사용자 등급을 사용하여 최적의 순위를 생성합니다. 그런 다음 예측된 점수를 사용하여 최적 순위에서 각 후보가 그 아래 항목보다 순위가 낮을 가능성을 계산합니다. 모델은 이러한 가능성을 최소화하여 높은 평가를 받은 후보자가 낮은 평가를 받은 후보자보다 앞서지 않도록 합니다. 당신은 종이의 2.2 절에 ListMLE의 세부 사항에 대해 자세히 알아볼 수 있습니다 위치 인식 ListMLE : 순차적 인 학습 과정을 .

우도는 최적 순위에서 후보 및 그 아래의 모든 후보에 대해 계산되므로 손실은 쌍별이 아니라 목록별입니다. 따라서 훈련은 목록 최적화를 사용합니다.

listwise_model = RankingModel(tfr.keras.losses.ListMLELoss())
listwise_model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))
listwise_model.fit(cached_train, epochs=epochs, verbose=False)
<keras.callbacks.History at 0x7f647b35f350>

모델 비교

mse_model_result = mse_model.evaluate(cached_test, return_dict=True)
print("NDCG of the MSE Model: {:.4f}".format(mse_model_result["ndcg_metric"]))
1/1 [==============================] - 0s 405ms/step - ndcg_metric: 0.9053 - root_mean_squared_error: 0.9671 - loss: 0.9354 - regularization_loss: 0.0000e+00 - total_loss: 0.9354
NDCG of the MSE Model: 0.9053
hinge_model_result = hinge_model.evaluate(cached_test, return_dict=True)
print("NDCG of the pairwise hinge loss model: {:.4f}".format(hinge_model_result["ndcg_metric"]))
1/1 [==============================] - 0s 457ms/step - ndcg_metric: 0.9058 - root_mean_squared_error: 3.8330 - loss: 1.0180 - regularization_loss: 0.0000e+00 - total_loss: 1.0180
NDCG of the pairwise hinge loss model: 0.9058
listwise_model_result = listwise_model.evaluate(cached_test, return_dict=True)
print("NDCG of the ListMLE model: {:.4f}".format(listwise_model_result["ndcg_metric"]))
1/1 [==============================] - 0s 432ms/step - ndcg_metric: 0.9071 - root_mean_squared_error: 2.7224 - loss: 4.5401 - regularization_loss: 0.0000e+00 - total_loss: 4.5401
NDCG of the ListMLE model: 0.9071

세 가지 모델 중 ListMLE를 사용하여 훈련된 모델이 가장 높은 NDCG 메트릭을 가지고 있습니다. 이 결과는 목록별 최적화가 순위 모델을 훈련하는 데 어떻게 사용될 수 있고 잠재적으로 포인트별 또는 쌍별 방식으로 최적화된 모델보다 더 나은 성능을 보이는 모델을 생성할 수 있는지 보여줍니다.