Known Direct Subclasses
BinaryCrossentropy<T extends TNumber>,
CategoricalCrossentropy<T extends TNumber>,
CategoricalHinge<T extends TNumber>,
CosineSimilarity<T extends TNumber>,
Hinge<T extends TNumber>,
KLDivergence<T extends TNumber>,
LogCoshError<T extends TNumber>,
MeanAbsoluteError<T extends TNumber>,
MeanAbsolutePercentageError<T extends TNumber>,
MeanSquaredError<T extends TNumber>,
MeanSquaredLogarithmicError<T extends TNumber>,
Poisson<T extends TNumber>,
SparseCategoricalCrossentropy<T extends TNumber>,
SquaredHinge<T extends TNumber>
|
A class that bridges a stateless loss function with the Mean
metric using a reduction of
WEIGHTED_MEAN
.
The loss function calculates the loss between the labels
and predictions
then passes this loss to the Mean
metric to calculate the weighted mean of the
loss over many iterations or epochs
Inherited Constants
Public Methods
LossMetric<T> |
getLoss()
Gets the loss function.
|
List<Op> |
updateStateList(Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
Creates Operations that update the state of the mean metric, by calling the loss function and
passing the loss to the Mean metric to calculate the weighted mean of the loss over many
iterations.
|
Inherited Methods
Variable<T> |
getCount()
Gets the count variable
|
Class<T> |
getResultType()
Gets the type for the variables
|
Variable<T> |
getTotal()
Gets the total variable
|
Op |
resetStates()
Resets any state variables to their initial values
|
Operand<T> |
result()
Gets the current result of the metric
|
List<Op> |
updateStateList(Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights)
Updates the metric variables based on the inputs.
|
final Operand<T> | |
String |
getName()
Gets the name of this metric.
|
long |
getSeed()
Gets the random number generator seed value
|
Ops |
getTF()
Gets the TensorFlow Ops
|
abstract Op |
resetStates()
Resets any state variables to their initial values
|
abstract Operand<T> |
result()
Gets the current result of the metric
|
final Op | |
final Op |
updateState(Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights)
Creates a NoOp Operation with control dependencies to update the metric state
|
List<Op> | |
List<Op> |
updateStateList(Operand<? extends TNumber> values, Operand<? extends TNumber> sampleWeights)
Creates a List of Operations to update the metric state based on input values.
|
boolean |
equals(Object arg0)
|
final Class<?> |
getClass()
|
int |
hashCode()
|
final void |
notify()
|
final void |
notifyAll()
|
String |
toString()
|
final void |
wait(long arg0, int arg1)
|
final void |
wait(long arg0)
|
final void |
wait()
|
Public Methods
public List<Op> updateStateList (Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions, Operand<? extends TNumber> sampleWeights)
Creates Operations that update the state of the mean metric, by calling the loss function and passing the loss to the Mean metric to calculate the weighted mean of the loss over many iterations.
Parameters
labels | the truth values or labels |
---|---|
predictions | the predictions |
sampleWeights | Optional sampleWeights acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sampleWeights vector. If the shape of sampleWeights is [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss functions reduce by 1 dimension, usually axis=-1.) |
Returns
- a List of control operations that updates the Mean state variables.