model_remediation.counterfactual.keras.utils.pack_counterfactual_data

Packs counterfactual_data with the original_input.

original_input An instance of tf.data.Dataset that was used for training the original model. The output should conform to the format used in tf.keras.Model.fit.
counterfactual_data An instance of tf.data.Dataset containing only examples that will be used to calculate the counterfactual_loss. This dataset is repeated to match the number of examples in original_input.

This function should be used to create an instance of CounterfactualPackedInputs that will be passed to counterfactual.keras.CounterfactualModel during training and, optionally, during evaluation.

Each original_input must output a tuple in the format used in tf.keras.Model.fit. Specifically the output must be a tuple of length 1, 2 or 3 in the form (x, y, sample_weight).

Every batch from the returned tf.data.Dataset will contain one batch from each of the input datasets as a CounterfactualPackedInputs. Each returned batch will be a tuple from the original dataset and counterfactual dataset of format ((x, y, sample_weight), (original_x, counterfactual_x, counterfactual_sample_weight)) matching the length of original_input batches where:

  • original_input: is a tf.data.Dataset that contains:

    • x: The x component taken directly from the original_input batch.
    • y: The y component taken directly from the original_input batch.
    • sample_weight: The sample_weight component taken directly from the original_input batch.
  • counterfactual_data: is a tf.data.Dataset that contains:

    • original_x: The x component taken directly from the original_input batch.
    • counterfactual_x: The counterfactual value for original_x (as described in build_counterfactual_data).
    • counterfactual_sample_weight: Batch of data formed from taken directly from the counterfactual_sample_weight of counterfactual_data.

The return of counterfactual_data will be an instance of CounterfactualPackedInputs that can be used in counterfactual.keras.CounterfactualModel when calculating the counterfactual_loss.

A tf,data,Dataset of CounterfactualPackedInputs. Each CounterfactualPackedInputs represents a (original_inputs, counterfactual_data) pair where original_inputs is a(x, y, sample_weight)tuple, andcounterfactual_datais a(original_x, counterfactual_x, counterfactual_sample_weight)` tuple.