model_remediation.min_diff.keras.utils.pack_min_diff_data

Packs min_diff_data with the x component of the original dataset.

original_dataset tf.data.Dataset that was used before applying min diff. The output should conform to the format used in tf.keras.Model.fit.
sensitive_group_dataset tf.data.Dataset or valid MinDiff structure (unnested dict) of tf.data.Datasets containing only examples that belong to the sensitive group.

This must be passed in if nonsensitive_group_dataset is passed in. Furthermore, the x component for every batch should have the same structure as that of the original_dataset batches' x components.

nonsensitive_group_dataset tf.data.Dataset or valid MinDiff structure (unnested dict) of tf.data.Datasets containing only examples that do not belong to the sensitive group.

This must be passed in if sensitive_group_dataset is passed in. Furthermore, the x component for every batch should have the same structure as that of the original_dataset batches' x components.

min_diff_dataset tf.data.Dataset or valid MinDiff structure (unnested dict) of tf.data.Datasets containing only examples to be used to calculate the min_diff_loss.

This should only be set if neither sensitive_group_dataset or nonsensitive_group_dataset is passed in. Furthermore, the x component for every batch should have the same structure as that of the original_dataset batches' x components.

This function should be used to create the dataset that will be passed to min_diff.keras.MinDiffModel during training and, optionally, during evaluation.

The inputs should either have both sensitive_group_dataset and nonsensitive_group_dataset passed in and min_diff_dataset left unset or vice versa. In the case of the former, min_diff_data will be built using utils.build_min_diff_dataset.

Each input dataset must output a tuple in the format used in tf.keras.Model.fit. Specifically the output must be a tuple of length 1, 2 or 3 in the form (x, y, sample_weight).

This output will be parsed internally in the following way:

batch = ...  # Batch from any one of the input datasets.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)

Every batch from the returned tf.data.Dataset will contain one batch from each of the input datasets. Each returned batch will be a tuple of (packed_inputs, original_y, original_sample_weight) matching the length of original_dataset batches where:

  • packed_inputs: is an instance of utils.MinDiffPackedInputs containing:

    • original_inputs: x component taken directly from the original_dataset batch.
    • min_diff_data: batch of data formed from sensitive_group_dataset and nonsensitive_group_dataset (as described in utils.build_min_diff_dataset) or taken directly from min_diff_dataset.
  • original_y: is the y component taken directly from the original_dataset batch.

  • original_sample_weight: is the sample_weight component taken directly from the original_dataset batch.

min_diff_data will be used in min_diff.keras.MinDiffModel when calculating the min_diff_loss. It is a tuple or structure (matching the structure of the inputs) of (min_diff_x, min_diff_membership, min_diff_sample_weight).

A tf.data.Dataset whose output is a tuple of (packed_inputs, original_y, original_sample_weight) matching the output length of original_dataset.