ScaNN approximate retrieval index for a factorized retrieval model.
Inherits From: TopK
tfrs.layers.factorized_top_k.ScaNN(
query_model: Optional[tf.keras.Model] = None,
k: int = 10,
distance_measure: Text = 'dot_product',
num_leaves: int = 100,
num_leaves_to_search: int = 10,
training_iterations: int = 12,
dimensions_per_block: int = 2,
num_reordering_candidates: Optional[int] = None,
parallelize_batch_searches: bool = True,
name: Optional[Text] = None
)
Used in the notebooks
Used in the tutorials |
---|
This layer uses the state-of-the-art ScaNN library to retrieve the best candidates for a given query.
To understand how to use this layer effectively, have a look at the efficient retrieval tutorial.
To deploy this layer in TensorFlow Serving you can use our customized TensorFlow Serving Docker container, available on Docker Hub. You can also build the image yourself from the Dockerfile.
Methods
call
call(
queries: Union[tf.Tensor, Dict[Text, tf.Tensor]], k: Optional[int] = None
) -> Tuple[tf.Tensor, tf.Tensor]
Query the index.
Args | |
---|---|
queries
|
Query features. If query_model was provided in the constructor,
these can be raw query features that will be processed by the query
model before performing retrieval. If query_model was not provided,
these should be pre-computed query embeddings.
|
k
|
The number of candidates to retrieve. If not supplied, defaults to the
k value supplied in the constructor.
|
Returns | |
---|---|
Tuple of (top candidate scores, top candidate identifiers). |
Raises | |
---|---|
ValueError if index has not been called.
|
index
index(
candidates: tf.Tensor, identifiers: Optional[tf.Tensor] = None
) -> 'ScaNN'
Builds the retrieval index.
When called multiple times the existing index will be dropped and a new one created.
Args | |
---|---|
candidates
|
Matrix of candidate embeddings. |
identifiers
|
Optional tensor of candidate identifiers. If given, these will be used as identifiers of top candidates returned when performing searches. If not given, indices into the candidates tensor will be returned instead. |
Returns | |
---|---|
Self. |
index_from_dataset
index_from_dataset(
candidates: tf.data.Dataset
) -> 'TopK'
Builds the retrieval index.
When called multiple times the existing index will be dropped and a new one created.
Args | |
---|---|
candidates
|
Dataset of candidate embeddings or (candidate identifier, candidate embedding) pairs. If the dataset returns tuples, the identifiers will be used as identifiers of top candidates returned when performing searches. If not given, indices into the candidates dataset will be given instead. |
Returns | |
---|---|
Self. |
Raises | |
---|---|
ValueError if the dataset does not have the correct structure. |
is_exact
is_exact() -> bool
Indicates whether the results returned by the layer are exact.
Some layers may return approximate scores: for example, the ScaNN layer may return approximate results.
Returns | |
---|---|
True if the layer returns exact results, and False otherwise. |
query_with_exclusions
@tf.function
query_with_exclusions( queries: Union[tf.Tensor, Dict[Text, tf.Tensor]], exclusions: tf.Tensor, k: Optional[int] = None ) -> Tuple[tf.Tensor, tf.Tensor]
Query the index.
Args | |
---|---|
queries
|
Query features. If query_model was provided in the constructor,
these can be raw query features that will be processed by the query
model before performing retrieval. If query_model was not provided,
these should be pre-computed query embeddings.
|
exclusions
|
[query_batch_size, num_to_exclude] tensor of identifiers to
be excluded from the top-k calculation. This is most commonly used to
exclude previously seen candidates from retrieval. For example, if a
user has already seen items with ids "42" and "43", you could set
exclude to [["42", "43"]] .
|
k
|
The number of candidates to retrieve. Defaults to constructor k
parameter if not supplied.
|
Returns | |
---|---|
Tuple of (top candidate scores, top candidate identifiers). |
Raises | |
---|---|
ValueError if index has not been called.
ValueError if queries is not a tensor (after being passed through
the query model).
|