ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Creates a Head for multi-label classification.

Inherits From: Head

Multi-label classification handles the case where each example may have zero or more associated labels, from a discrete set. This is distinct from MultiClassHead which has exactly one label per example.

Uses sigmoid_cross_entropy loss average over classes and weighted sum over the batch. Namely, if the input logits have shape [batch_size, n_classes], the loss is the average over n_classes and the weighted sum over batch_size.

The head expects logits with shape [D0, D1, ... DN, n_classes]. In many applications, the shape is [batch_size, n_classes].

Labels can be:

  • A multi-hot tensor of shape [D0, D1, ... DN, n_classes]
  • An integer SparseTensor of class indices. The dense_shape must be [D0, D1, ... DN, ?] and the values within [0, n_classes).
  • If label_vocabulary is given, a string SparseTensor. The dense_shape must be [D0, D1, ... DN, ?] and the values within label_vocabulary or a multi-hot tensor of shape [D0, D1, ... DN, n_classes].

If weight_column is specified, weights must be of shape [D0, D1, ... DN], or [D0, D1, ... DN, 1].

Also supports custom loss_fn. loss_fn takes (labels, logits) or (labels, logits, features) as arguments and returns unreduced loss with shape [D0, D1, ... DN, 1]. loss_fn must support indicator labels with shape [D0, D1, ... DN, n_classes]. Namely, the head applies label_vocabulary to the input labels before passing them to loss_fn.


n_classes = 2
head = tf.estimator.MultiLabelHead(n_classes)
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
features = {'x': np.array([[41], [42]], dtype=np.int32)}
# expected_loss = sum(_sigmoid_cross_entropy(labels, logits)) / batch_size
#               = sum(1.31326169, 0.9514133) / 2 = 1.13
loss = head.loss(labels, logits, features=features)
eval_metrics = head.metrics()
updated_metrics = head.update_metrics(
  eval_metrics, features, logits, labels)
for k in sorted(updated_metrics):
 print('{} : {:.2f}'.format(k, updated_metrics[k].result().numpy()))
auc : 0.33
auc_precision_recall : 0.77
average_loss : 1.13
preds = head.predictions(logits)
  [[-1.   1. ]
   [-1.5  1.5]], shape=(2, 2), dtype=float32)

Usage with a canned estimator:

my_head = tf.estimator.MultiLabelHead(n_classes=3)
my_estimator = tf.estimator.DNNEstimator(

It can also be used with a custom model_fn. Example:

def _my_model_fn(features, labels, mode):
  my_head = tf.estimator.MultiLabelHead(n_classes=3)
  logits = tf.keras.Model(...)(features)

  return my_head.create_estimator_spec(

my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)