ウェイト クラスタリング

Arm ML ツールにより管理

このドキュメントでは、ウェイト クラスタリングがユースケースにどの程度適しているかを判断できるよう、その概要について説明します。

概要

クラスタリング(ウェイト共有)を行うと、モデル内で固有のウェイト値の数を減らすことができ、デプロイ上のメリットが得られます。この手法では、まず各レイヤーのウェイトを N 個のクラスタにグループ化し、次にクラスタに属するすべてのウェイトとして、クラスタの重心値を共有します。

この手法では、モデルの圧縮により改善が行われます。将来、フレームワークのサポートによってメモリ フットプリントが改善されれば、リソースに制約のある組み込みシステムにディープ ラーニング モデルをデプロイするための重要な前進となります。

Google は、ビジョンや音声認識のさまざまなタスクを対象にクラスタリングをテストしてきました。下記の結果が示すとおり、精度の低下を最小限に抑えながら、モデルの圧縮において最大 5 倍の改善が確認されました。

なお、クラスタリングを軸ごとのトレーニング後の量子化と併用すると、バッチ正規化層の前にある畳み込み層と全結合層のメリットが減少することに注意してください。

API 互換性マトリックス

ユーザーは、次の API でクラスタリングを適用できます。

  • モデル構築: tf.keras(Sequential モデルと Functional モデルのみ)
  • TensorFlow の各バージョン: TF 1.x(バージョン 1.14 以降)と 2.x
    • TF 2.X パッケージの tf.compat.v1 と TF 1.X パッケージの tf.compat.v2 はサポートされていません。
  • TensorFlow 実行モード: グラフモードおよび eager モード

結果

画像分類

モデル オリジナル クラスタリング後
最上位の精度(%) 圧縮済み .tflite のサイズ(MB) 構成 クラスタ数 最上位の精度(%) 圧縮済み .tflite のサイズ(MB)
MobileNetV1 70.976 14.97
選択部分(最後の 3 つの Conv2D 層) 16, 16, 16 70.294 7.69
選択部分(最後の 3 つの Conv2D 層) 32, 32, 32 70.69 8.22
完全(すべての Conv2D 層) 32 69.4 4.43
MobileNetV2 71.778 12.38
選択部分(最後の 3 つの Conv2D 層) 16, 16, 16 70.742 6.68
選択部分(最後の 3 つの Conv2D 層) 32, 32, 32 70.926 7.03
完全(すべての Conv2D 層) 32 69.744 4.05

モデルは ImageNet でトレーニングされ、テストされました。

キーワード スポッティング

モデル オリジナル クラスタリング後
最上位の精度(%) 圧縮済み .tflite のサイズ(MB) 構成 クラスタ数 最上位の精度(%) 圧縮済み .tflite のサイズ(MB)
DS-CNN-L 95.233 1.46
完全(すべての Conv2D 層) 32 95.09 0.39
完全(すべての Conv2D 層) 8 94.272 0.27

モデルは SpeechCommands v0.02 でトレーニングされ、テストされました。

  1. Keras モデルを .h5 ファイルにシリアル化します
  2. TFLiteConverter.from_keras_model_file() を使用して .h5 ファイルを .tflite に変換します
  3. .tflite ファイルを zip に圧縮します

Keras でのウェイト クラスタリングの例に加えて、次の例もご覧ください。

  • MNIST の手書き数字の分類データセットでトレーニングされた CNN モデルのウェイトをクラスタリングする: コード

ウェイト クラスタリングの実装は、『Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffman Coding』の論文に基づいています。第 3 章「Trained Quantization and Weight Sharing」をご覧ください。