tf.contrib.estimator.multi_head
Stay organized with collections
Save and categorize content based on your preferences.
Creates a _Head
for multi-objective learning.
tf.contrib.estimator.multi_head(
heads, head_weights=None
)
This class merges the output of multiple _Head
objects. Specifically:
- For training, sums losses of each head, calls
train_op_fn
with this
final loss.
- For eval, merges metrics by adding
head.name
suffix to the keys in eval
metrics, such as precision/head1
, precision/head2
.
- For prediction, merges predictions and updates keys in prediction dict to a
2-tuple,
(head.name, prediction_key)
. Merges export_outputs
such that
by default the first head is served.
Usage:
# In `input_fn` specify labels as a dict keyed by head name:
def input_fn():
features = ...
labels1 = ...
labels2 = ...
return features, {'head1': labels1, 'head2': labels2}
# In `model_fn`, specify logits as a dict keyed by head name:
def model_fn(features, labels, mode):
# Create simple heads and specify head name.
head1 = multi_class_head(n_classes=3, name='head1')
head2 = binary_classification_head(name='head2')
# Create multi-head from two simple heads.
head = multi_head([head1, head2])
# Create logits for each head, and combine them into a dict.
logits1, logits2 = logit_fn()
logits = {'head1': logits1, 'head2': logits2}
# Return the merged EstimatorSpec
return head.create_estimator_spec(..., logits=logits, ...)
# Create an estimator with this model_fn.
estimator = tf.estimator.Estimator(model_fn=model_fn)
estimator.train(input_fn=input_fn, steps=100)
Also supports logits
as a Tensor
of shape
[D0, D1, ... DN, logits_dimension]
. It will split the Tensor
along the
last dimension and distribute it appropriately among the heads. E.g.:
def model_fn(features, labels, mode):
# Create simple heads and specify head name.
head1 = multi_class_head(n_classes=3, name='head1')
head2 = binary_classification_head(name='head2')
# Create multi-head from two simple heads.
head = multi_head([head1, head2])
# Create logits for the multihead.
logits = logit_fn(logits_dimension=head.logits_dimension)
# Return the merged EstimatorSpec
return head.create_estimator_spec(..., logits=logits, ...)
Args |
heads
|
List or tuple of _Head instances. All heads must have name
specified. The first head in the list is the default used at serving time.
|
head_weights
|
Optional list of weights, same length as heads . Used when
merging losses to calculate the weighted sum of losses from each head. If
None , all losses are weighted equally.
|
Returns |
A instance of _Head that merges multiple heads.
|
Raises |
ValueError
|
If heads is empty.
|
ValueError
|
If any of the heads does not have name specified.
|
ValueError
|
If heads and head_weights have different size.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[null,null,["Last updated 2020-10-01 UTC."],[],[],null,["# tf.contrib.estimator.multi_head\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/estimator/tree/master/tensorflow_estimator/contrib/estimator/python/estimator/multi_head.py) |\n\nCreates a `_Head` for multi-objective learning. \n\n tf.contrib.estimator.multi_head(\n heads, head_weights=None\n )\n\nThis class merges the output of multiple `_Head` objects. Specifically:\n\n- For training, sums losses of each head, calls `train_op_fn` with this final loss.\n- For eval, merges metrics by adding `head.name` suffix to the keys in eval metrics, such as `precision/head1`, `precision/head2`.\n- For prediction, merges predictions and updates keys in prediction dict to a 2-tuple, `(head.name, prediction_key)`. Merges `export_outputs` such that by default the first head is served.\n\n#### Usage:\n\n # In `input_fn` specify labels as a dict keyed by head name:\n def input_fn():\n features = ...\n labels1 = ...\n labels2 = ...\n return features, {'head1': labels1, 'head2': labels2}\n\n # In `model_fn`, specify logits as a dict keyed by head name:\n def model_fn(features, labels, mode):\n # Create simple heads and specify head name.\n head1 = multi_class_head(n_classes=3, name='head1')\n head2 = binary_classification_head(name='head2')\n # Create multi-head from two simple heads.\n head = multi_head([head1, head2])\n # Create logits for each head, and combine them into a dict.\n logits1, logits2 = logit_fn()\n logits = {'head1': logits1, 'head2': logits2}\n # Return the merged EstimatorSpec\n return head.create_estimator_spec(..., logits=logits, ...)\n\n # Create an estimator with this model_fn.\n estimator = tf.estimator.Estimator(model_fn=model_fn)\n estimator.train(input_fn=input_fn, steps=100)\n\nAlso supports `logits` as a `Tensor` of shape\n`[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the\nlast dimension and distribute it appropriately among the heads. E.g.: \n\n def model_fn(features, labels, mode):\n # Create simple heads and specify head name.\n head1 = multi_class_head(n_classes=3, name='head1')\n head2 = binary_classification_head(name='head2')\n # Create multi-head from two simple heads.\n head = multi_head([head1, head2])\n # Create logits for the multihead.\n logits = logit_fn(logits_dimension=head.logits_dimension)\n # Return the merged EstimatorSpec\n return head.create_estimator_spec(..., logits=logits, ...)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|----------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `heads` | List or tuple of `_Head` instances. All heads must have `name` specified. The first head in the list is the default used at serving time. |\n| `head_weights` | Optional list of weights, same length as `heads`. Used when merging losses to calculate the weighted sum of losses from each head. If `None`, all losses are weighted equally. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A instance of `_Head` that merges multiple heads. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|-------------------------------------------------------|\n| `ValueError` | If `heads` is empty. |\n| `ValueError` | If any of the `heads` does not have `name` specified. |\n| `ValueError` | If `heads` and `head_weights` have different size. |\n\n\u003cbr /\u003e"]]