MeanMetricWrapper

public class MeanMetricWrapper
Known Direct Subclasses

A class that bridges a stateless loss function with the Mean metric using a reduction of WEIGHTED_MEAN.

The loss function calculates the loss between the labels and predictions then passes this loss to the Mean metric to calculate the weighted mean of the loss over many iterations or epochs

Inherited Constants

org.tensorflow.framework.metrics.impl.Reduce
String COUNT
String TOTAL

Public Methods

LossMetric<T>
getLoss()
Gets the loss function.
List<Op>
updateStateList(Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
Creates Operations that update the state of the mean metric, by calling the loss function and passing the loss to the Mean metric to calculate the weighted mean of the loss over many iterations.

Inherited Methods

org.tensorflow.framework.metrics.impl.Reduce
Variable<T>
getCount()
Gets the count variable
Class<T>
getResultType()
Gets the type for the variables
Variable<T>
getTotal()
Gets the total variable
Op
resetStates()
Resets any state variables to their initial values
Operand<T>
result()
Gets the current result of the metric
List<Op>
updateStateList(Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights)
Updates the metric variables based on the inputs.
org.tensorflow.framework.metrics.Metric
final Operand<T>
callOnce(Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights)
Calls update state once, followed by a call to get the result
String
getName()
Gets the name of this metric.
long
getSeed()
Gets the random number generator seed value
Ops
getTF()
Gets the TensorFlow Ops
abstract Op
resetStates()
Resets any state variables to their initial values
abstract Operand<T>
result()
Gets the current result of the metric
final Op
updateState(Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
Creates a NoOp Operation with control dependencies to update the metric state
final Op
updateState(Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights)
Creates a NoOp Operation with control dependencies to update the metric state
List<Op>
updateStateList(Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
Creates a List of Operations to update the metric state based on labels and predictions.
List<Op>
updateStateList(Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights)
Creates a List of Operations to update the metric state based on input values.
boolean
equals(Object arg0)
final Class<?>
getClass()
int
hashCode()
final void
notify()
final void
notifyAll()
String
toString()
final void
wait(long arg0, int arg1)
final void
wait(long arg0)
final void
wait()

Public Methods

public LossMetric<T> getLoss ()

Gets the loss function.

Returns
  • the loss function.

public List<Op> updateStateList (Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)

Creates Operations that update the state of the mean metric, by calling the loss function and passing the loss to the Mean metric to calculate the weighted mean of the loss over many iterations.

Parameters
labels the truth values or labels
predictions the predictions
sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sampleWeights vector. If the shape of sampleWeights is [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss functions reduce by 1 dimension, usually axis=-1.)
Returns
  • a List of control operations that updates the Mean state variables.