Weight clustering comprehensive guide

View on TensorFlow.org View source on GitHub Download notebook

Welcome to the comprehensive guide for weight clustering, part of the TensorFlow Model Optimization toolkit.

This page documents various use cases and shows how to use the API for each one. Once you know which APIs you need, find the parameters and the low-level details in the API docs:

In this guide, the following use cases are covered:

  • Define a clustered model.
  • Checkpoint and deserialize a clustered model.
  • Improve the accuracy of the clustered model.
  • For deployment only, you must take steps to see compression benefits.

Setup

! pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
import tempfile
import os
import tensorflow_model_optimization as tfmot

input_dim = 20
output_dim = 20
x_train = np.random.randn(1, input_dim).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=output_dim)

def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(input_dim, input_shape=[input_dim]),
      tf.keras.layers.Flatten()
  ])
  return model

def train_model(model):
  model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )
  model.summary()
  model.fit(x_train, y_train)
  return model

def save_model_weights(model):
  _, pretrained_weights = tempfile.mkstemp('.h5')
  model.save_weights(pretrained_weights)
  return pretrained_weights

def setup_pretrained_weights():
  model= setup_model()
  model = train_model(model)
  pretrained_weights = save_model_weights(model)
  return pretrained_weights

def setup_pretrained_model():
  model = setup_model()
  pretrained_weights = setup_pretrained_weights()
  model.load_weights(pretrained_weights)
  return model

def save_model_file(model):
  _, keras_file = tempfile.mkstemp('.h5') 
  model.save(keras_file, include_optimizer=False)
  return keras_file

def get_gzipped_model_size(model):
  # It returns the size of the gzipped model in bytes.
  import os
  import zipfile

  keras_file = save_model_file(model)

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(keras_file)
  return os.path.getsize(zipped_file)

setup_model()
pretrained_weights = setup_pretrained_weights()

Define a clustered model

Cluster a whole model (sequential and functional)

Tips for better model accuracy:

  • You must pass a pre-trained model with acceptable accuracy to this API. Training models from scratch with clustering results in subpar accuracy.
  • In some cases, clustering certain layers has a detrimental effect on model accuracy. Check "Cluster some layers" to see how to skip clustering the layers that affect accuracy the most.

To cluster all layers, apply tfmot.clustering.keras.cluster_weights to the model.

import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 3,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

model = setup_model()
model.load_weights(pretrained_weights)

clustered_model = cluster_weights(model, **clustering_params)

