モデルのチェックポイント

モデルの状態を保存および復元する機能は、転移学習や事前トレーニング済みモデルを使用した推論の実行など、多くのアプリケーションにとって不可欠です。これを達成する 1 つの方法は、モデルのパラメーター (重み、バイアスなど) をチェックポイント ファイルまたはディレクトリに保存することです。

このモジュールは、 TensorFlow v2 形式のチェックポイントをロードおよび保存するための高レベルのインターフェイスと、このファイル形式への書き込みおよび読み取りを行う下位レベルのコンポーネントを提供します。

単純なモデルのロードと保存

Checkpointableプロトコルに準拠することで、コードを追加することなく、多くの単純なモデルをチェックポイントにシリアル化できます。

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

次に、次を使用して同じチェックポイントを読み取ることができます。

try model.readCheckpoint(from: directory, name: "LeNet")

モデルの読み込みと保存のためのこのデフォルトの実装では、モデル構造内のプロパティの名前に基づく、モデル内の各テンソルのパスベースの命名スキームが使用されます。たとえば、 LeNet-5 モデルの最初の畳み込み内の重みとバイアスは、それぞれconv1/filterconv1/biasという名前で保存されます。ロード時に、チェックポイント リーダーはこれらの名前のテンソルを検索します。

モデルの読み込みと保存のカスタマイズ

どの tensor を保存およびロードするか、またはそれらの tensor の名前付けをより詳細に制御したい場合、 Checkpointableプロトコルはいくつかのカスタマイズ ポイントを提供します。

特定の型のプロパティを無視するには、 Type.propertyの形式で文字列の Set を返す、 ignoredTensorPathsの実装をモデルに提供します。たとえば、すべてのアテンション レイヤーのscaleプロパティを無視するには、 ["Attention.scale"]を返すことができます。

デフォルトでは、モデル内の各より深いレベルを区切るためにスラッシュが使用されます。これは、モデルにcheckpointSeparatorを実装し、このセパレータに使用する新しい文字列を提供することでカスタマイズできます。

最後に、テンソルの命名を最大限にカスタマイズするには、 tensorNameMap実装し、モデル内のテンソルに対して生成されたデフォルトの文字列名をチェックポイント内の目的の文字列名にマッピングする関数を提供できます。最も一般的に、これは他のフレームワークで生成されたチェックポイントと相互運用するために使用されます。各フレームワークには独自の命名規則とモデル構造があります。カスタム マッピング関数を使用すると、これらのテンソルの名前の付け方を最大限にカスタマイズできます。

デフォルトのCheckpointWriter.identityMap (チェックポイントに自動的に生成されたテンソル パス名を使用する) や、ディクショナリからマッピングを構築できるCheckpointWriter.lookupMap(table:)関数など、いくつかの標準ヘルパー関数が提供されています。

カスタム マッピングを実現する方法の例については、 GPT-2 モデルを参照してください。このモデルでは、OpenAI のチェックポイントに使用される正確な命名スキームと一致するマッピング関数が使用されています。

CheckpointReader コンポーネントと CheckpointWriter コンポーネント

チェックポイント書き込みの場合、 Checkpointableプロトコルによって提供される拡張機能は、リフレクションとキーパスを使用してモデルのプロパティを反復し、文字列テンソル パスをテンソル値にマップする辞書を生成します。このディクショナリは、チェックポイントを書き込むディレクトリとともに、基礎となるCheckpointWriterに提供されます。そのCheckpointWriterは、そのディクショナリからディスク上のチェックポイントを生成するタスクを処理します。

このプロセスの逆は読み取りであり、 CheckpointReaderにディスク上のチェックポイント ディレクトリの場所が与えられます。次に、そのチェックポイントから読み取り、チェックポイント内のテンソルの名前とその保存された値をマップする辞書を形成します。このディクショナリは、モデル内の現在のテンソルをこのディクショナリ内のテンソルに置き換えるために使用されます。

ロードと保存の両方で、 Checkpointableプロトコルは、上記のマッピング関数を使用して、テンソルへの文字列パスを、対応するディスク上のテンソル名にマッピングします。

Checkpointableプロトコルに必要な機能が欠けている場合、またはチェックポイントのロードおよび保存プロセスをより詳細に制御する必要がある場合は、 CheckpointReaderクラスとCheckpointWriterクラスを単独で使用できます。

TensorFlow v2 チェックポイント形式

このヘッダーで簡単に説明されているように、TensorFlow v2 チェックポイント形式は、TensorFlow モデル チェックポイントの第 2 世代形式です。この第 2 世代形式は 2016 年末から使用されており、v1 チェックポイント形式に比べて多くの点が改善されています。 TensorFlow SavedModel は、その内部で v2 チェックポイントを使用してモデル パラメーターを保存します。

TensorFlow v2 チェックポイントは、次のような構造のディレクトリで構成されます。

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

最初のファイルはチェックポイントのメタデータを保存し、残りのファイルはモデルのシリアル化されたパラメーターを保持するバイナリ シャードです。

インデックス メタデータ ファイルには、シャードに含まれるすべてのシリアル化されたテンソルのタイプ、サイズ、場所、文字列名が含まれています。そのインデックス ファイルはチェックポイントの構造的に最も複雑な部分であり、 tensorflow::tableに基づいており、これ自体は SSTable / LevelDB に基づいています。このインデックス ファイルは一連のキーと値のペアで構成されており、キーは文字列、値はプロトコル バッファーです。文字列はソートされ、プレフィックス圧縮されます。たとえば、最初のエントリがconv1/weightで、次がconv1/biasの場合、2 番目のエントリはbias部分のみを使用します。

この全体的なインデックス ファイルは、 Snappy 圧縮を使用して圧縮される場合があります。 SnappyDecompression.swiftファイルは、圧縮データ インスタンスからの Snappy 解凍のネイティブ Swift 実装を提供します。

インデックス ヘッダー メタデータとテンソル メタデータはプロトコル バッファーとしてエンコードされ、 Swift Protobufを介して直接エンコード/デコードされます。

CheckpointIndexReaderクラスとCheckpointIndexWriterクラスは、包括的なCheckpointReaderCheckpointWriterクラスの一部として、これらのインデックス ファイルの読み込みと保存を処理します。後者は、テンソル データを含む構造的に単純なバイナリ シャードに対して何を読み書きするかを決定するための基礎としてインデックス ファイルを使用します。