View source on GitHub
|
Builds a learning process that performs the FedProx algorithm.
tff.learning.algorithms.build_weighted_fed_prox(
model_fn: Union[Callable[[], tff.learning.models.VariableModel], tff.learning.models.FunctionalModel],
proximal_strength: float,
client_optimizer_fn: tff.learning.optimizers.Optimizer,
server_optimizer_fn: tff.learning.optimizers.Optimizer = DEFAULT_SERVER_OPTIMIZER_FN,
client_weighting: Optional[tff.learning.ClientWeighting] = tff.learning.ClientWeighting.NUM_EXAMPLES,
model_distributor: Optional[tff.learning.templates.DistributionProcess] = None,
model_aggregator: Optional[tff.aggregators.WeightedAggregationFactory] = None,
metrics_aggregator: Optional[tff.learning.metrics.MetricsAggregatorType] = None,
loop_implementation: tff.learning.LoopImplementation = tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.LearningProcess
This function creates a tff.learning.templates.LearningProcess that performs
example-weighted FedProx on client models. This algorithm behaves the same as
federated averaging, except that it uses a proximal regularization term that
encourages clients to not drift too far from the server model.
The iterative process has the following methods inherited from
tff.learning.templates.LearningProcess:
initialize: Atff.Computationwith the functional type signature( -> S@SERVER), whereSis atff.learning.templates.LearningAlgorithmStaterepresenting the initial state of the server.next: Atff.Computationwith the functional type signature(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)whereSis atff.learning.templates.LearningAlgorithmStatewhose type matches the output ofinitializeand{B*}@CLIENTSrepresents the client datasets. The outputLcontains the updated server state, as well as aggregated metrics at the server, including client training metrics and any other metrics from distribution and aggregation processes.get_model_weights: Atff.Computationwith type signature(S -> M), whereSis atff.learning.templates.LearningAlgorithmStatewhose type matches the output ofinitializeandnext, andMrepresents the type of the model weights used during training.set_model_weights: Atff.Computationwith type signature(<S, M> -> S), whereSis atff.learning.templates.LearningAlgorithmStatewhose type matches the output ofinitializeandMrepresents the type of the model weights used during training.
Each time the next method is called, the server model is communicated to
each client using the provided model_distributor. For each client, local
training is performed using client_optimizer_fn. Each client computes the
difference between the client model after training and the initial model.
These model deltas are then aggregated at the server using a weighted
aggregation function, according to client_weighting. The aggregate model
delta is applied at the server using a server optimizer, as in the FedOpt
framework proposed in Reddi et al., 2021.
Args | |
|---|---|
model_fn
|
A no-arg function that returns a
tff.learning.models.VariableModel, or an instance of a
tff.learning.models.FunctionalModel. When passing a callable, the
callable must not capture TensorFlow tensors or variables and use them.
The model must be constructed entirely from scratch on each invocation,
returning the same pre-constructed model each call will result in an
error.
|
proximal_strength
|
A nonnegative float representing the parameter of
FedProx's regularization term. When set to 0.0, the algorithm reduces to
FedAvg. Higher values prevent clients from moving too far from the server
model during local training.
|
client_optimizer_fn
|
A tff.learning.optimizers.Optimizer.
|
server_optimizer_fn
|
A tff.learning.optimizers.Optimizer. By default, this
uses SGD with a learning rate of 1.0.
|
client_weighting
|
A member of tff.learning.ClientWeighting that specifies
a built-in weighting method. By default, weighting by number of examples
is used.
|
model_distributor
|
An optional DistributionProcess that broadcasts the
model weights on the server to the clients. If set to None, the
distributor is constructed via distributors.build_broadcast_process.
|
model_aggregator
|
An optional tff.aggregators.WeightedAggregationFactory
used to aggregate client updates on the server. If None, this is set to
tff.aggregators.MeanFactory.
|
metrics_aggregator
|
A function that takes in the metric finalizers (i.e.,
tff.learning.models.VariableModel.metric_finalizers()) and a
tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF
type of
tff.learning.models.VariableModel.report_local_unfinalized_metrics()),
and returns a tff.Computation for aggregating the unfinalized metrics.
If None, this is set to tff.learning.metrics.sum_then_finalize.
|
loop_implementation
|
Changes the implementation of the training loop
generated. See tff.learning.LoopImplementation for more details.
|
Returns | |
|---|---|
A tff.learning.templates.LearningProcess.
|
Raises | |
|---|---|
ValueError
|
If proximal_parameter is not a nonnegative float.
|
View source on GitHub