tf.estimator.MultiHead

Creates a Head for multi-objective learning.

Inherits From: Head

This class merges the output of multiple Head objects. Specifically:

  • For training, sums losses of each head, calls train_op_fn with this final loss.
  • For eval, merges metrics by adding head.name suffix to the keys in eval metrics, such as precision/head1.name, precision/head2.name.
  • For prediction, merges predictions and updates keys in prediction dict to a 2-tuple, (head.name, prediction_key). Merges export_outputs such that by default the first head is served.

Usage:

head1 = tf.estimator.MultiLabelHead(n_classes=2, name='head1')
head2 = tf.estimator.MultiLabelHead(n_classes=3, name='head2')
multi_head = tf.estimator.MultiHead([head1, head2])
logits = {
   'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
   'head2': np.array([[20., -20., 20.], [-30., 20., -20.]],
   dtype=np.float32),}
labels = {
   'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
   'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),}
features = {'x': np.array(((42,),), dtype=np.float32)}
# For large logits, sigmoid cross entropy loss is approximated as:
# loss = labels * (logits < 0) * (-logits) +
#        (1 - labels) * (logits > 0) * logits =>
# head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]
# loss1 = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75
# head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]
# loss2 = ((20 + 20 + 20) / 3 + (30 + 0 + 0) / 3) / 2 = 15.00
# loss = loss1 + loss2 = 8.75 + 15.00 = 23.75
loss = multi_head.loss(labels, logits, features=features)
print('{:.2f}'.format(loss.numpy()))
23.75