![]() |
Computes the lifted structured loss.
@tf.function
tfa.losses.lifted_struct_loss( labels:
tfa.types.TensorLike
, embeddings:tfa.types.TensorLike
, margin:tfa.types.FloatTensorLike
= 1.0 ) -> tf.Tensor
Args | |
---|---|
labels
|
1-D tf.int32 Tensor with shape [batch_size] of
multiclass integer labels.
|
embeddings
|
2-D float Tensor of embedding vectors. Embeddings should
not be l2 normalized.
|
margin
|
Float, margin term in the loss definition. |
Returns | |
---|---|
lifted_loss
|
float scalar with dtype of embeddings. |