|  TensorFlow 2 version |  View source on GitHub | 
Creates a Head for multi-objective learning.
Inherits From: Head
tf.estimator.MultiHead(
    heads, head_weights=None
)
This class merges the output of multiple Head objects. Specifically:
- For training, sums losses of each head, calls train_op_fnwith this final loss.
- For eval, merges metrics by adding head.namesuffix to the keys in eval metrics, such asprecision/head1.name,precision/head2.name.
- For prediction, merges predictions and updates keys in prediction dict to a
2-tuple, (head.name, prediction_key). Mergesexport_outputssuch that by default the first head is served.
Usage:
# In `input_fn`, specify labels as a dict keyed by head name:
def input_fn():
  features = ...
  labels1 = ...
  labels2 = ...
  return features, {'head1.name': labels1, 'head2.name': labels2}
# In `model_fn`, specify logits as a dict keyed by head name:
def model_fn(features, labels, mode):
  # Create simple heads and specify head name.
  head1 = tf.estimator.MultiClassHead(n_classes=3, name='head1')
  head2 = tf.estimator.BinaryClassHead(name='head2')
  # Create MultiHead from two simple heads.
  head = tf.estimator.MultiHead([head1, head2])
  # Create logits for each head, and combine them into a dict.
  logits1, logits2 = logit_fn()
  logits = {'head1.name': logits1, 'head2.name': logits2}
  # Return the merged EstimatorSpec
  return head.create_estimator_spec(..., logits=logits, ...)
# Create an estimator with this model_fn.
estimator = tf.estimator.Estimator(model_fn=model_fn)
estimator.train(input_fn=input_fn)
Also supports logits as a Tensor of shape
[D0, D1, ... DN, logits_dimension]. It will split the Tensor along the
last dimension and distribute it appropriately among the heads. E.g.:
Input logits.
logits = np.array([[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]],
                dtype=np.float32)
Suppose head1.logits_dimension = 2 and head2.logits_dimension = 3. After
splitting, the result is:
logits_dict = {'head1_name': [[-1., 1.], [-1.5, 1.]],
             'head2_name':  [[2., -2., 2.], [-3., 2., -2.]]}
Usage:
def model_fn(features, labels, mode):
  # Create simple heads and specify head name.
  head1 = tf.estimator.MultiClassHead(n_classes=3, name='head1')
  head2 = tf.estimator.BinaryClassHead(name='head2')
  # Create multi-head from two simple heads.
  head = tf.estimator.MultiHead([head1, head2])
  # Create logits for the multihead. The result of logits is a `Tensor`.
  logits = logit_fn(logits_dimension=head.logits_dimension)
  # Return the merged EstimatorSpec
  return head.create_estimator_spec(..., logits=logits, ...)
| Args | |
|---|---|
| heads | List or tuple of Headinstances. All heads must havenamespecified. The first head in the list is the default used at serving time. | 
| head_weights | Optional list of weights, same length as heads. Used when
merging losses to calculate the weighted sum of losses from each head. IfNone, all losses are weighted equally. | 
| Attributes | |
|---|---|
| logits_dimension | See base_head.Headfor details. | 
| loss_reduction | See base_head.Headfor details. | 
| name | See base_head.Headfor details. | 
Methods
create_estimator_spec
create_estimator_spec(
    features, mode, logits, labels=None, optimizer=None, trainable_variables=None,
    train_op_fn=None, update_ops=None, regularization_losses=None
)
Returns a model_fn.EstimatorSpec.
| Args | |
|---|---|
| features | Input dictofTensororSparseTensorobjects. | 
| mode | Estimator's ModeKeys. | 
| logits | Input dictkeyed by head name, or logitsTensorwith shape[D0, D1, ... DN, logits_dimension]. For many applications, theTensorshape is[batch_size, logits_dimension]. If logits is aTensor, it  will split theTensoralong the last dimension and
distribute it appropriately among the heads. CheckMultiHeadfor
examples. | 
| labels | Input dictkeyed by head name. For each head, the label value
can be integer or stringTensorwith shape matching its correspondinglogits.labelsis a required argument whenmodeequalsTRAINorEVAL. | 
| optimizer | An tf.keras.optimizers.Optimizerinstance to optimize the
loss in TRAIN mode. Namely, setstrain_op = optimizer.get_updates(loss,
trainable_variables), which updates variables to minimizeloss. | 
| trainable_variables | A list or tuple of Variableobjects to update to
minimizeloss. In Tensorflow 1.x, by default these are the list of
variables collected in the graph under the keyGraphKeys.TRAINABLE_VARIABLES. As Tensorflow 2.x doesn't have
collections and GraphKeys, trainable_variables need to be passed
explicitly here. | 
| train_op_fn | Function that takes a scalar loss Tensorand returnstrain_op. Used ifoptimizerisNone. | 
| update_ops | A list or tuple of update ops to be run at training time. For example, layers such as BatchNormalization create mean and variance update ops that need to be run at training time. In Tensorflow 1.x, these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x doesn't have collections, update_ops need to be passed explicitly here. | 
| regularization_losses | A list of additional scalar losses to be added to
the training loss, such as regularization losses. These losses are
usually expressed as a batch average, so for best results, in each head,
users need to use the default loss_reduction=SUM_OVER_BATCH_SIZEto
avoid scaling errors.  Compared to the regularization losses for each
head, this loss is to regularize the merged loss of all heads in multi
head, and will be added to the overall training loss of multi head. | 
| Returns | |
|---|---|
| A model_fn.EstimatorSpecinstance. | 
| Raises | |
|---|---|
| ValueError | If both train_op_fnandoptimizerareNonein TRAIN
mode, or if both are set.
Ifmodeis not in Estimator'sModeKeys. | 
loss
loss(
    labels, logits, features=None, mode=None, regularization_losses=None
)
Returns regularized training loss. See base_head.Head for details.
metrics
metrics(
    regularization_losses=None
)
Creates metrics. See base_head.Head for details.
predictions
predictions(
    logits, keys=None
)
Create predictions. See base_head.Head for details.
update_metrics
update_metrics(
    eval_metrics, features, logits, labels, regularization_losses=None
)
Updates eval metrics. See base_head.Head for details.