Trainer TFX パイプライン コンポーネントは、TensorFlow モデルをトレーニングします。
トレーナーと TensorFlow
Trainer は、モデルのトレーニングに Python TensorFlow API を広範囲に使用します。
成分
トレーナーは次のことを行います。
- tf.トレーニングと評価に使用される例。
- トレーナー ロジックを定義するユーザー指定のモジュール ファイル。
- train 引数と eval 引数のProtobuf定義。
- (オプション) SchemaGen パイプライン コンポーネントによって作成され、必要に応じて開発者によって変更されるデータ スキーマ。
- (オプション) 上流の Transform コンポーネントによって生成される変換グラフ。
- (オプション) ウォームスタートなどのシナリオに使用される事前トレーニングされたモデル。
- (オプション) ハイパーパラメータ。ユーザー モジュール関数に渡されます。 Tuner との統合の詳細については、ここを参照してください。
トレーナーは、推論/提供用の少なくとも 1 つのモデル (通常は SavedModelFormat) と、オプションで eval 用の別のモデル (通常は EvalSavedModel) を出力します。
モデル書き換えライブラリを通じて、 TFLiteなどの代替モデル形式のサポートを提供します。 Estimator モデルと Keras モデルの両方を変換する方法の例については、モデル書き換えライブラリへのリンクを参照してください。
ジェネリックトレーナー
汎用トレーナーを使用すると、開発者は任意の TensorFlow モデル API を Trainer コンポーネントで使用できます。 TensorFlow Estimator に加えて、開発者は Keras モデルまたはカスタム トレーニング ループを使用できます。詳細については、汎用トレーナーの RFCを参照してください。
トレーナーコンポーネントの構成
汎用トレーナーの一般的なパイプライン DSL コードは次のようになります。
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
トレーナーは、 module_file
パラメーターで指定されたトレーニング モジュールを呼び出します。 GenericExecutor
がcustom_executor_spec
で指定されている場合は、 module ファイル内にtrainer_fn
の代わりにrun_fn
が必要です。 trainer_fn
モデルの作成を担当しました。それに加えて、 run_fn
トレーニング部分を処理し、トレーニングされたモデルをFnArgsで指定された目的の場所に出力する必要もあります。
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
run_fn
を含むモジュール ファイルの例を次に示します。
Transform コンポーネントがパイプラインで使用されていない場合、トレーナーは ExampleGen からサンプルを直接取得することに注意してください。
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
詳細については、 「トレーナー API リファレンス」を参照してください。