O componente de pipeline Trainer TFX treina um modelo do TensorFlow.
Treinador e TensorFlow
O Trainer faz uso extensivo da API Python TensorFlow para modelos de treinamento.
Componente
O treinador leva:
- tf.Exemplos usados para treinamento e avaliação.
- Um arquivo de módulo fornecido pelo usuário que define a lógica do treinador.
- Definição protobuf de train args e eval args.
- (Opcional) Um esquema de dados criado por um componente de pipeline SchemaGen e opcionalmente alterado pelo desenvolvedor.
- (Opcional) gráfico de transformação produzido por um componente Transform upstream.
- (Opcional) modelos pré-treinados usados para cenários como warmstart.
- Hiperparâmetros (opcional), que serão passados para a função do módulo do usuário. Detalhes da integração com o Tuner podem ser encontrados aqui .
O instrutor emite: Pelo menos um modelo para inferência/exibição (normalmente em SavedModelFormat) e opcionalmente outro modelo para avaliação (normalmente um EvalSavedModel).
Fornecemos suporte para formatos de modelo alternativos, como TFLite, por meio da Model Rewriting Library . Consulte o link para a Biblioteca de Reescrita de Modelo para obter exemplos de como converter os modelos Estimator e Keras.
Treinador genérico
O treinador genérico permite que os desenvolvedores usem qualquer API de modelo do TensorFlow com o componente Trainer. Além dos estimadores do TensorFlow, os desenvolvedores podem usar modelos Keras ou loops de treinamento personalizados. Para obter detalhes, consulte a RFC para treinador genérico .
Configurando o componente Trainer
O código DSL de pipeline típico para o Trainer genérico seria assim:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
O Trainer invoca um módulo de treinamento, que é especificado no parâmetro module_file
. Em vez de trainer_fn
, um run_fn
será necessário no arquivo do módulo se GenericExecutor
for especificado em custom_executor_spec
. O trainer_fn
foi o responsável pela criação do modelo. Além disso, run_fn
também precisa lidar com a parte de treinamento e enviar o modelo treinado para o local desejado fornecido por FnArgs :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
Aqui está um exemplo de arquivo de módulo com run_fn
.
Observe que se o componente Transform não for usado no pipeline, o Trainer pegaria os exemplos diretamente do ExampleGen:
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Mais detalhes estão disponíveis na referência da API Trainer .