Komponen Pipa TFX Pelatih

Komponen pipeline TFX Trainer melatih model TensorFlow.

Pelatih dan TensorFlow

Trainer membuat ekstensif menggunakan Python TensorFlow API untuk model pelatihan.

Komponen

Pelatih mengambil:

  • tf.Contoh yang digunakan untuk pelatihan dan evaluasi.
  • File modul yang disediakan pengguna yang mendefinisikan logika pelatih.
  • Protobuf definisi args kereta api dan args eval.
  • (Opsional) Skema data yang dibuat oleh komponen pipeline SchemaGen dan secara opsional diubah oleh developer.
  • (Opsional) grafik transformasi yang dihasilkan oleh komponen Transform upstream.
  • (Opsional) model terlatih yang digunakan untuk skenario seperti warmstart.
  • (Opsional) hyperparameter, yang akan diteruskan ke fungsi modul pengguna. Rincian integrasi dengan Tuner dapat ditemukan di sini .

Pelatih memancarkan: Setidaknya satu model untuk inferensi/penyajian (biasanya di SavedModelFormat) dan secara opsional model lain untuk eval (biasanya EvalSavedModel).

Kami menyediakan dukungan untuk format Model alternatif seperti TFLite melalui Model Rewriting Perpustakaan . Lihat tautan ke Model Rewriting Library untuk contoh cara mengonversi model Estimator dan Keras.

Pelatih Umum

Pelatih umum memungkinkan pengembang menggunakan API model TensorFlow apa pun dengan komponen Pelatih. Selain Penaksir TensorFlow, pengembang dapat menggunakan model Keras atau loop pelatihan khusus. Untuk detail, silakan lihat RFC untuk pelatih generik .

Mengonfigurasi Komponen Pelatih

Kode DSL pipeline umum untuk Pelatih generik akan terlihat seperti ini:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer memanggil modul pelatihan, yang ditentukan dalam module_file parameter. Alih-alih trainer_fn , sebuah run_fn diperlukan dalam file modul jika GenericExecutor ditentukan dalam custom_executor_spec . The trainer_fn bertanggung jawab untuk menciptakan model. Selain itu, run_fn juga perlu untuk menangani bagian pelatihan dan output model dilatih untuk lokasi yang diinginkan diberikan oleh FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Berikut adalah modul contoh file dengan run_fn .

Perhatikan bahwa jika komponen Transform tidak digunakan dalam pipeline, maka Pelatih akan mengambil contoh dari ExampleGen secara langsung:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Keterangan lebih lanjut tersedia di Trainer API referensi .