Warning: This project is deprecated. TensorFlow Addons has stopped development, The project will only be providing minimal maintenance releases until May 2024. See the full announcement here or on github.


Computes the triplet loss with semi-hard negative mining.


y_true = tf.convert_to_tensor([0, 0])
y_pred = tf.convert_to_tensor([[0.0, 1.0], [1.0, 0.0]])
tfa.losses.triplet_semihard_loss(y_true, y_pred, distance_metric="L2")
<tf.Tensor: shape=(), dtype=float32, numpy=2.4142137>
# Calling with callable `distance_metric`
distance_metric = lambda x: tf.linalg.matmul(x, x, transpose_b=True)
tfa.losses.triplet_semihard_loss(y_true, y_pred, distance_metric=distance_metric)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

y_true 1-D integer Tensor with shape [batch_size] of multiclass integer labels.
y_pred 2-D float Tensor of embedding vectors. Embeddings should be l2 normalized.
margin Float, margin term in the loss definition.
distance_metric str or a Callable that determines distance metric. Valid strings are "L2" for l2-norm distance, "squared-L2" for squared l2-norm distance, and "angular" for cosine similarity.

A Callable should take a batch of embeddings as input and return the pairwise distance matrix.

triplet_loss float scalar with dtype of y_pred.