Customizing MinDiffModel


In most cases, using MinDiffModel directly as described in the "Integrating MinDiff with MinDiffModel" guide is sufficient. However, it is possible that you will need customized behavior. The two primary reasons for this are:

  • The keras.Model you are using has custom behavior that you want to preserve.
  • You want the MinDiffModel to behave differently from the default.

In either case, you will need to subclass MinDiffModel to achieve the desired results.


pip install --upgrade tensorflow-model-remediation
import tensorflow as tf
tf.get_logger().setLevel('ERROR')  # Avoid TF warnings.
from tensorflow_model_remediation import min_diff
from import uci as tutorials_utils

First, download the data. For succinctness, the input preparation logic has been factored out into helper functions as described in the input preparation guide. You can read the full guide for details on this process.

# Original Dataset for training, sampled at 0.3 for reduced runtimes.
train_df = tutorials_utils.get_uci_data(split='train', sample=0.3)
train_ds = tutorials_utils.df_to_dataset(train_df, batch_size=128)

# Dataset needed to train with MinDiff.
train_with_min_diff_ds = (
    tutorials_utils.get_uci_with_min_diff_dataset(split='train', sample=0.3))

Preserving Original Model Customizations

tf.keras.Model is designed to be easily customized via subclassing as described here. If your model has customized implementations that you wish to preserve when applying MinDiff, you will need to subclass MinDiffModel.

Original Custom Model

To see how you can preserve customizations, create a custom model that sets an attribute to True when its custom train_step is called. This is not a useful customization but will serve to illustrate behavior.

class CustomModel(tf.keras.Model):

  # Customized train_step
  def train_step(self, *args, **kwargs):
    self.used_custom_train_step = True  # Marker that we can check for.
    return super(CustomModel, self).train_step(*args, **kwargs)

Training such a model would look the same as a normal Sequential model.

model = tutorials_utils.get_uci_model(model_class=CustomModel)  # Use CustomModel.

model.compile(optimizer='adam', loss='binary_crossentropy')

_ =, epochs=1, verbose=0)

# Model has used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step'))  # True

Subclassing MinDiffModel

If you were to try and use MinDiffModel directly, the model would not use the custom train_step.

model = tutorials_utils.get_uci_model(model_class=CustomModel)
model = min_diff.keras.MinDiffModel(model, min_diff.losses.MMDLoss())

model.compile(optimizer='adam', loss='binary_crossentropy')

_ =, epochs=1, verbose=0)

# Model has not used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step'))  # False

In order to use the correct train_step method, you need a custom class that subclasses both MinDiffModel and CustomModel.

class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):
  pass  # No need for any further implementation.

Training this model will use the train_step from CustomModel.

model = tutorials_utils.get_uci_model(model_class=CustomModel)

model = CustomMinDiffModel(model, min_diff.losses.MMDLoss())

model.compile(optimizer='adam', loss='binary_crossentropy')

_ =, epochs=1, verbose=0)

# Model has used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step'))  # True

Customizing default behaviors of MinDiffModel

In other cases, you may want to change specific default behaviors of MinDiffModel. The most common use case of this is changing the default unpacking behavior to properly handle your data if you don't use pack_min_diff_data.

When packing the data into a custom format, this might appear as follows.

def _reformat_input(inputs, original_labels):
  min_diff_data = min_diff.keras.utils.unpack_min_diff_data(inputs)
  original_inputs = min_diff.keras.utils.unpack_original_inputs(inputs)

  return ({
      'min_diff_data': min_diff_data,
      'original_inputs': original_inputs}, original_labels)

customized_train_with_min_diff_ds =

The customized_train_with_min_diff_ds dataset returns batches composed of tuples (x, y) where x is a dict containing min_diff_data and original_inputs and y is the original_labels.

for x, _ in customized_train_with_min_diff_ds.take(1):
  print('Type of x:', type(x))  # dict
  print('Keys of x:', x.keys())  # 'min_diff_data', 'original_inputs'

This data format is not what MinDiffModel expects by default and passing customized_train_with_min_diff_ds to it would result in unexpected behavior. To fix this you will need to create your own subclass.

class CustomUnpackingMinDiffModel(min_diff.keras.MinDiffModel):

  def unpack_min_diff_data(self, inputs):
    return inputs['min_diff_data']

  def unpack_original_inputs(self, inputs):
    return inputs['original_inputs']

With this subclass, you can train as with the other examples.

model = tutorials_utils.get_uci_model()
model = CustomUnpackingMinDiffModel(model, min_diff.losses.MMDLoss())

model.compile(optimizer='adam', loss='binary_crossentropy')

_ =, epochs=1)

Limitations of a Customized MinDiffModel

Creating a custom MinDiffModel provides a huge amount of flexibility for more complex use cases. However, there are still some edge cases that it will not support.

Preprocessing or Validation of inputs before call

The biggest limitation for a subclass of MinDiffModel is that it requires the x component of the input data (i.e. the first or only element in the batch returned by the to be passed through without preprocessing or validation to call.

This is simply because the min_diff_data is packed into the x component of the input data. Any preprocessing or validation will not expect the additional structure containing min_diff_data and will likely break.

If the preprocessing or validation is easily customizable (e.g. factored into its own method) then this is easily addressed by overriding it to ensure it handles the additional structure correctly.

An example with validation might look like this:

class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):

  # Override so that it correctly handles additional `min_diff_data`.
  def validate_inputs(self, inputs):
    original_inputs = self.unpack_original_inputs(inputs)
    ...  # Optionally also validate min_diff_data
    # Call original validate method with correct inputs
    return super(CustomMinDiffModel, self).validate(original_inputs)

If the preprocessing or validation isn't easily customizable, then using MinDiffModel may not work for you and you will need to integrate MinDiff without it as described in this guide.

Method name collisions

It is possible that your model has methods whose names clash with those implemented in MinDiffModel (see full list of public methods in the API documentation).

This is only problematic if these will be called on an instance of the model (rather than internally in some other method). While highly unlikely, if you are in this situation you will have to either override and rename some methods or, if not possible, you may need to consider integrating MinDiff without MinDiffModel as described in this guide on the subject.

Additional Resources