View source on GitHub
|
Creates a _Head for multi-label classification.
tf.contrib.estimator.multi_label_head(
n_classes, weight_column=None, thresholds=None, label_vocabulary=None,
loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None,
classes_for_class_based_metrics=None, name=None
)
Multi-label classification handles the case where each example may have zero
or more associated labels, from a discrete set. This is distinct from
multi_class_head which has exactly one label per example.
Uses sigmoid_cross_entropy loss average over classes and weighted sum over
the batch. Namely, if the input logits have shape [batch_size, n_classes],
the loss is the average over n_classes and the weighted sum over
batch_size.
The head expects logits with shape [D0, D1, ... DN, n_classes]. In many
applications, the shape is [batch_size, n_classes].
Labels can be:
- A multi-hot tensor of shape
[D0, D1, ... DN, n_classes] - An integer
SparseTensorof class indices. Thedense_shapemust be[D0, D1, ... DN, ?]and the values within[0, n_classes). - If
label_vocabularyis given, a stringSparseTensor. Thedense_shapemust be[D0, D1, ... DN, ?]and the values withinlabel_vocabularyor a multi-hot tensor of shape[D0, D1, ... DN, n_classes].
If weight_column is specified, weights must be of shape
[D0, D1, ... DN], or [D0, D1, ... DN, 1].
Also supports custom loss_fn. loss_fn takes (labels, logits) or
(labels, logits, features) as arguments and returns unreduced loss with
shape [D0, D1, ... DN, 1]. loss_fn must support indicator labels with
shape [D0, D1, ... DN, n_classes]. Namely, the head applies
label_vocabulary to the input labels before passing them to loss_fn.
The head can be used with a canned estimator. Example:
my_head = tf.contrib.estimator.multi_label_head(n_classes=3)
my_estimator = tf.estimator.DNNEstimator(
head=my_head,
hidden_units=...,
feature_columns=...)
It can also be used with a custom model_fn. Example:
def _my_model_fn(features, labels, mode):
my_head = tf.contrib.estimator.multi_label_head(n_classes=3)
logits = tf.keras.Model(...)(features)
return my_head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
optimizer=tf.AdagradOptimizer(learning_rate=0.1),
logits=logits)
my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)
Args | |
|---|---|
n_classes
|
Number of classes, must be greater than 1 (for 1 class, use
binary_classification_head).
|
weight_column
|
A string or a _NumericColumn created by
tf.feature_column.numeric_column defining feature column representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example. Per-class weighting is
not supported.
|
thresholds
|
Iterable of floats in the range (0, 1). Accuracy, precision
and recall metrics are evaluated for each threshold value. The threshold
is applied to the predicted probabilities, i.e. above the threshold is
true, below is false.
|
label_vocabulary
|
A list of strings represents possible label values. If it
is not given, that means labels are already encoded as integer within
[0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor
string type and have any value in label_vocabulary. Also there will be
errors if vocabulary is not provided and labels are string.
|
loss_reduction
|
One of tf.losses.Reduction except NONE. Describes how to
reduce training loss over batch. Defaults to SUM_OVER_BATCH_SIZE, namely
weighted sum of losses divided by batch size. See tf.losses.Reduction.
|
loss_fn
|
Optional loss function. |
classes_for_class_based_metrics
|
List of integer class IDs or string class
names for which per-class metrics are evaluated. If integers, all must be
in the range [0, n_classes - 1]. If strings, all must be in
label_vocabulary.
|
name
|
name of the head. If provided, summary and metrics keys will be
suffixed by "/" + name. Also used as name_scope when creating ops.
|
Returns | |
|---|---|
An instance of _Head for multi-label classification.
|
Raises | |
|---|---|
ValueError
|
if n_classes, thresholds, loss_reduction, loss_fn or
metric_class_ids is invalid.
|
View source on GitHub