tf.estimator.add_metrics
Stay organized with collections
Save and categorize content based on your preferences.
Creates a new tf.estimator.Estimator
which has given metrics.
tf.estimator.add_metrics(
estimator, metric_fn
)
Example:
def my_auc(labels, predictions):
auc_metric = tf.keras.metrics.AUC(name="my_auc")
auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'])
return {'auc': auc_metric}
estimator = tf.estimator.DNNClassifier(...)
estimator = tf.estimator.add_metrics(estimator, my_auc)
estimator.train(...)
estimator.evaluate(...)
Example usage of custom metric which uses features:
def my_auc(labels, predictions, features):
auc_metric = tf.keras.metrics.AUC(name="my_auc")
auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'],
sample_weight=features['weight'])
return {'auc': auc_metric}
estimator = tf.estimator.DNNClassifier(...)
estimator = tf.estimator.add_metrics(estimator, my_auc)
estimator.train(...)
estimator.evaluate(...)
Args |
estimator
|
A tf.estimator.Estimator object.
|
metric_fn
|
A function which should obey the following signature:
- Args: can only have following four arguments in any order:
- predictions: Predictions
Tensor or dict of Tensor created by given
estimator .
- features: Input
dict of Tensor objects created by input_fn which
is given to estimator.evaluate as an argument.
- labels: Labels
Tensor or dict of Tensor created by input_fn
which is given to estimator.evaluate as an argument.
- config: config attribute of the
estimator .
- Returns: Dict of metric results keyed by name. Final metrics are a
union of this and
estimator's existing metrics. If there is a name
conflict between this and estimator s existing metrics, this will
override the existing one. The values of the dict are the results of
calling a metric function, namely a (metric_tensor, update_op) tuple.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2022-10-28 UTC.
[null,null,["Last updated 2022-10-28 UTC."],[],[],null,["# tf.estimator.add_metrics\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/estimator/tree/master/tensorflow_estimator/python/estimator/extenders.py#L29-L100) |\n\nCreates a new [`tf.estimator.Estimator`](../../tf/estimator/Estimator) which has given metrics.\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.estimator.add_metrics`](https://www.tensorflow.org/api_docs/python/tf/estimator/add_metrics)\n\n\u003cbr /\u003e\n\n tf.estimator.add_metrics(\n estimator, metric_fn\n )\n\n#### Example:\n\n def my_auc(labels, predictions):\n auc_metric = tf.keras.metrics.AUC(name=\"my_auc\")\n auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'])\n return {'auc': auc_metric}\n\n estimator = tf.estimator.DNNClassifier(...)\n estimator = tf.estimator.add_metrics(estimator, my_auc)\n estimator.train(...)\n estimator.evaluate(...)\n\nExample usage of custom metric which uses features: \n\n def my_auc(labels, predictions, features):\n auc_metric = tf.keras.metrics.AUC(name=\"my_auc\")\n auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'],\n sample_weight=features['weight'])\n return {'auc': auc_metric}\n\n estimator = tf.estimator.DNNClassifier(...)\n estimator = tf.estimator.add_metrics(estimator, my_auc)\n estimator.train(...)\n estimator.evaluate(...)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `estimator` | A [`tf.estimator.Estimator`](../../tf/estimator/Estimator) object. |\n| `metric_fn` | A function which should obey the following signature: \u003cbr /\u003e - Args: can only have following four arguments in any order: - predictions: Predictions `Tensor` or dict of `Tensor` created by given `estimator`. - features: Input `dict` of `Tensor` objects created by `input_fn` which is given to `estimator.evaluate` as an argument. - labels: Labels `Tensor` or dict of `Tensor` created by `input_fn` which is given to `estimator.evaluate` as an argument. - config: config attribute of the `estimator`. - Returns: Dict of metric results keyed by name. Final metrics are a union of this and `estimator's` existing metrics. If there is a name conflict between this and `estimator`s existing metrics, this will override the existing one. The values of the dict are the results of calling a metric function, namely a `(metric_tensor, update_op)` tuple. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A new [`tf.estimator.Estimator`](../../tf/estimator/Estimator) which has a union of original metrics with given ones. ||\n\n\u003cbr /\u003e"]]