Preprocess data with TensorFlow Transform

The Feature Engineering Component of TensorFlow Extended (TFX)

This example colab notebook provides a very simple example of how TensorFlow Transform (tf.Transform) can be used to preprocess data using exactly the same code for both training a model and serving inferences in production.

TensorFlow Transform is a library for preprocessing input data for TensorFlow, including creating features that require a full pass over the training dataset. For example, using TensorFlow Transform you could:

  • Normalize an input value by using the mean and standard deviation
  • Convert strings to integers by generating a vocabulary over all of the input values
  • Convert floats to integers by assigning them to buckets, based on the observed data distribution

TensorFlow has built-in support for manipulations on a single example or a batch of examples. tf.Transform extends these capabilities to support full passes over the entire training dataset.

The output of tf.Transform is exported as a TensorFlow graph which you can use for both training and serving. Using the same graph for both training and serving can prevent skew, since the same transformations are applied in both stages.

Upgrade Pip

To avoid upgrading Pip in a system when running locally, check to make sure that we're running in Colab. Local systems can of course be upgraded separately.

try:
  import colab
  !pip install --upgrade pip
except:
  pass

Install TensorFlow Transform

pip install -q -U tensorflow_transform
# This cell is only necessary because packages were installed while python was
# running. It avoids the need to restart the runtime when running in Colab.
import pkg_resources
import importlib

importlib.reload(pkg_resources)
/tmpfs/tmp/ipykernel_192169/639106435.py:3: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  import pkg_resources
<module 'pkg_resources' from '/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/pkg_resources/__init__.py'>

Imports

import pathlib
import pprint
import tempfile

import tensorflow as tf
import tensorflow_transform as tft

import tensorflow_transform.beam as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils
from tensorflow_transform.keras_lib import tf_keras
2024-04-30 10:54:48.029467: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-30 10:54:48.029516: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-30 10:54:48.030987: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Data: Create some dummy data

We'll create some simple dummy data for our simple example:

  • raw_data is the initial raw data that we're going to preprocess
  • raw_data_metadata contains the schema that tells us the types of each of the columns in raw_data. In this case, it's very simple.
raw_data = [
      {'x': 1, 'y': 1, 's': 'hello'},
      {'x': 2, 'y': 2, 's': 'world'},
      {'x': 3, 'y': 3, 's': 'hello'}
  ]

raw_data_metadata = dataset_metadata.DatasetMetadata(
    schema_utils.schema_from_feature_spec({
        'y': tf.io.FixedLenFeature([], tf.float32),
        'x': tf.io.FixedLenFeature([], tf.float32),
        's': tf.io.FixedLenFeature([], tf.string),
    }))

Transform: Create a preprocessing function

The preprocessing function is the most important concept of tf.Transform. A preprocessing function is where the transformation of the dataset really happens. It accepts and returns a dictionary of tensors, where a tensor means a Tensor or SparseTensor. There are two main groups of API calls that typically form the heart of a preprocessing function:

  1. TensorFlow Ops: Any function that accepts and returns tensors, which usually means TensorFlow ops. These add TensorFlow operations to the graph that transforms raw data into transformed data one feature vector at a time. These will run for every example, during both training and serving.
  2. Tensorflow Transform Analyzers/Mappers: Any of the analyzers/mappers provided by tf.Transform. These also accept and return tensors, and typically contain a combination of Tensorflow ops and Beam computation, but unlike TensorFlow ops they only run in the Beam pipeline during analysis requiring a full pass over the entire training dataset. The Beam computation runs only once, (prior to training, during analysis), and typically make a full pass over the entire training dataset. They create tf.constant tensors, which are added to your graph. For example, tft.min computes the minimum of a tensor over the training dataset.
def preprocessing_fn(inputs):
    """Preprocess input columns into transformed columns."""
    x = inputs['x']
    y = inputs['y']
    s = inputs['s']
    x_centered = x - tft.mean(x)
    y_normalized = tft.scale_to_0_1(y)
    s_integerized = tft.compute_and_apply_vocabulary(s)
    x_centered_times_y_normalized = (x_centered * y_normalized)
    return {
        'x_centered': x_centered,
        'y_normalized': y_normalized,
        's_integerized': s_integerized,
        'x_centered_times_y_normalized': x_centered_times_y_normalized,
    }

Syntax

You're almost ready to put everything together and use Apache Beam to run it.

Apache Beam uses a special syntax to define and invoke transforms. For example, in this line:

result = pass_this | 'name this step' >> to_this_call

The method to_this_call is being invoked and passed the object called pass_this, and this operation will be referred to as name this step in a stack trace. The result of the call to to_this_call is returned in result. You will often see stages of a pipeline chained together like this:

