|  View source on GitHub | 
Modifies a keras layer or model to be clustered during training.
tfmot.clustering.keras.cluster_weights(
    to_cluster,
    number_of_clusters,
    cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,
    **kwargs
)
Used in the notebooks
| Used in the guide | 
|---|
This function wraps a keras model or layer with clustering functionality which clusters the layer's weights during training. For examples, using this with number_of_clusters equals 8 will ensure that each weight tensor has no more than 8 unique values.
Before passing to the clustering API, a model should already be trained and show some acceptable performance on the testing/validation sets.
The function accepts either a single keras layer
(subclass of keras.layers.Layer), list of keras layers or a keras model
(instance of keras.models.Model) and handles them appropriately.
If it encounters a layer it does not know how to handle, it will throw an error. While clustering an entire model, even a single unknown layer would lead to an error.
Cluster a model:
clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
}
clustered_model = cluster_weights(original_model, **clustering_params)
Cluster a layer:
clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
}
model = tf.keras.Sequential([
    layers.Dense(10, activation='relu', input_shape=(100,)),
    cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
])
| Arguments | |
|---|---|
| to_cluster | A single keras layer, list of keras layers, or a tf.keras.Modelinstance. | 
| number_of_clusters | the number of cluster centroids to form when clustering a layer/model. For example, if number_of_clusters=8 then only 8 unique values will be used in each weight array. | 
| cluster_centroids_init | enum value that determines how the cluster
centroids will be initialized.
Can have following values: 
 | 
| **kwargs | Additional keyword arguments to be passed to the keras layer. Ignored when to_cluster is not a keras layer. | 
| Returns | |
|---|---|
| Layer or model modified to include clustering related metadata. | 
| Raises | |
|---|---|
| ValueError | if the keras layer is unsupported, or the keras model contains an unsupported layer. |