Giúp bảo vệ Great Barrier Reef với TensorFlow trên Kaggle Tham Challenge

Ví dụ về đào tạo nhận thức lượng tử hóa (PCQAT) của Keras

Xem trên TensorFlow.org Chạy trong Google Colab Xem trên GitHub Tải xuống sổ ghi chép

Tổng quat

Đây là một dấu chấm hết cho ví dụ cuối cho thấy việc sử dụng các thưa thớt và cụm giữ lượng tử đào tạo (PCQAT) API biết, một phần của đường ống dẫn tối ưu hóa hợp tác các TensorFlow Mẫu Tối ưu hóa Toolkit.

Những trang khác

Đối với một giới thiệu về các đường ống và các kỹ thuật có sẵn khác, hãy xem trang tổng quan về tối ưu hóa hợp tác .

Nội dung

Trong hướng dẫn này, bạn sẽ:

  1. Đào tạo một tf.keras mô hình cho các tập dữ liệu MNIST từ đầu.
  2. Tinh chỉnh mô hình bằng cách cắt tỉa và xem độ chính xác và quan sát rằng mô hình đã được cắt tỉa thành công.
  3. Áp dụng phân cụm bảo tồn thưa thớt trên mô hình đã được cắt tỉa và quan sát rằng sự thưa thớt được áp dụng trước đó đã được bảo toàn.
  4. Áp dụng QAT và quan sát sự mất mát của sự thưa thớt và cụm.
  5. Áp dụng PCQAT và quan sát rằng cả sự thưa thớt và phân cụm được áp dụng trước đó đã được bảo toàn.
  6. Tạo mô hình TFLite và quan sát tác động của việc áp dụng PCQAT trên đó.
  7. So sánh kích thước của các mô hình khác nhau để quan sát các lợi ích nén của việc áp dụng tính chất thưa thớt, theo sau là các kỹ thuật cộng tác tối ưu hóa phân nhóm bảo toàn độ thưa thớt và PCQAT.
  8. So sánh độ chính xác của mô hình được tối ưu hóa hoàn toàn với độ chính xác của mô hình cơ sở chưa được tối ưu hóa.

Thành lập

Bạn có thể chạy Notebook Jupyter này ở địa phương của bạn virtualenv hoặc colab . Để biết chi tiết về việc thiết lập phụ thuộc, vui lòng tham khảo hướng dẫn cài đặt .

 pip install -q tensorflow-model-optimization
import tensorflow as tf

import numpy as np
import tempfile
import zipfile
import os

Đào tạo mô hình tf.keras để MNIST được cắt tỉa và phân cụm

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

opt = tf.keras.optimizers.Adam(learning_rate=1e-3)

# Train the digit classification model
model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
2021-09-02 11:14:14.164834: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.2842 - accuracy: 0.9215 - val_loss: 0.1078 - val_accuracy: 0.9713
Epoch 2/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.1110 - accuracy: 0.9684 - val_loss: 0.0773 - val_accuracy: 0.9783
Epoch 3/10
1688/1688 [==============================] - 8s 4ms/step - loss: 0.0821 - accuracy: 0.9760 - val_loss: 0.0676 - val_accuracy: 0.9803
Epoch 4/10
1688/1688 [==============================] - 8s 4ms/step - loss: 0.0684 - accuracy: 0.9799 - val_loss: 0.0600 - val_accuracy: 0.9825
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0590 - accuracy: 0.9828 - val_loss: 0.0601 - val_accuracy: 0.9838
Epoch 6/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0522 - accuracy: 0.9845 - val_loss: 0.0599 - val_accuracy: 0.9835
Epoch 7/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0472 - accuracy: 0.9863 - val_loss: 0.0544 - val_accuracy: 0.9862
Epoch 8/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0422 - accuracy: 0.9868 - val_loss: 0.0579 - val_accuracy: 0.9848
Epoch 9/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0384 - accuracy: 0.9884 - val_loss: 0.0569 - val_accuracy: 0.9847
Epoch 10/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0347 - accuracy: 0.9892 - val_loss: 0.0559 - val_accuracy: 0.9840
<keras.callbacks.History at 0x7f6a8212c550>

Đánh giá mô hình cơ sở và lưu nó để sử dụng sau này

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9811000227928162
Saving model to:  /tmp/tmprlekfdwb.h5

Cắt tỉa và tinh chỉnh mô hình đến 50% độ thưa thớt

Áp dụng các prune_low_magnitude() API để đạt được các mô hình tỉa mà là để được nhóm lại trong bước tiếp theo. Hãy tham khảo những hướng dẫn toàn diện cắt tỉa để biết thêm thông tin về API tỉa.

