model_remediation.counterfactual.losses.PairwiseMSELoss
Pairwise mean squared error loss between the original and counterfactual.
Inherits From: CounterfactualLoss
model_remediation.counterfactual.losses.PairwiseMSELoss(
name: Optional[str] = None
)
Arguments |
name
|
Name used for logging and tracking. Defaults to 'pairwise_mse_loss' .
|
Methods
__call__
View source
__call__(
original: types.TensorType,
counterfactual: types.TensorType,
sample_weight: Optional[types.TensorType] = None
)
Computes Counterfactual loss.
Arguments |
original
|
The predictions from the original example values. shape =
[batch_size, d0, .. dN] . Tensor of type float32 or float64 .
Required.
|
counterfactual
|
The predictions from the counterfactual examples. shape =
[batch_size, d0, .. dN] . Tensor of the same type and shape as
original . Required.
|
sample_weight
|
(Optional) sample_weight acts as a coefficient for the
loss. If a scalar is provided, then the loss is simply scaled by the
given value. If sample_weight is a tensor of size [batch_size] , then
the total loss for each sample of the batch is rescaled by the
corresponding element in the sample_weight vector.
|
Returns |
The computed counterfactual loss.
|
Raises |
ValueError
|
If any of the input arguments are invalid.
|
TypeError
|
If any of the arguments are not of the expected type.
|
InvalidArgumentError
|
If original , counterfactual or sample_weight
have incompatible shapes.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2022-07-01 UTC.
[null,null,["Last updated 2022-07-01 UTC."],[],[]]