RSVP for your your local TensorFlow Everywhere event today!

Trainer TFX 流水线组件

Trainer TFX 流水线组件用于训练 TensorFlow 模型。

Trainer 和 TensorFlow

Trainer 广泛使用 Python TensorFlow API 来训练模型。

注:TFX 支持 TensorFlow 1.15 和 2.x。

组件

Trainer 需要:

  • 用于训练和评估的 tf.Examples。
  • 由用户提供、用于定义 Trainer 逻辑的模块文件。
  • 由 SchemaGen 流水线组件创建并且可由开发者选择性更改的数据架构。
  • 训练参数和评估参数的 Protobuf 定义。
  • (可选)由上游 Transform 组件生成的转换计算图。
  • (可选)用于热启动等场景的预训练模型。
  • (可选)将传递给用户模块函数的超参数。可以在此处找到与 Tuner 集成的详细信息。

Trainer 发出:至少一个用于推断/应用的模型(通常在 SavedModelFormat 中),以及另一个用于评估的可选模型(通常是一个 EvalSavedModel)。

我们通过模型重写库TFLite 等替代模型格式提供支持。有关如何同时转换 Estimator 和 Keras 模型的示例,请点击模型重写库的链接。

基于 Estimator 的 Trainer

要详细了解如何将基于 Estimator 的模型与 TFX 和 Trainer 一起使用,请参阅使用 tf.Estimator 为 TFX 设计 TensorFlow 建模代码

配置 Trainer 组件

典型的流水线 Python DSL 代码如下所示:

from tfx.components import Trainer

...

trainer = Trainer(
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      schema=infer_schema.outputs['schema'],
      base_models=latest_model_resolver.outputs['latest_model'],
      transform_graph=transform.outputs['transform_graph'],
      train_args=trainer_pb2.TrainArgs(num_steps=10000),
      eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer 调用一个在 module_file 参数中指定的训练模块。一个典型的训练模块如下所示:

# TFX will call this function
def trainer_fn(trainer_fn_args, schema):
  """Build the estimator using the high level API.

  Args:
    trainer_fn_args: Holds args used to train the model as name/value pairs.
    schema: Holds the schema of the training examples.

  Returns:
    A dict of the following:

      - estimator: The estimator that will be used for training and eval.
      - train_spec: Spec for training.
      - eval_spec: Spec for eval.
      - eval_input_receiver_fn: Input function for eval.
  """
  # Number of nodes in the first layer of the DNN
  first_dnn_layer_size = 100
  num_dnn_layers = 4
  dnn_decay_factor = 0.7

  train_batch_size = 40
  eval_batch_size = 40

  tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output)

  train_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.train_files,
      tf_transform_output,
      batch_size=train_batch_size)

  eval_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.eval_files,
      tf_transform_output,
      batch_size=eval_batch_size)

  train_spec = tf.estimator.TrainSpec(  # pylint: disable=g-long-lambda
      train_input_fn,
      max_steps=trainer_fn_args.train_steps)

  serving_receiver_fn = lambda: _example_serving_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_output, schema)

  exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=trainer_fn_args.eval_steps,
      exporters=[exporter],
      name='chicago-taxi-eval')

  run_config = tf.estimator.RunConfig(
      save_checkpoints_steps=999, keep_checkpoint_max=1)

  run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)
  warm_start_from = trainer_fn_args.base_models[
      0] if trainer_fn_args.base_models else None

  estimator = _build_estimator(
      # Construct layers sizes with exponetial decay
      hidden_units=[
          max(2, int(first_dnn_layer_size * dnn_decay_factor**i))
          for i in range(num_dnn_layers)
      ],
      config=run_config,
      warm_start_from=warm_start_from)

  # Create an input receiver for TFMA processing
  receiver_fn = lambda: _eval_input_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_output, schema)

  return {
      'estimator': estimator,
      'train_spec': train_spec,
      'eval_spec': eval_spec,
      'eval_input_receiver_fn': receiver_fn
  }

通用 Trainer

通用 Trainer 使开发者可以将任何 TensorFlow 模型 API 与 Trainer 组件一起使用。除了 TensorFlow Estimator 外,开发者还可以使用 Keras 模型或自定义训练循环。有关详细信息,请参阅通用 Trainer 的 RFC

配置 Trainer 组件以使用 GenericExecutor

通用 Trainer 的典型流水线 DSL 代码如下所示:

from tfx.components import Trainer
from tfx.components.base import executor_spec
from tfx.components.trainer.executor import GenericExecutor

...

trainer = Trainer(
    module_file=module_file,
    custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer 调用一个在 module_file 参数中指定的训练模块。如果在 custom_executor_spec 中指定了 GenericExecutor,则模块文件中需要 run_fn 而不是 trainer_fn

如果流水线中未使用 Transform 组件,则 Trainer 将直接从 ExampleGen 中获取样本:

trainer = Trainer(
    module_file=module_file,
    custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

这里提供了一个使用 run_fn模块文件示例