View source on GitHub
|
Packs min_diff_data with the x component of the original dataset.
model_remediation.min_diff.keras.utils.pack_min_diff_data(
original_dataset: tf.data.Dataset,
sensitive_group_dataset=None,
nonsensitive_group_dataset=None,
min_diff_dataset=None
) -> tf.data.Dataset
Arguments | |
|---|---|
original_dataset
|
tf.data.Dataset that was used before applying min
diff. The output should conform to the format used in
tf.keras.Model.fit.
|
sensitive_group_dataset
|
tf.data.Dataset or valid MinDiff structure
(unnested dict) of tf.data.Datasets containing only examples that
belong to the sensitive group.
This must be passed in if |
nonsensitive_group_dataset
|
tf.data.Dataset or valid MinDiff structure
(unnested dict) of tf.data.Datasets containing only examples that do
not belong to the sensitive group.
This must be passed in if |
min_diff_dataset
|
tf.data.Dataset or valid MinDiff structure (unnested
dict) of tf.data.Datasets containing only examples to be used to
calculate the min_diff_loss.
This should only be set if neither |
This function should be used to create the dataset that will be passed to
min_diff.keras.MinDiffModel during training and, optionally, during
evaluation.
The inputs should either have both sensitive_group_dataset and
nonsensitive_group_dataset passed in and min_diff_dataset left unset or
vice versa. In the case of the former, min_diff_data will be built using
utils.build_min_diff_dataset.
Each input dataset must output a tuple in the format used in
tf.keras.Model.fit. Specifically the output must be a tuple of
length 1, 2 or 3 in the form (x, y, sample_weight).
This output will be parsed internally in the following way:
batch = ... # Batch from any one of the input datasets.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)
Every batch from the returned tf.data.Dataset will contain one batch from
each of the input datasets. Each returned batch will be a tuple of
(packed_inputs, original_y, original_sample_weight) matching the length of
original_dataset batches where:
packed_inputs: is an instance ofutils.MinDiffPackedInputscontaining:original_inputs:xcomponent taken directly from theoriginal_datasetbatch.min_diff_data: batch of data formed fromsensitive_group_datasetandnonsensitive_group_dataset(as described inutils.build_min_diff_dataset) or taken directly frommin_diff_dataset.
original_y: is theycomponent taken directly from theoriginal_datasetbatch.original_sample_weight: is thesample_weightcomponent taken directly from theoriginal_datasetbatch.
min_diff_data will be used in min_diff.keras.MinDiffModel when calculating
the min_diff_loss. It is a tuple or structure (matching the structure of the
inputs) of (min_diff_x, min_diff_membership, min_diff_sample_weight).
Returns | |
|---|---|
A tf.data.Dataset whose output is a tuple of (packed_inputs,
original_y, original_sample_weight) matching the output length
of original_dataset.
|
View source on GitHub