Xác định mô hình và áp dụng API thưa thớt

Lưu ý rằng mô hình được đào tạo trước được sử dụng.

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
  }

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model = prune_low_magnitude(model, **pruning_params)

# Use smaller learning rate for fine-tuning
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

pruned_model.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer.py:2223: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '

Tinh chỉnh mô hình, kiểm tra độ thưa thớt và đánh giá độ chính xác so với đường cơ sở

Tinh chỉnh mô hình bằng cách cắt tỉa trong 3 kỷ nguyên.

# Fine-tune model
pruned_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1,
  callbacks=callbacks)
2021-09-02 11:15:31.836903: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
Epoch 1/3
1688/1688 [==============================] - 9s 5ms/step - loss: 0.2095 - accuracy: 0.9305 - val_loss: 0.1440 - val_accuracy: 0.9528
Epoch 2/3
1688/1688 [==============================] - 8s 4ms/step - loss: 0.1042 - accuracy: 0.9671 - val_loss: 0.0947 - val_accuracy: 0.9715
Epoch 3/3
1688/1688 [==============================] - 8s 4ms/step - loss: 0.0743 - accuracy: 0.9782 - val_loss: 0.0829 - val_accuracy: 0.9770
<keras.callbacks.History at 0x7f6a81f94250>

Xác định các hàm trợ giúp để tính toán và in độ thưa thớt và các cụm của mô hình.

def print_model_weights_sparsity(model):
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            if "kernel" not in weight.name or "centroid" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )

def print_model_weight_clusters(model):
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

Trước tiên, hãy tách lớp bao bọc cắt tỉa, sau đó kiểm tra xem các hạt nhân mô hình đã được cắt tỉa chính xác chưa.

stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)
conv2d/kernel:0: 50.00% sparsity  (54/108)
dense/kernel:0: 50.00% sparsity  (10140/20280)

Áp dụng phân cụm bảo toàn độ thưa thớt và kiểm tra ảnh hưởng của nó đối với độ thưa thớt của mô hình trong cả hai trường hợp

Tiếp theo, áp dụng phân cụm bảo tồn thưa thớt trên mô hình đã cắt tỉa và quan sát số lượng các cụm và kiểm tra xem sự thưa thớt có được bảo toàn hay không.

import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (
    cluster,
)

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

cluster_weights = cluster.cluster_weights

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'preserve_sparsity': True
}

sparsity_clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)

sparsity_clustered_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)
Train sparsity preserving clustering model:
Epoch 1/3
1688/1688 [==============================] - 9s 5ms/step - loss: 0.0495 - accuracy: 0.9847 - val_loss: 0.0611 - val_accuracy: 0.9843
Epoch 2/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0472 - accuracy: 0.9855 - val_loss: 0.0705 - val_accuracy: 0.9812
Epoch 3/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0463 - accuracy: 0.9846 - val_loss: 0.0796 - val_accuracy: 0.9780
<keras.callbacks.History at 0x7f6a81c10250>

Trước tiên, loại bỏ lớp bao bọc phân cụm, sau đó kiểm tra xem mô hình đã được cắt và phân cụm chính xác chưa.

stripped_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)

print("Model sparsity:\n")
print_model_weights_sparsity(stripped_clustered_model)

print("\nModel clusters:\n")
print_model_weight_clusters(stripped_clustered_model)
Model sparsity:

kernel:0: 51.85% sparsity  (56/108)
kernel:0: 60.83% sparsity  (12337/20280)

Model clusters:

conv2d/kernel:0: 8 clusters 
dense/kernel:0: 8 clusters

Áp dụng QAT và PCQAT và kiểm tra hiệu ứng trên các cụm mô hình và độ thưa thớt

Tiếp theo, áp dụng cả QAT và PCQAT trên mô hình phân cụm thưa thớt và quan sát rằng PCQAT bảo toàn khối lượng thưa thớt và các cụm trong mô hình của bạn. Lưu ý rằng mô hình đã loại bỏ được chuyển đến API QAT và PCQAT.

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# PCQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
pcqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))

pcqat_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train pcqat model:')
pcqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model:
422/422 [==============================] - 4s 8ms/step - loss: 0.0343 - accuracy: 0.9892 - val_loss: 0.0600 - val_accuracy: 0.9858
Train pcqat model:
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['conv2d/kernel:0', 'dense/kernel:0'] when minimizing the loss.
422/422 [==============================] - 4s 8ms/step - loss: 0.0371 - accuracy: 0.9880 - val_loss: 0.0664 - val_accuracy: 0.9832
<keras.callbacks.History at 0x7f6a81792910>
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("\nQAT Model sparsity:")
print_model_weights_sparsity(qat_model)
print("\nPCQAT Model clusters:")
print_model_weight_clusters(pcqat_model)
print("\nPCQAT Model sparsity:")
print_model_weights_sparsity(pcqat_model)
QAT Model clusters:
quant_conv2d/conv2d/kernel:0: 101 clusters 
quant_dense/dense/kernel:0: 18285 clusters 

