Join the SIG TFX-Addons community and help make TFX even better!

Simple TFX Pipeline Tutorial using Penguin dataset

A Short tutorial to run a simple TFX pipeline.

In this notebook-based tutorial, we will create and run a TFX pipeline for a simple classification model. The pipeline will consist of three essential TFX components: ExampleGen, Trainer and Pusher. The pipeline includes the most minimal ML workflow like importing data, training a model and exporting the trained model.

Please see Understanding TFX Pipelines to learn more about various concepts in TFX.

Set Up

We first need to install the TFX Python package and download the dataset which we will use for our model.

Upgrade Pip

To avoid upgrading Pip in a system when running locally, check to make sure that we are running in Colab. Local systems can of course be upgraded separately.

try:
  import colab
  !pip install -q --upgrade pip
except:
  pass

Install TFX

# TODO(b/178712706): Stop using legacy resolver after PIP issue is resolved.
pip install -q -U --use-deprecated=legacy-resolver tfx

Did you restart the runtime?

If you are using Google Colab, the first time that you run the cell above, you must restart the runtime by clicking above "RESTART RUNTIME" button or using "Runtime > Restart runtime ..." menu. This is because of the way that Colab loads packages.

Check the TensorFlow and TFX versions.

import tensorflow as tf
print('TensorFlow version: {}'.format(tf.__version__))
import tfx
print('TFX version: {}'.format(tfx.__version__))
TensorFlow version: 2.4.1
TFX version: 0.27.0

Set up variables

There are some variables used to define a pipeline. You can customize these variables as you want. By default all output from the pipeline will be generated under the current directory.

import os

PIPELINE_NAME = "penguin-simple"

# Output directory to store artifacts generated from the pipeline.
PIPELINE_ROOT = os.path.join('pipelines', PIPELINE_NAME)
# Path to a SQLite DB file to use as an MLMD storage.
METADATA_PATH = os.path.join('metadata', PIPELINE_NAME, 'metadata.db')
# Output directory where created models from the pipeline will be exported.
SERVING_MODEL_DIR = os.path.join('serving_model', PIPELINE_NAME)

from absl import logging
logging.set_verbosity(logging.INFO)  # Set default logging level.

Prepare example data

We will download the example dataset for use in our TFX pipeline. The dataset we are using is Palmer Penguins dataset which is also used in other TFX examples.

There are four numeric features in this dataset:

  • culmen_length_mm
  • culmen_depth_mm
  • flipper_length_mm
  • body_mass_g

All features were already normalized to have range [0,1]. We will build a classification model which predicts the species of penguins.

Because TFX ExampleGen reads inputs from a directory, we need to create a directory and copy dataset to it.

import urllib.request
import tempfile

DATA_ROOT = tempfile.mkdtemp(prefix='tfx-data')  # Create a temporary directory.
_data_url = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/penguin/data/penguins_processed.csv'
_data_filepath = os.path.join(DATA_ROOT, "data.csv")
urllib.request.urlretrieve(_data_url, _data_filepath)
('/tmp/tfx-data8oul1c7e/data.csv', <http.client.HTTPMessage at 0x7f20f9900208>)

Take a quick look at the CSV file.

head {_data_filepath}
species,culmen_length_mm,culmen_depth_mm,flipper_length_mm,body_mass_g
0,0.2545454545454545,0.6666666666666666,0.15254237288135594,0.2916666666666667
0,0.26909090909090905,0.5119047619047618,0.23728813559322035,0.3055555555555556
0,0.29818181818181805,0.5833333333333334,0.3898305084745763,0.1527777777777778
0,0.16727272727272732,0.7380952380952381,0.3559322033898305,0.20833333333333334
0,0.26181818181818167,0.892857142857143,0.3050847457627119,0.2638888888888889
0,0.24727272727272717,0.5595238095238096,0.15254237288135594,0.2569444444444444
0,0.25818181818181823,0.773809523809524,0.3898305084745763,0.5486111111111112
0,0.32727272727272727,0.5357142857142859,0.1694915254237288,0.1388888888888889
0,0.23636363636363636,0.9642857142857142,0.3220338983050847,0.3055555555555556

