tf.estimator.add_metrics
    
    
      
    
    
      
      Stay organized with collections
    
    
      
      Save and categorize content based on your preferences.
    
  
  
      
    
  
  
  
  
  
    
  
  
    
    
Creates a new tf.estimator.Estimator which has given metrics.
tf.estimator.add_metrics(
    estimator, metric_fn
)
Example:
  def my_auc(labels, predictions):
    auc_metric = tf.keras.metrics.AUC(name="my_auc")
    auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'])
    return {'auc': auc_metric}
  estimator = tf.estimator.DNNClassifier(...)
  estimator = tf.estimator.add_metrics(estimator, my_auc)
  estimator.train(...)
  estimator.evaluate(...)
Example usage of custom metric which uses features:
  def my_auc(labels, predictions, features):
    auc_metric = tf.keras.metrics.AUC(name="my_auc")
    auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'],
                            sample_weight=features['weight'])
    return {'auc': auc_metric}
  estimator = tf.estimator.DNNClassifier(...)
  estimator = tf.estimator.add_metrics(estimator, my_auc)
  estimator.train(...)
  estimator.evaluate(...)
| Args | 
|---|
| estimator | A tf.estimator.Estimatorobject. | 
| metric_fn | A function which should obey the following signature: 
Args: can only have following four arguments in any order:predictions: Predictions Tensoror dict ofTensorcreated by givenestimator.features: Input dictofTensorobjects created byinput_fnwhich
is given toestimator.evaluateas an argument.labels:  Labels Tensoror dict ofTensorcreated byinput_fnwhich is given toestimator.evaluateas an argument.config: config attribute of the estimator.Returns:
Dict of metric results keyed by name. Final metrics are a union of this
and estimator'sexisting metrics. If there is a name conflict between
this andestimators existing metrics, this will override the existing
one. The values of the dict are the results of calling a metric
function, namely a(metric_tensor, update_op)tuple. | 
  
  
 
  
    
    
      
       
    
    
  
  
  Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
  Last updated 2020-10-01 UTC.
  
  
  
    
      [null,null,["Last updated 2020-10-01 UTC."],[],[]]