next: A tff.Computation with the functional type signature
(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>) where S is a
tff.learning.templates.LearningAlgorithmState whose type matches the
output of initialize and {B*}@CLIENTS represents the client datasets.
The output L contains the updated server state, as well as aggregated
metrics at the server, including client training metrics and any other
metrics from distribution and aggregation processes.
get_model_weights: A tff.Computation with type signature (S -> M),
where S is a tff.learning.templates.LearningAlgorithmState whose type
matches the output of initialize and next, and M represents the type
of the model weights used during training.
set_model_weights: A tff.Computation with type signature
(<S, M> -> S), where S is a
tff.learning.templates.LearningAlgorithmState whose type matches the
output of initialize and M represents the type of the model weights
used during training.
Each time the next method is called, the server model is communicated to
each client using the provided model_distributor. For each client, local
training is performed using client_optimizer_fn. Each client computes the
difference between the client model after training and its initial model.
These model deltas are then aggregated at the server using an unweighted
aggregation function. The aggregate model delta is applied at the server using
a server optimizer.
Args
model_fn
A no-arg function that returns a
tff.learning.models.VariableModel, or an instance of a
tff.learning.models.FunctionalModel. When passing a callable, the
callable must not capture TensorFlow tensors or variables and use them.
The model must be constructed entirely from scratch on each invocation,
returning the same pre-constructed model each call will result in an
error.
An optional DistributionProcess that distributes the
model weights on the server to the clients. If set to None, the
distributor is constructed via distributors.build_broadcast_process.