TensorFlow.orgで表示 | Google Colab で実行 | GitHub でソースを表示{ | ノートブックをダウンロード/a0} |
TensorFlow Model Optimization ツールキットの一部である重みクラスタリングの総合ガイドへようこそ。
このページでは、さまざまなユースケースを示し、それぞれで API を使用する方法を説明します。どの API が必要であるかを特定したら、API ドキュメントでパラメータと詳細を確認してください。
- 重みクラスタリングのメリットとサポート対象を確認する場合は、概要をご覧ください。
- 単一のエンドツーエンドの例については、重みクラスタリングの例をご覧ください。
このガイドでは、次のユースケースについて説明しています。
- クラスタモデルを定義する
- クラスタモデルのチェックポイントと逆シリアル化
- クラスタモデルの精度を改善する
- デプロイのみについて、ステップを実行して圧縮のメリットを確認する必要があります。
セットアップ
! pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tempfile
import os
import tensorflow_model_optimization as tfmot
input_dim = 20
output_dim = 20
x_train = np.random.randn(1, input_dim).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=output_dim)
def setup_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(input_dim, input_shape=[input_dim]),
tf.keras.layers.Flatten()
])
return model
def train_model(model):
model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
model.summary()
model.fit(x_train, y_train)
return model
def save_model_weights(model):
_, pretrained_weights = tempfile.mkstemp('.h5')
model.save_weights(pretrained_weights)
return pretrained_weights
def setup_pretrained_weights():
model= setup_model()
model = train_model(model)
pretrained_weights = save_model_weights(model)
return pretrained_weights
def setup_pretrained_model():
model = setup_model()
pretrained_weights = setup_pretrained_weights()
model.load_weights(pretrained_weights)
return model
def save_model_file(model):
_, keras_file = tempfile.mkstemp('.h5')
model.save(keras_file, include_optimizer=False)
return keras_file
def get_gzipped_model_size(model):
# It returns the size of the gzipped model in bytes.
import os
import zipfile
keras_file = save_model_file(model)
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(keras_file)
return os.path.getsize(zipped_file)
setup_model()
pretrained_weights = setup_pretrained_weights()
クラスタモデルを定義する
モデル全体のクラスタリング(Sequential と Functional)
モデルの精度を高めるためのヒント:
- この API には許容できる精度のトレーニング済みモデルを渡す必要があります。クラスタリングを使用してモデルを最初からトレーニングすると、精度が低くなります。
- 一部のケースでは、特定のレイヤーをクラスタリングすると、モデルの精度に悪影響が及びます。精度に最も大きく影響するレイヤーのクラスタリングを省略する方法について、「一部のレイヤーをクラスタリングする」をご覧ください。
すべてのレイヤーをクラスタリングするには、モデルに tfmot.clustering.keras.cluster_weights
を適用します。
import tensorflow_model_optimization as tfmot
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
clustering_params = {
'number_of_clusters': 3,
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
}
model = setup_model()
model.load_weights(pretrained_weights)
clustered_model = cluster_weights(model, **clustering_params)
clustered_model.summary()
Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= cluster_dense_2 (ClusterWeig (None, 20) 423 _________________________________________________________________ cluster_flatten_2 (ClusterWe (None, 20) 0 ================================================================= Total params: 423 Trainable params: 23 Non-trainable params: 400 _________________________________________________________________
一部のレイヤーをクラスタリングする(Sequential モデルと Functional モデル)
モデルの精度を高めるためのヒント:
- You must pass a pre-trained model with acceptable accuracy to this API. Training models from scratch with clustering results in subpar accuracy.
- 初期のレイヤーと比較し、後のレイヤーはより多い冗長パラメータ(
tf.keras.layers.Dense
、tf.keras.layers.Conv2D
など)でクラスタリングします。 - 微調整中、クラスタリングレイヤーの前に初期のレイヤーを凍結します。凍結したレイヤーの数をハイパーパラメータとして処理します。経験的に、現在のクラスタリング API では、最も初期のレイヤーを凍結することが理想的です。
- クリティカルレイヤー(注意メカニズムなど)のクラスリングを回避します。
その他: tfmot.clustering.keras.cluster_weights
API ドキュメントには、レイヤーごとにクラスタ構成を変える方法が示されています。
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights)
# Helper function uses `cluster_weights` to make only
# the Dense layers train with clustering
def apply_clustering_to_dense(layer):
if isinstance(layer, tf.keras.layers.Dense):
return cluster_weights(layer, **clustering_params)
return layer
# Use `tf.keras.models.clone_model` to apply `apply_clustering_to_dense`
# to the layers of the model.
clustered_model = tf.keras.models.clone_model(
base_model,
clone_function=apply_clustering_to_dense,
)
clustered_model.summary()
Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= cluster_dense_3 (ClusterWeig (None, 20) 423 _________________________________________________________________ flatten_3 (Flatten) (None, 20) 0 ================================================================= Total params: 423 Trainable params: 23 Non-trainable params: 400 _________________________________________________________________
モデルのチェックポイントと逆シリアル化
ユースケース: このコードは、HDF5 モデル形式のみで必要です(HDF5 重みまたはその他の形式では不要です)。
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights)
clustered_model = cluster_weights(base_model, **clustering_params)
# Save or checkpoint the model.
_, keras_model_file = tempfile.mkstemp('.h5')
clustered_model.save(keras_model_file, include_optimizer=True)
# `cluster_scope` is needed for deserializing HDF5 models.
with tfmot.clustering.keras.cluster_scope():
loaded_model = tf.keras.models.load_model(keras_model_file)
loaded_model.summary()
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually. Model: "sequential_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= cluster_dense_4 (ClusterWeig (None, 20) 423 _________________________________________________________________ cluster_flatten_4 (ClusterWe (None, 20) 0 ================================================================= Total params: 423 Trainable params: 23 Non-trainable params: 400 _________________________________________________________________
クラスタモデルの精度を改善する
特定のユースケースについて、次のヒントを考慮できます。
最終的な最適化モデルの精度には、重心の初期化に重要な役割があります。通常、線形初期化は大きな重みを見逃さない傾向があるため、密度とランダム初期化のパフォーマンスを上回ります。ただし、密度の初期化は、二峰性分布で重みに非常に少数のクラスタを使用した場合に、より優れた精度を示すことが確認されています。
クラスタモデルを微調整する際は、トレーニングに使用されている学習率よりも低い率を設定します。
モデルの精度を改善するための一般的なアイデアについては、「クラスタモデルを定義する」に記載のケース別のヒントをご覧ください。
デプロイ
サイズ圧縮によるモデルのエクスポート
一般的な過ち: strip_clustering
と標準圧縮アルゴリズム(gzip など)の適用は、クラスタリングの圧縮のメリットを確認する上で必要です。
model = setup_model()
clustered_model = cluster_weights(model, **clustering_params)
clustered_model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
clustered_model.fit(
x_train,
y_train
)
final_model = tfmot.clustering.keras.strip_clustering(clustered_model)
print("final model")
final_model.summary()
print("\n")
print("Size of gzipped clustered model without stripping: %.2f bytes"
% (get_gzipped_model_size(clustered_model)))
print("Size of gzipped clustered model with stripping: %.2f bytes"
% (get_gzipped_model_size(final_model)))
1/1 [==============================] - 0s 345ms/step - loss: 1.4791 - accuracy: 0.0000e+00 final model Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_5 (Dense) (None, 20) 420 _________________________________________________________________ flatten_5 (Flatten) (None, 20) 0 ================================================================= Total params: 420 Trainable params: 420 Non-trainable params: 0 _________________________________________________________________ Size of gzipped clustered model without stripping: 1871.00 bytes Size of gzipped clustered model with stripping: 1475.00 bytes