View source on GitHub |
A factorized retrieval task.
Inherits From: Task
tfrs.tasks.Retrieval(
loss: Optional[tf.keras.losses.Loss] = None,
metrics: Optional[Union[Sequence[tfrs_metrics.Factorized], tfrs_metrics.Factorized]] = None,
batch_metrics: Optional[List[tf.keras.metrics.Metric]] = None,
loss_metrics: Optional[List[tf.keras.metrics.Metric]] = None,
temperature: Optional[float] = None,
num_hard_negatives: Optional[int] = None,
remove_accidental_hits: bool = False,
name: Optional[Text] = None
) -> None
Used in the notebooks
Used in the tutorials |
---|
Recommender systems are often composed of two components:
- a retrieval model, retrieving O(thousands) candidates from a corpus of O(millions) candidates.
- a ranker model, scoring the candidates retrieved by the retrieval model to return a ranked shortlist of a few dozen candidates.
This task defines models that facilitate efficient retrieval of candidates from large corpora by maintaining a two-tower, factorized structure: separate query and candidate representation towers, joined at the top via a lightweight scoring function.
Args | |
---|---|
loss
|
Loss function. Defaults to
tf.keras.losses.CategoricalCrossentropy .
|
metrics
|
Object for evaluating top-K metrics over a
corpus of candidates. These metrics measure how good the model is at
picking the true candidate out of all possible candidates in the system.
Note, because the metrics range over the entire candidate set, they are
usually much slower to compute. Consider setting compute_metrics=False
during training to save the time in computing the metrics.
|
batch_metrics
|
Metrics measuring how good the model is at picking out the true candidate for a query from other candidates in the batch. For example, a batch AUC metric would measure the probability that the true candidate is scored higher than the other candidates in the batch. |
loss_metrics
|
List of Keras metrics used to summarize the loss. |
temperature
|
Temperature of the softmax. |
num_hard_negatives
|
If positive, the num_hard_negatives negative
examples with largest logits are kept when computing cross-entropy loss.
If larger than batch size or non-positive, all the negative examples are
kept.
|
remove_accidental_hits
|
When given enables removing accidental hits of examples used as negatives. An accidental hit is defined as a candidate that is used as an in-batch negative but has the same id with the positive candidate. |
name
|
Optional task name. |
Attributes | |
---|---|
factorized_metrics
|
The metrics object used to compute retrieval metrics. |
Methods
call
call(
query_embeddings: tf.Tensor,
candidate_embeddings: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
candidate_sampling_probability: Optional[tf.Tensor] = None,
candidate_ids: Optional[tf.Tensor] = None,
compute_metrics: bool = True,
compute_batch_metrics: bool = True
) -> tf.Tensor
Computes the task loss and metrics.
The main argument are pairs of query and candidate embeddings: the first row of query_embeddings denotes a query for which the candidate from the first row of candidate embeddings was selected by the user.
The task will try to maximize the affinity of these query, candidate pairs while minimizing the affinity between the query and candidates belonging to other queries in the batch.
Args | |
---|---|
query_embeddings
|
[num_queries, embedding_dim] tensor of query representations. |
candidate_embeddings
|
[num_candidates, embedding_dim] tensor of candidate
representations. Normally, num_candidates is the same as
num_queries : there is a positive candidate corresponding for every
query. However, it is also possible for num_candidates to be larger
than num_queries . In this case, the extra candidates will be used an
extra negatives for all queries.
|
sample_weight
|
[num_queries] tensor of sample weights. |
candidate_sampling_probability
|
Optional tensor of candidate sampling probabilities. When given will be be used to correct the logits to reflect the sampling probability of negative candidates. |
candidate_ids
|
Optional tensor containing candidate ids. When given, factorized top-K evaluation will be id-based rather than score-based. |
compute_metrics
|
Whether to compute metrics. Set this to False during training for faster training. |
compute_batch_metrics
|
Whether to compute batch level metrics. In-batch loss_metrics will still be computed. |
Returns | |
---|---|
loss
|
Tensor of loss values. |