tff.learning.algorithms.build_fed_sgd
Stay organized with collections
Save and categorize content based on your preferences.
Builds a learning process that performs federated SGD.
tff.learning.algorithms.build_fed_sgd(
model_fn: Union[Callable[[], tff.learning.models.VariableModel
], tff.learning.models.FunctionalModel
],
server_optimizer_fn: tff.learning.optimizers.Optimizer
= DEFAULT_SERVER_OPTIMIZER_FN,
model_distributor: Optional[tff.learning.templates.DistributionProcess
] = None,
model_aggregator: Optional[tff.aggregators.WeightedAggregationFactory
] = None,
metrics_aggregator: Optional[tff.learning.metrics.MetricsAggregatorType
] = None,
loop_implementation: tff.learning.LoopImplementation
= tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.LearningProcess
This function creates a tff.learning.templates.LearningProcess
that performs
federated SGD on client models. The learning process has the following methods
inherited from tff.learning.templates.LearningProcess
:
initialize
: A tff.Computation
with type signature ( -> S@SERVER)
,
where S
is a tff.learning.templates.LearningAlgorithmState
representing the initial state of the server.
next
: A tff.Computation
with type signature
(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)
where S
is a
LearningAlgorithmState
whose type matches that of the output
of initialize
, and {B*}@CLIENTS
represents the client datasets, where
B
is the type of a single batch. This computation returns a
LearningAlgorithmState
representing the updated server state and the
metrics during client training and any other metrics from broadcast and
aggregation processes.
get_model_weights
: A tff.Computation
with type signature (S -> M)
,
where S
is a tff.learning.templates.LearningAlgorithmState
whose type
matches the output of initialize
and next
, and M
represents the type
of the model weights used during training.
set_model_weights
: A tff.Computation
with type signature
(<S, M> -> S)
, where S
is a
tff.learning.templates.LearningAlgorithmState
whose type matches the
output of initialize
and M
represents the type of the model weights
used during training.
Each time next
is called, the server model is broadcast to each client using
a distributor. Each client sums the gradients for each batch in its local
dataset (without updating its model) to calculate, and averages the gradients
based on their number of examples. These average gradients are then aggregated
at the server, and are applied at the server using an optimizer.
This implements the original FedSGD algorithm in McMahan et al.,
2017.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.
[null,null,["Last updated 2024-09-20 UTC."],[],[],null,["# tff.learning.algorithms.build_fed_sgd\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/federated/blob/v0.87.0 Version 2.0, January 2004 Licensed under the Apache License, Version 2.0 (the) |\n\nBuilds a learning process that performs federated SGD. \n\n tff.learning.algorithms.build_fed_sgd(\n model_fn: Union[Callable[[], ../../../tff/learning/models/VariableModel], ../../../tff/learning/models/FunctionalModel],\n server_optimizer_fn: ../../../tff/learning/optimizers/Optimizer = DEFAULT_SERVER_OPTIMIZER_FN,\n model_distributor: Optional[../../../tff/learning/templates/DistributionProcess] = None,\n model_aggregator: Optional[../../../tff/aggregators/WeightedAggregationFactory] = None,\n metrics_aggregator: Optional[../../../tff/learning/metrics/MetricsAggregatorType] = None,\n loop_implementation: ../../../tff/learning/LoopImplementation = ../../../tff/learning/LoopImplementation#DATASET_REDUCE\n ) -\u003e ../../../tff/learning/templates/LearningProcess\n\nThis function creates a [`tff.learning.templates.LearningProcess`](../../../tff/learning/templates/LearningProcess) that performs\nfederated SGD on client models. The learning process has the following methods\ninherited from [`tff.learning.templates.LearningProcess`](../../../tff/learning/templates/LearningProcess):\n\n- `initialize`: A [`tff.Computation`](../../../tff/Computation) with type signature `( -\u003e S@SERVER)`, where `S` is a [`tff.learning.templates.LearningAlgorithmState`](../../../tff/learning/templates/LearningAlgorithmState) representing the initial state of the server.\n- `next`: A [`tff.Computation`](../../../tff/Computation) with type signature `(\u003cS@SERVER, {B*}@CLIENTS\u003e -\u003e \u003cS@SERVER, T@SERVER\u003e)` where `S` is a `LearningAlgorithmState` whose type matches that of the output of `initialize`, and `{B*}@CLIENTS` represents the client datasets, where `B` is the type of a single batch. This computation returns a `LearningAlgorithmState` representing the updated server state and the metrics during client training and any other metrics from broadcast and aggregation processes.\n- `get_model_weights`: A [`tff.Computation`](../../../tff/Computation) with type signature `(S -\u003e M)`, where `S` is a [`tff.learning.templates.LearningAlgorithmState`](../../../tff/learning/templates/LearningAlgorithmState) whose type matches the output of `initialize` and `next`, and `M` represents the type of the model weights used during training.\n- `set_model_weights`: A [`tff.Computation`](../../../tff/Computation) with type signature `(\u003cS, M\u003e -\u003e S)`, where `S` is a [`tff.learning.templates.LearningAlgorithmState`](../../../tff/learning/templates/LearningAlgorithmState) whose type matches the output of `initialize` and `M` represents the type of the model weights used during training.\n\nEach time `next` is called, the server model is broadcast to each client using\na distributor. Each client sums the gradients for each batch in its local\ndataset (without updating its model) to calculate, and averages the gradients\nbased on their number of examples. These average gradients are then aggregated\nat the server, and are applied at the server using an optimizer.\n\nThis implements the original FedSGD algorithm in [McMahan et al.,\n2017](https://arxiv.org/abs/1602.05629).\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `model_fn` | A no-arg function that returns a [`tff.learning.models.VariableModel`](../../../tff/learning/models/VariableModel), or an instance of a [`tff.learning.models.FunctionalModel`](../../../tff/learning/models/FunctionalModel). When passing a callable, the callable must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. |\n| `server_optimizer_fn` | A [`tff.learning.optimizers.Optimizer`](../../../tff/learning/optimizers/Optimizer) used to apply client updates to the server model. |\n| `model_distributor` | An optional `DistributionProcess` that distributes the model weights on the server to the clients. If set to `None`, the distributor is constructed via `distributors.build_broadcast_process`. |\n| `model_aggregator` | An optional [`tff.aggregators.WeightedAggregationFactory`](../../../tff/aggregators/WeightedAggregationFactory) used to aggregate client updates on the server. If `None`, this is set to [`tff.aggregators.MeanFactory`](../../../tff/aggregators/MeanFactory). |\n| `metrics_aggregator` | A function that takes in the metric finalizers (i.e., [`tff.learning.models.VariableModel.metric_finalizers()`](../../../tff/learning/models/VariableModel#metric_finalizers)) and a [`tff.types.StructWithPythonType`](../../../tff/types/StructWithPythonType) of the unfinalized metrics (i.e., the TFF type of [`tff.learning.models.VariableModel.report_local_unfinalized_metrics()`](../../../tff/learning/models/VariableModel#report_local_unfinalized_metrics)), and returns a [`tff.Computation`](../../../tff/Computation) for aggregating the unfinalized metrics. |\n| `loop_implementation` | Changes the implementation of the training loop generated. See [`tff.learning.LoopImplementation`](../../../tff/learning/LoopImplementation) for more details. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A [`tff.learning.templates.LearningProcess`](../../../tff/learning/templates/LearningProcess). ||\n\n\u003cbr /\u003e"]]