View source on GitHub
|
Builds a learning process that performs federated SGD.
tff.learning.algorithms.build_fed_sgd(
model_fn: Union[Callable[[], tff.learning.models.VariableModel], tff.learning.models.FunctionalModel],
server_optimizer_fn: tff.learning.optimizers.Optimizer = DEFAULT_SERVER_OPTIMIZER_FN,
model_distributor: Optional[tff.learning.templates.DistributionProcess] = None,
model_aggregator: Optional[tff.aggregators.WeightedAggregationFactory] = None,
metrics_aggregator: Optional[tff.learning.metrics.MetricsAggregatorType] = None,
loop_implementation: tff.learning.LoopImplementation = tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.LearningProcess
This function creates a tff.learning.templates.LearningProcess that performs
federated SGD on client models. The learning process has the following methods
inherited from tff.learning.templates.LearningProcess:
initialize: Atff.Computationwith type signature( -> S@SERVER), whereSis atff.learning.templates.LearningAlgorithmStaterepresenting the initial state of the server.next: Atff.Computationwith type signature(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)whereSis aLearningAlgorithmStatewhose type matches that of the output ofinitialize, and{B*}@CLIENTSrepresents the client datasets, whereBis the type of a single batch. This computation returns aLearningAlgorithmStaterepresenting the updated server state and the metrics during client training and any other metrics from broadcast and aggregation processes.get_model_weights: Atff.Computationwith type signature(S -> M), whereSis atff.learning.templates.LearningAlgorithmStatewhose type matches the output ofinitializeandnext, andMrepresents the type of the model weights used during training.set_model_weights: Atff.Computationwith type signature(<S, M> -> S), whereSis atff.learning.templates.LearningAlgorithmStatewhose type matches the output ofinitializeandMrepresents the type of the model weights used during training.
Each time next is called, the server model is broadcast to each client using
a distributor. Each client sums the gradients for each batch in its local
dataset (without updating its model) to calculate, and averages the gradients
based on their number of examples. These average gradients are then aggregated
at the server, and are applied at the server using an optimizer.
This implements the original FedSGD algorithm in McMahan et al., 2017.
Args | |
|---|---|
model_fn
|
A no-arg function that returns a
tff.learning.models.VariableModel, or an instance of a
tff.learning.models.FunctionalModel. When passing a callable, the
callable must not capture TensorFlow tensors or variables and use them.
The model must be constructed entirely from scratch on each invocation,
returning the same pre-constructed model each call will result in an
error.
|
server_optimizer_fn
|
A tff.learning.optimizers.Optimizer used to apply
client updates to the server model.
|
model_distributor
|
An optional DistributionProcess that distributes the
model weights on the server to the clients. If set to None, the
distributor is constructed via distributors.build_broadcast_process.
|
model_aggregator
|
An optional tff.aggregators.WeightedAggregationFactory
used to aggregate client updates on the server. If None, this is set to
tff.aggregators.MeanFactory.
|
metrics_aggregator
|
A function that takes in the metric finalizers (i.e.,
tff.learning.models.VariableModel.metric_finalizers()) and a
tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF
type of
tff.learning.models.VariableModel.report_local_unfinalized_metrics()),
and returns a tff.Computation for aggregating the unfinalized metrics.
|
loop_implementation
|
Changes the implementation of the training loop
generated. See tff.learning.LoopImplementation for more details.
|
Returns | |
|---|---|
A tff.learning.templates.LearningProcess.
|
View source on GitHub