![]() | ![]() | ![]() | ![]() |
Chào mừng bạn đến với hướng dẫn toàn diện về cắt tỉa cân Keras.
Trang này ghi lại các trường hợp sử dụng khác nhau và chỉ ra cách sử dụng API cho từng trường hợp sử dụng. Một khi bạn biết được các API bạn cần, tìm các thông số và các chi tiết ở mức độ thấp trong tài liệu API .
- Nếu bạn muốn nhìn thấy những lợi ích của cắt tỉa và những gì đang được hỗ trợ, xem cái nhìn tổng quan .
- Đối với một đơn dụ end-to-end, xem ví dụ tỉa .
Các trường hợp sử dụng sau được đề cập:
- Xác định và đào tạo một mô hình đã được cắt tỉa.
- Tuần tự và Chức năng.
- Keras model.fit và các vòng đào tạo tùy chỉnh
- Checkpoint và deserialize một mô hình đã được cắt tỉa.
- Triển khai một mô hình đã được lược bớt và xem các lợi ích của việc nén.
Đối với cấu hình của thuật toán cắt tỉa, hãy tham khảo tfmot.sparsity.keras.prune_low_magnitude
tài liệu API.
Thành lập
Để tìm các API bạn cần và hiểu mục đích, bạn có thể chạy nhưng bỏ qua việc đọc phần này.
! pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
%load_ext tensorboard
import tempfile
input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)
def setup_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(20, input_shape=input_shape),
tf.keras.layers.Flatten()
])
return model
def setup_pretrained_weights():
model = setup_model()
model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
model.fit(x_train, y_train)
_, pretrained_weights = tempfile.mkstemp('.tf')
model.save_weights(pretrained_weights)
return pretrained_weights
def get_gzipped_model_size(model):
# Returns size of gzipped model, in bytes.
import os
import zipfile
_, keras_file = tempfile.mkstemp('.h5')
model.save(keras_file, include_optimizer=False)
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(keras_file)
return os.path.getsize(zipped_file)
setup_model()
pretrained_weights = setup_pretrained_weights()
Xác định mô hình
Tỉa toàn bộ mô hình (Tuần tự và Chức năng)
Mẹo để có độ chính xác của mô hình tốt hơn:
- Hãy thử "Tỉa một số lớp" để bỏ qua việc cắt tỉa các lớp làm giảm độ chính xác nhiều nhất.
- Nhìn chung, tốt hơn là bạn nên cắt tỉa bằng cách cắt tỉa thay vì đào tạo từ đầu.
Để làm cho toàn bộ mô hình đào tạo với cắt tỉa, áp dụng tfmot.sparsity.keras.prune_low_magnitude
cho mô hình.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended.
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
model_for_pruning.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:200: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: Please use `layer.add_weight` method instead. Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_2 (None, 20) 822 _________________________________________________________________ prune_low_magnitude_flatten_ (None, 20) 1 ================================================================= Total params: 823 Trainable params: 420 Non-trainable params: 403 _________________________________________________________________
Tỉa một số lớp (Tuần tự và Chức năng)
Việc cắt tỉa mô hình có thể có ảnh hưởng tiêu cực đến độ chính xác. Bạn có thể chọn lọc các lớp của mô hình để khám phá sự cân bằng giữa độ chính xác, tốc độ và kích thước mô hình.
Mẹo để có độ chính xác của mô hình tốt hơn:
- Nhìn chung, tốt hơn là bạn nên cắt tỉa bằng cách cắt tỉa thay vì đào tạo từ đầu.
- Hãy thử cắt tỉa các lớp sau thay vì các lớp đầu tiên.
- Tránh cắt tỉa các lớp quan trọng (ví dụ: cơ chế chú ý).
hơn:
- Các
tfmot.sparsity.keras.prune_low_magnitude
tài liệu API cung cấp chi tiết về làm thế nào để thay đổi cấu hình tỉa cho mỗi lớp.
Trong ví dụ dưới đây, mận chỉ Dense
lớp.
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
# Helper function uses `prune_low_magnitude` to make only the
# Dense layers train with pruning.
def apply_pruning_to_dense(layer):
if isinstance(layer, tf.keras.layers.Dense):
return tfmot.sparsity.keras.prune_low_magnitude(layer)
return layer
# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense`
# to the layers of the model.
model_for_pruning = tf.keras.models.clone_model(
base_model,
clone_function=apply_pruning_to_dense,
)
model_for_pruning.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details. Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_3 (None, 20) 822 _________________________________________________________________ flatten_3 (Flatten) (None, 20) 0 ================================================================= Total params: 822 Trainable params: 420 Non-trainable params: 402 _________________________________________________________________
Trong khi ví dụ này sử dụng các loại hình lớp để quyết định những gì để prune, cách dễ nhất để prune một lớp đặc biệt là để thiết lập của nó name
tài sản, và nhìn cho tên đó trong clone_function
.
print(base_model.layers[0].name)
dense_3
Dễ đọc hơn nhưng độ chính xác của mô hình có thể thấp hơn
Điều này không tương thích với tinh chỉnh với cắt tỉa, đó là lý do tại sao nó có thể kém chính xác hơn các ví dụ hỗ trợ tinh chỉnh ở trên.
Trong khi prune_low_magnitude
thể được áp dụng trong khi xác định mô hình ban đầu, tải trọng sau khi không làm việc ở dưới đây ví dụ.
Ví dụ chức năng
# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
i = tf.keras.Input(shape=(20,))
x = tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)
model_for_pruning = tf.keras.Model(inputs=i, outputs=o)
model_for_pruning.summary()
Model: "functional_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 20)] 0 _________________________________________________________________ prune_low_magnitude_dense_4 (None, 10) 412 _________________________________________________________________ flatten_4 (Flatten) (None, 10) 0 ================================================================= Total params: 412 Trainable params: 210 Non-trainable params: 202 _________________________________________________________________
Ví dụ tuần tự
# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(20, input_shape=input_shape)),
tf.keras.layers.Flatten()
])
model_for_pruning.summary()
Model: "sequential_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_5 (None, 20) 822 _________________________________________________________________ flatten_5 (Flatten) (None, 20) 0 ================================================================= Total params: 822 Trainable params: 420 Non-trainable params: 402 _________________________________________________________________
Cắt tỉa lớp Keras tùy chỉnh hoặc sửa đổi các phần của lớp để cắt tỉa
Sai lầm phổ biến: cắt tỉa thiên vị thường gây tổn hại mô hình chính xác quá nhiều.
tfmot.sparsity.keras.PrunableLayer
phục vụ hai trường hợp sử dụng:
- Tỉa một lớp Keras tùy chỉnh
- Sửa đổi các phần của lớp Keras tích hợp để cắt tỉa.
Đối với một ví dụ, giá trị mặc định API để chỉ cắt tỉa hạt nhân của Dense
lớp. Ví dụ dưới đây cũng loại bỏ sự thiên vị.
class MyDenseLayer(tf.keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):
def get_prunable_weights(self):
# Prune bias also, though that usually harms model accuracy too much.
return [self.kernel, self.bias]
# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape=input_shape)),
tf.keras.layers.Flatten()
])
model_for_pruning.summary()
Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_my_dense (None, 20) 843 _________________________________________________________________ flatten_6 (Flatten) (None, 20) 0 ================================================================= Total params: 843 Trainable params: 420 Non-trainable params: 423 _________________________________________________________________
Mô hình xe lửa
Model.fit
Gọi tfmot.sparsity.keras.UpdatePruningStep
callback trong đào tạo.
Để đào tạo debug giúp đỡ, sử dụng tfmot.sparsity.keras.PruningSummaries
gọi lại.
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
log_dir = tempfile.mkdtemp()
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
# Log sparsity and other metrics in Tensorboard.
tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir)
]
model_for_pruning.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
model_for_pruning.fit(
x_train,
y_train,
callbacks=callbacks,
epochs=2,
)
#docs_infra: no_execute
%tensorboard --logdir={log_dir}
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details. Epoch 1/2 1/1 [==============================] - 0s 3ms/step - loss: 1.2485 - accuracy: 0.0000e+00 Epoch 2/2 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01. Instructions for updating: use `tf.profiler.experimental.stop` instead. 1/1 [==============================] - 0s 2ms/step - loss: 1.1999 - accuracy: 0.0000e+00
Đối với người dùng không Colab, bạn có thể thấy kết quả của một hoạt động trước đó của khối mã này trên TensorBoard.dev .
Vòng đào tạo tùy chỉnh
Gọi tfmot.sparsity.keras.UpdatePruningStep
callback trong đào tạo.
Để đào tạo debug giúp đỡ, sử dụng tfmot.sparsity.keras.PruningSummaries
gọi lại.
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
# Boilerplate
loss = tf.keras.losses.categorical_crossentropy
optimizer = tf.keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
unused_arg = -1
epochs = 2
batches = 1 # example is hardcoded so that the number of batches cannot change.
# Non-boilerplate.
model_for_pruning.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(model_for_pruning)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # Log sparsity and other metrics in Tensorboard.
log_callback.set_model(model_for_pruning)
step_callback.on_train_begin() # run pruning callback
for _ in range(epochs):
log_callback.on_epoch_begin(epoch=unused_arg) # run pruning callback
for _ in range(batches):
step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback
with tf.GradientTape() as tape:
logits = model_for_pruning(x_train, training=True)
loss_value = loss(y_train, logits)
grads = tape.gradient(loss_value, model_for_pruning.trainable_variables)
optimizer.apply_gradients(zip(grads, model_for_pruning.trainable_variables))
step_callback.on_epoch_end(batch=unused_arg) # run pruning callback
#docs_infra: no_execute
%tensorboard --logdir={log_dir}
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Đối với người dùng không Colab, bạn có thể thấy kết quả của một hoạt động trước đó của khối mã này trên TensorBoard.dev .
Cải thiện độ chính xác của mô hình được cắt tỉa
Thứ nhất, nhìn vào tfmot.sparsity.keras.prune_low_magnitude
tài liệu API để hiểu những gì một lịch trình cắt tỉa là và toán học của từng loại cắt tỉa đúng tiến độ.
Mẹo:
Có tỷ lệ học tập không quá cao hoặc quá thấp khi mô hình đang cắt tỉa. Hãy xem xét các kế hoạch cắt tỉa là một hyperparameter.
Như một thử nghiệm nhanh, cố gắng thử nghiệm với cắt tỉa một mô hình để các thưa thớt thức vào đầu đào tạo bằng cách thiết lập
begin_step
0 với mộttfmot.sparsity.keras.ConstantSparsity
lịch. Bạn có thể gặp may mắn với kết quả tốt.Không cắt tỉa thường xuyên để mô hình có thời gian phục hồi. Các lịch trình cắt tỉa cung cấp một tần số mặc định phong nha.
Để biết các ý tưởng chung nhằm cải thiện độ chính xác của mô hình, hãy tìm các mẹo cho (các) trường hợp sử dụng của bạn trong "Xác định mô hình".
Checkpoint và deserialize
Bạn phải duy trì bước tối ưu hóa trong quá trình kiểm tra. Điều này có nghĩa là trong khi bạn có thể sử dụng các kiểu Keras HDF5 để kiểm tra, bạn không thể sử dụng các trọng lượng Keras HDF5.
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
_, keras_model_file = tempfile.mkstemp('.h5')
# Checkpoint: saving the optimizer is necessary (include_optimizer=True is the default).
model_for_pruning.save(keras_model_file, include_optimizer=True)
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Những điều trên được áp dụng chung. Mã bên dưới chỉ cần thiết cho định dạng mô hình HDF5 (không phải trọng lượng HDF5 và các định dạng khác).
# Deserialize model.
with tfmot.sparsity.keras.prune_scope():
loaded_model = tf.keras.models.load_model(keras_model_file)
loaded_model.summary()
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually. Model: "sequential_8" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_8 (None, 20) 822 _________________________________________________________________ prune_low_magnitude_flatten_ (None, 20) 1 ================================================================= Total params: 823 Trainable params: 420 Non-trainable params: 403 _________________________________________________________________
Triển khai mô hình cắt tỉa
Xuất mô hình với nén kích thước
Sai lầm phổ biến: cả hai strip_pruning
và áp dụng một thuật toán nén chuẩn (ví dụ như thông qua gzip) là cần thiết để thấy được lợi ích nén của tỉa.
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
# Typically you train the model here.
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
print("final model")
model_for_export.summary()
print("\n")
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details. final model Model: "sequential_9" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_9 (Dense) (None, 20) 420 _________________________________________________________________ flatten_10 (Flatten) (None, 20) 0 ================================================================= Total params: 420 Trainable params: 420 Non-trainable params: 0 _________________________________________________________________ Size of gzipped pruned model without stripping: 3299.00 bytes Size of gzipped pruned model with stripping: 2876.00 bytes
Tối ưu hóa phần cứng cụ thể
Khi backends khác nhau cho phép cắt tỉa để cải thiện độ trễ , sử dụng khối thưa thớt có thể cải thiện độ trễ cho phần cứng nhất định.
Tăng kích thước khối sẽ làm giảm độ thưa thớt tối đa có thể đạt được để có độ chính xác của mô hình mục tiêu. Mặc dù vậy, độ trễ vẫn có thể cải thiện.
Để biết chi tiết về những gì đang được hỗ trợ cho khối thưa thớt, xem tfmot.sparsity.keras.prune_low_magnitude
tài liệu API.
base_model = setup_model()
# For using intrinsics on a CPU with 128-bit registers, together with 8-bit
# quantized weights, a 1x16 block size is nice because the block perfectly
# fits into the register.
pruning_params = {'block_size': [1, 16]}
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, **pruning_params)
model_for_pruning.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2 WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details. Model: "sequential_10" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= prune_low_magnitude_dense_10 (None, 20) 822 _________________________________________________________________ prune_low_magnitude_flatten_ (None, 20) 1 ================================================================= Total params: 823 Trainable params: 420 Non-trainable params: 403 _________________________________________________________________