result = apache_beam.Pipeline() | 'first step' >> do_this_first() | 'second step' >> do_this_last()

and since that started with a new pipeline, you can continue like this:

next_result = result | 'doing more stuff' >> another_function()

Putting it all together

Now we're ready to transform our data. We'll use Apache Beam with a direct runner, and supply three inputs:

  1. raw_data - The raw input data that we created above
  2. raw_data_metadata - The schema for the raw data
  3. preprocessing_fn - The function that we created to do our transformation
def main(output_dir):
  # Ignore the warnings
  with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
    transformed_dataset, transform_fn = (  # pylint: disable=unused-variable
        (raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(
            preprocessing_fn))

  transformed_data, transformed_metadata = transformed_dataset  # pylint: disable=unused-variable

  # Save the transform_fn to the output_dir
  _ = (
      transform_fn
      | 'WriteTransformFn' >> tft_beam.WriteTransformFn(output_dir))

  return transformed_data, transformed_metadata
output_dir = pathlib.Path(tempfile.mkdtemp())

transformed_data, transformed_metadata = main(str(output_dir))

print('\nRaw data:\n{}\n'.format(pprint.pformat(raw_data)))
print('Transformed data:\n{}'.format(pprint.pformat(transformed_data)))
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
WARNING:absl:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:absl:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:absl:You are outputting instance dicts from `TransformDataset` which will not provide optimal performance. Consider setting  `output_record_batches=True` to upgrade to the TFXIO format (Apache Arrow RecordBatch). Encoding functionality in this module works with both formats.
WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpgsoge9im.json', '--HistoryManager.hist_file=:memory:']
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8s0_zhbm/tftransform_tmp/c576d13575254973b6f7263cfcf3ffc3/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8s0_zhbm/tftransform_tmp/c576d13575254973b6f7263cfcf3ffc3/assets
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8s0_zhbm/tftransform_tmp/b9fda3835766458d8e33d05f6357bed2/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8s0_zhbm/tftransform_tmp/b9fda3835766458d8e33d05f6357bed2/assets
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpgsoge9im.json', '--HistoryManager.hist_file=:memory:']
Raw data:
[{'s': 'hello', 'x': 1, 'y': 1},
 {'s': 'world', 'x': 2, 'y': 2},
 {'s': 'hello', 'x': 3, 'y': 3}]

Transformed data:
[{'s_integerized': 0,
  'x_centered': -1.0,
  'x_centered_times_y_normalized': -0.0,
  'y_normalized': 0.0},
 {'s_integerized': 1,
  'x_centered': 0.0,
  'x_centered_times_y_normalized': 0.0,
  'y_normalized': 0.5},
 {'s_integerized': 0,
  'x_centered': 1.0,
  'x_centered_times_y_normalized': 1.0,
  'y_normalized': 1.0}]

Is this the right answer?

Previously, we used tf.Transform to do this:

x_centered = x - tft.mean(x)
y_normalized = tft.scale_to_0_1(y)
s_integerized = tft.compute_and_apply_vocabulary(s)
x_centered_times_y_normalized = (x_centered * y_normalized)
  • x_centered - With input of [1, 2, 3] the mean of x is 2, and we subtract it from x to center our x values at 0. So our result of [-1.0, 0.0, 1.0] is correct.
  • y_normalized - We wanted to scale our y values between 0 and 1. Our input was [1, 2, 3] so our result of [0.0, 0.5, 1.0] is correct.
  • s_integerized - We wanted to map our strings to indexes in a vocabulary, and there were only 2 words in our vocabulary ("hello" and "world"). So with input of ["hello", "world", "hello"] our result of [0, 1, 0] is correct. Since "hello" occurs most frequently in this data, it will be the first entry in the vocabulary.
  • x_centered_times_y_normalized - We wanted to create a new feature by crossing x_centered and y_normalized using multiplication. Note that this multiplies the results, not the original values, and our new result of [-0.0, 0.0, 1.0] is correct.

Use the resulting transform_fn

ls -l {output_dir}
total 8
drwxr-xr-x 4 kbuilder kbuilder 4096 Apr 30 10:54 transform_fn
drwxr-xr-x 2 kbuilder kbuilder 4096 Apr 30 10:54 transformed_metadata

The transform_fn/ directory contains a tf.saved_model implementing with all the constants tensorflow-transform analysis results built into the graph.

It is possible to load this directly with tf.saved_model.load, but this not easy to use:

loaded = tf.saved_model.load(str(output_dir/'transform_fn'))
loaded.signatures['serving_default']
<ConcreteFunction (*, inputs: TensorSpec(shape=(None,), dtype=tf.string, name='inputs'), inputs_1: TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_1'), inputs_2: TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_2')) -> Dict[['x_centered', TensorSpec(shape=(None,), dtype=tf.float32, name='x_centered')], ['s_integerized', TensorSpec(shape=<unknown>, dtype=tf.int64, name='s_integerized')], ['x_centered_times_y_normalized', TensorSpec(shape=(None,), dtype=tf.float32, name='x_centered_times_y_normalized')], ['y_normalized', TensorSpec(shape=(None,), dtype=tf.float32, name='y_normalized')]] at 0x7F452C2C6400>

