ส่วนประกอบไปป์ไลน์ TFX ของเทรนเนอร์

ส่วนประกอบไปป์ไลน์ Trainer TFX จะฝึกโมเดล TensorFlow

เทรนเนอร์และ TensorFlow

Trainer ใช้ Python TensorFlow API อย่างกว้างขวางสำหรับโมเดลการฝึก

ส่วนประกอบ

ผู้ฝึกสอนใช้เวลา:

  • tf.ตัวอย่างที่ใช้สำหรับการฝึกอบรมและประเมินผล
  • ไฟล์โมดูลที่ผู้ใช้จัดเตรียมไว้ซึ่งกำหนดตรรกะของผู้ฝึกสอน
  • คำจำกัดความ Protobuf ของ train args และ eval args
  • (ไม่บังคับ) สคีมาข้อมูลที่สร้างโดยส่วนประกอบไปป์ไลน์ SchemaGen และอาจแก้ไขโดยนักพัฒนา
  • (ทางเลือก) กราฟการแปลงที่สร้างโดยส่วนประกอบการแปลงต้นทาง
  • (ไม่บังคับ) โมเดลที่ได้รับการฝึกอบรมล่วงหน้าที่ใช้สำหรับสถานการณ์ เช่น วอร์มสตาร์ท
  • (ไม่บังคับ) ไฮเปอร์พารามิเตอร์ ซึ่งจะถูกส่งไปยังฟังก์ชันโมดูลผู้ใช้ ดูรายละเอียดของการทำงานร่วมกับ Tuner ได้ ที่นี่

ผู้ฝึกสอนส่งเสียง: อย่างน้อยหนึ่งรุ่นสำหรับการอนุมาน/การให้บริการ (โดยทั่วไปจะอยู่ใน SavedModelFormat) และอีกรุ่นหนึ่งสำหรับ eval (โดยทั่วไปคือ EvalSavedModel)

เราให้การสนับสนุนรูปแบบโมเดลทางเลือก เช่น TFLite ผ่านทาง Model Rewriting Library ดูลิงก์ไปยัง Model Rewriting Library สำหรับตัวอย่างวิธีแปลงทั้งโมเดล Estimator และ Keras

เทรนเนอร์ทั่วไป

โปรแกรมฝึกสอนทั่วไปช่วยให้นักพัฒนาสามารถใช้ API ของโมเดล TensorFlow กับส่วนประกอบของโปรแกรมฝึกสอนได้ นอกจาก TensorFlow Estimators แล้ว นักพัฒนายังสามารถใช้โมเดล Keras หรือลูปการฝึกแบบกำหนดเองได้ สำหรับรายละเอียด โปรดดู RFC สำหรับผู้ฝึกสอนทั่วไป

การกำหนดค่าส่วนประกอบเทรนเนอร์

รหัส DSL ไปป์ไลน์ทั่วไปสำหรับ Trainer ทั่วไปจะมีลักษณะดังนี้:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

ผู้ฝึกสอนเรียกใช้โมดูลการฝึกอบรม ซึ่งระบุไว้ในพารามิเตอร์ module_file แทนที่จะเป็น trainer_fn จำเป็นต้องใช้ run_fn ในไฟล์โมดูลหากระบุ GenericExecutor ใน custom_executor_spec trainer_fn มีหน้าที่รับผิดชอบในการสร้างแบบจำลอง นอกจากนั้น run_fn ยังต้องจัดการส่วนการฝึกและส่งออกโมเดลที่ได้รับการฝึกไปยังตำแหน่งที่ต้องการที่กำหนดโดย FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

นี่คือ ตัวอย่างไฟล์โมดูล ที่มี run_fn

โปรดทราบว่าหากไม่ได้ใช้ส่วนประกอบ Transform ในไปป์ไลน์ Trainer จะนำตัวอย่างจาก ExampleGen โดยตรง:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

มีรายละเอียดเพิ่มเติมใน ข้อมูลอ้างอิง Trainer API