model_remediation.min_diff.keras.utils.build_min_diff_dataset

Build MinDiff dataset from sensitive and nonsensitive datasets.

sensitive_group_dataset tf.data.Dataset or valid MinDiff structure (unnested dict) of tf.data.Datasets containing only examples that belong to the sensitive group.
nonsensitive_group_dataset tf.data.Dataset or valid MinDiff structure (unnested dict) of tf.data.Datasets containing only examples that do not belong to the sensitive group.

This function builds a tf.data.Dataset containing examples that are meant to only be used when calculating a min_diff_loss. This resulting dataset will need to be packed with the original dataset used for the original task of the model which can be done by calling utils.pack_min_diff_data.

Each input dataset must output a tuple in the format used in tf.keras.Model.fit. Specifically the output must be a tuple of length 1, 2 or 3 in the form (x, y, sample_weight).

This output will be parsed internally in the following way:

batch = ...  # Batch from any of the input datasets.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)

Every batch from the returned tf.data.Dataset will contain one batch from each of the input datasets. Each returned batch will be a tuple or structure (matching the structure of the inputs) of (min_diff_x, min_diff_membership, min_diff_sample_weight) where, for each pair of input datasets:

  • min_diff_x: is formed by concatenating the x components of the paired datasets. The structure of these must match. If they don't the dataset will raise an error at the first batch.
  • min_diff_membership: is a tensor of size [min_diff_batch_size, 1] indicating which dataset each example comes from (1.0 for sensitive_group_dataset and 0.0 for nonsensitive_group_dataset).
  • min_diff_sample_weight: is formed by concatenating the sample_weight components of the paired datasets. If both are None, then this will be set to None. If only one is None, it is replaced with a Tensor of ones of the appropriate shape.

A tf.data.Dataset whose output is a tuple or structure (matching the structure of the inputs) of (min_diff_x, min_diff_membership, min_diff_sample_weight).

ValueError If either sensitive_group_dataset or nonsensitive_group_dataset is not a valid MinDiff structure (unnested dict).
ValueError If sensitive_group_dataset and nonsensitive_group_dataset do not have the same structure.