model_remediation.counterfactual.losses.PairwiseMSELoss

Pairwise mean squared error loss between the original and counterfactual.

Inherits From: CounterfactualLoss

name Name used for logging and tracking. Defaults to 'pairwise_mse_loss'.

Methods

__call__

View source

Computes Counterfactual loss.

Arguments
original The predictions from the original example values. shape = [batch_size, d0, .. dN]. Tensor of type float32 or float64. Required.
counterfactual The predictions from the counterfactual examples. shape = [batch_size, d0, .. dN]. Tensor of the same type and shape as original. Required.
sample_weight (Optional) sample_weight acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sample_weight is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sample_weight vector.

Returns
The computed counterfactual loss.

Raises
ValueError If any of the input arguments are invalid.
TypeError If any of the arguments are not of the expected type.
InvalidArgumentError If original, counterfactual or sample_weight have incompatible shapes.