View source on GitHub |
Packs counterfactual_data
with the original_input
.
model_remediation.counterfactual.keras.utils.pack_counterfactual_data(
original_input: tf.data.Dataset, counterfactual_data: tf.data.Dataset
) -> tf.data.Dataset
Arguments | |
---|---|
original_input
|
An instance of tf.data.Dataset that was used for training
the original model. The output should conform to the format used in
tf.keras.Model.fit .
|
counterfactual_data
|
An instance of tf.data.Dataset containing only
examples that will be used to calculate the counterfactual_loss . This
dataset is repeated to match the number of examples in original_input .
|
This function should be used to create an instance of
CounterfactualPackedInputs
that will be passed to
counterfactual.keras.CounterfactualModel
during training and, optionally,
during evaluation.
Each original_input
must output a tuple in the format used in
tf.keras.Model.fit
. Specifically the output must be a tuple of
length 1, 2 or 3 in the form (x, y, sample_weight)
.
Every batch from the returned tf.data.Dataset
will contain one batch from
each of the input datasets as a CounterfactualPackedInputs
. Each returned
batch will be a tuple from the original dataset and counterfactual dataset
of format ((x, y, sample_weight), (original_x, counterfactual_x,
counterfactual_sample_weight))
matching the length of original_input
batches where:
original_input
: is atf.data.Dataset
that contains:x
: Thex
component taken directly from theoriginal_input
batch.y
: They
component taken directly from theoriginal_input
batch.sample_weight
: Thesample_weight
component taken directly from theoriginal_input
batch.
counterfactual_data
: is atf.data.Dataset
that contains:original_x
: Thex
component taken directly from theoriginal_input
batch.counterfactual_x
: The counterfactual value fororiginal_x
(as described inbuild_counterfactual_data
).counterfactual_sample_weight
: Batch of data formed from taken directly from thecounterfactual_sample_weight
ofcounterfactual_data
.
The return of counterfactual_data
will be an instance of
CounterfactualPackedInputs
that can be used in
counterfactual.keras.CounterfactualModel
when calculating the
counterfactual_loss
.
Returns | |
---|---|
A tf,data,Dataset of CounterfactualPackedInputs . Each
CounterfactualPackedInputs represents a
(original_inputs, counterfactual_data) pair where original_inputs is
a (x, y, sample_weight)tuple, and counterfactual_datais a (original_x, counterfactual_x, counterfactual_sample_weight)` tuple.
|