Attend the Women in ML Symposium on December 7 Register now


Builds a learning process for federated k-means clustering.

This function creates a tff.learning.templates.LearningProcess that performs federated k-means clustering. Specifically, this performs mini-batch k-means clustering. Note that mini-batch k-means only processes a mini-batch of the data at each round, and updates clusters in a weighted manner based on how many points in the mini-batch were assigned to each cluster. In the federated version, clients do the assignment of each of their point locally, and the server updates the clusters. Conceptually, the "mini-batch" being used is the union of all client datasets involved in a given round.

The learning process has the following methods inherited from tff.learning.templates.LearningProcess:

  • initialize: A tff.Computation with the functional type signature ( -> S@SERVER), where S is a LearningAlgorithmState representing the initial state of the server.
  • next: A tff.Computation with the functional type signature (<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>) where S is a LearningAlgorithmState whose type matches the output of initialize and {B*}@CLIENTS represents the client datasets. The output L is a tff.learning.templates.LearningProcessOutput containing the state S and metrics computed during training.
  • get_model_weights: A tff.Computation with type signature (S -> W), where W represents the current k-means centroids.
  • set_model_weights: A tff.Computation with type signature (<S, M> -> S), where S is a tff.learning.templates.LearningAlgorithmState whose type matches the output of initialize and M a new set of k-means centroids.

Here, S is a tff.learning.templates.LearningAlgorithmState. The centroids W is a tensor representing the current centroids, and is of shape (num_clusters,) + data_shape. The datasets {B*} must have elements of shape data_shape, and not employ batching.

The centroids are updated at each round by assigning all clients' points to the nearest centroid, and then summing these points according to these centroids. The centroids are then updated at the server based on these points. To do so, we keep track of how many points have been assigned to each centroid overall, as an integer tensor of shape (num_clusters,). This information can be found in state.finalizer. Note that we begin with a "pseudo-count" of 1, in order to ensure that the centroids do not collapse to zero.

num_clusters The number of clusters to use.
data_shape A tuple of integers specifying the shape of each data point. Note that this data shape should be unbatched, as this algorithm does not currently support batched data points.
random_seed A tuple of two integers used to seed the initialization phase.
distributor An optional tff.learning.tekmplates.DistributionProcess that broadcasts the centroids on the server to the clients. If set to None, the distributor is constructed via tff.learning.templates.build_broadcast_process.
sum_aggregator An optional tff.aggregators.UnweightedAggregationFactory used to sum updates across clients. If None, we use tff.aggregators.SumFactory.

A LearningProcess.