View source on GitHub |
Creates a ClientWorkProcess
for federated averaging.
tff.learning.templates.build_model_delta_client_work(
model_fn: Callable[[], tff.learning.models.VariableModel
],
optimizer: Union[tff.learning.optimizers.Optimizer
, Callable[[], tf.keras.optimizers.Optimizer]],
client_weighting: tff.learning.ClientWeighting
,
metrics_aggregator: Optional[tff.learning.metrics.MetricsAggregatorType
] = None,
*,
loop_implementation: tff.learning.LoopImplementation
= tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.ClientWorkProcess
This client work is constructed in slightly different manners depending on
whether optimizer
is a tff.learning.optimizers.Optimizer
, or a no-arg
callable returning a tf.keras.optimizers.Optimizer
.
If it is a tff.learning.optimizers.Optimizer
, we avoid creating
tf.Variable
s associated with the optimizer state within the scope of the
client work, as they are not necessary. This also means that the client's
model weights are updated by computing optimizer.next
and then assigning
the result to the model weights (while a tf.keras.optimizers.Optimizer
will
modify the model weight in place using optimizer.apply_gradients
).
Args | |
---|---|
model_fn
|
A no-arg function that returns a
tff.learning.models.VariableModel . This method must not capture
TensorFlow tensors or variables and use them. The model must be
constructed entirely from scratch on each invocation, returning the same
pre-constructed model each call will result in an error.
|
optimizer
|
A tff.learning.optimizers.Optimizer , or a no-arg callable that
returns a tf.keras.Optimizer . If using a tf.keras.Optimizer , the
resulting process will have no hyperparameters in its state (ie.
process.get_hparams will return an empty dictionary), while if using a
tff.learning.optimizers.Optimizer , the process will have the same
hyperparameters as the optimizer.
|
client_weighting
|
A tff.learning.ClientWeighting value.
|
metrics_aggregator
|
A function that takes in the metric finalizers (i.e.,
tff.learning.models.VariableModel.metric_finalizers() ) returns a
tff.Computation for aggregating the unfinalized metrics. If None ,
this is set to tff.learning.metrics.sum_then_finalize .
|
loop_implementation
|
Changes the implementation of the training loop
generated. See tff.learning.LoopImplementation for more details.
|
Returns | |
---|---|
A ClientWorkProcess .
|