Google I/O is a wrap! Catch up on TensorFlow sessions View sessions


Builds a learning process for FedAvg with client optimizer scheduling.

This function creates a LearningProcess that performs federated averaging on client models. The iterative process has the following methods inherited from LearningProcess:

  • initialize: A tff.Computation with the functional type signature ( -> S@SERVER), where S is a LearningAlgorithmState representing the initial state of the server.
  • next: A tff.Computation with the functional type signature (<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>) where S is a tff.learning.templates.LearningAlgorithmState whose type matches the output of initialize and {B*}@CLIENTS represents the client datasets. The output L contains the updated server state, as well as aggregated metrics at the server, including client training metrics and any other metrics from distribution and aggregation processes.
  • get_model_weights: A tff.Computation with type signature (S -> M), where S is a tff.learning.templates.LearningAlgorithmState whose type matches the output of initialize and M represents the type of the model weights used during training.

Each time the next method is called, the server model is broadcast to each client using a broadcast function. For each client, local training is performed using client_optimizer_fn. Each client computes the difference between the client model after training and the initial broadcast model. These model deltas are then aggregated at the server using a weighted aggregation function. Clients weighted by the number of examples they see thoughout local training. The aggregate model delta is applied at the server using a server optimizer.

The primary purpose of this implementation of FedAvg is that it allows for the client optimizer to be scheduled across rounds. The process keeps track of how many iterations of .next have occurred (starting at 0), and for each such round_num, the clients will use client_optimizer_fn(round_num) to perform local optimization. This allows learning rate scheduling (eg. starting with a large learning rate and decaying it over time) as well as a small learning rate (eg. switching optimizers as learning progresses).

model_fn A no-arg function that returns a tff.learning.Model. This method must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
client_learning_rate_fn A callable accepting an integer round number and returning a float to be used as a learning rate for the optimizer. The client work will call optimizer_fn(learning_rate_fn(round_num)) where round_num is the integer round number. Note that the round numbers supplied will start at 0 and increment by one each time .next is called on the resulting process. Also note that this function must be serializable by TFF.
client_optimizer_fn A callable accepting a float learning rate, and returning a tff.learning.optimizers.Optimizer or a tf.keras.Optimizer.
server_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg callable that returns a tf.keras.Optimizer. By default, this uses tf.keras.optimizers.SGD with a learning rate of 1.0.
model_distributor An optional DistributionProcess that distributes the model weights on the server to the clients. If set to None, the distributor is constructed via distributors.build_broadcast_process.
model_aggregator An optional tff.aggregators.WeightedAggregationFactory used to aggregate client updates on the server. If None, this is set to tff.aggregators.MeanFactory.
metrics_aggregator A function that takes in the metric finalizers (i.e., tff.learning.Model.metric_finalizers()) and a tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF type of tff.learning.Model.report_local_unfinalized_metrics()), and returns a tff.Computation for aggregating the unfinalized metrics. If None, this is set to tff.learning.metrics.sum_then_finalize.
use_experimental_simulation_loop Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. It is currently necessary to set this flag to True for performant GPU simulations.

A LearningProcess.