tff.learning.algorithms.build_weighted_fed_avg

Builds a learning process that performs federated averaging.

Used in the notebooks

Used in the tutorials

This function creates a tff.learning.templates.LearningProcess that performs federated averaging on client models. The iterative process has the following methods inherited from tff.learning.templates.LearningProcess:

Each time the next method is called, the server model is communicated to each client using the provided model_distributor. For each client, local training is performed using client_optimizer_fn. Each client computes the difference between the client model after training and its initial model. These model deltas are then aggregated at the server using a weighted aggregation function, according to client_weighting. The aggregate model delta is applied at the server using a server optimizer.

model_fn A no-arg function that returns a tff.learning.models.VariableModel, or an instance of a tff.learning.models.FunctionalModel. When passing a callable, the callable must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
client_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg callable that returns a tf.keras.Optimizer. If model_fn is a tff.learning.models.FunctionalModel, must be a tff.learning.optimizers.Optimizer.
server_optimizer_fn A tff.learning.optimizers.Optimizer, a no-arg callable that returns a tf.keras.Optimizer, or None. By default, this uses tff.leanring.optimizers.build_sgdm with a learning rate of 1.0. If model_fn is a tff.learning.models.FunctionalModel, must be a tff.learning.optimizers.Optimizer.
client_weighting A member of tff.learning.ClientWeighting that specifies a built-in weighting method. By default, weighting by number of examples is used.
model_distributor An optional DistributionProcess that distributes the model weights on the server to the clients. If set to None, the distributor is constructed via tff.learning.templates.build_broadcast_process.
model_aggregator An optional tff.aggregators.WeightedAggregationFactory used to aggregate client updates on the server. If None, this is set to tff.aggregators.MeanFactory.
metrics_aggregator A function that takes in the metric finalizers (i.e., tff.learning.models.VariableModel.metric_finalizers()) and a tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF type of tff.learning.models.VariableModel.report_local_unfinalized_metrics()), and returns a tff.Computation for aggregating the unfinalized metrics. If None, this is set to tff.learning.metrics.sum_then_finalize.
loop_implementation Changes the implementation of the training loop generated. See tff.learning.LoopImplementation for more details.

A tff.learning.templates.LearningProcess.

TypeError If arguments are not the documented types.