Model that adds a Counterfactual loss component to another model during training.

Inherits from: tf.keras.Model

original_model Instance of tf.keras.Model that will be trained with the additional counterfactual_loss.
loss Instance of counterfactual.losses.CounterfactualLoss or string of loss name that will be used to calculate the counterfactual_loss. Defaults to PairwiseMSELoss.
loss_weight Scalar applied to the counterfactual_loss before being included in training. Defaults to 1.0.
**kwargs Named parameters that will be passed directly to the base class' __init__ function.

CounterfactualModel wraps the model passed in, original_model, and adds a component to the loss during training and optionally during evaluation.


There are two ways to construct a CounterfactualModel instance:

1 - Directly wrap your model with CounterfactualModel. This is the simplest usage and is most likely what you will want to use (unless your original model has some custom implementations that need to be taken into account).

import tensorflow as tf

model = tf.keras.Sequential([...])

model = CounterfactualModel(model, ...)

In this case, all methods other than the ones listed below will use the default implementations of tf.keras.Model.

If you are in this use case, the next section is not relevant to you and you skip to the section on usage.

2 - Subclassing CounterfactualModel to integrate custom implementations. This will likely be needed if the original_model is itself a customized subclass of tf.keras.Model. If that is the case and you want to preserve the custom implementations, you can create a new custom class that inherits first from CounterfactualModel and second from your custom class.

import tensorflow as tf

class CustomSequential(tf.keras.Sequential):

  def train_step(self, data):
    print("In a custom train_step!")

class CustomCounterfactualModel(CounterfactualModel, CustomSequential):
  pass  # No additional implementation is required.

model = CustomSequential([...])

model = CustomCounterfactualModel(model, ...)  # This will use the custom
                                               # train_step.

If you need to customize methods defined by CounterfactualModel, then you can create a direct subclass and override whatever is needed.

import tensorflow as tf

class CustomCounterfactualModel(CounterfactualModel):

  def update_metrics(self, inputs, ...):
    print("In a custom CounterfactualModel method!")
    super().update_metrics(inputs, ...)

model = tf.keras.Sequential([...])

model = CounterfactualModel(model, ...)  # This will use the custom
                                         # update_metrics method.


Once you have created an instance of CounterfactualModel, it can be used almost exactly the same way as the model it wraps. The main two exceptions to this are:

Optionally, inputs containing counterfactual_data can be passed in to evaluate and predict. For the former, this will result in the counterfactual_loss appearing in the metrics. For predict this should have no visible effect.

original_model tf.keras.Model to be trained with the additional counterfactual_loss.

Inference and evaluation will also come from the results this model provides.



View source

Calls the model on new inputs and returns the outputs as tensors.

In this case call() just reapplies all ops in the graph to the new inputs (e.g. build a new computational graph from the provided inputs).

inputs Input tensor, or dict/list/tuple of input tensors.
training Boolean or boolean scalar tensor, indicating whether to run the Network in training mode or inference mode.
mask A mask or list of masks. A mask can be either a boolean tensor or None (no mask). For more details, check the guide here.

A tensor if there is a single output, or a list of tensors if there are more than one outputs.


View source

Compile both self and original_model using the same parameters.

See tf.keras.Model.compile for details.


View source

Computes counterfactual_loss(es) corresponding to counterfactual_data.

original_predictions Predictions on original data.
counterfactual_predictions Predictions of a model on counterfactual data.
counterfactual_sample_weight Per sample weight to scale counterfactual loss.

Scalar (if only one) or list of counterfactual_loss values calculated from counterfactual_data.


View source


View source

Exports the model as described in

For subclasses of CounterfactualModel that have not been registered as Keras objects, this method will likely be what you want to call to continue training your model with Counterfactual after having loaded it. If you want to use the loaded model purely for inference, you will likely want to use CounterfactualModel.save_original_model instead.

The exception noted above for unregistered CounterfactualModel subclasses is the only difference with To avoid these subtle differences, we strongly recommend registering CounterfactualModel subclasses as Keras objects. See the documentation of tf.keras.utils.register_keras_serializable for details.


View source

Exports the original_model.

This model will be the type of original_model and will no longer be able to train or evaluate with Counterfactual data.


View source

Updates mean metrics being tracked for Counterfactual losses.