View source on GitHub
|
Returns min k values and their indices of the input operand in an approximate manner.
tf.math.approx_min_k(
operand,
k,
reduction_dimension=-1,
recall_target=0.95,
reduction_input_size_override=-1,
aggregate_to_topk=True,
name=None
)
See https://arxiv.org/abs/2206.14286 for the algorithm details. This op is only optimized on TPU currently.
Returns | |
|---|---|
Tuple of two arrays. The arrays are the least k values and the
corresponding indices along the reduction_dimension of the input
operand. The arrays' dimensions are the same as the input operand
except for the reduction_dimension: when aggregate_to_topk is true,
the reduction dimension is k; otherwise, it is greater equals to k
where the size is implementation-defined.
|
We encourage users to wrap approx_min_k with jit. See the following example
for nearest neighbor search over the squared l2 distance:
import tensorflow as tf@tf.function(jit_compile=True)def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95):dists = half_db_norms - tf.einsum('ik,jk->ij', qy, db)return tf.nn.approx_min_k(dists, k=k, recall_target=recall_target)qy = tf.random.uniform((256,128))db = tf.random.uniform((2048,128))half_db_norms = tf.norm(db, axis=1) / 2dists, neighbors = l2_ann(qy, db, half_db_norms)
In the example above, we compute db_norms/2 - dot(qy, db^T) instead of
qy^2 - 2 dot(qy, db^T) + db^2 for performance reason. The former uses less
arithmetics and produces the same set of neighbors.
View source on GitHub