tff.learning.algorithms.build_fed_recon

Builds the IterativeProcess for optimization using FedRecon.

Used in the notebooks

Used in the tutorials

Returns a tff.templates.IterativeProcess for Federated Reconstruction. On the client, computation can be divided into two stages: (1) reconstruction of local variables and (2) training of global variables.

model_fn A no-arg function that returns a tff.learning.reconstruction.Model. This method must not capture Tensorflow tensors or variables and use them. must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
loss_fn A no-arg function returning a tf.keras.losses.Loss to use to compute local model updates during reconstruction and post-reconstruction and evaluate the model during training. The final loss metric is the example-weighted mean loss across batches and across clients. The loss metric does not include reconstruction batches in the loss.
metrics_fn A no-arg function returning a list of tf.keras.metrics.Metrics to evaluate the model. Metrics results are computed locally as described by the metric, and are aggregated across clients as in federated_aggregate_keras_metric. If None, no metrics are applied. Metrics are not computed on reconstruction batches.
server_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg function that returns a tf.keras.optimizers.Optimizer for applying updates to the global model on the server.
client_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg function that returns a tf.keras.optimizers.Optimizer for local client training after reconstruction.
reconstruction_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg function that returns a tf.keras.optimizers.Optimizer used to reconstruct the local variables, with the global ones frozen, or the first stage described above.
dataset_split_fn A tff.learning.models.ReconstructionDatasetSplitFn taking in a single TF dataset and producing two TF datasets. The first is iterated over during reconstruction, and the second is iterated over post-reconstruction. This can be used to preprocess datasets to e.g. iterate over them for multiple epochs or use disjoint data for reconstruction and post-reconstruction. If None, split client data in half for each user, using one half for reconstruction and the other for evaluation. See tff.learning.models.ReconstructionModel.build_dataset_split_fn for options.
client_weighting A value of tff.learning.ClientWeighting that specifies a built-in weighting method, or a callable that takes the local metrics of the model and returns a tensor that provides the weight in the federated average of model deltas. If None, defaults to weighting by number of examples.
model_distributor An optional DistributionProcess that distributes the model weights on the server to the clients. If set to None, the distributor is constructed via tff.learning.templates.build_broadcast_process.
model_aggregator_factory An optional instance of tff.aggregators.WeightedAggregationFactory or tff.aggregators.UnweightedAggregationFactory determining the method of aggregation to perform. If unspecified, uses a default tff.aggregators.MeanFactory which computes a stateless mean across clients (weighted depending on client_weighting).
metrics_aggregator A function that takes in the metric finalizers (i.e., tff.learning.Model.metric_finalizers()) and a tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF type of tff.learning.Model.report_local_unfinalized_metrics()), and returns a tff.Computation for aggregating the unfinalized metrics. If None, this is set to tff.learning.metrics.sum_then_finalize.

A tff.learning.templates.LearningProcess.

TypeError If model_fn does not return instances of tff.learning.models.ReconstructionModel.
ValueError If model_aggregator_factory is a tff.aggregators.UnweightedAggregationFactory and client_weighting is any value other than tff.learning.ClientWeighting.UNIFORM.