Tham gia cùng chúng tôi tại DevFest cho Ukraine Ngày 14-15 tháng 6 Trực tuyến Đăng ký ngay

Thành phần đường ống TFX của Trainer

Thành phần đường ống Trainer TFX huấn luyện mô hình TensorFlow.

Trainer và TensorFlow

Trainer làm cho sử dụng rộng rãi của Python TensorFlow API cho các mô hình đào tạo.

Thành phần

Huấn luyện viên có:

  • tf. Các ví dụ được sử dụng để đào tạo và đánh giá.
  • Tệp mô-đun do người dùng cung cấp xác định logic của trình huấn luyện.
  • Protobuf định nghĩa của args xe lửa và args eval.
  • (Tùy chọn) Một lược đồ dữ liệu được tạo bởi thành phần đường dẫn SchemaGen và được nhà phát triển thay đổi tùy chọn.
  • (Tùy chọn) đồ thị biến đổi do thành phần Biến đổi ngược dòng tạo ra.
  • (Tùy chọn) các mô hình được đào tạo trước được sử dụng cho các tình huống như khởi động lại.
  • (Tùy chọn) siêu tham số, sẽ được chuyển cho chức năng mô-đun người dùng. Chi tiết về hội nhập với Tuner có thể được tìm thấy ở đây .

Người huấn luyện phát ra: Ít nhất một mô hình cho suy luận / phục vụ (thường là trong SavedModelFormat) và tùy chọn một mô hình khác cho eval (thường là EvalSavedModel).

Chúng tôi cung cấp hỗ trợ cho các định dạng mô hình thay thế như TFLite qua mẫu Viết lại Thư viện . Xem liên kết tới Thư viện Viết lại Mô hình để biết ví dụ về cách chuyển đổi cả hai mô hình Công cụ ước tính và Keras.

Huấn luyện viên Chung

Trình huấn luyện chung cho phép các nhà phát triển sử dụng bất kỳ API mô hình TensorFlow nào với thành phần Huấn luyện viên. Ngoài Công cụ ước tính TensorFlow, các nhà phát triển có thể sử dụng mô hình Keras hoặc các vòng huấn luyện tùy chỉnh. Để biết chi tiết, vui lòng xem RFC cho huấn luyện chung .

Cấu hình thành phần huấn luyện viên

Mã DSL đường ống điển hình cho Trainer chung sẽ giống như sau:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer gọi một mô-đun đào tạo, được quy định tại các module_file tham số. Thay vì trainer_fn , một run_fn là cần thiết trong file mô-đun nếu GenericExecutor được quy định trong custom_executor_spec . Các trainer_fn chịu trách nhiệm cho việc tạo mô hình. Bên cạnh đó, run_fn cũng cần phải xử lý phần đào tạo và đầu ra mô hình đào tạo để một vị trí mong muốn do FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Dưới đây là một ví dụ tập tin mô-đun với run_fn .

Lưu ý rằng nếu thành phần Chuyển đổi không được sử dụng trong đường dẫn, thì Người huấn luyện sẽ trực tiếp lấy các ví dụ từ ExampleGen:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Thông tin chi tiết có sẵn trong các tài liệu tham khảo Trainer API .