QAT Model sparsity:
conv2d/kernel:0: 7.41% sparsity  (8/108)
dense/kernel:0: 7.64% sparsity  (1549/20280)

PCQAT Model clusters:
quant_conv2d/conv2d/kernel:0: 8 clusters 
quant_dense/dense/kernel:0: 8 clusters 

PCQAT Model sparsity:
conv2d/kernel:0: 51.85% sparsity  (56/108)
dense/kernel:0: 60.84% sparsity  (12338/20280)

Xem lợi ích nén của mô hình PCQAT

Xác định chức năng trợ giúp để lấy tệp mô hình nén.

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

Quan sát thấy rằng việc áp dụng thưa thớt, phân cụm và PCQAT cho một mô hình sẽ mang lại những lợi ích nén đáng kể.

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# PCQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(pcqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pcqat_tflite_model = converter.convert()
pcqat_model_file = 'pcqat_model.tflite'
# Save the model.
with open(pcqat_model_file, 'wb') as f:
    f.write(pcqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("PCQAT model size: ", get_gzipped_model_size(pcqat_model_file), ' KB')
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmp6_obh00g/assets
INFO:tensorflow:Assets written to: /tmp/tmp6_obh00g/assets
2021-09-02 11:16:32.221664: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format.
2021-09-02 11:16:32.221712: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency.
WARNING:absl:Found untraced functions such as reshape_layer_call_and_return_conditional_losses, reshape_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, conv2d_layer_call_fn, flatten_layer_call_and_return_conditional_losses while saving (showing 5 of 20). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tmpuqqwyk0s/assets
INFO:tensorflow:Assets written to: /tmp/tmpuqqwyk0s/assets
QAT model size:  13.723  KB
PCQAT model size:  7.352  KB
2021-09-02 11:16:33.766310: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format.
2021-09-02 11:16:33.766350: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency.

Xem độ chính xác lâu dài từ TF sang TFLite

Xác định một chức năng trợ giúp để đánh giá mô hình TFLite trên tập dữ liệu thử nghiệm.

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

Đánh giá mô hình, đã được lược bớt, phân cụm và lượng hóa, sau đó thấy rằng độ chính xác từ TensorFlow vẫn tồn tại trong phần phụ trợ TFLite.

interpreter = tf.lite.Interpreter(pcqat_model_file)
interpreter.allocate_tensors()

pcqat_test_accuracy = eval_model(interpreter)

print('Pruned, clustered and quantized TFLite test_accuracy:', pcqat_test_accuracy)
print('Baseline TF test accuracy:', baseline_model_accuracy)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned, clustered and quantized TFLite test_accuracy: 0.9803
Baseline TF test accuracy: 0.9811000227928162

Sự kết luận

Trong hướng dẫn này, bạn đã học như thế nào để tạo ra một mô hình, tỉa nó bằng cách sử dụng prune_low_magnitude() API, và áp dụng thưa thớt bảo quản phân nhóm bằng cách sử dụng cluster_weights() API để giữ gìn thưa thớt trong khi phân nhóm các trọng.

Tiếp theo, đào tạo nhận thức lượng tử hóa (PCQAT) được áp dụng để duy trì độ thưa thớt và cụm mô hình trong khi sử dụng QAT. Mô hình PCQAT cuối cùng được so sánh với mô hình QAT để chỉ ra rằng sự thưa thớt và các cụm được bảo tồn ở mô hình trước và bị mất ở mô hình sau.

Tiếp theo, các mô hình được chuyển đổi sang TFLite để hiển thị các lợi ích nén của kỹ thuật tối ưu hóa mô hình chuỗi thưa thớt, phân cụm và PCQAT và mô hình TFLite được đánh giá để đảm bảo rằng độ chính xác vẫn tồn tại trong phần phụ trợ TFLite.

Cuối cùng, độ chính xác của mô hình PCQAT TFLite được so sánh với độ chính xác của mô hình cơ sở trước khi tối ưu hóa để cho thấy rằng các kỹ thuật tối ưu hóa cộng tác được quản lý để đạt được lợi ích nén trong khi vẫn duy trì độ chính xác tương tự so với mô hình ban đầu.