Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge


Constructs tff.templates.IterativeProcess for Federated Averaging or SGD.

This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a ClientDeltaFn that specifies how weight_deltas are computed on device.

model_fn A no-arg function that returns a tff.learning.Model.
model_to_client_delta_fn A function from a model_fn to a ClientDeltaFn.
server_optimizer_fn A no-arg function that returns a tf.Optimizer. The apply_gradients method of this optimizer is used to apply client updates to the server model.
broadcast_process A tff.templates.MeasuredProcess that broadcasts the model weights on the server to the clients. It must support the signature (input_values@SERVER -> output_values@CLIENT). If set to default None, the server model is broadcast to the clients using the default tff.federated_broadcast.
aggregation_process A tff.templates.MeasuredProcess that aggregates the model updates on the clients back to the server. It must support the signature ({input_values}@CLIENTS-> output_values@SERVER). Must be None if model_update_aggregation_factory is not None.
model_update_aggregation_factory An optional tff.aggregators.WeightedAggregationFactory that contstructs tff.templates.AggregationProcess for aggregating the client model updates on the server. If None, uses a default constructed tff.aggregators.MeanFactory, creating a stateless mean aggregation. Must be None if aggregation_process is not None.

A tff.templates.IterativeProcess.

ProcessTypeError if broadcast_process or aggregation_process do not conform to the signature of broadcast (SERVER->CLIENTS) or aggregation (CLIENTS->SERVER).
DisjointArgumentError if both aggregation_process and model_update_aggregation_factory are not None.