tff.learning.algorithms.build_fed_recon
Builds the IterativeProcess for optimization using FedRecon.
tff.learning.algorithms.build_fed_recon(
model_fn: Callable[[], tff.learning.models.ReconstructionModel
],
*,
loss_fn: LossFn,
metrics_fn: Optional[MetricsFn] = None,
server_optimizer_fn: tff.learning.optimizers.Optimizer
= sgdm.build_sgdm(learning_rate=1.0),
client_optimizer_fn: tff.learning.optimizers.Optimizer
= sgdm.build_sgdm(learning_rate=0.1),
reconstruction_optimizer_fn: tff.learning.optimizers.Optimizer
= sgdm.build_sgdm(learning_rate=0.1),
dataset_split_fn: Optional[tff.learning.models.ReconstructionDatasetSplitFn
] = None,
client_weighting: Optional[client_weight_lib.ClientWeightType] = None,
model_distributor: Optional[tff.learning.templates.DistributionProcess
] = None,
model_aggregator_factory: Optional[AggregationFactory] = None,
metrics_aggregator: Optional[Callable[[MetricFinalizersType, computation_types.
StructWithPythonType], computation_base.Computation]] = tff.learning.metrics.sum_then_finalize
) -> tff.learning.templates.LearningProcess
Used in the notebooks
Returns a tff.templates.IterativeProcess
for Federated Reconstruction. On
the client, computation can be divided into two stages: (1) reconstruction of
local variables and (2) training of global variables.
Args |
model_fn
|
A no-arg function that returns a
tff.learning.reconstruction.Model . This method must not capture
Tensorflow tensors or variables and use them. must be constructed entirely
from scratch on each invocation, returning the same pre-constructed model
each call will result in an error.
|
loss_fn
|
A no-arg function returning a tf.keras.losses.Loss to use to
compute local model updates during reconstruction and post-reconstruction
and evaluate the model during training. The final loss metric is the
example-weighted mean loss across batches and across clients. The loss
metric does not include reconstruction batches in the loss.
|
metrics_fn
|
A no-arg function returning a list of tf.keras.metrics.Metric s
to evaluate the model. Metrics results are computed locally as described
by the metric, and are aggregated across clients as in
federated_aggregate_keras_metric . If None, no metrics are applied.
Metrics are not computed on reconstruction batches.
|
server_optimizer_fn
|
A tff.learning.optimizers.Optimizer for applying
updates to the global model on the server.
|
client_optimizer_fn
|
A tff.learning.optimizers.Optimizer for local client
training after reconstruction.
|
reconstruction_optimizer_fn
|
A tff.learning.optimizers.Optimizer used to
reconstruct the local variables, with the global ones frozen, or the first
stage described above.
|
dataset_split_fn
|
A tff.learning.models.ReconstructionDatasetSplitFn
taking in a single TF dataset and producing two TF datasets. The first is
iterated over during reconstruction, and the second is iterated over
post-reconstruction. This can be used to preprocess datasets to e.g.
iterate over them for multiple epochs or use disjoint data for
reconstruction and post-reconstruction. If None, split client data in half
for each user, using one half for reconstruction and the other for
evaluation. See
tff.learning.models.ReconstructionModel.build_dataset_split_fn for
options.
|
client_weighting
|
A value of tff.learning.ClientWeighting that specifies a
built-in weighting method, or a callable that takes the local metrics of
the model and returns a tensor that provides the weight in the federated
average of model deltas. If None, defaults to weighting by number of
examples.
|
model_distributor
|
An optional DistributionProcess that distributes the
model weights on the server to the clients. If set to None , the
distributor is constructed via
tff.learning.templates.build_broadcast_process .
|
model_aggregator_factory
|
An optional instance of
tff.aggregators.WeightedAggregationFactory or
tff.aggregators.UnweightedAggregationFactory determining the method of
aggregation to perform. If unspecified, uses a default
tff.aggregators.MeanFactory which computes a stateless mean across
clients (weighted depending on client_weighting ).
|
metrics_aggregator
|
A function that takes in the metric finalizers (i.e.,
tff.learning.Model.metric_finalizers() ) and a
tff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF
type of tff.learning.Model.report_local_unfinalized_metrics() ), and
returns a tff.Computation for aggregating the unfinalized metrics. If
None , this is set to tff.learning.metrics.sum_then_finalize .
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.
[null,null,["Last updated 2024-09-20 UTC."],[],[]]