tf.estimator.add_metrics
Creates a new tf.estimator.Estimator
which has given metrics.
tf.estimator.add_metrics(
estimator, metric_fn
)
Example:
def my_auc(labels, predictions):
auc_metric = tf.keras.metrics.AUC(name="my_auc")
auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'])
return {'auc': auc_metric}
estimator = tf.estimator.DNNClassifier(...)
estimator = tf.estimator.add_metrics(estimator, my_auc)
estimator.train(...)
estimator.evaluate(...)
Example usage of custom metric which uses features:
def my_auc(labels, predictions, features):
auc_metric = tf.keras.metrics.AUC(name="my_auc")
auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'],
sample_weight=features['weight'])
return {'auc': auc_metric}
estimator = tf.estimator.DNNClassifier(...)
estimator = tf.estimator.add_metrics(estimator, my_auc)
estimator.train(...)
estimator.evaluate(...)
Args |
estimator
|
A tf.estimator.Estimator object.
|
metric_fn
|
A function which should obey the following signature:
- Args: can only have following four arguments in any order:
- predictions: Predictions
Tensor or dict of Tensor created by given
estimator .
- features: Input
dict of Tensor objects created by input_fn which
is given to estimator.evaluate as an argument.
- labels: Labels
Tensor or dict of Tensor created by input_fn
which is given to estimator.evaluate as an argument.
- config: config attribute of the
estimator .
- Returns: Dict of metric results keyed by name. Final metrics are a
union of this and
estimator's existing metrics. If there is a name
conflict between this and estimator s existing metrics, this will
override the existing one. The values of the dict are the results of
calling a metric function, namely a (metric_tensor, update_op) tuple.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2021-02-18 UTC.
[null,null,["Last updated 2021-02-18 UTC."],[],[]]