These are helper methods for Losses and Metrics and will be module private when Java modularity is applied to TensorFlow Java. These methods should not be used outside of the losses and metrics packages.
Public Constructors
Public Methods
| static <T extends TNumber> Operand<TInt32> | |
| static <T extends TNumber> Operand<T> |
computeWeightedLoss(Ops tf, Operand<T> loss, Reduction reduction, Operand<T> sampleWeight)
Computes the weighted loss
|
| static <T extends TNumber> Operand<T> |
rangeCheck(Ops tf, String prefix, Operand<T> values, Operand<T> minValue, Operand<T> maxValue)
Perform an inclusive range check on the values
|
| static <T extends TNumber> LossTuple<T> |
removeSqueezableDimensions(Ops tf, Operand<T> labels, Operand<T> predictions)
Squeeze last dim if ranks differ from expected by exactly 1.
|
| static <T extends TNumber> LossTuple<T> |
removeSqueezableDimensions(Ops tf, Operand<T> labels, Operand<T> predictions, int expectedRankDiff)
Squeeze last dim if ranks differ from expected by exactly 1.
|
| static <T extends TNumber> Operand<T> | |
| static <T extends TNumber> LossTuple<T> |
squeezeOrExpandDimensions(Ops tf, Operand<T> labels, Operand<T> predictions)
Squeeze or expand last dimension if needed with a sampleWeights of one.
|
| static <T extends TNumber> LossTuple<T> |
squeezeOrExpandDimensions(Ops tf, Operand<T> labels, Operand<T> predictions, Operand<T> sampleWeights)
Squeeze or expand last dimension if needed.
|
| static <T extends TNumber> Operand<T> |
valueCheck(Ops tf, String prefix, Operand<T> values, Operand<T> allowedValues)
Checks to see if all the values are in the allowed values set.
|
Inherited Methods
Public Constructors
public LossesHelper ()
Public Methods
public static Operand<TInt32> allAxes (Ops tf, Operand<T> op)
Gets a Constant integer array representing all the axes of the operand.
Parameters
| tf | the TensorFlow Ops |
|---|---|
| op | the TensorFlow Ops |
Returns
- a Constant that represents all the axes of the operand.
public static Operand<T> computeWeightedLoss (Ops tf, Operand<T> loss, Reduction reduction, Operand<T> sampleWeight)
Computes the weighted loss
Parameters
| tf | the TensorFlow Ops |
|---|---|
| loss | the unweighted loss |
| reduction | the type of reduction |
| sampleWeight | the sample weight, if null then this defaults to one. |
Returns
- the weighted loss
public static Operand<T> rangeCheck (Ops tf, String prefix, Operand<T> values, Operand<T> minValue, Operand<T> maxValue)
Perform an inclusive range check on the values
Parameters
| tf | the TensorFlow Ops |
|---|---|
| prefix | A String prefix to include in the error message |
| values | the values to check |
| minValue | the minimum value |
| maxValue | the maximum value |
Returns
- the values possibly with control dependencies if the TensorFlow Ops represents a Graph Session
Throws
| IllegalArgumentException | if the TensorFlow Ops represents an Eager Session |
|---|
public static LossTuple<T> removeSqueezableDimensions (Ops tf, Operand<T> labels, Operand<T> predictions)
Squeeze last dim if ranks differ from expected by exactly 1.
Parameters
| tf | the TensorFlowOps |
|---|---|
| labels | Label values, a Tensor whose dimensions match predictions
. |
| predictions | Predicted values, a Tensor of arbitrary dimensions. |
Returns
labelsandpredictions, possibly with last dim squeezed.
public static LossTuple<T> removeSqueezableDimensions (Ops tf, Operand<T> labels, Operand<T> predictions, int expectedRankDiff)
Squeeze last dim if ranks differ from expected by exactly 1.
Parameters
| tf | the TensorFlowOps |
|---|---|
| labels | Label values, a Operand whose dimensions match predictions
. |
| predictions | Predicted values, a Tensor of arbitrary dimensions. |
| expectedRankDiff | Expected result of rank(predictions) - rank(labels). |
Returns
labelsandpredictions, possibly with last dim squeezed.
public static Operand<T> safeMean (Ops tf, Operand<T> losses, long numElements)
Computes a safe mean of the losses.
Parameters
| tf | the TensorFlow Ops |
|---|---|
| losses | Operand whose elements contain individual loss measurements. |
| numElements | The number of measurable elements in losses. |
Returns
- A scalar representing the mean of
losses. IfnumElementsis zero, then zero is returned.
public static LossTuple<T> squeezeOrExpandDimensions (Ops tf, Operand<T> labels, Operand<T> predictions)
Squeeze or expand last dimension if needed with a sampleWeights of one.
- Squeezes last dim of
predictionsorlabelsif their rank differs by 1 (usingremoveSqueezableDimensions(Ops, Operand<T>, Operand<T>)). - Squeezes or expands last dim of
sampleWeightif its rank differs by 1 from the new rank ofpredictions. IfsampleWeightis scalar, it is kept scalar.
Parameters
| tf | the TensorFlow Ops |
|---|---|
| labels | Optional label Operand whose dimensions match prediction
. |
| predictions | Predicted values, a Operand of arbitrary dimensions. |
Returns
- LossTuple of
prediction,label,sampleWeightwill be null. Each of them possibly has the last dimension squeezed,sampleWeightcould be extended by one dimension. IfsampleWeightis null, (prediction, label) is returned.
public static LossTuple<T> squeezeOrExpandDimensions (Ops tf, Operand<T> labels, Operand<T> predictions, Operand<T> sampleWeights)
Squeeze or expand last dimension if needed.
- Squeezes last dim of
predictionsorlabelsif their rank do not differ by 1. - Squeezes or expands last dim of
sampleWeightif its rank differs by 1 from the new rank ofpredictions. IfsampleWeightis scalar, it is kept scalar.
Parameters
| tf | the TensorFlow Ops |
|---|---|
| labels | Optional label Operand whose dimensions match prediction
. |
| predictions | Predicted values, a Operand of arbitrary dimensions. |
| sampleWeights | Optional sample weight(s) Operand whose dimensions match
prediction. |
Returns
- LossTuple of
predictions,labelsandsampleWeight. Each of them possibly has the last dimension squeezed,sampleWeightcould be extended by one dimension. IfsampleWeightis null, only the possibly shape modifiedpredictionsandlabelsare returned.
public static Operand<T> valueCheck (Ops tf, String prefix, Operand<T> values, Operand<T> allowedValues)
Checks to see if all the values are in the allowed values set. Running the operand in Graph
mode will throw TFInvalidArgumentException, if at least one
value is not in the allowed values set. In Eager mode, this method will throw an IllegalArgumentException if at least one value is not in the allowed values set.
Parameters
| tf | The TensorFlow Ops |
|---|---|
| prefix | A String prefix to include in the error message |
| values | the values to check |
| allowedValues | the allowed values |
Returns
- the values possibly with control dependencies if the TensorFlow Ops represents a Graph Session
Throws
| IllegalArgumentException | if the Session is in Eager mode and at least one value is not in the allowed values set |
|---|