You should be able to see five values. species is one of 0, 1 or 2, and all other features should have values between 0 and 1.

Create a pipeline

TFX pipelines are defined using Python APIs. We will define a pipeline which consists of following three components.

  • CsvExampleGen: Reads in data files and convert them to TFX internal format for further processing. There are multiple ExampleGens for various formats. In this tutorial, we will use CsvExampleGen which takes CSV file input.
  • Trainer: Trains an ML model. Trainer component requires a model definition code from users. You can use TensorFlow APIs to specify how to train a model and save it in a _savedmodel format.
  • Pusher: Copies the trained model outside of the TFX pipeline. Pusher component can be thought of an deployment process of the trained ML model.

Before actually define the pipeline, we need to write a model code for the Trainer component first.

Write model training code

We will create a simple DNN model for classification using TensorFlow Keras API. This model training code will be saved to a separate file.

In this tutorial we will use Generic Trainer of TFX which support Keras-based models. You need to write a Python file containing run_fn function, which is the entrypoint for the Trainer component.

_trainer_module_file = 'penguin_trainer.py'
%%writefile {_trainer_module_file}

from typing import List
from absl import logging
import tensorflow as tf
from tensorflow import keras
from tensorflow_transform.tf_metadata import schema_utils

from tfx.components.trainer.executor import TrainerFnArgs
from tfx.components.trainer.fn_args_utils import DataAccessor
from tfx_bsl.tfxio import dataset_options
from tensorflow_metadata.proto.v0 import schema_pb2

_FEATURE_KEYS = [
    'culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g'
]
_LABEL_KEY = 'species'

_TRAIN_BATCH_SIZE = 20
_EVAL_BATCH_SIZE = 10

# Since we're not generating or creating a schema, we will instead create
# a feature spec.  Since there are a fairly small number of features this is
# manageable for this dataset.
_FEATURE_SPEC = {
    **{
        feature: tf.io.FixedLenFeature(shape=[1], dtype=tf.float32)
           for feature in _FEATURE_KEYS
       },
    _LABEL_KEY: tf.io.FixedLenFeature(shape=[1], dtype=tf.int64)
}


def _input_fn(file_pattern: List[str],
              data_accessor: DataAccessor,
              schema: schema_pb2.Schema,
              batch_size: int = 200) -> tf.data.Dataset:
  """Generates features and label for training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    schema: schema of the input data.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
  return data_accessor.tf_dataset_factory(
      file_pattern,
      dataset_options.TensorFlowDatasetOptions(
          batch_size=batch_size, label_key=_LABEL_KEY),
      schema=schema).repeat()


def _build_keras_model() -> tf.keras.Model:
  """Creates a DNN Keras model for classifying penguin data.

  Returns:
    A Keras Model.
  """
  # The model below is built with Functional API, please refer to
  # https://www.tensorflow.org/guide/keras/overview for all API options.
  inputs = [keras.layers.Input(shape=(1,), name=f) for f in _FEATURE_KEYS]
  d = keras.layers.concatenate(inputs)
  for _ in range(2):
    d = keras.layers.Dense(8, activation='relu')(d)
  outputs = keras.layers.Dense(3)(d)

  model = keras.Model(inputs=inputs, outputs=outputs)
  model.compile(
      optimizer=keras.optimizers.Adam(1e-2),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[keras.metrics.SparseCategoricalAccuracy()])

  model.summary(print_fn=logging.info)
  return model


# TFX Trainer will call this function.
def run_fn(fn_args: TrainerFnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """

  # This schema is usually either an output of SchemaGen or a manually-curated
  # version provided by pipeline author. A schema can also derived from TFT
  # graph if a Transform component is used. In the case when either is missing,
  # `schema_from_feature_spec` could be used to generate schema from very simple
  # feature_spec, but the schema returned would be very primitive.
  schema = schema_utils.schema_from_feature_spec(_FEATURE_SPEC)

  train_dataset = _input_fn(
      fn_args.train_files,
      fn_args.data_accessor,
      schema,
      batch_size=_TRAIN_BATCH_SIZE)
  eval_dataset = _input_fn(
      fn_args.eval_files,
      fn_args.data_accessor,
      schema,
      batch_size=_EVAL_BATCH_SIZE)

  model = _build_keras_model()
  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps)

  # The result of the training should be saved in `fn_args.serving_model_dir`
  # directory.
  model.save(fn_args.serving_model_dir, save_format='tf')
