Model that adds one or more loss component(s) to another model during training.

Inherits from: tf.keras.Model

original_model Instance of tf.keras.Model that will be trained with the additional min_diff_loss.
loss dict or single element of string(s) (name of loss) or min_diff.losses.MinDiffLoss instance(s) that will be used to calculate the min_diff_loss(es).
loss_weight dict of scalars or single scalar applied to the min_diff_loss(es) before being included in training.
predictions_transform Optional if the output of original_model is a tf.Tensor. Function that transforms the output of original_model after it is called on MinDiff examples. The resulting predictions tensor is what will be passed in to the losses.MinDiffLoss(es).
**kwargs Named parameters that will be passed directly to the base class' __init__ function.

MinDiffModel wraps the model passed in, original_model, and adds a component to the loss during training and optionally during evaluation.


There are two ways to construct a MinDiffModel instance, the first is the simplest and the most common:

1 - Directly wrap your model with MinDiffModel. This is the simplest usage and is most likely what you will want to use (unless your original model has some custom implementations that need to be taken into account).

import tensorflow as tf

model = tf.keras.Sequential([...])

model = MinDiffModel(model, ...)

In this case, all methods other than the ones listed below will use the default implementations of tf.keras.Model.

If you are in this use case, the next section is not relevant to you and you skip to the section on usage.

2 - Subclassing MinDiffModel to integrate custom implementations. This will likely be needed if the original_model is itself a customized subclass of tf.keras.Model. If that is the case and you want to preserve the custom implementations, you can create a new custom class that inherits first from MinDiffModel and second from your custom class.

import tensorflow as tf

class CustomSequential(tf.keras.Sequential):

  def train_step(self, data):
    print("In a custom train_step!")

class CustomMinDiffModel(MinDiffModel, CustomSequential):
  pass  # No additional implementation is required.

model = CustomSequential([...])

model = CustomMinDiffModel(model, ...)  # This will use the custom train_step.

If you need to customize methods defined by MinDiffModel, then you can create a direct subclass and override whatever is needed.

import tensorflow as tf

class CustomMinDiffModel(MinDiffModel):

  def unpack_min_diff_data(self, inputs):
    print("In a custom MinDiffModel method!")

model = tf.keras.Sequential([...])

model = CustomMinDiffModel(model, ...)  # This will use the custom
                                        # unpack_min_diff_data method.

Multiple Applications of MinDiff

It is possible to apply MinDiff multiple times within a single instance of MinDiffModel. To do so, you can pass in a dictionary of losses where keys are the names of each MinDiff application and the values are the names or instances of losses.MinDiffLoss that will be applied for each respective MinDiff application. Loss weights can be set as either one value that will be used for all applications or with a dictionary that specifies weights for individual applications. Weights not specified will default to 1.0.

import tensorflow as tf

model = tf.keras.Sequential([...])

