도움말 Kaggle에 TensorFlow과 그레이트 배리어 리프 (Great Barrier Reef)를 보호하기 도전에 참여

Trainer TFX 파이프라인 구성 요소

Trainer TFX 파이프라인 구성 요소는 TensorFlow 모델을 훈련합니다.

트레이너와 TensorFlow

트레이너는 파이썬의 광범위한 사용하게 TensorFlow의 교육 모델에 대한 API를.

요소

트레이너는 다음을 수행합니다.

  • 훈련 및 평가에 사용되는 tf.Examples.
  • 트레이너 로직을 정의하는 사용자 제공 모듈 파일입니다.
  • Protobuf 기차 인수 및 평가 인수의 정의.
  • (선택 사항) SchemaGen 파이프라인 구성 요소에 의해 생성되고 개발자가 선택적으로 변경한 데이터 스키마입니다.
  • (선택 사항) 업스트림 Transform 구성 요소에서 생성된 변환 그래프.
  • (선택 사항) 웜스타트와 같은 시나리오에 사용되는 사전 훈련된 모델.
  • (선택 사항) 사용자 모듈 함수에 전달될 하이퍼파라미터. 튜너와 통합의 세부 사항은 찾을 수 있습니다 여기에 .

트레이너는 다음을 내보냅니다. 추론/서빙을 위한 하나 이상의 모델(일반적으로 SavedModelFormat) 및 선택적으로 eval용 다른 모델(일반적으로 EvalSavedModel).

We provide support for alternate model formats such as TFLite through the Model Rewriting Library . Estimator 및 Keras 모델을 모두 변환하는 방법에 대한 예제는 Model Rewriting Library 링크를 참조하십시오.

일반 트레이너

일반 트레이너를 사용하면 개발자가 Trainer 구성 요소와 함께 모든 TensorFlow 모델 API를 사용할 수 있습니다. TensorFlow Estimators 외에도 개발자는 Keras 모델 또는 사용자 지정 교육 루프를 사용할 수 있습니다. 자세한 내용은 참조하십시오 일반적인 트레이너에 대한 RFC를 .

트레이너 구성 요소 구성

일반 Trainer의 일반적인 파이프라인 DSL 코드는 다음과 같습니다.

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer invokes a training module, which is specified in the module_file parameter. 대신에 trainer_fn 하는 run_fn 경우 생성 모듈 파일에 필요한 GenericExecutor 에 지정된되어 custom_executor_spec . trainer_fn 모델을 만들기위한 책임. 그뿐만 아니라, run_fn 또한 교육 부분과 출력에 의해 주어진 원하는 위치로 훈련 모델을 처리해야 FnArgs를 :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

여기이다 예 모듈 파일run_fn .

Transform 구성 요소가 파이프라인에서 사용되지 않는 경우 Trainer는 ExampleGen에서 직접 예제를 가져옵니다.

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

자세한 내용은에서 사용할 수있는 트레이너 API 참조 .