## Introduction

It is possible to integrate MinDiff directly into your model's implementation. While doing so does not have the convenience of using `MinDiffModel`

, this option offers the highest level of control which can be particularly useful when your model is a subclass of `tf.keras.Model`

.

This guide demonstrates how you can integrate MinDiff directly into a custom model's implementation by adding to the `train_step`

method.

## Setup

`pip install --upgrade tensorflow-model-remediation`

```
import tensorflow as tf
tf.get_logger().setLevel('ERROR') # Avoid TF warnings.
from tensorflow_model_remediation import min_diff
from tensorflow_model_remediation.tools.tutorials_utils import uci as tutorials_utils
```

First, download the data. For succinctness, the input preparation logic has been factored out into helper functions as described in the input preparation guide. You can read the full guide for details on this process.

```
# Original Dataset for training, sampled at 0.3 for reduced runtimes.
train_df = tutorials_utils.get_uci_data(split='train', sample=0.3)
train_ds = tutorials_utils.df_to_dataset(train_df, batch_size=128)
# Dataset needed to train with MinDiff.
train_with_min_diff_ds = (
tutorials_utils.get_uci_with_min_diff_dataset(split='train', sample=0.3))
```

## Original Custom Model Customizations

`tf.keras.Model`

is designed to be easily customized via subclassing. This usually involves changing what happens in the call to `fit`

as described here.

This guide uses a custom implementation where the `train_step`

closely resembles the default `tf.keras.Model.train_step`

. Normally, there would be no benefit to doing so, but here, it will help demonstrate how to integrate MinDiff.

```
class CustomModel(tf.keras.Model):
def train_step(self, data):
# Unpack the data.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass.
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
# Compute the loss value.
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
# Compute gradients and update weights.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
# Update and return metrics.
self.compiled_metrics.update_state(y, y_pred, sample_weight)
return {m.name: m.result() for m in self.metrics}
```

Train the model as you would a typical `Model`

using the Functional API.

```
model = tutorials_utils.get_uci_model(model_class=CustomModel) # Use CustomModel.
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_ds, epochs=1)
```

77/77 [==============================] - 2s 7ms/step - loss: 0.5331

## Integrating MinDiff directly into your model

### Adding MinDiff to the `train_step`

To integrate MinDiff, you will need to add some lines to the `CustomModel`

which is renamed here as `CustomModelWithMinDiff`

.

For clarity, this guide uses a boolean flag called `apply_min_diff`

. All of the code relevant to MinDiff will only be run if it is set to `True`

. If set to `False`

then the model would behave exactly the same as `CustomModel`

.

```
min_diff_loss_fn = min_diff.losses.MMDLoss() # Hard coded for convenience.
min_diff_weight = 2 # Arbitrary number for example, hard coded for convenience.
apply_min_diff = True # Flag to help show where the additional lines are.
class CustomModelWithMinDiff(tf.keras.Model):
def train_step(self, data):
# Unpack the data.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
# Unpack the MinDiff data.
if apply_min_diff:
min_diff_data = min_diff.keras.utils.unpack_min_diff_data(x)
min_diff_x, membership, min_diff_sample_weight = (
tf.keras.utils.unpack_x_y_sample_weight(min_diff_data))
x = min_diff.keras.utils.unpack_original_inputs(x)
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass.
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
# Compute the loss value.
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
# Calculate and add the min_diff_loss. This must be done within the scope
# of tf.GradientTape().
if apply_min_diff:
min_diff_predictions = self(min_diff_x, training=True)
min_diff_loss = min_diff_weight * min_diff_loss_fn(
min_diff_predictions, membership, min_diff_sample_weight)
loss += min_diff_loss
# Compute gradients and update weights.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
# Update and return metrics.
self.compiled_metrics.update_state(y, y_pred, sample_weight)
return {m.name: m.result() for m in self.metrics}
```

Training with this model looks exactly the same as with the previous with the exception of the dataset used.

```
model = tutorials_utils.get_uci_model(model_class=CustomModelWithMinDiff)
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_with_min_diff_ds, epochs=1)
```

1/36 [..............................] - ETA: 54s - loss: 0.8567 2022-04-01 00:07:40.717416: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: mmd_loss_inputs/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8 36/36 [==============================] - 2s 8ms/step - loss: 0.6380

### Reshaping your input (optional)

Given that this approach provides full control, you can take this opportunity to reshape the input into a slightly cleaner form. When using `MinDiffModel`

, the `min_diff_data`

needs to be packed into the first component of every batch. This is the case with the `train_with_min_diff_ds`

dataset.

```
for x, y in train_with_min_diff_ds.take(1):
print('Type of x:', type(x)) # MinDiffPackedInputs
print('Type of y:', type(y)) # Tensor (original labels)
```

Type of x: <class 'tensorflow_model_remediation.min_diff.keras.utils.input_utils.MinDiffPackedInputs'> Type of y: <class 'tensorflow.python.framework.ops.EagerTensor'>

With this requirement lifted, you can reorganize the data in a slightly more intuitive structure with the original and MinDiff data cleanly separated.

```
def _reformat_input(inputs, original_labels):
min_diff_data = min_diff.keras.utils.unpack_min_diff_data(inputs)
original_inputs = min_diff.keras.utils.unpack_original_inputs(inputs)
original_data = (original_inputs, original_labels)
return {
'min_diff_data': min_diff_data,
'original_data': original_data}
customized_train_with_min_diff_ds = train_with_min_diff_ds.map(_reformat_input)
```

This step is completely optional but can be useful to better organize the data. If you do so, the only difference in how you implement `CustomModelWithMinDiff`

will be how you unpack `data`

at the beginning.

```
class CustomModelWithMinDiff(tf.keras.Model):
def train_step(self, data):
# Unpack the MinDiff data from the custom structure.
if apply_min_diff:
min_diff_data = data['min_diff_data']
min_diff_x, membership, min_diff_sample_weight = (
tf.keras.utils.unpack_x_y_sample_weight(min_diff_data))
data = data['original_data']
... # possible preprocessing or validation on data before unpacking.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
...
```

With this last step, you can fully control both the input format and how it is used within the model to apply MinDiff.

## Additional Resources

- For an in depth discussion on fairness evaluation see the Fairness Indicators guidance
- For general information on Remediation and MinDiff, see the remediation overview.
- For details on requirements surrounding MinDiff see this guide.
- To see an end-to-end tutorial on using MinDiff in Keras, see this tutorial.