tff.learning.algorithms.build_fed_kmeans
Stay organized with collections
Save and categorize content based on your preferences.
Builds a learning process for federated k-means clustering.
tff.learning.algorithms.build_fed_kmeans(
num_clusters: int,
data_shape: tuple[int, ...],
random_seed: Optional[tuple[int, int]] = None,
distributor: Optional[tff.learning.templates.DistributionProcess
] = None,
sum_aggregator: Optional[tff.aggregators.UnweightedAggregationFactory
] = None
) -> tff.learning.templates.LearningProcess
This function creates a tff.learning.templates.LearningProcess
that performs
federated k-means clustering. Specifically, this performs mini-batch k-means
clustering. Note that mini-batch k-means only processes a mini-batch of the
data at each round, and updates clusters in a weighted manner based on how
many points in the mini-batch were assigned to each cluster. In the federated
version, clients do the assignment of each of their point locally, and the
server updates the clusters. Conceptually, the "mini-batch" being used is the
union of all client datasets involved in a given round.
The learning process has the following methods inherited from
tff.learning.templates.LearningProcess
:
initialize
: A tff.Computation
with the functional type signature
( -> S@SERVER)
, where S
is a LearningAlgorithmState
representing the
initial state of the server.
next
: A tff.Computation
with the functional type signature
(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)
where S
is a
LearningAlgorithmState
whose type matches the output of initialize
and {B*}@CLIENTS
represents the client datasets. The output L
is a
tff.learning.templates.LearningProcessOutput
containing the state S
and metrics computed during training.
get_model_weights
: A tff.Computation
with type signature (S -> W)
,
where W
represents the current k-means centroids.
set_model_weights
: A tff.Computation
with type signature
(<S, M> -> S)
, where S
is a
tff.learning.templates.LearningAlgorithmState
whose type matches the
output of initialize
and M
a new set of k-means centroids.
Here, S
is a tff.learning.templates.LearningAlgorithmState
. The centroids
W
is a tensor representing the current centroids, and is of shape
(num_clusters,) + data_shape
. The datasets {B*}
must have elements of
shape data_shape
, and not employ batching.
The centroids are updated at each round by assigning all clients' points to
the nearest centroid, and then summing these points according to these
centroids. The centroids are then updated at the server based on these points.
To do so, we keep track of how many points have been assigned to each centroid
overall, as a float tensor of shape (num_clusters,)
. This information can
be found in state.finalizer
. Note that we begin with a "pseudo-count" of 1,
in order to ensure that the centroids do not collapse to zero.
Args |
num_clusters
|
The number of clusters to use.
|
data_shape
|
A tuple of integers specifying the shape of each data point.
Note that this data shape should be unbatched, as this algorithm does not
currently support batched data points.
|
random_seed
|
A tuple of two integers used to seed the initialization phase.
|
distributor
|
An optional tff.learning.tekmplates.DistributionProcess that
broadcasts the centroids on the server to the clients. If set to None ,
the distributor is constructed via
tff.learning.templates.build_broadcast_process .
|
sum_aggregator
|
An optional tff.aggregators.UnweightedAggregationFactory
used to sum updates across clients. If None , we use
tff.aggregators.SumFactory .
|
Returns |
A LearningProcess .
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.
[null,null,["Last updated 2024-09-20 UTC."],[],[],null,["# tff.learning.algorithms.build_fed_kmeans\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/federated/blob/v0.87.0 Version 2.0, January 2004 Licensed under the Apache License, Version 2.0 (the) |\n\nBuilds a learning process for federated k-means clustering. \n\n tff.learning.algorithms.build_fed_kmeans(\n num_clusters: int,\n data_shape: tuple[int, ...],\n random_seed: Optional[tuple[int, int]] = None,\n distributor: Optional[../../../tff/learning/templates/DistributionProcess] = None,\n sum_aggregator: Optional[../../../tff/aggregators/UnweightedAggregationFactory] = None\n ) -\u003e ../../../tff/learning/templates/LearningProcess\n\nThis function creates a [`tff.learning.templates.LearningProcess`](../../../tff/learning/templates/LearningProcess) that performs\nfederated k-means clustering. Specifically, this performs mini-batch k-means\nclustering. Note that mini-batch k-means only processes a mini-batch of the\ndata at each round, and updates clusters in a weighted manner based on how\nmany points in the mini-batch were assigned to each cluster. In the federated\nversion, clients do the assignment of each of their point locally, and the\nserver updates the clusters. Conceptually, the \"mini-batch\" being used is the\nunion of all client datasets involved in a given round.\n\nThe learning process has the following methods inherited from\n[`tff.learning.templates.LearningProcess`](../../../tff/learning/templates/LearningProcess):\n\n- `initialize`: A [`tff.Computation`](../../../tff/Computation) with the functional type signature `( -\u003e S@SERVER)`, where `S` is a `LearningAlgorithmState` representing the initial state of the server.\n- `next`: A [`tff.Computation`](../../../tff/Computation) with the functional type signature `(\u003cS@SERVER, {B*}@CLIENTS\u003e -\u003e \u003cL@SERVER\u003e)` where `S` is a `LearningAlgorithmState` whose type matches the output of `initialize` and `{B*}@CLIENTS` represents the client datasets. The output `L` is a [`tff.learning.templates.LearningProcessOutput`](../../../tff/learning/templates/LearningProcessOutput) containing the state `S` and metrics computed during training.\n- `get_model_weights`: A [`tff.Computation`](../../../tff/Computation) with type signature `(S -\u003e W)`, where `W` represents the current k-means centroids.\n- `set_model_weights`: A [`tff.Computation`](../../../tff/Computation) with type signature `(\u003cS, M\u003e -\u003e S)`, where `S` is a [`tff.learning.templates.LearningAlgorithmState`](../../../tff/learning/templates/LearningAlgorithmState) whose type matches the output of `initialize` and `M` a new set of k-means centroids.\n\nHere, `S` is a [`tff.learning.templates.LearningAlgorithmState`](../../../tff/learning/templates/LearningAlgorithmState). The centroids\n`W` is a tensor representing the current centroids, and is of shape\n`(num_clusters,) + data_shape`. The datasets `{B*}` must have elements of\nshape `data_shape`, and not employ batching.\n\nThe centroids are updated at each round by assigning all clients' points to\nthe nearest centroid, and then summing these points according to these\ncentroids. The centroids are then updated at the server based on these points.\nTo do so, we keep track of how many points have been assigned to each centroid\noverall, as a float tensor of shape `(num_clusters,)`. This information can\nbe found in `state.finalizer`. Note that we begin with a \"pseudo-count\" of 1,\nin order to ensure that the centroids do not collapse to zero.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `num_clusters` | The number of clusters to use. |\n| `data_shape` | A tuple of integers specifying the shape of each data point. Note that this data shape should be unbatched, as this algorithm does not currently support batched data points. |\n| `random_seed` | A tuple of two integers used to seed the initialization phase. |\n| `distributor` | An optional `tff.learning.tekmplates.DistributionProcess` that broadcasts the centroids on the server to the clients. If set to `None`, the distributor is constructed via [`tff.learning.templates.build_broadcast_process`](../../../tff/learning/templates/build_broadcast_process). |\n| `sum_aggregator` | An optional [`tff.aggregators.UnweightedAggregationFactory`](../../../tff/aggregators/UnweightedAggregationFactory) used to sum updates across clients. If `None`, we use [`tff.aggregators.SumFactory`](../../../tff/aggregators/SumFactory). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A `LearningProcess`. ||\n\n\u003cbr /\u003e"]]