Builds the IterativeProcess for optimization using FedRecon.
tff.learning.algorithms.build_fed_recon(
model_fn: Callable[[], tff.learning.models.ReconstructionModel
],
*,
loss_fn: LossFn,
metrics_fn: Optional[MetricsFn] = None,
server_optimizer_fn: OptimizerFn = functools.partial(tf.keras.optimizers.SGD, 1.0),
client_optimizer_fn: OptimizerFn = functools.partial(tf.keras.optimizers.SGD, 0.1),
reconstruction_optimizer_fn: OptimizerFn = functools.partial(tf.keras.optimizers.SGD, 0.1),
dataset_split_fn: Optional[tff.learning.models.ReconstructionDatasetSplitFn
] = None,
client_weighting: Optional[client_weight_lib.ClientWeightType] = None,
model_distributor: Optional[tff.learning.templates.DistributionProcess
] = None,
model_aggregator_factory: Optional[AggregationFactory] = None,
metrics_aggregator: Optional[Callable[[MetricFinalizersType, computation_types.
StructWithPythonType], computation_base.Computation]] = tff.learning.metrics.sum_then_finalize
) -> tff.learning.templates.LearningProcess
Used in the notebooks
Returns a tff.templates.IterativeProcess
for Federated Reconstruction. On
the client, computation can be divided into two stages: (1) reconstruction of
local variables and (2) training of global variables.
Args |
model_fn
|
A no-arg function that returns a
tff.learning.reconstruction.Model . This method must not capture
Tensorflow tensors or variables and use them. must be constructed entirely
from scratch on each invocation, returning the same pre-constructed model
each call will result in an error.
|
loss_fn
|
A no-arg function returning a tf.keras.losses.Loss to use to
compute local model updates during reconstruction and post-reconstruction
and evaluate the model during training. The final loss metric is the
example-weighted mean loss across batches and across clients. The loss
metric does not include reconstruction batches in the loss.
|
metrics_fn
|
A no-arg function returning a list of tf.keras.metrics.Metric s
to evaluate the model. Metrics results are computed locally as described
by the metric, and are aggregated across clients as in
federated_aggregate_keras_metric . If None, no metrics are applied.
Metrics are not computed on reconstruction batches.
|
server_optimizer_fn
|
A tff.learning.optimizers.Optimizer , or a no-arg
function that returns a tf.keras.optimizers.Optimizer for applying
updates to the global model on the server.
|
client_optimizer_fn
|
A tff.learning.optimizers.Optimizer , or a no-arg
function that returns a tf.keras.optimizers.Optimizer for local client
training after reconstruction.
|
reconstruction_optimizer_fn
|
A tff.learning.optimizers.Optimizer , or a
no-arg function that returns a tf.keras.optimizers.Optimizer used to
reconstruct the local variables, with the global ones frozen, or the first
stage described above.
|
dataset_split_fn
|
A tff.learning.models.ReconstructionDatasetSplitFn
taking in a single TF dataset and producing two TF datasets. The first is
iterated over during reconstruction, and the second is iterated over
post-reconstruction. This can be used to preprocess datasets to e.g.
iterate over them for multiple epochs or use disjoint data for
reconstruction and post-reconstruction. If None, split client data in half
for each user, using one half for reconstruction and the other for
evaluation. See
tff.learning.models.ReconstructionModel.build_dataset_split_fn for
options.
|
client_weighting
|
A value of tff.learning.ClientWeighting that specifies a
built-in weighting method, or a callable that takes the local metrics of
the model and returns a tensor that provides the weight in the federated
average of model deltas. If None, defaults to weighting by number of
examples.
|
model_distributor
|
An optional DistributionProcess that distributes the
model weights on the server to the clients. If set to None , the
distributor is constructed via
tff.learning.templates.build_broadcast_process .
|
model_aggregator_factory
|
An optional instance of
tff.aggregators.WeightedAggregationFactory or
tff.aggregators.UnweightedAggregationFactory determining the method of
aggregation to perform. If unspecified, uses a default
tff.aggregators.MeanFactory which computes a stateless mean across
clients (weighted depending on client_weighting ).
|
metrics_aggregator
|
A function that takes in the metric finalizers (i.e.,
tff.learning.Model.metric_finalizers() ) and a
tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF
type of tff.learning.Model.report_local_unfinalized_metrics() ), and
returns a tff.Computation for aggregating the unfinalized metrics. If
None , this is set to tff.learning.metrics.sum_then_finalize .
|