Komponent potoku Trainer TFX

Komponent potoku Trainer TFX szkoli model TensorFlow.

Trener i TensorFlow

Trainer w szerokim zakresie wykorzystuje interfejs API Pythona TensorFlow do modeli szkoleniowych.

Część

Trener bierze:

  • tf.Przykłady użyte do szkolenia i ewaluacji.
  • Plik modułu dostarczony przez użytkownika, który definiuje logikę trenera.
  • Definicja protobufa argumentów pociągowych i argumentów ewaluacyjnych.
  • (Opcjonalnie) Schemat danych utworzony przez komponent potoku SchemaGen i opcjonalnie zmieniony przez programistę.
  • (Opcjonalnie) wykres transformacji utworzony przez nadrzędny komponent Transform.
  • (Opcjonalnie) wstępnie przeszkolone modele używane w scenariuszach, takich jak ciepły start.
  • (Opcjonalnie) hiperparametry, które zostaną przekazane do funkcji modułu użytkownika. Szczegóły integracji z Tunerem znajdziesz tutaj .

Trener emituje: Co najmniej jeden model do wnioskowania/obsługiwania (zwykle w formacie SavedModelFormat) i opcjonalnie inny model do eval (zazwyczaj EvalSavedModel).

Zapewniamy obsługę alternatywnych formatów modeli, takich jak TFLite, za pośrednictwem biblioteki przepisywania modeli . Zobacz łącze do Biblioteki przepisywania modeli, aby zapoznać się z przykładami konwertowania modeli estymatora i modelu Keras.

Trener ogólny

Generic trainer umożliwia programistom korzystanie z dowolnego interfejsu API modelu TensorFlow z komponentem Trainer. Oprócz estymatorów TensorFlow programiści mogą korzystać z modeli Keras lub niestandardowych pętli szkoleniowych. Szczegółowe informacje można znaleźć w dokumencie RFC dotyczącym generycznego trenera .

Konfiguracja komponentu Trainer

Typowy kod DSL potoku dla ogólnego Trainera wyglądałby następująco:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trener wywołuje moduł szkoleniowy, który jest określony w parametrze module_file . Zamiast trainer_fn , w pliku modułu wymagany jest run_fn , jeśli w pliku custom_executor_spec określono GenericExecutor . Za stworzenie modelu odpowiedzialny był trainer_fn . Oprócz tego run_fn musi również obsłużyć część szkoleniową i wyprowadzić przeszkolony model do żądanej lokalizacji określonej przez FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Oto przykładowy plik modułu z run_fn .

Należy pamiętać, że jeśli w potoku nie zostanie użyty komponent Transform, Trainer pobierze przykłady bezpośrednio z PrzykładGen:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Więcej szczegółów można znaleźć w dokumentacji Trainer API .