Google I/O is a wrap! Catch up on TensorFlow sessions View sessions


Builds a learning process that performs federated SGD.

This function creates a tff.learning.templates.LearningProcess that performs federated SGD on client models. The learning process has the following methods inherited from tff.learning.templates.LearningProcess:

  • initialize: A tff.Computation with type signature ( -> S@SERVER), where S is a tff.learning.templates.LearningAlgorithmState representing the initial state of the server.
  • next: A tff.Computation with type signature (<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>) where S is a LearningAlgorithmState whose type matches that of the output of initialize, and {B*}@CLIENTS represents the client datasets, where B is the type of a single batch. This computation returns a LearningAlgorithmState representing the updated server state and metrics that are the result of tff.learning.Model.federated_output_computation during client training and any other metrics from broadcast and aggregation processes.
  • get_model_weights: A tff.Computation with type signature (S -> M), where S is a tff.learning.templates.LearningAlgorithmState whose type matches the output of initialize and M represents the type of the model weights used during training.

Each time next is called, the server model is broadcast to each client using a distributor. Each client sums the gradients for each batch in its local dataset (without updating its model) to calculate, and averages the gradients based on their number of examples. These average gradients are then aggregated at the server, and are applied at the server using a tf.keras.optimizers.Optimizer.

This implements the original FedSGD algorithm in McMahan et al., 2017.

model_fn A no-arg function that returns a tff.learning.Model. This method must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
server_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg callable that returns a tf.keras.Optimizer. The optimizer is used to apply client updates to the server model.
model_distributor An optional DistributionProcess that distributes the model weights on the server to the clients. If set to None, the distributor is constructed via distributors.build_broadcast_process.
model_aggregator An optional tff.aggregators.WeightedAggregationFactory used to aggregate client updates on the server. If None, this is set to tff.aggregators.MeanFactory.
metrics_aggregator A function that takes in the metric finalizers (i.e., tff.learning.Model.metric_finalizers()) and a tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF type of tff.learning.Model.report_local_unfinalized_metrics()), and returns a tff.Computation for aggregating the unfinalized metrics.
use_experimental_simulation_loop Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation.

A tff.learning.templates.LearningProcess.