clustered_model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
cluster_dense_2 (ClusterWeig (None, 20)                823       
_________________________________________________________________
cluster_flatten_2 (ClusterWe (None, 20)                0         
=================================================================
Total params: 823
Trainable params: 423
Non-trainable params: 400
_________________________________________________________________

Cluster some layers (sequential and functional models)

Tips for better model accuracy:

  • You must pass a pre-trained model with acceptable accuracy to this API. Training models from scratch with clustering results in subpar accuracy.
  • Cluster later layers with more redundant parameters (e.g. tf.keras.layers.Dense, tf.keras.layers.Conv2D), as opposed to the early layers.
  • Freeze early layers prior to the clustered layers during fine-tuning. Treat the number of frozen layers as a hyperparameter. Empirically, freezing most early layers is ideal for the current clustering API.
  • Avoid clustering critical layers (e.g. attention mechanism).

More: the tfmot.clustering.keras.cluster_weights API docs provide details on how to vary the clustering configuration per layer.

# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights)

# Helper function uses `cluster_weights` to make only 
# the Dense layers train with clustering
def apply_clustering_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return cluster_weights(layer, **clustering_params)
  return layer

# Use `tf.keras.models.clone_model` to apply `apply_clustering_to_dense` 
# to the layers of the model.
clustered_model = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_clustering_to_dense,
)

clustered_model.summary()
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
cluster_dense_3 (ClusterWeig (None, 20)                823       
_________________________________________________________________
flatten_3 (Flatten)          (None, 20)                0         
=================================================================
Total params: 823
Trainable params: 423
Non-trainable params: 400
_________________________________________________________________

Cluster custom Keras layer or specify which weights of layer to cluster

tfmot.clustering.keras.ClusterableLayer serves two use cases:

  1. Cluster any layer that is not supported natively, including a custom Keras layer.
  2. Specify which weights of a supported layer are to be clustered.

For an example, the API defaults to only clustering the kernel of the Dense layer. The example below shows how to modify it to also cluster the bias. Note that when deriving from the keras layer, you need to override the function get_clusterable_weights, where you specify the name of the trainable variable to be clustered and the trainable variable itself. For example, if you return an empty list [], then no weights will be clusterable.

Common mistake: clustering the bias usually harms model accuracy too much.

class MyDenseLayer(tf.keras.layers.Dense, tfmot.clustering.keras.ClusterableLayer):

  def get_clusterable_weights(self):
   # Cluster kernel and bias. This is just an example, clustering
   # bias usually hurts model accuracy.
   return [('kernel', self.kernel), ('bias', self.bias)]

# Use `cluster_weights` to make the `MyDenseLayer` layer train with clustering as usual.
model_for_clustering = tf.keras.Sequential([
  tfmot.clustering.keras.cluster_weights(MyDenseLayer(20, input_shape=[input_dim]), **clustering_params),
  tf.keras.layers.Flatten()
])

model_for_clustering.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
cluster_my_dense_layer (Clus (None, 20)                846       
_________________________________________________________________
flatten_4 (Flatten)          (None, 20)                0         
=================================================================
Total params: 846
Trainable params: 426
Non-trainable params: 420
_________________________________________________________________

You may also use tfmot.clustering.keras.ClusterableLayer to cluster a keras custom layer. To do this, you extend tf.keras.Layer as usual and implement the __init__, call, and build functions, but you also need to extend the clusterable_layer.ClusterableLayer class and implement:

  1. get_clusterable_weights, where you specify the weight kernel to be clustered, as shown above.
  2. get_clusterable_algorithm, where you specify the clustering algorithm for the weight tensor. This is because you need to specify how the custom layer weights are shaped for clustering. The returned clustering algorithm class should be derived from the clustering_algorithm.ClusteringAlgorithm class and the function get_pulling_indices should be overwritten. An example of this function, which supports weights of ranks 1D, 2D, and 3D, can be found here.

An example of this use case can be found here.

Checkpoint and deserialize a clustered model

Your use case: this code is only needed for the HDF5 model format (not HDF5 weights or other formats).

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights)
clustered_model = cluster_weights(base_model, **clustering_params)

# Save or checkpoint the model.
_, keras_model_file = tempfile.mkstemp('.h5')
clustered_model.save(keras_model_file, include_optimizer=True)

# `cluster_scope` is needed for deserializing HDF5 models.
with tfmot.clustering.keras.cluster_scope():
  loaded_model = tf.keras.models.load_model(keras_model_file)

loaded_model.summary()
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
cluster_dense_4 (ClusterWeig (None, 20)                823       
_________________________________________________________________
cluster_flatten_5 (ClusterWe (None, 20)                0         
=================================================================
Total params: 823
Trainable params: 423
Non-trainable params: 400
_________________________________________________________________

Improve the accuracy of the clustered model

For your specific use case, there are tips you can consider:

  • Centroid initialization plays a key role in the final optimized model accuracy. In general, kmeans++ initialization outperforms linear, density and random initialization. When not using kmeans++, linear initialization tends to outperform density and random initialization, since it does not tend to miss large weights. However, density initialization has been observed to give better accuracy for the case of using very few clusters on weights with bimodal distributions.

  • Set a learning rate that is lower than the one used in training when fine-tuning the clustered model.

  • For general ideas to improve model accuracy, look for tips for your use case(s) under "Define a clustered model".

Deployment

Export model with size compression

Common mistake: both strip_clustering and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression benefits of clustering.

model = setup_model()
clustered_model = cluster_weights(model, **clustering_params)

clustered_model.compile(
    loss=tf.keras.losses.categorical_crossentropy,
    optimizer='adam',
    metrics=['accuracy']
)

clustered_model.fit(
    x_train,
    y_train
)

final_model = tfmot.clustering.keras.strip_clustering(clustered_model)

print("final model")
final_model.summary()

print("\n")
print("Size of gzipped clustered model without stripping: %.2f bytes" 
      % (get_gzipped_model_size(clustered_model)))
print("Size of gzipped clustered model with stripping: %.2f bytes" 
      % (get_gzipped_model_size(final_model)))
1/1 [==============================] - 0s 343ms/step - loss: 16.1181 - accuracy: 0.0000e+00
final model
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_5 (Dense)              (None, 20)                420       
_________________________________________________________________
flatten_6 (Flatten)          (None, 20)                0         
=================================================================
Total params: 420
Trainable params: 420
Non-trainable params: 0
_________________________________________________________________


Size of gzipped clustered model without stripping: 3447.00 bytes
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
Size of gzipped clustered model with stripping: 1433.00 bytes