tff.learning.reconstruction.build_dataset_split_fn

Builds a DatasetSplitFn for Federated Reconstruction training/evaluation.

Returned DatasetSplitFn parameterizes training and evaluation computations and enables reconstruction for multiple local epochs, multiple epochs of post-reconstruction training, limiting the number of steps for both stages, and splitting client datasets into disjoint halves for each stage.

Note that the returned function is used during both training and evaluation: during training, "post-reconstruction" refers to training of global variables, and during evaluation, it refers to calculation of metrics using reconstructed local variables and fixed global variables.

recon_epochs The integer number of iterations over the dataset to make during reconstruction.
recon_steps_max If not None, the integer maximum number of steps (batches) to iterate through during reconstruction. This maximum number of steps is across all reconstruction iterations, i.e. it is applied after recon_epochs. If None, this has no effect.
post_recon_epochs The integer constant number of iterations to make over client data after reconstruction.
post_recon_steps_max If not None, the integer maximum number of steps (batches) to iterate through after reconstruction. This maximum number of steps is across all post-reconstruction iterations, i.e. it is applied after post_recon_epochs. If None, this has no effect.
split_dataset If True, splits client_dataset in half for each user, using even-indexed entries in reconstruction and odd-indexed entries after reconstruction. If False, client_dataset is used for both reconstruction and post-reconstruction, with the above arguments applied. If True, splitting requires that mupltiple iterations through the dataset yield the same ordering. For example if client_dataset.shuffle(reshuffle_each_iteration=True) has been called, then the split datasets may have overlap. If True, note that the dataset should have more than one batch for reasonable results, since the splitting does not occur within batches.

A SplitDatasetFn.