Writing penguin_trainer.py

Now you have completed all preparation steps to build a TFX pipeline.

Write a pipeline definition

We define a function to create a TFX pipeline. A Pipeline object represents a TFX pipeline which can be run using one of pipeline orchestration systems that TFX supports.

from tfx.components import CsvExampleGen
from tfx.components import Pusher
from tfx.components import Trainer
from tfx.components.trainer.executor import GenericExecutor
from tfx.dsl.components.base import executor_spec
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2

def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
                     module_file: str, serving_model_dir: str,
                     metadata_path: str) -> pipeline.Pipeline:
  """Creates a three component penguin pipeline with TFX."""
  # Brings data into the pipeline.
  example_gen = CsvExampleGen(input_base=data_root)

  # Uses user-provided Python function that trains a model.
  trainer = Trainer(
      module_file=module_file,
      custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
      examples=example_gen.outputs['examples'],
      train_args=trainer_pb2.TrainArgs(num_steps=100),
      eval_args=trainer_pb2.EvalArgs(num_steps=5))

  # Pushes the model to a filesystem destination.
  pusher = Pusher(
      model=trainer.outputs['model'],
      push_destination=pusher_pb2.PushDestination(
          filesystem=pusher_pb2.PushDestination.Filesystem(
              base_directory=serving_model_dir)))

  # Following three components will be included in the pipeline.
  components = [
      example_gen,
      trainer,
      pusher,
  ]

  return pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      metadata_connection_config=metadata.sqlite_metadata_connection_config(
          metadata_path),
      components=components)
INFO:absl:tensorflow_ranking is not available: No module named 'tensorflow_ranking'
INFO:absl:tensorflow_text is not available: No module named 'tensorflow_text'
WARNING:absl:RuntimeParameter is only supported on Cloud-based DAG runner currently.

Run the pipeline

TFX supports multiple orchestrators to run pipelines. In this tutorial we will use LocalDagRunner which is included in the TFX Python package and runs pipelines on local environment. We often call TFX pipelines "DAGs" which stands for directed acyclic graph.

LocalDagRunner provides fast iterations for developemnt and debugging. TFX also supports other orchestrators including Kubeflow Pipelines and Apache Airflow which are suitable for production use cases.

See TFX on Cloud AI Platform Pipelines or TFX Airflow Tutorial to learn more about other orchestration systems.

Now we create a LocalDagRunner and pass a Pipeline object created from the function we already defined.

The pipeline runs directly and you can see logs for the progress of the pipeline including ML model training.

import os
from tfx.orchestration.local import local_dag_runner

local_dag_runner.LocalDagRunner().run(
  _create_pipeline(
      pipeline_name=PIPELINE_NAME,
      pipeline_root=PIPELINE_ROOT,
      data_root=DATA_ROOT,
      module_file=_trainer_module_file,
      serving_model_dir=SERVING_MODEL_DIR,
      metadata_path=METADATA_PATH))
