tff.learning.algorithms.build_unweighted_fed_avg

Builds a learning process that performs federated averaging.

Used in the notebooks

Used in the tutorials

This function creates a tff.learning.templates.LearningProcess that performs federated averaging on client models. The iterative process has the following methods inherited from tff.learning.templates.LearningProcess:

Each time the next method is called, the server model is communicated to each client using the provided model_distributor. For each client, local training is performed using client_optimizer_fn. Each client computes the difference between the client model after training and its initial model. These model deltas are then aggregated at the server using an unweighted aggregation function. The aggregate model delta is applied at the server using a server optimizer.

model_fn A no-arg function that returns a tff.learning.models.VariableModel, or an instance of a tff.learning.models.FunctionalModel. When passing a callable, the callable must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
client_optimizer_fn A tff.learning.optimizers.Optimizer.
server_optimizer_fn An optional tff.learning.optimizers.Optimizer. By default, uses tff.learning.optimizers.build_sgdm(learning_rate=1.0).
model_distributor An optional DistributionProcess that distributes the model weights on the server to the clients. If set to None, the distributor is constructed via distributors.build_broadcast_process.
model_aggregator An optional tff.aggregators.UnweightedAggregationFactory used to aggregate client updates on the server. If None, this is set to tff.aggregators.UnweightedMeanFactory.
metrics_aggregator A function that takes in the metric finalizers (i.e., tff.learning.models.VariableModel.metric_finalizers()) and a tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF type of tff.learning.models.VariableModel.report_local_unfinalized_metrics()), and returns a tff.Computation for aggregating the unfinalized metrics. If None, this is set to tff.learning.metrics.sum_then_finalize.
loop_implementation Changes the implementation of the training loop generated. See tff.learning.LoopImplementation for more details.

A tff.learning.templates.LearningProcess.