Creates a ClientWorkProcess
for federated averaging.
tff.learning.templates.build_functional_model_delta_client_work(
*,
model: tff.learning.models.FunctionalModel
,
optimizer: tff.learning.optimizers.Optimizer
,
client_weighting: tff.learning.ClientWeighting
,
metrics_aggregator: Optional[tff.learning.metrics.MetricsAggregatorType
] = None,
loop_implementation: tff.learning.LoopImplementation
= tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.ClientWorkProcess
This differs from tff.learning.templates.build_model_delta_client_work
in
that it only accepts tff.learning.models.FunctionalModel
and
tff.learning.optimizers.Optimizer
type arguments, resulting in TensorFlow
graphs that do not contain tf.Variable
operations.
Args | |
---|---|
model
|
A tff.learning.models.FunctionalModel to train.
|
optimizer
|
A tff.learning.optimizers.Optimizer to use for local, on-client
optimization.
|
client_weighting
|
A tff.learning.ClientWeighting value.
|
metrics_aggregator
|
A function that takes in the metric finalizers (i.e.,
tff.learning.models.VariableModel.metric_finalizers() ) and returns a
tff.Computation for aggregating the unfinalized metrics. If None ,
this is set to tff.learning.metrics.sum_then_finalize .
|
loop_implementation
|
Changes the implementation of the training loop
generated. See tff.learning.LoopImplementation for more details.
|
Returns | |
---|---|
A ClientWorkProcess .
|