A better approach is to load it using tft.TFTransformOutput. The TFTransformOutput.transform_features_layer method returns a tft.TransformFeaturesLayer object that can be used to apply the transformation:

tf_transform_output = tft.TFTransformOutput(output_dir)

tft_layer = tf_transform_output.transform_features_layer()
tft_layer
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
<tensorflow_transform.output_wrapper.TransformFeaturesLayer at 0x7f46bc272700>

This tft.TransformFeaturesLayer expects a dictionary of batched features. So create a Dict[str, tf.Tensor] from the List[Dict[str, Any]] in raw_data:

raw_data_batch = {
    's': tf.constant([ex['s'] for ex in raw_data]),
    'x': tf.constant([ex['x'] for ex in raw_data], dtype=tf.float32),
    'y': tf.constant([ex['y'] for ex in raw_data], dtype=tf.float32),
}

You can use the tft.TransformFeaturesLayer on its own:

transformed_batch = tft_layer(raw_data_batch)

{key: value.numpy() for key, value in transformed_batch.items()}
{'x_centered': array([-1.,  0.,  1.], dtype=float32),
 'x_centered_times_y_normalized': array([-0.,  0.,  1.], dtype=float32),
 'y_normalized': array([0. , 0.5, 1. ], dtype=float32),
 's_integerized': array([0, 1, 0])}

Export

A more typical use case would use tf.Transform to apply the transformation to the training and evaluation datasets (see the next tutorial for an example). Then, after training, before exporting the model attach the tft.TransformFeaturesLayer as the first layer so that you can export it as part of your tf.saved_model. For a concrete example, keep reading.

An example training model

Below is a model that:

  1. takes the transformed batch,
  2. stacks them all together into a simple (batch, features) matrix,
  3. runs them through a few dense layers, and
  4. produces 10 linear outputs.

In a real use case you would apply a one-hot to the s_integerized feature.

You could train this model on a dataset transformed by tf.Transform:

class StackDict(tf_keras.layers.Layer):
  def call(self, inputs):
    values = [
        tf.cast(v, tf.float32)
        for k,v in sorted(inputs.items(), key=lambda kv: kv[0])]
    return tf.stack(values, axis=1)
class TrainedModel(tf_keras.Model):
  def __init__(self):
    super().__init__(self)
    self.concat = StackDict()
    self.body = tf_keras.Sequential([
        tf_keras.layers.Dense(64, activation='relu'),
        tf_keras.layers.Dense(64, activation='relu'),
        tf_keras.layers.Dense(10),
    ])

  def call(self, inputs, training=None):
    x = self.concat(inputs)
    return self.body(x, training)
trained_model = TrainedModel()

Imagine we trained the model.

trained_model.compile(loss=..., optimizer='adam')
trained_model.fit(...)

This model runs on the transformed inputs

trained_model_output = trained_model(transformed_batch)
trained_model_output.shape
TensorShape([3, 10])

An example export wrapper

Imagine you've trained the above model and want to export it.

You'll want to include the transform function in the exported model:

class ExportModel(tf.Module):
  def __init__(self, trained_model, input_transform):
    self.trained_model = trained_model
    self.input_transform = input_transform

  @tf.function
  def __call__(self, inputs, training=None):
    x = self.input_transform(inputs)
    return self.trained_model(x)
export_model = ExportModel(trained_model=trained_model,
                           input_transform=tft_layer)

This combined model works on the raw data, and produces exactly the same results as calling the trained model directly:

export_model_output = export_model(raw_data_batch)
export_model_output.shape
TensorShape([3, 10])
tf.reduce_max(abs(export_model_output - trained_model_output)).numpy()
0.0

This export_model includes the tft.TransformFeaturesLayer and is entirely self-contained. You can save it and restore it in another environment and still get exactly the same result:

import tempfile
model_dir = tempfile.mkdtemp(suffix='tft')

tf.saved_model.save(export_model, model_dir)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpjz93eylstft/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpjz93eylstft/assets
reloaded = tf.saved_model.load(model_dir)

reloaded_model_output = reloaded(raw_data_batch)
reloaded_model_output.shape
TensorShape([3, 10])
tf.reduce_max(abs(export_model_output - reloaded_model_output)).numpy()
0.0