tff.learning.algorithms.build_fed_eval

Builds a learning process that performs federated evaluation.

Used in the notebooks

This function creates a tff.learning.templates.LearningProcess that performs federated evaluation on clients. The learning process has the following methods inherited from tff.learning.templates.LearningProcess:

  • initialize: A tff.Computation with type signature ( -> S@SERVER), where S is a tff.learning.templates.LearningAlgorithmState representing the initial state of the server.
  • next: A tff.Computation with type signature (<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>) where S is a LearningAlgorithmState whose type matches that of the output of initialize, and {B*}@CLIENTS represents the client datasets, where B is the type of a single batch. The output L contains the updated server state, as well as aggregated metrics at the server, including client evaluation metrics and any other metrics from distribution and aggregation processes.
  • get_model_weights: A tff.Computation with type signature (S -> M), where S is a tff.learning.templates.LearningAlgorithmState whose type matches the output of initialize and next, and M represents the type of the model weights used during evaluation.
  • set_model_weights: A tff.Computation with type signature (<S, M> -> S), where S is a tff.learning.templates.LearningAlgorithmState whose type matches the output of initialize and M represents the type of the model weights used during evaluation.

Each time next is called, the server model is broadcast to each client using a distributor. Each client evaluates the model and reports local unfinalized metrics. The local unfinalized metrics are then aggregated and finalized at server using the metrics aggregator. Both current round and total rounds metrics will be produced. There are no update of the server model during the evaluation process.

model_fn A no-arg function that returns a tff.learning.models.VariableModel, or an instance of a tff.learning.models.FunctionalModel. When passing a callable, the callable must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
model_distributor An optional tff.learning.templates.DistributionProcess that broadcasts the model weights on the server to the clients. It must support the signature (input_values@SERVER -> output_values@CLIENTS) and have empty state. If None, the server model is broadcast to the clients using the default tff.federated_broadcast.
metrics_aggregation_process An optional tff.templates.AggregationProcess which aggregates the local unfinalized metrics at clients to server and finalizes the metrics at server. The tff.templates.AggregationProcess accumulates unfinalized metrics across round in the state, and produces a tuple of current round metrics and total rounds metrics in the result. If None, the tff.templates.AggregationProcess created by the SumThenFinalizeFactory with metric finalizers defined in the model is used.
loop_implementation Changes the implementation of the training loop generated. See tff.learning.LoopImplementation for more details.

A tff.learning.templates.LearningProcess performs federated evaluation on clients, and returns updated state and metrics.

TypeError If any argument type mismatches.