INFO:absl:Component CsvExampleGen is running.
INFO:absl:Running driver for CsvExampleGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:select span and version = (0, None)
INFO:absl:latest span and version = (0, None)
INFO:absl:Running executor for CsvExampleGen
INFO:absl:Generating examples.
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
INFO:absl:Processing input csv data /tmp/tfx-data8oul1c7e/* to TFExample.
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
INFO:absl:Examples generated.
INFO:absl:Running publisher for CsvExampleGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component CsvExampleGen is finished.
INFO:absl:Component Trainer is running.
INFO:absl:Running driver for Trainer
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Trainer
INFO:absl:Train on the 'train' split when train_args.splits is not set.
INFO:absl:Evaluate on the 'eval' split when eval_args.splits is not set.
INFO:absl:Loading penguin_trainer.py because it has not been loaded before.
INFO:absl:Training model.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Model: "model"
INFO:absl:__________________________________________________________________________________________________
INFO:absl:Layer (type)                    Output Shape         Param #     Connected to                     
INFO:absl:==================================================================================================
INFO:absl:culmen_length_mm (InputLayer)   [(None, 1)]          0                                            
INFO:absl:__________________________________________________________________________________________________
INFO:absl:culmen_depth_mm (InputLayer)    [(None, 1)]          0                                            
INFO:absl:__________________________________________________________________________________________________
INFO:absl:flipper_length_mm (InputLayer)  [(None, 1)]          0                                            
INFO:absl:__________________________________________________________________________________________________
INFO:absl:body_mass_g (InputLayer)        [(None, 1)]          0                                            
INFO:absl:__________________________________________________________________________________________________
INFO:absl:concatenate (Concatenate)       (None, 4)            0           culmen_length_mm[0][0]           
INFO:absl:                                                                 culmen_depth_mm[0][0]            
INFO:absl:                                                                 flipper_length_mm[0][0]          
INFO:absl:                                                                 body_mass_g[0][0]                
INFO:absl:__________________________________________________________________________________________________
INFO:absl:dense (Dense)                   (None, 8)            40          concatenate[0][0]                
INFO:absl:__________________________________________________________________________________________________
INFO:absl:dense_1 (Dense)                 (None, 8)            72          dense[0][0]                      
INFO:absl:__________________________________________________________________________________________________
INFO:absl:dense_2 (Dense)                 (None, 3)            27          dense_1[0][0]                    
INFO:absl:==================================================================================================
INFO:absl:Total params: 139
INFO:absl:Trainable params: 139
INFO:absl:Non-trainable params: 0
INFO:absl:__________________________________________________________________________________________________
100/100 [==============================] - 1s 6ms/step - loss: 0.7571 - sparse_categorical_accuracy: 0.6395 - val_loss: 0.1308 - val_sparse_categorical_accuracy: 0.9600
INFO:tensorflow:Assets written to: pipelines/penguin-simple/Trainer/model/2/serving_model_dir/assets
INFO:tensorflow:Assets written to: pipelines/penguin-simple/Trainer/model/2/serving_model_dir/assets
INFO:absl:Training complete. Model written to pipelines/penguin-simple/Trainer/model/2/serving_model_dir. ModelRun written to pipelines/penguin-simple/Trainer/model_run/2
INFO:absl:Running publisher for Trainer
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component Trainer is finished.
INFO:absl:Component Pusher is running.
INFO:absl:Running driver for Pusher
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Pusher
WARNING:absl:Pusher is going to push the model without validation. Consider using Evaluator or InfraValidator in your pipeline.
INFO:absl:Model version: 1615285163
INFO:absl:Model written to serving path serving_model/penguin-simple/1615285163.
INFO:absl:Model pushed to pipelines/penguin-simple/Pusher/pushed_model/3.
INFO:absl:Running publisher for Pusher
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component Pusher is finished.

You should see "INFO:absl:Component Pusher is finished." at the end of the logs if the pipeline finished successfully. Because Pusher component is the last component of the pipeline.

The pusher component pushes the trained model to the SERVING_MODEL_DIR which is the serving_model/penguin-simple directory if you did not change the variables in the previous steps. You can see the result from the file browser in the left-side panel in Colab, or using the following command:

# List files in created model directory.
find {SERVING_MODEL_DIR}
serving_model/penguin-simple
serving_model/penguin-simple/1615285163
serving_model/penguin-simple/1615285163/variables
serving_model/penguin-simple/1615285163/variables/variables.data-00000-of-00001
serving_model/penguin-simple/1615285163/variables/variables.index
serving_model/penguin-simple/1615285163/assets
serving_model/penguin-simple/1615285163/saved_model.pb

Next steps

You can find more resources on https://www.tensorflow.org/tfx/tutorials

Please see Understanding TFX Pipelines to learn more about various concepts in TFX.