Builds a learning process that performs federated evaluation.
tff.learning.algorithms.build_fed_eval(
model_fn: Union[Callable[[], tff.learning.models.VariableModel
], tff.learning.models.FunctionalModel
],
model_distributor: Optional[tff.learning.templates.DistributionProcess
] = None,
metrics_aggregation_process: Optional[tff.templates.AggregationProcess
] = None,
loop_implementation: tff.learning.LoopImplementation
= tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.LearningProcess
Used in the notebooks
This function creates a tff.learning.templates.LearningProcess
that performs
federated evaluation on clients. The learning process has the following
methods inherited from tff.learning.templates.LearningProcess
:
initialize
: A tff.Computation
with type signature ( -> S@SERVER)
,
where S
is a tff.learning.templates.LearningAlgorithmState
representing the initial state of the server.
next
: A tff.Computation
with type signature
(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)
where S
is a
LearningAlgorithmState
whose type matches that of the output
of initialize
, and {B*}@CLIENTS
represents the client datasets, where
B
is the type of a single batch. The output L
contains the updated
server state, as well as aggregated metrics at the server, including
client evaluation metrics and any other metrics from distribution and
aggregation processes.
get_model_weights
: A tff.Computation
with type signature (S -> M)
,
where S
is a tff.learning.templates.LearningAlgorithmState
whose type
matches the output of initialize
and next
, and M
represents the type
of the model weights used during evaluation.
set_model_weights
: A tff.Computation
with type signature
(<S, M> -> S)
, where S
is a
tff.learning.templates.LearningAlgorithmState
whose type matches the
output of initialize
and M
represents the type of the model weights
used during evaluation.
Each time next
is called, the server model is broadcast to each client using
a distributor. Each client evaluates the model and reports local unfinalized
metrics. The local unfinalized metrics are then aggregated and finalized at
server using the metrics aggregator. Both current round and total rounds
metrics will be produced. There are no update of the server model during the
evaluation process.
Args |
model_fn
|
A no-arg function that returns a
tff.learning.models.VariableModel , or an instance of a
tff.learning.models.FunctionalModel . When passing a callable, the
callable must not capture TensorFlow tensors or variables and use them.
The model must be constructed entirely from scratch on each invocation,
returning the same pre-constructed model each call will result in an
error.
|
model_distributor
|
An optional tff.learning.templates.DistributionProcess
that broadcasts the model weights on the server to the clients. It must
support the signature (input_values@SERVER -> output_values@CLIENTS) and
have empty state. If None, the server model is broadcast to the clients
using the default tff.federated_broadcast .
|
metrics_aggregation_process
|
An optional tff.templates.AggregationProcess
which aggregates the local unfinalized metrics at clients to server and
finalizes the metrics at server. The tff.templates.AggregationProcess
accumulates unfinalized metrics across round in the state, and produces a
tuple of current round metrics and total rounds metrics in the result. If
None, the tff.templates.AggregationProcess created by the
SumThenFinalizeFactory with metric finalizers defined in the model is
used.
|
loop_implementation
|
Changes the implementation of the training loop
generated. See tff.learning.LoopImplementation for more details.
|
Raises |
TypeError
|
If any argument type mismatches.
|