tf.keras.metrics.MeanMetricWrapper

Wraps a stateless metric function with the Mean metric.

Inherits From: Mean, Metric, Layer, Module

You could use this class to quickly build a mean metric from a function. The function needs to have the signature fn(y_true, y_pred) and return a per-sample loss array. MeanMetricWrapper.result() will return the average metric value across all samples seen so far.

For example:

def accuracy(y_true, y_pred):
  return tf.cast(tf.math.equal(y_true, y_pred), tf.float32)

accuracy_metric = tf.keras.metrics.MeanMetricWrapper(fn=accuracy)

keras_model.compile(..., metrics=accuracy_metric)

fn The metric function to wrap, with signature fn(y_true, y_pred, **kwargs).
name (Optional) string name of the metric instance.
dtype (Optional) data type of the metric result.
**kwargs Keyword arguments to pass on to fn.

Methods

reset_state

View source

Resets all of the metric state variables.

This function is called between epochs/steps, when a metric is evaluated during training.

result