Computes the clustering loss.
tf.contrib.losses.metric_learning.cluster_loss(
labels, embeddings, margin_multiplier, enable_pam_finetuning=True,
margin_type='nmi', print_losses=False
)
The following structured margins are supported:
nmi: normalized mutual information
ami: adjusted mutual information
ari: adjusted random index
vmeasure: v-measure
const: indicator checking whether the two clusterings are the same.
Args |
labels
|
2-D Tensor of labels of shape [batch size, 1]
|
embeddings
|
2-D Tensor of embeddings of shape
[batch size, embedding dimension]. Embeddings should be l2 normalized.
|
margin_multiplier
|
float32 scalar. multiplier on the structured margin term
See section 3.2 of paper for discussion.
|
enable_pam_finetuning
|
Boolean, Whether to run local pam refinement.
See section 3.4 of paper for discussion.
|
margin_type
|
Type of structured margin to use. See section 3.2 of
paper for discussion. Can be 'nmi', 'ami', 'ari', 'vmeasure', 'const'.
|
print_losses
|
Boolean. Option to print the loss.
|
Paper: https://arxiv.org/abs/1612.01213
Returns |
clustering_loss
|
A float32 scalar Tensor .
|
Raises |
ImportError
|
If sklearn dependency is not installed.
|