|  View on TensorFlow.org |  Run in Google Colab |  View source on GitHub |  Download notebook | 
Overview
Welcome to an end-to-end example for magnitude-based weight pruning.
Other pages
For an introduction to what pruning is and to determine if you should use it (including what's supported), see the overview page.
To quickly find the APIs you need for your use case (beyond fully pruning a model with 80% sparsity), see the comprehensive guide.
Summary
In this tutorial, you will:
- Train a kerasmodel for MNIST from scratch.
- Fine tune the model by applying the pruning API and see the accuracy.
- Create 3x smaller TF and TFLite models from pruning.
- Create a 10x smaller TFLite model from combining pruning and post-training quantization.
- See the persistence of accuracy from TF to TFLite.
Setup
 pip install -q tensorflow-model-optimizationimport tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow_model_optimization.python.core.keras.compat import keras
%load_ext tensorboard
2025-10-11 13:07:17.517247: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1760188037.542074 13713 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1760188037.550326 13713 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered W0000 00:00:1760188037.569635 13713 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1760188037.569656 13713 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1760188037.569658 13713 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1760188037.569661 13713 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
Train a model for MNIST without pruning
# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 [==============================] - 0s 0us/step 2025-10-11 13:07:21.736504: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Epoch 1/4 1688/1688 [==============================] - 7s 4ms/step - loss: 0.2958 - accuracy: 0.9168 - val_loss: 0.1148 - val_accuracy: 0.9703 Epoch 2/4 1688/1688 [==============================] - 7s 4ms/step - loss: 0.1104 - accuracy: 0.9686 - val_loss: 0.0782 - val_accuracy: 0.9795 Epoch 3/4 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0799 - accuracy: 0.9769 - val_loss: 0.0712 - val_accuracy: 0.9802 Epoch 4/4 1688/1688 [==============================] - 7s 4ms/step - loss: 0.0665 - accuracy: 0.9802 - val_loss: 0.0639 - val_accuracy: 0.9827 <tf_keras.src.callbacks.History at 0x7fcdcdfe1a30>
Evaluate baseline test accuracy and save the model for later usage.
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
_, keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)
Baseline test accuracy: 0.9794999957084656
Saved baseline model to: /tmpfs/tmp/tmpki5olst3.h5
/tmpfs/tmp/ipykernel_13713/3790298460.py:7: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  keras.models.save_model(model, keras_file, include_optimizer=False)
Fine-tune pre-trained model with pruning
Define the model
You will apply pruning to the whole model and see this in the model summary.
In this example, you start the model with 50% sparsity (50% zeros in weights) and end with 80% sparsity.
In the comprehensive guide, you can see how to prune some layers for model accuracy improvements.
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. 
num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}
model_for_pruning = prune_low_magnitude(model, **pruning_params)
# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model_for_pruning.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_reshap  (None, 28, 28, 1)         1         
 e (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_conv2d  (None, 26, 26, 12)        230       
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_max_po  (None, 13, 13, 12)        1         
 oling2d (PruneLowMagnitude                                      
 )                                                               
                                                                 
 prune_low_magnitude_flatte  (None, 2028)              1         
 n (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_dense   (None, 10)                40572     
 (PruneLowMagnitude)                                             
                                                                 
=================================================================
Total params: 40805 (159.41 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 20395 (79.69 KB)
_________________________________________________________________
Train and evaluate the model against baseline
Fine tune with pruning for two epochs.
tfmot.sparsity.keras.UpdatePruningStep is required during training, and tfmot.sparsity.keras.PruningSummaries provides logs for tracking progress and debugging.
logdir = tempfile.mkdtemp()
callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)
Epoch 1/2 422/422 [==============================] - 5s 6ms/step - loss: 0.0847 - accuracy: 0.9764 - val_loss: 0.1078 - val_accuracy: 0.9718 Epoch 2/2 422/422 [==============================] - 2s 5ms/step - loss: 0.1014 - accuracy: 0.9722 - val_loss: 0.0830 - val_accuracy: 0.9783 <tf_keras.src.callbacks.History at 0x7fcd1e61b160>
For this example, there is minimal loss in test accuracy after pruning, compared to the baseline.
_, model_for_pruning_accuracy = model_for_pruning.evaluate(
   test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)
Baseline test accuracy: 0.9794999957084656 Pruned test accuracy: 0.9718999862670898
The logs show the progression of sparsity on a per-layer basis.
#docs_infra: no_execute
%tensorboard --logdir={logdir}
For non-Colab users, you can see the results of a previous run of this code block on TensorBoard.dev.
Create 3x smaller models from pruning
Both tfmot.sparsity.keras.strip_pruning and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression
benefits of pruning.
- strip_pruningis necessary since it removes every tf.Variable that pruning only needs during training, which would otherwise add to model size during inference
- Applying a standard compression algorithm is necessary since the serialized weight matrices are the same size as they were before pruning. However, pruning makes most of the weights zeros, which is added redundancy that algorithms can utilize to further compress the model.
First, create a compressible model for TensorFlow.
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
_, pruned_keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
Saved pruned Keras model to: /tmpfs/tmp/tmpa3iund_u.h5
/tmpfs/tmp/ipykernel_13713/3267383138.py:4: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
Then, create a compressible model for TFLite.
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()
_, pruned_tflite_file = tempfile.mkstemp('.tflite')
with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)
print('Saved pruned TFLite model to:', pruned_tflite_file)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp5_xxhxp9/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp5_xxhxp9/assets WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1760188079.570481 13713 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1760188079.570521 13713 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency. Saved pruned TFLite model to: /tmpfs/tmp/tmpo0gtwfw1.tflite I0000 00:00:1760188079.576126 13713 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
Define a helper function to actually compress the models via gzip and measure the zipped size.
def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile
  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)
  return os.path.getsize(zipped_file)
