tfrs.tasks.Retrieval

A factorized retrieval task.

Inherits From: Task

Used in the notebooks

Used in the tutorials

Recommender systems are often composed of two components:

  • a retrieval model, retrieving O(thousands) candidates from a corpus of O(millions) candidates.
  • a ranker model, scoring the candidates retrieved by the retrieval model to return a ranked shortlist of a few dozen candidates.

This task defines models that facilitate efficient retrieval of candidates from large corpora by maintaining a two-tower, factorized structure: separate query and candidate representation towers, joined at the top via a lightweight scoring function.

loss Loss function. Defaults to tf.keras.losses.CategoricalCrossentropy.
metrics Object for evaluating top-K metrics over a corpus of candidates. These metrics measure how good the model is at picking the true candidate out of all possible candidates in the system. Note, because the metrics range over the entire candidate set, they are usually much slower to compute. Consider setting compute_metrics=False during training to save the time in computing the metrics.
batch_metrics Metrics measuring how good the model is at picking out the true candidate for a query from other candidates in the batch. For example, a batch AUC metric would measure the probability that the true candidate is scored higher than the other candidates in the batch.
loss_metrics List of Keras metrics used to summarize the loss.
temperature Temperature of the softmax.
num_hard_negatives If positive, the num_hard_negatives negative examples with largest logits are kept when computing cross-entropy loss. If larger than batch size or non-positive, all the negative examples are kept.
remove_accidental_hits When given enables removing accidental hits of examples used as negatives. An accidental hit is defined as a candidate that is used as an in-batch negative but has the same id with the positive candidate.
name Optional task name.

factorized_metrics The metrics object used to compute retrieval metrics.

Methods

call

View source

Computes the task loss and metrics.

The main argument are pairs of query and candidate embeddings: the first row of query_embeddings denotes a query for which the candidate from the first row of candidate embeddings was selected by the user.

The task will try to maximize the affinity of these query, candidate pairs while minimizing the affinity between the query and candidates belonging to other queries in the batch.

Args
query_embeddings [num_queries, embedding_dim] tensor of query representations.
candidate_embeddings [num_candidates, embedding_dim] tensor of candidate representations. Normally, num_candidates is the same as num_queries: there is a positive candidate corresponding for every query. However, it is also possible for num_candidates to be larger than num_queries. In this case, the extra candidates will be used an extra negatives for all queries.
sample_weight [num_queries] tensor of sample weights.
candidate_sampling_probability Optional tensor of candidate sampling probabilities. When given will be be used to correct the logits to reflect the sampling probability of negative candidates.
candidate_ids Optional tensor containing candidate ids. When given, factorized top-K evaluation will be id-based rather than score-based.
compute_metrics Whether to compute metrics. Set this to False during training for faster training.
compute_batch_metrics Whether to compute batch level metrics. In-batch loss_metrics will still be computed.

Returns
loss Tensor of loss values.