View source on GitHub |
Modifies a keras layer or model to be clustered during training.
tfmot.clustering.keras.cluster_weights(
to_cluster,
number_of_clusters,
cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,
**kwargs
)
Used in the notebooks
Used in the guide |
---|
This function wraps a keras model or layer with clustering functionality which clusters the layer's weights during training. For examples, using this with number_of_clusters equals 8 will ensure that each weight tensor has no more than 8 unique values.
Before passing to the clustering API, a model should already be trained and show some acceptable performance on the testing/validation sets.
The function accepts either a single keras layer
(subclass of keras.layers.Layer
), list of keras layers or a keras model
(instance of keras.models.Model
) and handles them appropriately.
If it encounters a layer it does not know how to handle, it will throw an error. While clustering an entire model, even a single unknown layer would lead to an error.
Cluster a model:
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
}
clustered_model = cluster_weights(original_model, **clustering_params)
Cluster a layer:
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
}
model = tf.keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
])
Arguments | |
---|---|
to_cluster
|
A single keras layer, list of keras layers, or a
tf.keras.Model instance.
|
number_of_clusters
|
the number of cluster centroids to form when clustering a layer/model. For example, if number_of_clusters=8 then only 8 unique values will be used in each weight array. |
cluster_centroids_init
|
enum value that determines how the cluster
centroids will be initialized.
Can have following values:
|
**kwargs
|
Additional keyword arguments to be passed to the keras layer. Ignored when to_cluster is not a keras layer. |
Returns | |
---|---|
Layer or model modified to include clustering related metadata. |
Raises | |
---|---|
ValueError
|
if the keras layer is unsupported, or the keras model contains an unsupported layer. |