model = MinDiffModel(model, loss={
  "application1": min_diff.losses.MMDLoss(),  # Loss for first application.
  "application2": min_diff.losses.MMDLoss()   # Loss for second application.
loss_weight=2.0)  # 2.0 will used as the weight for all applications.

A MinDiffModel initialized as shown above will expect min_diff_data to have a structure matching that of loss (i.e. a dictionary of inputs with keys matching that of loss). See MinDiffModel.compute_min_diff_loss for details.


Once you have created an instance of MinDiffModel, it can be used almost exactly the same way as the model it wraps. The main two exceptions to this are:

Optionally, inputs containing min_diff_data can be passed in to evaluate and predict. For the former, this will result in the min_diff_loss appearing in the metrics. For predict this should have no visible effect.

ValueError If predictions_transform is passed in but not callable.

original_model tf.keras.Model to be trained with the additional min_diff_loss term.

Inference and evaluation will also come from the results this model provides.

predictions_transform Function to be applied on MinDiff predictions before calculating loss.

MinDiff predictions are the output of original_model on the MinDiff examples (see compute_min_diff_loss for details). These might not initially be a tf.Tensor, for example if the model is multi-output. If this is the case, the predictions need to be converted into a tf.Tensor.

This can be done by selecting one of the outputs or by combining them in some way.

# Pick out a specific output to use for MinDiff.
transform = lambda predictions: predictions["output2"]

model = MinDiffModel(..., predictions_transform=transform)

# test data imitating multi_output predictions
test_predictions = {
  "output1": [1, 2, 3],
  "output2": [4, 5, 6],
model.predictions_transform(test_predictions)  # [4, 5, 6]

If no predictions_transform parameter is passed in (or None is used), then it will default to the identity.

model = MinDiffModel(..., predictions_transform=None)

model.predictions_transform([1, 2, 3])  # [1, 2, 3]

The result of applying predictions_transform on the MinDiff predictions must be a tf.Tensor. The min_diff_loss will be calculated on these results.



View source

Calls original_model with optional min_diff_loss as regularization loss.

inputs Inputs to original_model, optionally containing min_diff_data as described below.
training Boolean indicating whether to run in training or inference mode. See for details.
mask Mask or list of masks as described in

This method should be used the same way as Depending on whether you are in train mode, inputs may need to include min_diff_data (see MinDiffModel.compute_min_diff_data for details on what form that needs to take).

  • If training=True: inputs must contain min_diff_data (see details below).
  • If training=False: including min_diff_data is optional.

If present, the min_diff_loss is added by calling self.add_loss and will show up in self.losses.

model = ...  # MinDiffModel.

dataset = ...  # Dataset containing min_diff_data.

for batch in dataset.take(1):
  model(batch, training=True)

model.losses[0]  # First element(s) will be the min_diff_loss(es).

Including min_diff_data in inputs implies that MinDiffModel.unpack_original_inputs and MinDiffModel.unpack_min_diff_data behave as expected when called on inputs (see methods for details).

This condition is satisfied with the default implementations if you use min_diff.keras.utils.pack_min_diff_data to create the dataset that includes min_diff_data.

A tf.Tensor or nested structure of tf.Tensors according to the behavior original_model. See for details.

ValueError If training is set to True but inputs does not include min_diff_data.


View source

Compile both self and original_model using the same parameters.

See tf.keras.Model.compile for details.


View source

Computes min_diff_loss(es) corresponding to min_diff_data.

min_diff_data Tuple of data or valid MinDiff structure of tuples as described below.
training Boolean indicating whether to run in training or inference mode. See for details.
mask Mask or list of masks as described in These will be applied when calling the original_model.

min_diff_data must have a structure (or be a single element) matching that of the loss parameter passed in during initialization. Each element of min_diff_data (and loss) corresponds to one application of MinDiff.

Like the input requirements described in, each element of min_diff_data must be a tuple of length 2 or 3. The tuple will be unpacked using the standard tf.keras.utils.unpack_x_y_sample_weight function:

min_diff_data_elem = ...  # Single element from a batch of min_diff_data.

min_diff_x, min_diff_membership, min_diff_sample_weight = (

The components are defined as follows:

  • min_diff_x: inputs to original_model to get the corresponding MinDiff predictions.
  • min_diff_membership: numerical [batch_size, 1] Tensor indicating which group each example comes from (marked as 0.0 or 1.0).
  • min_diff_sample_weight: Optional weight Tensor. The weights will be applied to the examples during the min_diff_loss calculation.

For each application of MinDiff, the min_diff_loss is ultimately calculated from the MinDiff predictions which are evaluated in the following way:

...  # In compute_min_diff_loss call.

min_diff_x = ...  # Single batch of MinDiff examples.

# Get predictions for MinDiff examples.
min_diff_predictions = self.original_model(min_diff_x, training=training)
# Transform the predictions if needed. By default this is the identity.
min_diff_predictions = self.predictions_transform(min_diff_predictions)

Scalar (if only one) or list of min_diff_loss values calculated from min_diff_data.

ValueError If the structure of min_diff_data does not match that of the loss that was passed to the model during initialization.
ValueError If the transformed min_diff_predictions is not a tf.Tensor.


View source

Exports the model as described in

For subclasses of MinDiffModel that have not been registered as Keras objects, this method will likely be what you want to call to continue training your model with MinDiff after having loaded it. If you want to use the loaded model purely for inference, you will likely want to use MinDiffModel.save_original_model instead.

The exception noted above for unregistered MinDiffModel subclasses is the only difference with To avoid these subtle differences, we strongly recommend registering MinDiffModel subclasses as Keras objects. See the documentation of tf.keras.utils.register_keras_serializable for details.


View source

Exports the original_model.

Exports the original_model. When loaded, this model will be the type of original_model and will no longer be able to train or evaluate with MinDiff data.


View source

Extracts min_diff_data from inputs if present or returns None.

inputs inputs as described in

Identifies whether min_diff_data is included in inputs and returns min_diff_data if it is.

model = ...  # MinDiffModel.

inputs = ...  # Batch containing `min_diff_data`

min_diff_data = model.unpack_min_diff_data(inputs)

If min_diff_data is not included, then None is returned.

model = ...  # MinDiffModel.

# Test batch without `min_diff_data` (i.e. just passing in a simple array)
print(model.unpack_min_diff_data([1, 2, 3]))  # None

The default implementation is a pure wrapper around min_diff.keras.utils.unpack_min_diff_data. See there for implementation details.

min_diff_data to be passed to MinDiffModel.compute_min_diff_loss if present or None otherwise.


View source

Extracts original_inputs from inputs.

inputs inputs as described in

Identifies whether min_diff_data is included in inputs. If it is, then what is returned is the component that is only meant to be used in the call to original_model.

model = ...  # MinDiffModel.

inputs = ...  # Batch containing `min_diff_data`

# Extracts component that is only meant to be passed to `original_model`.
original_inputs = model.unpack_original_inputs(inputs)

If min_diff_data is not included, then inputs is returned directly.

model = ...  # MinDiffModel.

# Test batch without `min_diff_data` (i.e. just passing in a simple array)
print(model.unpack_original_inputs([1, 2, 3]))  # [1, 2, 3]

The default implementation is a pure wrapper around min_diff.keras.utils.unpack_original_inputs. See there for implementation details.

Inputs to be used in the call to original_model.