Builds a learning process that performs federated evaluation.
tff.learning.algorithms.build_fed_eval(
model_fn: Union[Callable[[], tff.learning.models.VariableModel
], tff.learning.models.FunctionalModel
],
model_distributor: Optional[tff.learning.templates.DistributionProcess
] = None,
metrics_aggregation_process: Optional[tff.templates.AggregationProcess
] = None,
loop_implementation: tff.learning.LoopImplementation
= tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.LearningProcess
Used in the notebooks
Used in the tutorials |
---|
This function creates a tff.learning.templates.LearningProcess
that performs
federated evaluation on clients. The learning process has the following
methods inherited from tff.learning.templates.LearningProcess
:
initialize
: Atff.Computation
with type signature( -> S@SERVER)
, whereS
is atff.learning.templates.LearningAlgorithmState
representing the initial state of the server.next
: Atff.Computation
with type signature(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)
whereS
is aLearningAlgorithmState
whose type matches that of the output ofinitialize
, and{B*}@CLIENTS
represents the client datasets, whereB
is the type of a single batch. The outputL
contains the updated server state, as well as aggregated metrics at the server, including client evaluation metrics and any other metrics from distribution and aggregation processes.get_model_weights
: Atff.Computation
with type signature(S -> M)
, whereS
is atff.learning.templates.LearningAlgorithmState
whose type matches the output ofinitialize
andnext
, andM
represents the type of the model weights used during evaluation.set_model_weights
: Atff.Computation
with type signature(<S, M> -> S)
, whereS
is atff.learning.templates.LearningAlgorithmState
whose type matches the output ofinitialize
andM
represents the type of the model weights used during evaluation.
Each time next
is called, the server model is broadcast to each client using
a distributor. Each client evaluates the model and reports local unfinalized
metrics. The local unfinalized metrics are then aggregated and finalized at
server using the metrics aggregator. Both current round and total rounds
metrics will be produced. There are no update of the server model during the
evaluation process.
Args | |
---|---|
model_fn
|
A no-arg function that returns a
tff.learning.models.VariableModel , or an instance of a
tff.learning.models.FunctionalModel . When passing a callable, the
callable must not capture TensorFlow tensors or variables and use them.
The model must be constructed entirely from scratch on each invocation,
returning the same pre-constructed model each call will result in an
error.
|
model_distributor
|
An optional tff.learning.templates.DistributionProcess
that broadcasts the model weights on the server to the clients. It must
support the signature (input_values@SERVER -> output_values@CLIENTS) and
have empty state. If None, the server model is broadcast to the clients
using the default tff.federated_broadcast .
|
metrics_aggregation_process
|
An optional tff.templates.AggregationProcess
which aggregates the local unfinalized metrics at clients to server and
finalizes the metrics at server. The tff.templates.AggregationProcess
accumulates unfinalized metrics across round in the state, and produces a
tuple of current round metrics and total rounds metrics in the result. If
None, the tff.templates.AggregationProcess created by the
SumThenFinalizeFactory with metric finalizers defined in the model is
used.
|
loop_implementation
|
Changes the implementation of the training loop
generated. See tff.learning.LoopImplementation for more details.
|
Returns | |
---|---|
A tff.learning.templates.LearningProcess performs federated evaluation on
clients, and returns updated state and metrics.
|