Komponen pipeline TFX Trainer melatih model TensorFlow.
Pelatih dan TensorFlow
Trainer membuat ekstensif menggunakan Python TensorFlow API untuk model pelatihan.
Komponen
Pelatih mengambil:
- tf.Contoh yang digunakan untuk pelatihan dan evaluasi.
- File modul yang disediakan pengguna yang mendefinisikan logika pelatih.
- Protobuf definisi args kereta api dan args eval.
- (Opsional) Skema data yang dibuat oleh komponen pipeline SchemaGen dan secara opsional diubah oleh developer.
- (Opsional) grafik transformasi yang dihasilkan oleh komponen Transform upstream.
- (Opsional) model terlatih yang digunakan untuk skenario seperti warmstart.
- (Opsional) hyperparameter, yang akan diteruskan ke fungsi modul pengguna. Rincian integrasi dengan Tuner dapat ditemukan di sini .
Pelatih memancarkan: Setidaknya satu model untuk inferensi/penyajian (biasanya di SavedModelFormat) dan secara opsional model lain untuk eval (biasanya EvalSavedModel).
Kami menyediakan dukungan untuk format Model alternatif seperti TFLite melalui Model Rewriting Perpustakaan . Lihat tautan ke Model Rewriting Library untuk contoh cara mengonversi model Estimator dan Keras.
Pelatih Umum
Pelatih umum memungkinkan pengembang menggunakan API model TensorFlow apa pun dengan komponen Pelatih. Selain Penaksir TensorFlow, pengembang dapat menggunakan model Keras atau loop pelatihan khusus. Untuk detail, silakan lihat RFC untuk pelatih generik .
Mengonfigurasi Komponen Pelatih
Kode DSL pipeline umum untuk Pelatih generik akan terlihat seperti ini:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Trainer memanggil modul pelatihan, yang ditentukan dalam module_file
parameter. Alih-alih trainer_fn
, sebuah run_fn
diperlukan dalam file modul jika GenericExecutor
ditentukan dalam custom_executor_spec
. The trainer_fn
bertanggung jawab untuk menciptakan model. Selain itu, run_fn
juga perlu untuk menangani bagian pelatihan dan output model dilatih untuk lokasi yang diinginkan diberikan oleh FnArgs :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
Berikut adalah modul contoh file dengan run_fn
.
Perhatikan bahwa jika komponen Transform tidak digunakan dalam pipeline, maka Pelatih akan mengambil contoh dari ExampleGen secara langsung:
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Keterangan lebih lanjut tersedia di Trainer API referensi .