tff.learning.algorithms.build_fed_recon_eval

Builds a tff.Computation for evaluating a reconstruction Model.

Used in the notebooks

Used in the tutorials

The returned computation proceeds in two stages: (1) reconstruction and (2) evaluation. During the reconstruction stage, local variables are reconstructed by freezing global variables and training using reconstruction_optimizer_fn. During the evaluation stage, the reconstructed local variables and global variables are evaluated using the provided loss_fn and metrics_fn.

Usage of returned computation: eval_comp = build_federated_evaluation(...) metrics = eval_comp( tff.learning.models.ReconstructionModel.get_global_variables(model), federated_data)

model_fn A no-arg function that returns a tff.learning.models.ReconstructionModel. This method must not capture Tensorflow tensors or variables and use them. Must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
loss_fn A no-arg function returning a tf.keras.losses.Loss to use to reconstruct and evaluate the model. The loss will be applied to the model's outputs during the evaluation stage. The final loss metric is the example-weighted mean loss across batches (and across clients).
metrics_fn A no-arg function returning a list of tf.keras.metrics.Metrics to evaluate the model. The metrics will be applied to the model's outputs during the evaluation stage. Final metric values are the example-weighted mean of metric values across batches (and across clients). If None, no metrics are applied.
reconstruction_optimizer_fn A tff.learning.optimizers.Optimizer used to reconstruct the local variables with the global ones frozen.
dataset_split_fn A tff.learning.models.ReconstructionDatasetSplitFn taking in a single TF dataset and producing two TF datasets. The first is iterated over during reconstruction, and the second is iterated over during evaluation. This can be used to preprocess datasets to e.g. iterate over them for multiple epochs or use disjoint data for reconstruction and evaluation. If None, split client data in half for each user, using one half for reconstruction and the other for evaluation. See tff.learning.models.ReconstructionModel.build_dataset_split_fn for options.
model_distributor An optional tff.learning.templates.DistributionProcess that distributes the model weights on the server to the clients. If set to None, the distributor is constructed via tff.learning.templates.build_broadcast_process.
metrics_aggregation_process An optional tff.templates.AggregationProcess which aggregates the local unfinalized metrics at clients to server and finalizes the metrics at server. The tff.templates.AggregationProcess accumulates unfinalized metrics across round in the state, and produces a tuple of current round metrics and total rounds metrics in the result. If None, the tff.templates.AggregationProcess created by the SumThenFinalizeFactory with metric finalizers defined in the model is used.

TypeError if model_distributor does not have the expected signature.

A tff.learning.templates.LearningProcess that accepts global model parameters and federated data and returns example-weighted evaluation loss and metrics.