Compare and see that the models are 3x smaller from pruning.
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file)))
Size of gzipped baseline Keras model: 78244.00 bytes Size of gzipped pruned Keras model: 25702.00 bytes Size of gzipped pruned TFlite model: 24828.00 bytes
Create a 10x smaller model from combining pruning and quantization
You can apply post-training quantization to the pruned model for additional benefits.
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()
_, quantized_and_pruned_tflite_file = tempfile.mkstemp('.tflite')
with open(quantized_and_pruned_tflite_file, 'wb') as f:
  f.write(quantized_and_pruned_tflite_model)
print('Saved quantized and pruned TFLite model to:', quantized_and_pruned_tflite_file)
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_pruned_tflite_file)))
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp2ckf_cl5/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp2ckf_cl5/assets W0000 00:00:1760188080.187761 13713 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format. W0000 00:00:1760188080.187788 13713 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency. Saved quantized and pruned TFLite model to: /tmpfs/tmp/tmp5xylnj_1.tflite Size of gzipped baseline Keras model: 78244.00 bytes Size of gzipped pruned and quantized TFlite model: 8689.00 bytes
See persistence of accuracy from TF to TFLite
Define a helper function to evaluate the TF Lite model on the test dataset.
import numpy as np
def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]
  # Run predictions on ever y image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)
    # Run inference.
    interpreter.invoke()
    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)
  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy
You evaluate the pruned and quantized model and see that the accuracy from TensorFlow persists to the TFLite backend.
interpreter = tf.lite.Interpreter(model_content=quantized_and_pruned_tflite_model)
interpreter.allocate_tensors()
test_accuracy = evaluate_model(interpreter)
print('Pruned and quantized TFLite test_accuracy:', test_accuracy)
print('Pruned TF test accuracy:', model_for_pruning_accuracy)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:457: UserWarning:     Warning: tf.lite.Interpreter is deprecated and is scheduled for deletion in
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    
  warnings.warn(_INTERPRETER_DELETION_WARNING)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.
Pruned and quantized TFLite test_accuracy: 0.9717
Pruned TF test accuracy: 0.9718999862670898
Conclusion
In this tutorial, you saw how to create sparse models with the TensorFlow Model Optimization Toolkit API for both TensorFlow and TFLite. You then combined pruning with post-training quantization for additional benefits.
You created a 10x smaller model for MNIST, with minimal accuracy difference.
We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments.