Google I / O là một kết quả hoàn hảo! Cập nhật các phiên TensorFlow Xem phiên

Hướng dẫn toàn diện về cắt tỉa

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Chào mừng bạn đến với hướng dẫn toàn diện về cắt tỉa cân Keras.

Trang này ghi lại các trường hợp sử dụng khác nhau và chỉ ra cách sử dụng API cho từng trường hợp sử dụng. Một khi bạn biết được các API bạn cần, tìm các thông số và các chi tiết ở mức độ thấp trong tài liệu API .

  • Nếu bạn muốn nhìn thấy những lợi ích của cắt tỉa và những gì đang được hỗ trợ, xem cái nhìn tổng quan .
  • Đối với một đơn dụ end-to-end, xem ví dụ tỉa .

Các trường hợp sử dụng sau được đề cập:

  • Xác định và đào tạo một mô hình đã được cắt tỉa.
    • Tuần tự và Chức năng.
    • Keras model.fit và các vòng đào tạo tùy chỉnh
  • Checkpoint và deserialize một mô hình đã được cắt tỉa.
  • Triển khai một mô hình đã được lược bớt và xem các lợi ích của việc nén.

Đối với cấu hình của thuật toán cắt tỉa, hãy tham khảo tfmot.sparsity.keras.prune_low_magnitude tài liệu API.

Thành lập

Để tìm các API bạn cần và hiểu mục đích, bạn có thể chạy nhưng bỏ qua việc đọc phần này.

! pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot

%load_ext tensorboard

import tempfile

input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)

def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(20, input_shape=input_shape),
      tf.keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model = setup_model()

  model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )

  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.tf')

  model.save_weights(pretrained_weights)

  return pretrained_weights

def get_gzipped_model_size(model):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, keras_file = tempfile.mkstemp('.h5')
  model.save(keras_file, include_optimizer=False)

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(keras_file)

  return os.path.getsize(zipped_file)

setup_model()
pretrained_weights = setup_pretrained_weights()

Xác định mô hình

Tỉa toàn bộ mô hình (Tuần tự và Chức năng)

Mẹo để có độ chính xác của mô hình tốt hơn:

  • Hãy thử "Tỉa một số lớp" để bỏ qua việc cắt tỉa các lớp làm giảm độ chính xác nhiều nhất.
  • Nhìn chung, tốt hơn là bạn nên cắt tỉa bằng cách cắt tỉa thay vì đào tạo từ đầu.

Để làm cho toàn bộ mô hình đào tạo với cắt tỉa, áp dụng tfmot.sparsity.keras.prune_low_magnitude cho mô hình.

base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended.

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

model_for_pruning.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:200: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_2  (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________

Tỉa một số lớp (Tuần tự và Chức năng)

Việc cắt tỉa mô hình có thể có ảnh hưởng tiêu cực đến độ chính xác. Bạn có thể chọn lọc các lớp của mô hình để khám phá sự cân bằng giữa độ chính xác, tốc độ và kích thước mô hình.

Mẹo để có độ chính xác của mô hình tốt hơn:

  • Nhìn chung, tốt hơn là bạn nên cắt tỉa bằng cách cắt tỉa thay vì đào tạo từ đầu.
  • Hãy thử cắt tỉa các lớp sau thay vì các lớp đầu tiên.
  • Tránh cắt tỉa các lớp quan trọng (ví dụ: cơ chế chú ý).

hơn:

Trong ví dụ dưới đây, mận chỉ Dense lớp.

# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

# Helper function uses `prune_low_magnitude` to make only the 
# Dense layers train with pruning.
def apply_pruning_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.sparsity.keras.prune_low_magnitude(layer)
  return layer

# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense` 
# to the layers of the model.
model_for_pruning = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_pruning_to_dense,
)

model_for_pruning.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_3  (None, 20)                822       
_________________________________________________________________
flatten_3 (Flatten)          (None, 20)                0         
=================================================================
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________

Trong khi ví dụ này sử dụng các loại hình lớp để quyết định những gì để prune, cách dễ nhất để prune một lớp đặc biệt là để thiết lập của nó name tài sản, và nhìn cho tên đó trong clone_function .

print(base_model.layers[0].name)
dense_3

Dễ đọc hơn nhưng độ chính xác của mô hình có thể thấp hơn

Điều này không tương thích với tinh chỉnh với cắt tỉa, đó là lý do tại sao nó có thể kém chính xác hơn các ví dụ hỗ trợ tinh chỉnh ở trên.

Trong khi prune_low_magnitude thể được áp dụng trong khi xác định mô hình ban đầu, tải trọng sau khi không làm việc ở dưới đây ví dụ.

Ví dụ chức năng

# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
i = tf.keras.Input(shape=(20,))
x = tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)
model_for_pruning = tf.keras.Model(inputs=i, outputs=o)

model_for_pruning.summary()
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 20)]              0         
_________________________________________________________________
prune_low_magnitude_dense_4  (None, 10)                412       
_________________________________________________________________
flatten_4 (Flatten)          (None, 10)                0         
=================================================================
Total params: 412
Trainable params: 210
Non-trainable params: 202
_________________________________________________________________

Ví dụ tuần tự

# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])

model_for_pruning.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_5  (None, 20)                822       
_________________________________________________________________
flatten_5 (Flatten)          (None, 20)                0         
=================================================================
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________

Cắt tỉa lớp Keras tùy chỉnh hoặc sửa đổi các phần của lớp để cắt tỉa

Sai lầm phổ biến: cắt tỉa thiên vị thường gây tổn hại mô hình chính xác quá nhiều.

tfmot.sparsity.keras.PrunableLayer phục vụ hai trường hợp sử dụng:

  1. Tỉa một lớp Keras tùy chỉnh
  2. Sửa đổi các phần của lớp Keras tích hợp để cắt tỉa.

Đối với một ví dụ, giá trị mặc định API để chỉ cắt tỉa hạt nhân của Dense lớp. Ví dụ dưới đây cũng loại bỏ sự thiên vị.

class MyDenseLayer(tf.keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):

  def get_prunable_weights(self):
    # Prune bias also, though that usually harms model accuracy too much.
    return [self.kernel, self.bias]

# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])

model_for_pruning.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_my_dense (None, 20)                843       
_________________________________________________________________
flatten_6 (Flatten)          (None, 20)                0         
=================================================================
Total params: 843
Trainable params: 420
Non-trainable params: 423
_________________________________________________________________

Mô hình xe lửa

Model.fit

Gọi tfmot.sparsity.keras.UpdatePruningStep callback trong đào tạo.

Để đào tạo debug giúp đỡ, sử dụng tfmot.sparsity.keras.PruningSummaries gọi lại.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

log_dir = tempfile.mkdtemp()
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    # Log sparsity and other metrics in Tensorboard.
    tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir)
]

model_for_pruning.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
)

model_for_pruning.fit(
    x_train,
    y_train,
    callbacks=callbacks,
    epochs=2,
)

#docs_infra: no_execute
%tensorboard --logdir={log_dir}
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Epoch 1/2
1/1 [==============================] - 0s 3ms/step - loss: 1.2485 - accuracy: 0.0000e+00
Epoch 2/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
1/1 [==============================] - 0s 2ms/step - loss: 1.1999 - accuracy: 0.0000e+00

Đối với người dùng không Colab, bạn có thể thấy kết quả của một hoạt động trước đó của khối mã này trên TensorBoard.dev .

Vòng đào tạo tùy chỉnh

Gọi tfmot.sparsity.keras.UpdatePruningStep callback trong đào tạo.

Để đào tạo debug giúp đỡ, sử dụng tfmot.sparsity.keras.PruningSummaries gọi lại.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

# Boilerplate
loss = tf.keras.losses.categorical_crossentropy
optimizer = tf.keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
unused_arg = -1
epochs = 2
batches = 1 # example is hardcoded so that the number of batches cannot change.

# Non-boilerplate.
model_for_pruning.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(model_for_pruning)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # Log sparsity and other metrics in Tensorboard.
log_callback.set_model(model_for_pruning)

step_callback.on_train_begin() # run pruning callback
for _ in range(epochs):
  log_callback.on_epoch_begin(epoch=unused_arg) # run pruning callback
  for _ in range(batches):
    step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback

    with tf.GradientTape() as tape:
      logits = model_for_pruning(x_train, training=True)
      loss_value = loss(y_train, logits)
      grads = tape.gradient(loss_value, model_for_pruning.trainable_variables)
      optimizer.apply_gradients(zip(grads, model_for_pruning.trainable_variables))

  step_callback.on_epoch_end(batch=unused_arg) # run pruning callback

#docs_infra: no_execute
%tensorboard --logdir={log_dir}
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

Đối với người dùng không Colab, bạn có thể thấy kết quả của một hoạt động trước đó của khối mã này trên TensorBoard.dev .

Cải thiện độ chính xác của mô hình được cắt tỉa

Thứ nhất, nhìn vào tfmot.sparsity.keras.prune_low_magnitude tài liệu API để hiểu những gì một lịch trình cắt tỉa là và toán học của từng loại cắt tỉa đúng tiến độ.

Mẹo:

  • Có tỷ lệ học tập không quá cao hoặc quá thấp khi mô hình đang cắt tỉa. Hãy xem xét các kế hoạch cắt tỉa là một hyperparameter.

  • Như một thử nghiệm nhanh, cố gắng thử nghiệm với cắt tỉa một mô hình để các thưa thớt thức vào đầu đào tạo bằng cách thiết lập begin_step 0 với một tfmot.sparsity.keras.ConstantSparsity lịch. Bạn có thể gặp may mắn với kết quả tốt.

  • Không cắt tỉa thường xuyên để mô hình có thời gian phục hồi. Các lịch trình cắt tỉa cung cấp một tần số mặc định phong nha.

  • Để biết các ý tưởng chung nhằm cải thiện độ chính xác của mô hình, hãy tìm các mẹo cho (các) trường hợp sử dụng của bạn trong "Xác định mô hình".

Checkpoint và deserialize

Bạn phải duy trì bước tối ưu hóa trong quá trình kiểm tra. Điều này có nghĩa là trong khi bạn có thể sử dụng các kiểu Keras HDF5 để kiểm tra, bạn không thể sử dụng các trọng lượng Keras HDF5.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

_, keras_model_file = tempfile.mkstemp('.h5')

# Checkpoint: saving the optimizer is necessary (include_optimizer=True is the default).
model_for_pruning.save(keras_model_file, include_optimizer=True)
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

Những điều trên được áp dụng chung. Mã bên dưới chỉ cần thiết cho định dạng mô hình HDF5 (không phải trọng lượng HDF5 và các định dạng khác).

# Deserialize model.
with tfmot.sparsity.keras.prune_scope():
  loaded_model = tf.keras.models.load_model(keras_model_file)

loaded_model.summary()
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_8  (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________

Triển khai mô hình cắt tỉa

Xuất mô hình với nén kích thước

Sai lầm phổ biến: cả hai strip_pruning và áp dụng một thuật toán nén chuẩn (ví dụ như thông qua gzip) là cần thiết để thấy được lợi ích nén của tỉa.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

# Typically you train the model here.

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

print("final model")
model_for_export.summary()

print("\n")
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
final model
Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_9 (Dense)              (None, 20)                420       
_________________________________________________________________
flatten_10 (Flatten)         (None, 20)                0         
=================================================================
Total params: 420
Trainable params: 420
Non-trainable params: 0
_________________________________________________________________


Size of gzipped pruned model without stripping: 3299.00 bytes
Size of gzipped pruned model with stripping: 2876.00 bytes

Tối ưu hóa phần cứng cụ thể

Khi backends khác nhau cho phép cắt tỉa để cải thiện độ trễ , sử dụng khối thưa thớt có thể cải thiện độ trễ cho phần cứng nhất định.

Tăng kích thước khối sẽ làm giảm độ thưa thớt tối đa có thể đạt được để có độ chính xác của mô hình mục tiêu. Mặc dù vậy, độ trễ vẫn có thể cải thiện.

Để biết chi tiết về những gì đang được hỗ trợ cho khối thưa thớt, xem tfmot.sparsity.keras.prune_low_magnitude tài liệu API.

base_model = setup_model()

# For using intrinsics on a CPU with 128-bit registers, together with 8-bit
# quantized weights, a 1x16 block size is nice because the block perfectly
# fits into the register.
pruning_params = {'block_size': [1, 16]}
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, **pruning_params)

model_for_pruning.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Model: "sequential_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_10 (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________