トレーナー TFX パイプライン コンポーネント

Trainer TFX パイプライン コンポーネントは、TensorFlow モデルをトレーニングします。

トレーナーと TensorFlow

Trainer は、モデルのトレーニングに Python TensorFlow API を広範囲に使用します。

成分

トレーナーは次のことを行います。

  • tf.トレーニングと評価に使用される例。
  • トレーナー ロジックを定義するユーザー指定のモジュール ファイル。
  • train 引数と eval 引数のProtobuf定義。
  • (オプション) SchemaGen パイプライン コンポーネントによって作成され、必要に応じて開発者によって変更されるデータ スキーマ。
  • (オプション) 上流の Transform コンポーネントによって生成される変換グラフ。
  • (オプション) ウォームスタートなどのシナリオに使用される事前トレーニングされたモデル。
  • (オプション) ハイパーパラメータ。ユーザー モジュール関数に渡されます。 Tuner との統合の詳細については、ここを参照してください。

トレーナーは、推論/提供用の少なくとも 1 つのモデル (通常は SavedModelFormat) と、オプションで eval 用の別のモデル (通常は EvalSavedModel) を出力します。

モデル書き換えライブラリを通じて、 TFLiteなどの代替モデル形式のサポートを提供します。 Estimator モデルと Keras モデルの両方を変換する方法の例については、モデル書き換えライブラリへのリンクを参照してください。

ジェネリックトレーナー

汎用トレーナーを使用すると、開発者は任意の TensorFlow モデル API を Trainer コンポーネントで使用できます。 TensorFlow Estimator に加えて、開発者は Keras モデルまたはカスタム トレーニング ループを使用できます。詳細については、汎用トレーナーの RFCを参照してください。

トレーナーコンポーネントの構成

汎用トレーナーの一般的なパイプライン DSL コードは次のようになります。

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

トレーナーは、 module_fileパラメーターで指定されたトレーニング モジュールを呼び出します。 GenericExecutorcustom_executor_specで指定されている場合は、 module ファイル内にtrainer_fnの代わりにrun_fnが必要です。 trainer_fnモデルの作成を担当しました。それに加えて、 run_fnトレーニング部分を処理し、トレーニングされたモデルをFnArgsで指定された目的の場所に出力する必要もあります。

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

run_fnを含むモジュール ファイルの例を次に示します。

Transform コンポーネントがパイプラインで使用されていない場合、トレーナーは ExampleGen からサンプルを直接取得することに注意してください。

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

詳細については、 「トレーナー API リファレンス」を参照してください。