This guide introduces the basic concepts of tf.Transform
and how to use them.
It will:
- Define a preprocessing function, a logical description of the pipeline that transforms the raw data into the data used to train a machine learning model.
- Show the Apache Beam implementation used to transform data by converting the preprocessing function into a Beam pipeline.
- Show additional usage examples.
Setup
pip install -U tensorflow_transform
pip install pyarrow
import pkg_resources
import importlib
importlib.reload(pkg_resources)
<module 'pkg_resources' from '/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/pkg_resources/__init__.py'>
import os
import tempfile
import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils
from tfx_bsl.public import tfxio
2023-04-13 09:15:54.685940: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2023-04-13 09:15:54.686060: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2023-04-13 09:15:54.686073: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Define a preprocessing function
The preprocessing function is the most important concept of tf.Transform
.
The preprocessing function is a logical description of a transformation of the
dataset. The preprocessing function accepts and returns a dictionary of tensors,
where a tensor means Tensor
or SparseTensor
. There are two kinds of
functions used to define the preprocessing function:
- Any function that accepts and returns tensors. These add TensorFlow operations to the graph that transform raw data into transformed data.
- Any of the analyzers provided by
tf.Transform
. Analyzers also accept and return tensors, but unlike TensorFlow functions, they do not add operations to the graph. Instead, analyzers causetf.Transform
to compute a full-pass operation outside of TensorFlow. They use the input tensor values over the entire dataset to generate a constant tensor that is returned as the output. For example,tft.min
computes the minimum of a tensor over the dataset.tf.Transform
provides a fixed set of analyzers, but this will be extended in future versions.
Preprocessing function example
By combining analyzers and regular TensorFlow functions, users can create flexible pipelines for transforming data. The following preprocessing function transforms each of the three features in different ways, and combines two of the features:
def preprocessing_fn(inputs):
x = inputs['x']
y = inputs['y']
s = inputs['s']
x_centered = x - tft.mean(x)
y_normalized = tft.scale_to_0_1(y)
s_integerized = tft.compute_and_apply_vocabulary(s)
x_centered_times_y_normalized = x_centered * y_normalized
return {
'x_centered': x_centered,
'y_normalized': y_normalized,
'x_centered_times_y_normalized': x_centered_times_y_normalized,
's_integerized': s_integerized
}
Here, x
, y
and s
are Tensor
s that represent input features. The first
new tensor that is created, x_centered
, is built by applying tft.mean
to x
and subtracting this from x
. tft.mean(x)
returns a tensor representing the
mean of the tensor x
. x_centered
is the tensor x
with the mean subtracted.
The second new tensor, y_normalized
, is created in a similar manner but using
the convenience method tft.scale_to_0_1
. This method does something similar to
computing x_centered
, namely computing a maximum and minimum and using these
to scale y
.
The tensor s_integerized
shows an example of string manipulation. In this
case, we take a string and map it to an integer. This uses the convenience
function tft.compute_and_apply_vocabulary
. This function uses an analyzer to
compute the unique values taken by the input strings, and then uses TensorFlow
operations to convert the input strings to indices in the table of unique
values.
The final column shows that it is possible to use TensorFlow operations to create new features by combining tensors.
The preprocessing function defines a pipeline of operations on a dataset. In
order to apply the pipeline, we rely on a concrete implementation of the
tf.Transform
API. The Apache Beam implementation provides PTransform
which
applies a user's preprocessing function to data. The typical workflow of a
tf.Transform
user will construct a preprocessing function, then incorporate
this into a larger Beam pipeline, creating the data for training.
Batching
Batching is an important part of TensorFlow. Since one of the goals of
tf.Transform
is to provide a TensorFlow graph for preprocessing that can be
incorporated into the serving graph (and, optionally, the training graph),
batching is also an important concept in tf.Transform
.
While not obvious in the example above, the user defined preprocessing function
is passed tensors representing batches and not individual instances, as
happens during training and serving with TensorFlow. On the other hand,
analyzers perform a computation over the entire dataset that returns a single
value and not a batch of values. x
is a Tensor
with a shape of
(batch_size,)
, while tft.mean(x)
is a Tensor
with a shape of ()
. The
subtraction x - tft.mean(x)
broadcasts where the value of tft.mean(x)
is
subtracted from every element of the batch represented by x
.
Apache Beam Implementation
While the preprocessing function is intended as a logical description of a
preprocessing pipeline implemented on multiple data processing frameworks,
tf.Transform
provides a canonical implementation used on Apache Beam. This
implementation demonstrates the functionality required from an implementation.
There is no formal API for this functionality, so each implementation can use an
API that is idiomatic for its particular data processing framework.
The Apache Beam implementation provides two PTransform
s used to process data
for a preprocessing function. The following shows the usage for the composite
PTransform
- tft_beam.AnalyzeAndTransformDataset
:
raw_data = [
{'x': 1, 'y': 1, 's': 'hello'},
{'x': 2, 'y': 2, 's': 'world'},
{'x': 3, 'y': 3, 's': 'hello'}
]
raw_data_metadata = dataset_metadata.DatasetMetadata(
schema_utils.schema_from_feature_spec({
'y': tf.io.FixedLenFeature([], tf.float32),
'x': tf.io.FixedLenFeature([], tf.float32),
's': tf.io.FixedLenFeature([], tf.string),
}))
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
transformed_dataset, transform_fn = (
(raw_data, raw_data_metadata) |
tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features. WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_transform/tf_utils.py:324: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Use ref() instead. 2023-04-13 09:15:56.867283: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_transform/tf_utils.py:324: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Use ref() instead. WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpzu0d2pwa.json', '--HistoryManager.hist_file=:memory:'] INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpdhm3m_yu/tftransform_tmp/88750e1500194862a87b2f23e04367bc/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpdhm3m_yu/tftransform_tmp/88750e1500194862a87b2f23e04367bc/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpdhm3m_yu/tftransform_tmp/8fad0af5a26242cc9733a752a7652277/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpdhm3m_yu/tftransform_tmp/8fad0af5a26242cc9733a752a7652277/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available.
transformed_data, transformed_metadata = transformed_dataset
The transformed_data
content is shown below and contains the transformed
columns in the same format as the raw data. In particular, the values of
s_integerized
are [0, 1, 0]
—these values depend on how the words hello
and
world
were mapped to integers, which is deterministic. For the column
x_centered
, we subtracted the mean so the values of the column x
, which were
[1.0, 2.0, 3.0]
, became [-1.0, 0.0, 1.0]
. Similarly, the rest of the columns
match their expected values.
transformed_data
[{'s_integerized': 0, 'x_centered': -1.0, 'x_centered_times_y_normalized': -0.0, 'y_normalized': 0.0}, {'s_integerized': 1, 'x_centered': 0.0, 'x_centered_times_y_normalized': 0.0, 'y_normalized': 0.5}, {'s_integerized': 0, 'x_centered': 1.0, 'x_centered_times_y_normalized': 1.0, 'y_normalized': 1.0}]
Both raw_data
and transformed_data
are datasets. The next two sections show
how the Beam implementation represents datasets and how to read and write data
to disk. The other return value, transform_fn
, represents the transformation
applied to the data, covered in detail below.
The tft_beam.AnalyzeAndTransformDataset
class is the composition of the two
fundamental transforms provided by the implementation
tft_beam.AnalyzeDataset
and tft_beam.TransformDataset
. So the following
two code snippets are equivalent:
my_data = (raw_data, raw_data_metadata)
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
transformed_data, transform_fn = (
my_data | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpzu0d2pwa.json', '--HistoryManager.hist_file=:memory:'] INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8afa0l99/tftransform_tmp/8dc250e431e848a386d53f050ae886df/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8afa0l99/tftransform_tmp/8dc250e431e848a386d53f050ae886df/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8afa0l99/tftransform_tmp/46d2e23e8b9745219e9812f9b7f5aee1/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8afa0l99/tftransform_tmp/46d2e23e8b9745219e9812f9b7f5aee1/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available.
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
transform_fn = my_data | tft_beam.AnalyzeDataset(preprocessing_fn)
transformed_data = (my_data, transform_fn) | tft_beam.TransformDataset()
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpzu0d2pwa.json', '--HistoryManager.hist_file=:memory:'] INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpoezjiky4/tftransform_tmp/2f6feb69b15d4a429fa4f56dd7fb02a3/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpoezjiky4/tftransform_tmp/2f6feb69b15d4a429fa4f56dd7fb02a3/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpoezjiky4/tftransform_tmp/26cbcc6000e947c798b5af9ad57c0b42/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpoezjiky4/tftransform_tmp/26cbcc6000e947c798b5af9ad57c0b42/assets WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch). WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpzu0d2pwa.json', '--HistoryManager.hist_file=:memory:'] INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available.
transform_fn
is a pure function that represents an operation that is applied
to each row of the dataset. In particular, the analyzer values are already
computed and treated as constants. In the example, the transform_fn
contains
as constants the mean of column x
, the min and max of column y
, and the
vocabulary used to map the strings to integers.
An important feature of tf.Transform
is that transform_fn
represents a map
over rows—it is a pure function applied to each row separately. All of the
computation for aggregating rows is done in AnalyzeDataset
. Furthermore, the
transform_fn
is represented as a TensorFlow Graph
which can be embedded into
the serving graph.
AnalyzeAndTransformDataset
is provided for optimizations in this special case.
This is the same pattern used in
scikit-learn, providing the fit
,
transform
, and fit_transform
methods.
Data Formats and Schema
TFT Beam implementation accepts two different input data formats. The "instance dict" format (as seen in the example above and simple.ipynb & simple_example.py) is an intuitive format and is suitable for small datasets while the TFXIO (Apache Arrow) format provides improved performance and is suitble for large datasets.
The "metadata" accompanying the PCollection
tells the Beam implementation the format of the PCollection
.
(raw_data, raw_data_metadata) | tft.AnalyzeDataset(...)
- If
raw_data_metadata
is adataset_metadata.DatasetMetadata
(see below, "The 'instance dict' format" section), thenraw_data
is expected to be in the "instance dict" format. - If
raw_data_metadata
is atfxio.TensorAdapterConfig
(see below, "The TFXIO format" section), thenraw_data
is expected to be in the TFXIO format.
The "instance dict" format
The previous code examples used this format. The metadata contains the schema that defines the layout of the data and how it is read from and written to various formats. Even this in-memory format is not self-describing and requires the schema in order to be interpreted as tensors.
Again, here is the definition of the schema for the example data:
import tensorflow_transform as tft
raw_data_metadata = tft.DatasetMetadata.from_feature_spec({
's': tf.io.FixedLenFeature([], tf.string),
'y': tf.io.FixedLenFeature([], tf.float32),
'x': tf.io.FixedLenFeature([], tf.float32),
})
The Schema
proto contains the information needed to parse the
data from its on-disk or in-memory format, into tensors. It is typically
constructed by calling schema_utils.schema_from_feature_spec
with a dict
mapping feature keys to tf.io.FixedLenFeature
, tf.io.VarLenFeature
, and
tf.io.SparseFeature
values. See the documentation for
tf.parse_example
for more details.
Above we use tf.io.FixedLenFeature
to indicate that each feature contains a
fixed number of values, in this case a single scalar value. Because
tf.Transform
batches instances, the actual Tensor
representing the feature
will have shape (None,)
where the unknown dimension is the batch dimension.
The TFXIO format
With this format, the data is expected to be contained in a
pyarrow.RecordBatch
.
For tabular data, our Apache Beam implementation
accepts Arrow RecordBatch
es that consist of columns of the following types:
pa.list_(<primitive>)
, where<primitive>
ispa.int64()
,pa.float32()
pa.binary()
orpa.large_binary()
.pa.large_list(<primitive>)
The toy input dataset we used above, when represented as a RecordBatch
, looks
like the following:
import pyarrow as pa
raw_data = [
pa.record_batch(
data=[
pa.array([[1], [2], [3]], pa.list_(pa.float32())),
pa.array([[1], [2], [3]], pa.list_(pa.float32())),
pa.array([['hello'], ['world'], ['hello']], pa.list_(pa.binary())),
],
names=['x', 'y', 's'])
]
Similar to the dataset_metadata.DatasetMetadata
instance that accompanies the "instance dict" format, a tfxio.TensorAdapterConfig
is must accompany the RecordBatch
es. It consists of the Arrow schema of
the RecordBatch
es, and
tfxio.TensorRepresentations
to uniquely determine how columns in RecordBatch
es can be interpreted as TensorFlow Tensors (including but not limited to tf.Tensor
, tf.SparseTensor
).
tfxio.TensorRepresentations
is type alias for a Dict[str, tensorflow_metadata.proto.v0.schema_pb2.TensorRepresentation]
which
establishes the relationship between a Tensor that a preprocessing_fn
accepts
and columns in the RecordBatch
es. For example:
from google.protobuf import text_format
from tensorflow_metadata.proto.v0 import schema_pb2
tensor_representation = {
'x': text_format.Parse(
"""dense_tensor { column_name: "col1" shape { dim { size: 2 } } }""",
schema_pb2.TensorRepresentation())
}
Means that inputs['x']
in preprocessing_fn
should be a dense tf.Tensor
,
whose values come from a column of name 'col1'
in the input RecordBatch
es,
and its (batched) shape should be [batch_size, 2]
.
A schema_pb2.TensorRepresentation
is a Protobuf defined in
TensorFlow Metadata.
Compatibility with TensorFlow
tf.Transform
provides support for exporting the transform_fn
as
a SavedModel, see the simple tutorial for an example. The default behavior before the 0.30
release
exported a TF 1.x SavedModel. Starting with the 0.30
release, the default
behavior is to export a TF 2.x SavedModel unless TF 2.x behaviors are explicitly
disabled (by calling tf.compat.v1.disable_v2_behavior()
).
If using TF 1.x concepts such as tf.estimator
and tf.Sessions
, you can retain the previous behavior by passing force_tf_compat_v1=True
to
tft_beam.Context
if using tf.Transform
as a standalone library or to the
Transform
component in TFX.
When exporting the transform_fn
as a TF 2.x SavedModel, the preprocessing_fn
is expected to be traceable using tf.function
. Additionally, if running your
pipeline remotely (for example with the DataflowRunner
), ensure that the
preprocessing_fn
and any dependencies are packaged properly as described
here.
Known issues with using tf.Transform
to export a TF 2.x SavedModel are
documented here.
Input and output with Apache Beam
So far, we've seen input and output data in python lists (of RecordBatch
es or
instance dictionaries). This is a simplification that relies on Apache Beam's
ability to work with lists as well as its main representation of data, the
PCollection
.
A PCollection
is a data representation that forms a part of a Beam pipeline.
A Beam pipeline is formed by applying various PTransform
s, including
AnalyzeDataset
and TransformDataset
, and running the pipeline. A
PCollection
is not created in the memory of the main binary, but instead is
distributed among the workers (although this section uses the in-memory
execution mode).
Pre-canned PCollection
Sources (TFXIO
)
The RecordBatch
format that our implementation accepts is a common format that
other TFX libraries accept. Therefore TFX offers convenient "sources" (a.k.a
TFXIO
) that read files of various formats on disk and produce RecordBatch
es
and can also give tfxio.TensorAdapterConfig
, including inferred
tfxio.TensorRepresentations
.
Those TFXIO
s can be found in package tfx_bsl
(tfx_bsl.public.tfxio
).
Example: "Census Income" dataset
The following example requires both reading and writing data on disk and
representing data as a PCollection
(not a list), see:
census_example.py
.
Below we show how to download the data and run this example. The "Census Income"
dataset is provided by the
UCI Machine Learning Repository.
This dataset contains both categorical and numeric data.
Here is some code to download and preview this data:
wget https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.data
--2023-04-13 09:16:10-- https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.data Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.203.128, 74.125.141.128, 142.250.98.128, ... Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.203.128|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 3974305 (3.8M) [application/octet-stream] Saving to: ‘adult.data’ adult.data 100%[===================>] 3.79M --.-KB/s in 0.02s 2023-04-13 09:16:10 (153 MB/s) - ‘adult.data’ saved [3974305/3974305]
import pandas as pd
train_data_file = "adult.data"
There's some configuration code hidden in the cell below.
ORDERED_CSV_COLUMNS = [
'age', 'workclass', 'fnlwgt', 'education', 'education-num',
'marital-status', 'occupation', 'relationship', 'race', 'sex',
'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'label'
]
CATEGORICAL_FEATURE_KEYS = [
'workclass',
'education',
'marital-status',
'occupation',
'relationship',
'race',
'sex',
'native-country',
]
NUMERIC_FEATURE_KEYS = [
'age',
'capital-gain',
'capital-loss',
'hours-per-week',
'education-num',
]
LABEL_KEY = 'label'
RAW_DATA_FEATURE_SPEC = dict(
[(name, tf.io.FixedLenFeature([], tf.string))
for name in CATEGORICAL_FEATURE_KEYS] +
[(name, tf.io.FixedLenFeature([], tf.float32))
for name in NUMERIC_FEATURE_KEYS] +
[(LABEL_KEY, tf.io.FixedLenFeature([], tf.string))]
)
SCHEMA = tft.tf_metadata.dataset_metadata.DatasetMetadata(
tft.tf_metadata.schema_utils.schema_from_feature_spec(RAW_DATA_FEATURE_SPEC)).schema
pd.read_csv(train_data_file, names = ORDERED_CSV_COLUMNS).head()
The columns of the dataset are either categorical or numeric. This dataset
describes a classification problem: predicting the last column where the
individual earns more or less than 50K per year. However, from the perspective
of tf.Transform
, this label is just another categorical column.
We use a Pre-canned tfxio.BeamRecordCsvTFXIO
to translate the CSV lines
into RecordBatches
. TFXIO
requires two important piece of information:
- a TensorFlow Metadata Schema,
tfmd.proto.v0.shema_pb2
, that contains type and shape information about each CSV column.schema_pb2.TensorRepresentation
s are an optional part of the Schema; if not provided (which is the case in this example), they will be inferred from the type and shape information. One can get the Schema either by using a helper function we provide to translate from TF parsing specs (shown in this example), or by running TensorFlow Data Validation. - a list of column names, in the order they appear in the CSV file. Note that those names must match the feature names in the Schema.
pip install -U -q tfx_bsl
from tfx_bsl.public import tfxio
from tfx_bsl.coders.example_coder import RecordBatchToExamples
import apache_beam as beam
pipeline = beam.Pipeline()
csv_tfxio = tfxio.BeamRecordCsvTFXIO(
physical_format='text', column_names=ORDERED_CSV_COLUMNS, schema=SCHEMA)
raw_data = (
pipeline
| 'ReadTrainData' >> beam.io.ReadFromText(
train_data_file, coder=beam.coders.BytesCoder())
| 'FixCommasTrainData' >> beam.Map(
lambda line: line.replace(b', ', b','))
| 'DecodeTrainData' >> csv_tfxio.BeamSource())
raw_data
<PCollection[[21]: DecodeTrainData/RawRecordToRecordBatch/CollectRecordBatchTelemetry/ProfileRecordBatches.None] at 0x7feeaa6fd5b0>
Note that we had to do some additional fix-ups after the CSV lines are read
in. Otherwise, we could rely on the tfxio.CsvTFXIO
to handle both reading the files
and translating to RecordBatch
es:
csv_tfxio = tfxio.CsvTFXIO(train_data_file,
telemetry_descriptors=[], #???
column_names=ORDERED_CSV_COLUMNS,
schema=SCHEMA)
p2 = beam.Pipeline()
raw_data_2 = p2 | 'TFXIORead' >> csv_tfxio.BeamSource()
Preprocessing for this dataset is similar to the previous example,
except the preprocessing function is programmatically generated instead of manually specifying each column. In the preprocessing function below, NUMERICAL_COLUMNS
and CATEGORICAL_COLUMNS
are lists that contain the names of the numeric and categorical columns:
NUM_OOV_BUCKETS = 1
def preprocessing_fn(inputs):
"""Preprocess input columns into transformed columns."""
# Since we are modifying some features and leaving others unchanged, we
# start by setting `outputs` to a copy of `inputs.
outputs = inputs.copy()
# Scale numeric columns to have range [0, 1].
for key in NUMERIC_FEATURE_KEYS:
outputs[key] = tft.scale_to_0_1(outputs[key])
# For all categorical columns except the label column, we generate a
# vocabulary but do not modify the feature. This vocabulary is instead
# used in the trainer, by means of a feature column, to convert the feature
# from a string to an integer id.
for key in CATEGORICAL_FEATURE_KEYS:
outputs[key] = tft.compute_and_apply_vocabulary(
tf.strings.strip(inputs[key]),
num_oov_buckets=NUM_OOV_BUCKETS,
vocab_filename=key)
# For the label column we provide the mapping from string to index.
with tf.init_scope():
# `init_scope` - Only initialize the table once.
initializer = tf.lookup.KeyValueTensorInitializer(
keys=['>50K', '<=50K'],
values=tf.cast(tf.range(2), tf.int64),
key_dtype=tf.string,
value_dtype=tf.int64)
table = tf.lookup.StaticHashTable(initializer, default_value=-1)
outputs[LABEL_KEY] = table.lookup(outputs[LABEL_KEY])
return outputs
One difference from the previous example is the label column manually specifies
the mapping from the string to an index. So '>50'
is mapped to 0
and
'<=50K'
is mapped to 1
because it's useful to know which index in the
trained model corresponds to which label.
The record_batches
variable represents a PCollection
of
pyarrow.RecordBatch
es. The tensor_adapter_config
is given by csv_tfxio
,
which is inferred from SCHEMA
(and ultimately, in this example, from the TF
parsing specs).
The final stage is to write the transformed data to disk and has a similar form
to reading the raw data. The schema used to do this is part of the output of
tft_beam.AnalyzeAndTransformDataset
which infers a schema for the output data. The code to write to disk is shown below. The schema is a part of the metadata but uses the two interchangeably in the tf.Transform
API (i.e. pass the metadata to the tft.coders.ExampleProtoCoder
). Be aware that this writes to a different format. Instead of textio.WriteToText
, use Beam's built-in support for the TFRecord
format and use a coder to encode the data as Example
protos. This is a better format to use for training, as shown in the next section. transformed_eval_data_base
provides the base filename for the individual shards that are written.
raw_dataset = (raw_data, csv_tfxio.TensorAdapterConfig())
working_dir = tempfile.mkdtemp()
with tft_beam.Context(temp_dir=working_dir):
transformed_dataset, transform_fn = (
raw_dataset | tft_beam.AnalyzeAndTransformDataset(
preprocessing_fn, output_record_batches=True))
output_dir = tempfile.mkdtemp()
transformed_data, _ = transformed_dataset
_ = (
transformed_data
| 'EncodeTrainData' >>
beam.FlatMapTuple(lambda batch, _: RecordBatchToExamples(batch))
| 'WriteTrainData' >> beam.io.WriteToTFRecord(
os.path.join(output_dir , 'transformed.tfrecord')))
In addition to the training data, transform_fn
is also written out with the
metadata:
_ = (
transform_fn
| 'WriteTransformFn' >> tft_beam.WriteTransformFn(output_dir))
Run the entire Beam pipeline with pipeline.run().wait_until_finish()
. Up until this point, the Beam pipeline represents a deferred, distributed computation. It provides instructions for what will be done, but the instructions have not been executed. This final call executes the specified pipeline.
result = pipeline.run().wait_until_finish()
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphiyrst4f/tftransform_tmp/c633cd0eb0c14a2bba2bc6f7ba556ce3/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphiyrst4f/tftransform_tmp/c633cd0eb0c14a2bba2bc6f7ba556ce3/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphiyrst4f/tftransform_tmp/9080e8c73e2443fea34d6505feed4129/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphiyrst4f/tftransform_tmp/9080e8c73e2443fea34d6505feed4129/assets INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
After running the pipeline the output directory contains two artifacts.
- The transformed data, and the metadata describing it.
- The
tf.saved_model
containing the resultingpreprocessing_fn
ls {output_dir}
transform_fn transformed.tfrecord-00000-of-00001 transformed_metadata
To see how to use these artifacts refer to the Advanced preprocessing tutorial.