TFX Keras Component Tutorial

A Component-by-Component Introduction to TensorFlow Extended (TFX)

This Colab-based tutorial will interactively walk through each built-in component of TensorFlow Extended (TFX).

It covers every step in an end-to-end machine learning pipeline, from data ingestion to pushing a model to serving.

When you're done, the contents of this notebook can be automatically exported as TFX pipeline source code, which you can orchestrate with Apache Airflow and Apache Beam.

Background

This notebook demonstrates how to use TFX in a Jupyter/Colab environment. Here, we walk through the Chicago Taxi example in an interactive notebook.

Working in an interactive notebook is a useful way to become familiar with the structure of a TFX pipeline. It's also useful when doing development of your own pipelines as a lightweight development environment, but you should be aware that there are differences in the way interactive notebooks are orchestrated, and how they access metadata artifacts.

Orchestration

In a production deployment of TFX, you will use an orchestrator such as Apache Airflow, Kubeflow Pipelines, or Apache Beam to orchestrate a pre-defined pipeline graph of TFX components. In an interactive notebook, the notebook itself is the orchestrator, running each TFX component as you execute the notebook cells.

Metadata

In a production deployment of TFX, you will access metadata through the ML Metadata (MLMD) API. MLMD stores metadata properties in a database such as MySQL or SQLite, and stores the metadata payloads in a persistent store such as on your filesystem. In an interactive notebook, both properties and payloads are stored in an ephemeral SQLite database in the /tmp directory on the Jupyter notebook or Colab server.

Setup

First, we install and import the necessary packages, set up paths, and download data.

Upgrade Pip

To avoid upgrading Pip in a system when running locally, check to make sure that we're running in Colab. Local systems can of course be upgraded separately.

import sys
if 'google.colab' in sys.modules:
  !pip install --upgrade pip

Install TFX

pip install tfx

Uninstall shapely

TODO(b/263441833) This is a temporal solution to avoid an ImportError. Ultimately, it should be handled by supporting a recent version of Bigquery, instead of uninstalling other extra dependencies.

pip uninstall shapely -y

Did you restart the runtime?

If you are using Google Colab, the first time that you run the cell above, you must restart the runtime (Runtime > Restart runtime ...). This is because of the way that Colab loads packages.

Import packages

We import necessary packages, including standard TFX component classes.

import os
import pprint
import tempfile
import urllib

import absl
import tensorflow as tf
import tensorflow_model_analysis as tfma
tf.get_logger().propagate = False
pp = pprint.PrettyPrinter()

from tfx import v1 as tfx
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

%load_ext tfx.orchestration.experimental.interactive.notebook_extensions.skip

Let's check the library versions.

print('TensorFlow version: {}'.format(tf.__version__))
print('TFX version: {}'.format(tfx.__version__))
TensorFlow version: 2.12.1
TFX version: 1.13.0

Set up pipeline paths

# This is the root directory for your TFX pip package installation.
_tfx_root = tfx.__path__[0]

# This is the directory containing the TFX Chicago Taxi Pipeline example.
_taxi_root = os.path.join(_tfx_root, 'examples/chicago_taxi_pipeline')

# This is the path where your model will be pushed for serving.
_serving_model_dir = os.path.join(
    tempfile.mkdtemp(), 'serving_model/taxi_simple')

# Set up logging.
absl.logging.set_verbosity(absl.logging.INFO)

Download example data

We download the example dataset for use in our TFX pipeline.

The dataset we're using is the Taxi Trips dataset released by the City of Chicago. The columns in this dataset are:

pickup_community_areafaretrip_start_month
trip_start_hourtrip_start_daytrip_start_timestamp
pickup_latitudepickup_longitudedropoff_latitude
dropoff_longitudetrip_milespickup_census_tract
dropoff_census_tractpayment_typecompany
trip_secondsdropoff_community_areatips

With this dataset, we will build a model that predicts the tips of a trip.

_data_root = tempfile.mkdtemp(prefix='tfx-data')
DATA_PATH = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/chicago_taxi_pipeline/data/simple/data.csv'
_data_filepath = os.path.join(_data_root, "data.csv")
urllib.request.urlretrieve(DATA_PATH, _data_filepath)
('/tmpfs/tmp/tfx-dataf9g2t5h4/data.csv',
 <http.client.HTTPMessage at 0x7f9ac42f4460>)

Take a quick look at the CSV file.

head {_data_filepath}
pickup_community_area,fare,trip_start_month,trip_start_hour,trip_start_day,trip_start_timestamp,pickup_latitude,pickup_longitude,dropoff_latitude,dropoff_longitude,trip_miles,pickup_census_tract,dropoff_census_tract,payment_type,company,trip_seconds,dropoff_community_area,tips
,12.45,5,19,6,1400269500,,,,,0.0,,,Credit Card,Chicago Elite Cab Corp. (Chicago Carriag,0,,0.0
,0,3,19,5,1362683700,,,,,0,,,Unknown,Chicago Elite Cab Corp.,300,,0
60,27.05,10,2,3,1380593700,41.836150155,-87.648787952,,,12.6,,,Cash,Taxi Affiliation Services,1380,,0.0
10,5.85,10,1,2,1382319000,41.985015101,-87.804532006,,,0.0,,,Cash,Taxi Affiliation Services,180,,0.0
14,16.65,5,7,5,1369897200,41.968069,-87.721559063,,,0.0,,,Cash,Dispatch Taxi Affiliation,1080,,0.0
13,16.45,11,12,3,1446554700,41.983636307,-87.723583185,,,6.9,,,Cash,,780,,0.0
16,32.05,12,1,1,1417916700,41.953582125,-87.72345239,,,15.4,,,Cash,,1200,,0.0
30,38.45,10,10,5,1444301100,41.839086906,-87.714003807,,,14.6,,,Cash,,2580,,0.0
11,14.65,1,1,3,1358213400,41.978829526,-87.771166703,,,5.81,,,Cash,,1080,,0.0

Disclaimer: This site provides applications using data that has been modified for use from its original source, www.cityofchicago.org, the official website of the City of Chicago. The City of Chicago makes no claims as to the content, accuracy, timeliness, or completeness of any of the data provided at this site. The data provided at this site is subject to change at any time. It is understood that the data provided at this site is being used at one’s own risk.

Create the InteractiveContext

Last, we create an InteractiveContext, which will allow us to run TFX components interactively in this notebook.

# Here, we create an InteractiveContext using default parameters. This will
# use a temporary directory with an ephemeral ML Metadata database instance.
# To use your own pipeline root or database, the optional properties
# `pipeline_root` and `metadata_connection_config` may be passed to
# InteractiveContext. Calls to InteractiveContext are no-ops outside of the
# notebook.
context = InteractiveContext()
WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86 as root for pipeline outputs.
WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/metadata.sqlite.

Run TFX components interactively

In the cells that follow, we create TFX components one-by-one, run each of them, and visualize their output artifacts.

ExampleGen

The ExampleGen component is usually at the start of a TFX pipeline. It will:

  1. Split data into training and evaluation sets (by default, 2/3 training + 1/3 eval)
  2. Convert data into the tf.Example format (learn more here)
  3. Copy data into the _tfx_root directory for other components to access

ExampleGen takes as input the path to your data source. In our case, this is the _data_root path that contains the downloaded CSV.

Enabling the Cache

When using the InteractiveContext in a notebook to develop a pipeline you can control when individual components will cache their outputs. Set enable_cache to True when you want to reuse the previous output artifacts that the component generated. Set enable_cache to False when you want to recompute the output artifacts for a component, if you are making changes to the code for example.

example_gen = tfx.components.CsvExampleGen(input_base=_data_root)
context.run(example_gen, enable_cache=True)
INFO:absl:Running driver for CsvExampleGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:select span and version = (0, None)
INFO:absl:latest span and version = (0, None)
INFO:absl:Running executor for CsvExampleGen
INFO:absl:Generating examples.
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
INFO:absl:Processing input csv data /tmpfs/tmp/tfx-dataf9g2t5h4/* to TFExample.
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
INFO:absl:Examples generated.
INFO:absl:Running publisher for CsvExampleGen
INFO:absl:MetadataStore with DB connection initialized

Let's examine the output artifacts of ExampleGen. This component produces two artifacts, training examples and evaluation examples:

artifact = example_gen.outputs['examples'].get()[0]
print(artifact.split_names, artifact.uri)
["train", "eval"] /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/CsvExampleGen/examples/1

We can also take a look at the first three training examples:

# Get the URI of the output artifact representing the training examples, which is a directory
train_uri = os.path.join(example_gen.outputs['examples'].get()[0].uri, 'Split-train')

# Get the list of files in this directory (all compressed TFRecord files)
tfrecord_filenames = [os.path.join(train_uri, name)
                      for name in os.listdir(train_uri)]

# Create a `TFRecordDataset` to read these files
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")

# Iterate over the first 3 records and decode them.
for tfrecord in dataset.take(3):
  serialized_example = tfrecord.numpy()
  example = tf.train.Example()
  example.ParseFromString(serialized_example)
  pp.pprint(example)
features {
  feature {
    key: "company"
    value {
      bytes_list {
        value: "Chicago Elite Cab Corp. (Chicago Carriag"
      }
    }
  }
  feature {
    key: "dropoff_census_tract"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "dropoff_community_area"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "dropoff_latitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "dropoff_longitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "fare"
    value {
      float_list {
        value: 12.449999809265137
      }
    }
  }
  feature {
    key: "payment_type"
    value {
      bytes_list {
        value: "Credit Card"
      }
    }
  }
  feature {
    key: "pickup_census_tract"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "pickup_community_area"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "pickup_latitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "pickup_longitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "tips"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_miles"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_seconds"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "trip_start_day"
    value {
      int64_list {
        value: 6
      }
    }
  }
  feature {
    key: "trip_start_hour"
    value {
      int64_list {
        value: 19
      }
    }
  }
  feature {
    key: "trip_start_month"
    value {
      int64_list {
        value: 5
      }
    }
  }
  feature {
    key: "trip_start_timestamp"
    value {
      int64_list {
        value: 1400269500
      }
    }
  }
}

features {
  feature {
    key: "company"
    value {
      bytes_list {
        value: "Taxi Affiliation Services"
      }
    }
  }
  feature {
    key: "dropoff_census_tract"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "dropoff_community_area"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "dropoff_latitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "dropoff_longitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "fare"
    value {
      float_list {
        value: 27.049999237060547
      }
    }
  }
  feature {
    key: "payment_type"
    value {
      bytes_list {
        value: "Cash"
      }
    }
  }
  feature {
    key: "pickup_census_tract"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "pickup_community_area"
    value {
      int64_list {
        value: 60
      }
    }
  }
  feature {
    key: "pickup_latitude"
    value {
      float_list {
        value: 41.836151123046875
      }
    }
  }
  feature {
    key: "pickup_longitude"
    value {
      float_list {
        value: -87.64878845214844
      }
    }
  }
  feature {
    key: "tips"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_miles"
    value {
      float_list {
        value: 12.600000381469727
      }
    }
  }
  feature {
    key: "trip_seconds"
    value {
      int64_list {
        value: 1380
      }
    }
  }
  feature {
    key: "trip_start_day"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "trip_start_hour"
    value {
      int64_list {
        value: 2
      }
    }
  }
  feature {
    key: "trip_start_month"
    value {
      int64_list {
        value: 10
      }
    }
  }
  feature {
    key: "trip_start_timestamp"
    value {
      int64_list {
        value: 1380593700
      }
    }
  }
}

features {
  feature {
    key: "company"
    value {
      bytes_list {
      }
    }
  }
  feature {
    key: "dropoff_census_tract"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "dropoff_community_area"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "dropoff_latitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "dropoff_longitude"
    value {
      float_list {
      }
    }
  }
  feature {
    key: "fare"
    value {
      float_list {
        value: 16.450000762939453
      }
    }
  }
  feature {
    key: "payment_type"
    value {
      bytes_list {
        value: "Cash"
      }
    }
  }
  feature {
    key: "pickup_census_tract"
    value {
      int64_list {
      }
    }
  }
  feature {
    key: "pickup_community_area"
    value {
      int64_list {
        value: 13
      }
    }
  }
  feature {
    key: "pickup_latitude"
    value {
      float_list {
        value: 41.98363494873047
      }
    }
  }
  feature {
    key: "pickup_longitude"
    value {
      float_list {
        value: -87.72357940673828
      }
    }
  }
  feature {
    key: "tips"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_miles"
    value {
      float_list {
        value: 6.900000095367432
      }
    }
  }
  feature {
    key: "trip_seconds"
    value {
      int64_list {
        value: 780
      }
    }
  }
  feature {
    key: "trip_start_day"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "trip_start_hour"
    value {
      int64_list {
        value: 12
      }
    }
  }
  feature {
    key: "trip_start_month"
    value {
      int64_list {
        value: 11
      }
    }
  }
  feature {
    key: "trip_start_timestamp"
    value {
      int64_list {
        value: 1446554700
      }
    }
  }
}

Now that ExampleGen has finished ingesting the data, the next step is data analysis.

StatisticsGen

The StatisticsGen component computes statistics over your dataset for data analysis, as well as for use in downstream components. It uses the TensorFlow Data Validation library.

StatisticsGen takes as input the dataset we just ingested using ExampleGen.

statistics_gen = tfx.components.StatisticsGen(
    examples=example_gen.outputs['examples'])
context.run(statistics_gen, enable_cache=True)
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Running driver for StatisticsGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for StatisticsGen
INFO:absl:Generating statistics for split train.
INFO:absl:Statistics for split train written to /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/StatisticsGen/statistics/2/Split-train.
INFO:absl:Generating statistics for split eval.
INFO:absl:Statistics for split eval written to /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/StatisticsGen/statistics/2/Split-eval.
INFO:absl:Running publisher for StatisticsGen
INFO:absl:MetadataStore with DB connection initialized

After StatisticsGen finishes running, we can visualize the outputted statistics. Try playing with the different plots!

context.show(statistics_gen.outputs['statistics'])

SchemaGen

The SchemaGen component generates a schema based on your data statistics. (A schema defines the expected bounds, types, and properties of the features in your dataset.) It also uses the TensorFlow Data Validation library.

SchemaGen will take as input the statistics that we generated with StatisticsGen, looking at the training split by default.

schema_gen = tfx.components.SchemaGen(
    statistics=statistics_gen.outputs['statistics'],
    infer_feature_shape=False)
context.run(schema_gen, enable_cache=True)
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Running driver for SchemaGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for SchemaGen
INFO:absl:Processing schema from statistics for split train.
INFO:absl:Processing schema from statistics for split eval.
INFO:absl:Schema written to /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/SchemaGen/schema/3/schema.pbtxt.
INFO:absl:Running publisher for SchemaGen
INFO:absl:MetadataStore with DB connection initialized

After SchemaGen finishes running, we can visualize the generated schema as a table.

context.show(schema_gen.outputs['schema'])

Each feature in your dataset shows up as a row in the schema table, alongside its properties. The schema also captures all the values that a categorical feature takes on, denoted as its domain.

To learn more about schemas, see the SchemaGen documentation.

ExampleValidator

The ExampleValidator component detects anomalies in your data, based on the expectations defined by the schema. It also uses the TensorFlow Data Validation library.

ExampleValidator will take as input the statistics from StatisticsGen, and the schema from SchemaGen.

example_validator = tfx.components.ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_gen.outputs['schema'])
context.run(example_validator, enable_cache=True)
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Running driver for ExampleValidator
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for ExampleValidator
INFO:absl:Validating schema against the computed statistics for split train.
INFO:absl:Validation complete for split train. Anomalies written to /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/ExampleValidator/anomalies/4/Split-train.
INFO:absl:Validating schema against the computed statistics for split eval.
INFO:absl:Validation complete for split eval. Anomalies written to /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/ExampleValidator/anomalies/4/Split-eval.
INFO:absl:Running publisher for ExampleValidator
INFO:absl:MetadataStore with DB connection initialized

After ExampleValidator finishes running, we can visualize the anomalies as a table.

context.show(example_validator.outputs['anomalies'])

In the anomalies table, we can see that there are no anomalies. This is what we'd expect, since this the first dataset that we've analyzed and the schema is tailored to it. You should review this schema -- anything unexpected means an anomaly in the data. Once reviewed, the schema can be used to guard future data, and anomalies produced here can be used to debug model performance, understand how your data evolves over time, and identify data errors.

Transform

The Transform component performs feature engineering for both training and serving. It uses the TensorFlow Transform library.

Transform will take as input the data from ExampleGen, the schema from SchemaGen, as well as a module that contains user-defined Transform code.

Let's see an example of user-defined Transform code below (for an introduction to the TensorFlow Transform APIs, see the tutorial). First, we define a few constants for feature engineering:

_taxi_constants_module_file = 'taxi_constants.py'
%%writefile {_taxi_constants_module_file}

NUMERICAL_FEATURES = ['trip_miles', 'fare', 'trip_seconds']

BUCKET_FEATURES = [
    'pickup_latitude', 'pickup_longitude', 'dropoff_latitude',
    'dropoff_longitude'
]
# Number of buckets used by tf.transform for encoding each feature.
FEATURE_BUCKET_COUNT = 10

CATEGORICAL_NUMERICAL_FEATURES = [
    'trip_start_hour', 'trip_start_day', 'trip_start_month',
    'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area',
    'dropoff_community_area'
]

CATEGORICAL_STRING_FEATURES = [
    'payment_type',
    'company',
]

# Number of vocabulary terms used for encoding categorical features.
VOCAB_SIZE = 1000

# Count of out-of-vocab buckets in which unrecognized categorical are hashed.
OOV_SIZE = 10

# Keys
LABEL_KEY = 'tips'
FARE_KEY = 'fare'

def t_name(key):
  """
  Rename the feature keys so that they don't clash with the raw keys when
  running the Evaluator component.
  Args:
    key: The original feature key
  Returns:
    key with '_xf' appended
  """
  return key + '_xf'
Writing taxi_constants.py

Next, we write a preprocessing_fn that takes in raw data as input, and returns transformed features that our model can train on:

_taxi_transform_module_file = 'taxi_transform.py'
%%writefile {_taxi_transform_module_file}

import tensorflow as tf
import tensorflow_transform as tft

# Imported files such as taxi_constants are normally cached, so changes are
# not honored after the first import.  Normally this is good for efficiency, but
# during development when we may be iterating code it can be a problem. To
# avoid this problem during development, reload the file.
import taxi_constants
import sys
if 'google.colab' in sys.modules:  # Testing to see if we're doing development
  import importlib
  importlib.reload(taxi_constants)

_NUMERICAL_FEATURES = taxi_constants.NUMERICAL_FEATURES
_BUCKET_FEATURES = taxi_constants.BUCKET_FEATURES
_FEATURE_BUCKET_COUNT = taxi_constants.FEATURE_BUCKET_COUNT
_CATEGORICAL_NUMERICAL_FEATURES = taxi_constants.CATEGORICAL_NUMERICAL_FEATURES
_CATEGORICAL_STRING_FEATURES = taxi_constants.CATEGORICAL_STRING_FEATURES
_VOCAB_SIZE = taxi_constants.VOCAB_SIZE
_OOV_SIZE = taxi_constants.OOV_SIZE
_FARE_KEY = taxi_constants.FARE_KEY
_LABEL_KEY = taxi_constants.LABEL_KEY


def _make_one_hot(x, key):
  """Make a one-hot tensor to encode categorical features.
  Args:
    X: A dense tensor
    key: A string key for the feature in the input
  Returns:
    A dense one-hot tensor as a float list
  """
  integerized = tft.compute_and_apply_vocabulary(x,
          top_k=_VOCAB_SIZE,
          num_oov_buckets=_OOV_SIZE,
          vocab_filename=key, name=key)
  depth = (
      tft.experimental.get_vocabulary_size_by_name(key) + _OOV_SIZE)
  one_hot_encoded = tf.one_hot(
      integerized,
      depth=tf.cast(depth, tf.int32),
      on_value=1.0,
      off_value=0.0)
  return tf.reshape(one_hot_encoded, [-1, depth])


def _fill_in_missing(x):
  """Replace missing values in a SparseTensor.
  Fills in missing values of `x` with '' or 0, and converts to a dense tensor.
  Args:
    x: A `SparseTensor` of rank 2.  Its dense shape should have size at most 1
      in the second dimension.
  Returns:
    A rank 1 tensor where missing values of `x` have been filled in.
  """
  if not isinstance(x, tf.sparse.SparseTensor):
    return x

  default_value = '' if x.dtype == tf.string else 0
  return tf.squeeze(
      tf.sparse.to_dense(
          tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
          default_value),
      axis=1)


def preprocessing_fn(inputs):
  """tf.transform's callback function for preprocessing inputs.
  Args:
    inputs: map from feature keys to raw not-yet-transformed features.
  Returns:
    Map from string feature key to transformed feature operations.
  """
  outputs = {}
  for key in _NUMERICAL_FEATURES:
    # If sparse make it dense, setting nan's to 0 or '', and apply zscore.
    outputs[taxi_constants.t_name(key)] = tft.scale_to_z_score(
        _fill_in_missing(inputs[key]), name=key)

  for key in _BUCKET_FEATURES:
    outputs[taxi_constants.t_name(key)] = tf.cast(tft.bucketize(
            _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT, name=key),
            dtype=tf.float32)

  for key in _CATEGORICAL_STRING_FEATURES:
    outputs[taxi_constants.t_name(key)] = _make_one_hot(_fill_in_missing(inputs[key]), key)

  for key in _CATEGORICAL_NUMERICAL_FEATURES:
    outputs[taxi_constants.t_name(key)] = _make_one_hot(tf.strings.strip(
        tf.strings.as_string(_fill_in_missing(inputs[key]))), key)

  # Was this passenger a big tipper?
  taxi_fare = _fill_in_missing(inputs[_FARE_KEY])
  tips = _fill_in_missing(inputs[_LABEL_KEY])
  outputs[_LABEL_KEY] = tf.where(
      tf.math.is_nan(taxi_fare),
      tf.cast(tf.zeros_like(taxi_fare), tf.int64),
      # Test if the tip was > 20% of the fare.
      tf.cast(
          tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64))

  return outputs
Writing taxi_transform.py

Now, we pass in this feature engineering code to the Transform component and run it to transform your data.

transform = tfx.components.Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file=os.path.abspath(_taxi_transform_module_file))
context.run(transform, enable_cache=True)
INFO:absl:Generating ephemeral wheel package for '/tmpfs/src/temp/docs/tutorials/tfx/taxi_transform.py' (including modules: ['taxi_constants', 'taxi_transform']).
INFO:absl:User module package has hash fingerprint version d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '/tmpfs/tmp/tmplkbmtkqk/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmpfs/tmp/tmpu10gi8bp', '--dist-dir', '/tmpfs/tmp/tmpc82xos_7']
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying taxi_constants.py -> build/lib
copying taxi_transform.py -> build/lib
installing to /tmpfs/tmp/tmpu10gi8bp
running install
running install_lib
copying build/lib/taxi_constants.py -> /tmpfs/tmp/tmpu10gi8bp
copying build/lib/taxi_transform.py -> /tmpfs/tmp/tmpu10gi8bp
running install_egg_info
running egg_info
creating tfx_user_code_Transform.egg-info
writing tfx_user_code_Transform.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Transform.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Transform.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
INFO:absl:Successfully built user code wheel distribution at '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl'; target user module is 'taxi_transform'.
INFO:absl:Full user module path is 'taxi_transform@/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl'
INFO:absl:Running driver for Transform
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Transform
INFO:absl:Analyze the 'train' split and transform all splits when splits_config is not set.
INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'taxi_transform@/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl', 'preprocessing_fn': None} 'preprocessing_fn'
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmpfnsf0s7v', '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl']
reading manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
Copying tfx_user_code_Transform.egg-info to /tmpfs/tmp/tmpu10gi8bp/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmpu10gi8bp/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d.dist-info/WHEEL
creating '/tmpfs/tmp/tmpc82xos_7/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl' and adding '/tmpfs/tmp/tmpu10gi8bp' to it
adding 'taxi_constants.py'
adding 'taxi_transform.py'
adding 'tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d.dist-info/METADATA'
adding 'tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d.dist-info/WHEEL'
adding 'tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d.dist-info/top_level.txt'
adding 'tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d.dist-info/RECORD'
removing /tmpfs/tmp/tmpu10gi8bp
Processing /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl'.
INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'taxi_transform@/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl', 'stats_options_updater_fn': None} 'stats_options_updater_fn'
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmpii24n_fl', '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl']
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d
Processing /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl'.
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmphgrq7y7z', '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl']
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d
Processing /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d-py3-none-any.whl'.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+d7f32accc04453d93cd29bae5b4d879eb83d8a54c7e01d354a58158f2f84251d
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: payment_type/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: company/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: trip_start_hour/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: trip_start_day/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: trip_start_month/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: pickup_census_tract/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: dropoff_census_tract/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: pickup_community_area/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: dropoff_community_area/apply_vocab/text_file_init/InitializeTableFromTextFileV2
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: payment_type/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: company/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: trip_start_hour/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: trip_start_day/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: trip_start_month/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: pickup_census_tract/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: dropoff_census_tract/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: pickup_community_area/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: dropoff_community_area/apply_vocab/text_file_init/InitializeTableFromTextFileV2
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/Transform/transform_graph/5/.temp_path/tftransform_tmp/7f3848af1e304086a0e894dc76bcefd4/assets
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/Transform/transform_graph/5/.temp_path/tftransform_tmp/d4872ce15de74c5a9555b686a5543e8a/assets
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:If the number of unique tokens is smaller than the provided top_k or approximation error is acceptable, consider using tft.experimental.approximate_vocabulary for a potentially more efficient implementation.
INFO:absl:Feature company_xf has a shape dim {
  size: 55
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract_xf has a shape dim {
  size: 216
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area_xf has a shape dim {
  size: 79
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature fare_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type_xf has a shape dim {
  size: 16
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract_xf has a shape dim {
  size: 11
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_community_area_xf has a shape dim {
  size: 66
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day_xf has a shape dim {
  size: 17
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_hour_xf has a shape dim {
  size: 34
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_month_xf has a shape dim {
  size: 22
}
. Setting to DenseTensor.
INFO:absl:Feature company_xf has a shape dim {
  size: 55
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract_xf has a shape dim {
  size: 216
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area_xf has a shape dim {
  size: 79
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature fare_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type_xf has a shape dim {
  size: 16
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract_xf has a shape dim {
  size: 11
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_community_area_xf has a shape dim {
  size: 66
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day_xf has a shape dim {
  size: 17
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_hour_xf has a shape dim {
  size: 34
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_month_xf has a shape dim {
  size: 22
}
. Setting to DenseTensor.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:absl:Running publisher for Transform
INFO:absl:MetadataStore with DB connection initialized

Let's examine the output artifacts of Transform. This component produces two types of outputs:

  • transform_graph is the graph that can perform the preprocessing operations (this graph will be included in the serving and evaluation models).
  • transformed_examples represents the preprocessed training and evaluation data.
transform.outputs
{'transform_graph': OutputChannel(artifact_type=TransformGraph, producer_component_id=Transform, output_key=transform_graph, additional_properties={}, additional_custom_properties={}, _input_trigger=None,
 'transformed_examples': OutputChannel(artifact_type=Examples, producer_component_id=Transform, output_key=transformed_examples, additional_properties={}, additional_custom_properties={}, _input_trigger=None,
 'updated_analyzer_cache': OutputChannel(artifact_type=TransformCache, producer_component_id=Transform, output_key=updated_analyzer_cache, additional_properties={}, additional_custom_properties={}, _input_trigger=None,
 'pre_transform_schema': OutputChannel(artifact_type=Schema, producer_component_id=Transform, output_key=pre_transform_schema, additional_properties={}, additional_custom_properties={}, _input_trigger=None,
 'pre_transform_stats': OutputChannel(artifact_type=ExampleStatistics, producer_component_id=Transform, output_key=pre_transform_stats, additional_properties={}, additional_custom_properties={}, _input_trigger=None,
 'post_transform_schema': OutputChannel(artifact_type=Schema, producer_component_id=Transform, output_key=post_transform_schema, additional_properties={}, additional_custom_properties={}, _input_trigger=None,
 'post_transform_stats': OutputChannel(artifact_type=ExampleStatistics, producer_component_id=Transform, output_key=post_transform_stats, additional_properties={}, additional_custom_properties={}, _input_trigger=None,
 'post_transform_anomalies': OutputChannel(artifact_type=ExampleAnomalies, producer_component_id=Transform, output_key=post_transform_anomalies, additional_properties={}, additional_custom_properties={}, _input_trigger=None}

Take a peek at the transform_graph artifact. It points to a directory containing three subdirectories.

train_uri = transform.outputs['transform_graph'].get()[0].uri
os.listdir(train_uri)
['transform_fn', 'metadata', 'transformed_metadata']

The transformed_metadata subdirectory contains the schema of the preprocessed data. The transform_fn subdirectory contains the actual preprocessing graph. The metadata subdirectory contains the schema of the original data.

We can also take a look at the first three transformed examples:

# Get the URI of the output artifact representing the transformed examples, which is a directory
train_uri = os.path.join(transform.outputs['transformed_examples'].get()[0].uri, 'Split-train')

# Get the list of files in this directory (all compressed TFRecord files)
tfrecord_filenames = [os.path.join(train_uri, name)
                      for name in os.listdir(train_uri)]

# Create a `TFRecordDataset` to read these files
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")

# Iterate over the first 3 records and decode them.
for tfrecord in dataset.take(3):
  serialized_example = tfrecord.numpy()
  example = tf.train.Example()
  example.ParseFromString(serialized_example)
  pp.pprint(example)
features {
  feature {
    key: "company_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "dropoff_census_tract_xf"
    value {
      float_list {
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "dropoff_community_area_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "dropoff_latitude_xf"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "dropoff_longitude_xf"
    value {
      float_list {
        value: 9.0
      }
    }
  }
  feature {
    key: "fare_xf"
    value {
      float_list {
        value: 0.061060599982738495
      }
    }
  }
  feature {
    key: "payment_type_xf"
    value {
      float_list {
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "pickup_census_tract_xf"
    value {
      float_list {
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "pickup_community_area_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "pickup_latitude_xf"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "pickup_longitude_xf"
    value {
      float_list {
        value: 9.0
      }
    }
  }
  feature {
    key: "tips"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "trip_miles_xf"
    value {
      float_list {
        value: -0.15886740386486053
      }
    }
  }
  feature {
    key: "trip_seconds_xf"
    value {
      float_list {
        value: -0.7118487358093262
      }
    }
  }
  feature {
    key: "trip_start_day_xf"
    value {
      float_list {
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_start_hour_xf"
    value {
      float_list {
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_start_month_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
}

features {
  feature {
    key: "company_xf"
    value {
      float_list {
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "dropoff_census_tract_xf"
    value {
      float_list {
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "dropoff_community_area_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "dropoff_latitude_xf"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "dropoff_longitude_xf"
    value {
      float_list {
        value: 9.0
      }
    }
  }
  feature {
    key: "fare_xf"
    value {
      float_list {
        value: 1.2521240711212158
      }
    }
  }
  feature {
    key: "payment_type_xf"
    value {
      float_list {
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "pickup_census_tract_xf"
    value {
      float_list {
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "pickup_community_area_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "pickup_latitude_xf"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "pickup_longitude_xf"
    value {
      float_list {
        value: 3.0
      }
    }
  }
  feature {
    key: "tips"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "trip_miles_xf"
    value {
      float_list {
        value: 0.532160758972168
      }
    }
  }
  feature {
    key: "trip_seconds_xf"
    value {
      float_list {
        value: 0.5509493350982666
      }
    }
  }
  feature {
    key: "trip_start_day_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_start_hour_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_start_month_xf"
    value {
      float_list {
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
}

features {
  feature {
    key: "company_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "dropoff_census_tract_xf"
    value {
      float_list {
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "dropoff_community_area_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "dropoff_latitude_xf"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "dropoff_longitude_xf"
    value {
      float_list {
        value: 9.0
      }
    }
  }
  feature {
    key: "fare_xf"
    value {
      float_list {
        value: 0.3873794376850128
      }
    }
  }
  feature {
    key: "payment_type_xf"
    value {
      float_list {
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "pickup_census_tract_xf"
    value {
      float_list {
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "pickup_community_area_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "pickup_latitude_xf"
    value {
      float_list {
        value: 9.0
      }
    }
  }
  feature {
    key: "pickup_longitude_xf"
    value {
      float_list {
        value: 0.0
      }
    }
  }
  feature {
    key: "tips"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "trip_miles_xf"
    value {
      float_list {
        value: 0.21955278515815735
      }
    }
  }
  feature {
    key: "trip_seconds_xf"
    value {
      float_list {
        value: 0.0019067146349698305
      }
    }
  }
  feature {
    key: "trip_start_day_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_start_hour_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
  feature {
    key: "trip_start_month_xf"
    value {
      float_list {
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 1.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
        value: 0.0
      }
    }
  }
}

After the Transform component has transformed your data into features, and the next step is to train a model.

Trainer

The Trainer component will train a model that you define in TensorFlow. Default Trainer support Estimator API, to use Keras API, you need to specify Generic Trainer by setup custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor) in Trainer's contructor.

Trainer takes as input the schema from SchemaGen, the transformed data and graph from Transform, training parameters, as well as a module that contains user-defined model code.

Let's see an example of user-defined model code below (for an introduction to the TensorFlow Keras APIs, see the tutorial):

_taxi_trainer_module_file = 'taxi_trainer.py'
%%writefile {_taxi_trainer_module_file}

from typing import Dict, List, Text

import os
import glob
from absl import logging

import datetime
import tensorflow as tf
import tensorflow_transform as tft

from tfx import v1 as tfx
from tfx_bsl.public import tfxio
from tensorflow_transform import TFTransformOutput

# Imported files such as taxi_constants are normally cached, so changes are
# not honored after the first import.  Normally this is good for efficiency, but
# during development when we may be iterating code it can be a problem. To
# avoid this problem during development, reload the file.
import taxi_constants
import sys
if 'google.colab' in sys.modules:  # Testing to see if we're doing development
  import importlib
  importlib.reload(taxi_constants)

_LABEL_KEY = taxi_constants.LABEL_KEY

_BATCH_SIZE = 40


def _input_fn(file_pattern: List[Text],
              data_accessor: tfx.components.DataAccessor,
              tf_transform_output: tft.TFTransformOutput,
              batch_size: int = 200) -> tf.data.Dataset:
  """Generates features and label for tuning/training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    tf_transform_output: A TFTransformOutput.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
  return data_accessor.tf_dataset_factory(
      file_pattern,
      tfxio.TensorFlowDatasetOptions(
          batch_size=batch_size, label_key=_LABEL_KEY),
      tf_transform_output.transformed_metadata.schema)

def _get_tf_examples_serving_signature(model, tf_transform_output):
  """Returns a serving signature that accepts `tensorflow.Example`."""

  # We need to track the layers in the model in order to save it.
  # TODO(b/162357359): Revise once the bug is resolved.
  model.tft_layer_inference = tf_transform_output.transform_features_layer()

  @tf.function(input_signature=[
      tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')
  ])
  def serve_tf_examples_fn(serialized_tf_example):
    """Returns the output to be used in the serving signature."""
    raw_feature_spec = tf_transform_output.raw_feature_spec()
    # Remove label feature since these will not be present at serving time.
    raw_feature_spec.pop(_LABEL_KEY)
    raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec)
    transformed_features = model.tft_layer_inference(raw_features)
    logging.info('serve_transformed_features = %s', transformed_features)

    outputs = model(transformed_features)
    # TODO(b/154085620): Convert the predicted labels from the model using a
    # reverse-lookup (opposite of transform.py).
    return {'outputs': outputs}

  return serve_tf_examples_fn


def _get_transform_features_signature(model, tf_transform_output):
  """Returns a serving signature that applies tf.Transform to features."""

  # We need to track the layers in the model in order to save it.
  # TODO(b/162357359): Revise once the bug is resolved.
  model.tft_layer_eval = tf_transform_output.transform_features_layer()

  @tf.function(input_signature=[
      tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')
  ])
  def transform_features_fn(serialized_tf_example):
    """Returns the transformed_features to be fed as input to evaluator."""
    raw_feature_spec = tf_transform_output.raw_feature_spec()
    raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec)
    transformed_features = model.tft_layer_eval(raw_features)
    logging.info('eval_transformed_features = %s', transformed_features)
    return transformed_features

  return transform_features_fn


def export_serving_model(tf_transform_output, model, output_dir):
  """Exports a keras model for serving.
  Args:
    tf_transform_output: Wrapper around output of tf.Transform.
    model: A keras model to export for serving.
    output_dir: A directory where the model will be exported to.
  """
  # The layer has to be saved to the model for keras tracking purpases.
  model.tft_layer = tf_transform_output.transform_features_layer()

  signatures = {
      'serving_default':
          _get_tf_examples_serving_signature(model, tf_transform_output),
      'transform_features':
          _get_transform_features_signature(model, tf_transform_output),
  }

  model.save(output_dir, save_format='tf', signatures=signatures)


def _build_keras_model(tf_transform_output: TFTransformOutput
                       ) -> tf.keras.Model:
  """Creates a DNN Keras model for classifying taxi data.

  Args:
    tf_transform_output: [TFTransformOutput], the outputs from Transform

  Returns:
    A keras Model.
  """
  feature_spec = tf_transform_output.transformed_feature_spec().copy()
  feature_spec.pop(_LABEL_KEY)

  inputs = {}
  for key, spec in feature_spec.items():
    if isinstance(spec, tf.io.VarLenFeature):
      inputs[key] = tf.keras.layers.Input(
          shape=[None], name=key, dtype=spec.dtype, sparse=True)
    elif isinstance(spec, tf.io.FixedLenFeature):
      # TODO(b/208879020): Move into schema such that spec.shape is [1] and not
      # [] for scalars.
      inputs[key] = tf.keras.layers.Input(
          shape=spec.shape or [1], name=key, dtype=spec.dtype)
    else:
      raise ValueError('Spec type is not supported: ', key, spec)

  output = tf.keras.layers.Concatenate()(tf.nest.flatten(inputs))
  output = tf.keras.layers.Dense(100, activation='relu')(output)
  output = tf.keras.layers.Dense(70, activation='relu')(output)
  output = tf.keras.layers.Dense(50, activation='relu')(output)
  output = tf.keras.layers.Dense(20, activation='relu')(output)
  output = tf.keras.layers.Dense(1)(output)
  return tf.keras.Model(inputs=inputs, outputs=output)


# TFX Trainer will call this function.
def run_fn(fn_args: tfx.components.FnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

  train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor, 
                            tf_transform_output, _BATCH_SIZE)
  eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor, 
                           tf_transform_output, _BATCH_SIZE)

  model = _build_keras_model(tf_transform_output)

  model.compile(
      loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      metrics=[tf.keras.metrics.BinaryAccuracy()])

  tensorboard_callback = tf.keras.callbacks.TensorBoard(
      log_dir=fn_args.model_run_dir, update_freq='batch')

  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps,
      callbacks=[tensorboard_callback])

  # Export the model.
  export_serving_model(tf_transform_output, model, fn_args.serving_model_dir)
Writing taxi_trainer.py

Now, we pass in this model code to the Trainer component and run it to train the model.

trainer = tfx.components.Trainer(
    module_file=os.path.abspath(_taxi_trainer_module_file),
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    train_args=tfx.proto.TrainArgs(num_steps=10000),
    eval_args=tfx.proto.EvalArgs(num_steps=5000))
context.run(trainer, enable_cache=True)
INFO:absl:Generating ephemeral wheel package for '/tmpfs/src/temp/docs/tutorials/tfx/taxi_trainer.py' (including modules: ['taxi_constants', 'taxi_transform', 'taxi_trainer']).
INFO:absl:User module package has hash fingerprint version b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '/tmpfs/tmp/tmp41x4q5n1/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmpfs/tmp/tmpvus1roh7', '--dist-dir', '/tmpfs/tmp/tmpupkud42b']
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying taxi_constants.py -> build/lib
copying taxi_transform.py -> build/lib
copying taxi_trainer.py -> build/lib
installing to /tmpfs/tmp/tmpvus1roh7
running install
running install_lib
copying build/lib/taxi_constants.py -> /tmpfs/tmp/tmpvus1roh7
copying build/lib/taxi_transform.py -> /tmpfs/tmp/tmpvus1roh7
copying build/lib/taxi_trainer.py -> /tmpfs/tmp/tmpvus1roh7
running install_egg_info
running egg_info
creating tfx_user_code_Trainer.egg-info
writing tfx_user_code_Trainer.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Trainer.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Trainer.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
INFO:absl:Successfully built user code wheel distribution at '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78-py3-none-any.whl'; target user module is 'taxi_trainer'.
INFO:absl:Full user module path is 'taxi_trainer@/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78-py3-none-any.whl'
INFO:absl:Running driver for Trainer
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Trainer
INFO:absl:Train on the 'train' split when train_args.splits is not set.
INFO:absl:Evaluate on the 'eval' split when eval_args.splits is not set.
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
INFO:absl:udf_utils.get_fn {'train_args': '{\n  "num_steps": 10000\n}', 'eval_args': '{\n  "num_steps": 5000\n}', 'module_file': None, 'run_fn': None, 'trainer_fn': None, 'custom_config': 'null', 'module_path': 'taxi_trainer@/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78-py3-none-any.whl'} 'run_fn'
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmprnox1jiu', '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78-py3-none-any.whl']
reading manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
Copying tfx_user_code_Trainer.egg-info to /tmpfs/tmp/tmpvus1roh7/tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmpvus1roh7/tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78.dist-info/WHEEL
creating '/tmpfs/tmp/tmpupkud42b/tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78-py3-none-any.whl' and adding '/tmpfs/tmp/tmpvus1roh7' to it
adding 'taxi_constants.py'
adding 'taxi_trainer.py'
adding 'taxi_transform.py'
adding 'tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78.dist-info/METADATA'
adding 'tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78.dist-info/WHEEL'
adding 'tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78.dist-info/top_level.txt'
adding 'tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78.dist-info/RECORD'
removing /tmpfs/tmp/tmpvus1roh7
Processing /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/_wheels/tfx_user_code_Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78-py3-none-any.whl'.
INFO:absl:Training model.
INFO:absl:Feature company_xf has a shape dim {
  size: 55
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract_xf has a shape dim {
  size: 216
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area_xf has a shape dim {
  size: 79
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature fare_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type_xf has a shape dim {
  size: 16
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract_xf has a shape dim {
  size: 11
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_community_area_xf has a shape dim {
  size: 66
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day_xf has a shape dim {
  size: 17
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_hour_xf has a shape dim {
  size: 34
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_month_xf has a shape dim {
  size: 22
}
. Setting to DenseTensor.
Installing collected packages: tfx-user-code-Trainer
Successfully installed tfx-user-code-Trainer-0.0+b9709bb91d12b3ca961d7e09adb386fe9a4b2739bd1f503e449ed3ebdf707d78
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tfx_bsl/tfxio/tf_example_record.py:339: parse_example_dataset (from tensorflow.python.data.experimental.ops.parsing_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.map(tf.io.parse_example(...))` instead.
INFO:absl:Feature company_xf has a shape dim {
  size: 55
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract_xf has a shape dim {
  size: 216
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area_xf has a shape dim {
  size: 79
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature fare_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type_xf has a shape dim {
  size: 16
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract_xf has a shape dim {
  size: 11
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_community_area_xf has a shape dim {
  size: 66
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day_xf has a shape dim {
  size: 17
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_hour_xf has a shape dim {
  size: 34
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_month_xf has a shape dim {
  size: 22
}
. Setting to DenseTensor.
INFO:absl:Feature company_xf has a shape dim {
  size: 55
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract_xf has a shape dim {
  size: 216
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area_xf has a shape dim {
  size: 79
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature fare_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type_xf has a shape dim {
  size: 16
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract_xf has a shape dim {
  size: 11
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_community_area_xf has a shape dim {
  size: 66
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day_xf has a shape dim {
  size: 17
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_hour_xf has a shape dim {
  size: 34
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_month_xf has a shape dim {
  size: 22
}
. Setting to DenseTensor.
INFO:absl:Feature company_xf has a shape dim {
  size: 55
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract_xf has a shape dim {
  size: 216
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area_xf has a shape dim {
  size: 79
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature fare_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type_xf has a shape dim {
  size: 16
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract_xf has a shape dim {
  size: 11
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_community_area_xf has a shape dim {
  size: 66
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day_xf has a shape dim {
  size: 17
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_hour_xf has a shape dim {
  size: 34
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_month_xf has a shape dim {
  size: 22
}
. Setting to DenseTensor.
INFO:absl:Feature company_xf has a shape dim {
  size: 55
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_census_tract_xf has a shape dim {
  size: 216
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_community_area_xf has a shape dim {
  size: 79
}
. Setting to DenseTensor.
INFO:absl:Feature dropoff_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature dropoff_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature fare_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature payment_type_xf has a shape dim {
  size: 16
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_census_tract_xf has a shape dim {
  size: 11
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_community_area_xf has a shape dim {
  size: 66
}
. Setting to DenseTensor.
INFO:absl:Feature pickup_latitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature pickup_longitude_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature tips has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_miles_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_seconds_xf has a shape . Setting to DenseTensor.
INFO:absl:Feature trip_start_day_xf has a shape dim {
  size: 17
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_hour_xf has a shape dim {
  size: 34
}
. Setting to DenseTensor.
INFO:absl:Feature trip_start_month_xf has a shape dim {
  size: 22
}
. Setting to DenseTensor.
10000/10000 [==============================] - 100s 10ms/step - loss: 0.0773 - binary_accuracy: 0.9632 - val_loss: 1.2753 - val_binary_accuracy: 0.8684
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:serve_transformed_features = {'trip_start_month_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:16' shape=(None, 22) dtype=float32>, 'trip_miles_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:12' shape=(None,) dtype=float32>, 'dropoff_latitude_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:3' shape=(None,) dtype=float32>, 'company_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:0' shape=(None, 55) dtype=float32>, 'pickup_community_area_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:8' shape=(None, 66) dtype=float32>, 'payment_type_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:6' shape=(None, 16) dtype=float32>, 'trip_seconds_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:13' shape=(None,) dtype=float32>, 'pickup_longitude_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:10' shape=(None,) dtype=float32>, 'pickup_latitude_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:9' shape=(None,) dtype=float32>, 'trip_start_day_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:14' shape=(None, 17) dtype=float32>, 'dropoff_longitude_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:4' shape=(None,) dtype=float32>, 'dropoff_census_tract_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:1' shape=(None, 216) dtype=float32>, 'trip_start_hour_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:15' shape=(None, 34) dtype=float32>, 'pickup_census_tract_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:7' shape=(None, 11) dtype=float32>, 'fare_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:5' shape=(None,) dtype=float32>, 'dropoff_community_area_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:2' shape=(None, 79) dtype=float32>}
INFO:absl:Feature company has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature dropoff_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature fare has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature payment_type has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_census_tract has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_community_area has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_latitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature pickup_longitude has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature tips has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_miles has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_seconds has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_day has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_hour has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_month has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature trip_start_timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:eval_transformed_features = {'trip_start_month_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:16' shape=(None, 22) dtype=float32>, 'trip_miles_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:12' shape=(None,) dtype=float32>, 'dropoff_latitude_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:3' shape=(None,) dtype=float32>, 'company_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:0' shape=(None, 55) dtype=float32>, 'pickup_community_area_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:8' shape=(None, 66) dtype=float32>, 'payment_type_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:6' shape=(None, 16) dtype=float32>, 'trip_seconds_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:13' shape=(None,) dtype=float32>, 'pickup_longitude_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:10' shape=(None,) dtype=float32>, 'pickup_latitude_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:9' shape=(None,) dtype=float32>, 'trip_start_day_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:14' shape=(None, 17) dtype=float32>, 'dropoff_longitude_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:4' shape=(None,) dtype=float32>, 'dropoff_census_tract_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:1' shape=(None, 216) dtype=float32>, 'trip_start_hour_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:15' shape=(None, 34) dtype=float32>, 'pickup_census_tract_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:7' shape=(None, 11) dtype=float32>, 'fare_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:5' shape=(None,) dtype=float32>, 'tips': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:11' shape=(None,) dtype=int64>, 'dropoff_community_area_xf': <tf.Tensor 'transform_features_layer/StatefulPartitionedCall:2' shape=(None, 79) dtype=float32>}
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/Trainer/model/6/Format-Serving/assets
INFO:absl:Training complete. Model written to /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/Trainer/model/6/Format-Serving. ModelRun written to /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/Trainer/model_run/6
INFO:absl:Running publisher for Trainer
INFO:absl:MetadataStore with DB connection initialized

Analyze Training with TensorBoard

Take a peek at the trainer artifact. It points to a directory containing the model subdirectories.

model_artifact_dir = trainer.outputs['model'].get()[0].uri
pp.pprint(os.listdir(model_artifact_dir))
model_dir = os.path.join(model_artifact_dir, 'Format-Serving')
pp.pprint(os.listdir(model_dir))
['Format-Serving']
['keras_metadata.pb', 'assets', 'fingerprint.pb', 'variables', 'saved_model.pb']

Optionally, we can connect TensorBoard to the Trainer to analyze our model's training curves.

model_run_artifact_dir = trainer.outputs['model_run'].get()[0].uri

%load_ext tensorboard
%tensorboard --logdir {model_run_artifact_dir}

Evaluator

The Evaluator component computes model performance metrics over the evaluation set. It uses the TensorFlow Model Analysis library. The Evaluator can also optionally validate that a newly trained model is better than the previous model. This is useful in a production pipeline setting where you may automatically train and validate a model every day. In this notebook, we only train one model, so the Evaluator automatically will label the model as "good".

Evaluator will take as input the data from ExampleGen, the trained model from Trainer, and slicing configuration. The slicing configuration allows you to slice your metrics on feature values (e.g. how does your model perform on taxi trips that start at 8am versus 8pm?). See an example of this configuration below:

# Imported files such as taxi_constants are normally cached, so changes are
# not honored after the first import.  Normally this is good for efficiency, but
# during development when we may be iterating code it can be a problem. To
# avoid this problem during development, reload the file.
import taxi_constants
import sys
if 'google.colab' in sys.modules:  # Testing to see if we're doing development
  import importlib
  importlib.reload(taxi_constants)

eval_config = tfma.EvalConfig(
    model_specs=[
        # This assumes a serving model with signature 'serving_default'. If
        # using estimator based EvalSavedModel, add signature_name: 'eval' and
        # remove the label_key.
        tfma.ModelSpec(
            signature_name='serving_default',
            label_key=taxi_constants.LABEL_KEY,
            preprocessing_function_names=['transform_features'],
            )
        ],
    metrics_specs=[
        tfma.MetricsSpec(
            # The metrics added here are in addition to those saved with the
            # model (assuming either a keras model or EvalSavedModel is used).
            # Any metrics added into the saved model (for example using
            # model.compile(..., metrics=[...]), etc) will be computed
            # automatically.
            # To add validation thresholds for metrics saved with the model,
            # add them keyed by metric name to the thresholds map.
            metrics=[
                tfma.MetricConfig(class_name='ExampleCount'),
                tfma.MetricConfig(class_name='BinaryAccuracy',
                  threshold=tfma.MetricThreshold(
                      value_threshold=tfma.GenericValueThreshold(
                          lower_bound={'value': 0.5}),
                      # Change threshold will be ignored if there is no
                      # baseline model resolved from MLMD (first run).
                      change_threshold=tfma.GenericChangeThreshold(
                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                          absolute={'value': -1e-10})))
            ]
        )
    ],
    slicing_specs=[
        # An empty slice spec means the overall slice, i.e. the whole dataset.
        tfma.SlicingSpec(),
        # Data can be sliced along a feature column. In this case, data is
        # sliced along feature column trip_start_hour.
        tfma.SlicingSpec(
            feature_keys=['trip_start_hour'])
    ])

Next, we give this configuration to Evaluator and run it.

# Use TFMA to compute a evaluation statistics over features of a model and
# validate them against a baseline.

# The model resolver is only required if performing model validation in addition
# to evaluation. In this case we validate against the latest blessed model. If
# no model has been blessed before (as in this case) the evaluator will make our
# candidate the first blessed model.
model_resolver = tfx.dsl.Resolver(
      strategy_class=tfx.dsl.experimental.LatestBlessedModelStrategy,
      model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),
      model_blessing=tfx.dsl.Channel(
          type=tfx.types.standard_artifacts.ModelBlessing)).with_id(
              'latest_blessed_model_resolver')
context.run(model_resolver, enable_cache=True)

evaluator = tfx.components.Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'],
    baseline_model=model_resolver.outputs['model'],
    eval_config=eval_config)
context.run(evaluator, enable_cache=True)
INFO:absl:Running driver for latest_blessed_model_resolver
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running publisher for latest_blessed_model_resolver
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running driver for Evaluator
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Evaluator
INFO:absl:udf_utils.get_fn {'eval_config': '{\n  "metrics_specs": [\n    {\n      "metrics": [\n        {\n          "class_name": "ExampleCount"\n        },\n        {\n          "class_name": "BinaryAccuracy",\n          "threshold": {\n            "change_threshold": {\n              "absolute": -1e-10,\n              "direction": "HIGHER_IS_BETTER"\n            },\n            "value_threshold": {\n              "lower_bound": 0.5\n            }\n          }\n        }\n      ]\n    }\n  ],\n  "model_specs": [\n    {\n      "label_key": "tips",\n      "preprocessing_function_names": [\n        "transform_features"\n      ],\n      "signature_name": "serving_default"\n    }\n  ],\n  "slicing_specs": [\n    {},\n    {\n      "feature_keys": [\n        "trip_start_hour"\n      ]\n    }\n  ]\n}', 'feature_slicing_spec': None, 'fairness_indicator_thresholds': 'null', 'example_splits': 'null', 'module_file': None, 'module_path': None} 'custom_eval_shared_model'
INFO:absl:Request was made to ignore the baseline ModelSpec and any change thresholds. This is likely because a baseline model was not provided: updated_config=
model_specs {
  signature_name: "serving_default"
  label_key: "tips"
  preprocessing_function_names: "transform_features"
}
slicing_specs {
}
slicing_specs {
  feature_keys: "trip_start_hour"
}
metrics_specs {
  metrics {
    class_name: "ExampleCount"
  }
  metrics {
    class_name: "BinaryAccuracy"
    threshold {
      value_threshold {
        lower_bound {
          value: 0.5
        }
      }
    }
  }
}

INFO:absl:Using /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/Trainer/model/6/Format-Serving as  model.
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. For example, in the saved checkpoint object, `model.layer.weight` and `model.layer_copy.weight` reference the same variable, while in the current object these are two different variables. The referenced variables are:(<keras.saving.legacy.saved_model.load.TensorFlowTransform>TransformFeaturesLayer object at 0x7f9970614f40> and <keras.engine.input_layer.InputLayer object at 0x7f98d9921340>).
INFO:absl:The 'example_splits' parameter is not set, using 'eval' split.
INFO:absl:Evaluating model.
INFO:absl:udf_utils.get_fn {'eval_config': '{\n  "metrics_specs": [\n    {\n      "metrics": [\n        {\n          "class_name": "ExampleCount"\n        },\n        {\n          "class_name": "BinaryAccuracy",\n          "threshold": {\n            "change_threshold": {\n              "absolute": -1e-10,\n              "direction": "HIGHER_IS_BETTER"\n            },\n            "value_threshold": {\n              "lower_bound": 0.5\n            }\n          }\n        }\n      ]\n    }\n  ],\n  "model_specs": [\n    {\n      "label_key": "tips",\n      "preprocessing_function_names": [\n        "transform_features"\n      ],\n      "signature_name": "serving_default"\n    }\n  ],\n  "slicing_specs": [\n    {},\n    {\n      "feature_keys": [\n        "trip_start_hour"\n      ]\n    }\n  ]\n}', 'feature_slicing_spec': None, 'fairness_indicator_thresholds': 'null', 'example_splits': 'null', 'module_file': None, 'module_path': None} 'custom_extractors'
INFO:absl:Request was made to ignore the baseline ModelSpec and any change thresholds. This is likely because a baseline model was not provided: updated_config=
model_specs {
  signature_name: "serving_default"
  label_key: "tips"
  preprocessing_function_names: "transform_features"
}
slicing_specs {
}
slicing_specs {
  feature_keys: "trip_start_hour"
}
metrics_specs {
  metrics {
    class_name: "ExampleCount"
  }
  metrics {
    class_name: "BinaryAccuracy"
    threshold {
      value_threshold {
        lower_bound {
          value: 0.5
        }
      }
    }
  }
  model_names: ""
}

INFO:absl:Request was made to ignore the baseline ModelSpec and any change thresholds. This is likely because a baseline model was not provided: updated_config=
model_specs {
  signature_name: "serving_default"
  label_key: "tips"
  preprocessing_function_names: "transform_features"
}
slicing_specs {
}
slicing_specs {
  feature_keys: "trip_start_hour"
}
metrics_specs {
  metrics {
    class_name: "ExampleCount"
  }
  metrics {
    class_name: "BinaryAccuracy"
    threshold {
      value_threshold {
        lower_bound {
          value: 0.5
        }
      }
    }
  }
  model_names: ""
}

INFO:absl:Request was made to ignore the baseline ModelSpec and any change thresholds. This is likely because a baseline model was not provided: updated_config=
model_specs {
  signature_name: "serving_default"
  label_key: "tips"
  preprocessing_function_names: "transform_features"
}
slicing_specs {
}
slicing_specs {
  feature_keys: "trip_start_hour"
}
metrics_specs {
  metrics {
    class_name: "ExampleCount"
  }
  metrics {
    class_name: "BinaryAccuracy"
    threshold {
      value_threshold {
        lower_bound {
          value: 0.5
        }
      }
    }
  }
  model_names: ""
}
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. For example, in the saved checkpoint object, `model.layer.weight` and `model.layer_copy.weight` reference the same variable, while in the current object these are two different variables. The referenced variables are:(<keras.saving.legacy.saved_model.load.TensorFlowTransform>TransformFeaturesLayer object at 0x7f99cc69edc0> and <keras.engine.input_layer.InputLayer object at 0x7f99e406f2b0>).
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. For example, in the saved checkpoint object, `model.layer.weight` and `model.layer_copy.weight` reference the same variable, while in the current object these are two different variables. The referenced variables are:(<keras.saving.legacy.saved_model.load.TensorFlowTransform>TransformFeaturesLayer object at 0x7f98d9d54280> and <keras.engine.input_layer.InputLayer object at 0x7f9994264eb0>).
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. For example, in the saved checkpoint object, `model.layer.weight` and `model.layer_copy.weight` reference the same variable, while in the current object these are two different variables. The referenced variables are:(<keras.saving.legacy.saved_model.load.TensorFlowTransform>TransformFeaturesLayer object at 0x7f99e439b6a0> and <keras.engine.input_layer.InputLayer object at 0x7f99701c5f40>).
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. For example, in the saved checkpoint object, `model.layer.weight` and `model.layer_copy.weight` reference the same variable, while in the current object these are two different variables. The referenced variables are:(<keras.saving.legacy.saved_model.load.TensorFlowTransform>TransformFeaturesLayer object at 0x7f9b1588a7f0> and <keras.engine.input_layer.InputLayer object at 0x7f9a54093c70>).
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. For example, in the saved checkpoint object, `model.layer.weight` and `model.layer_copy.weight` reference the same variable, while in the current object these are two different variables. The referenced variables are:(<keras.saving.legacy.saved_model.load.TensorFlowTransform>TransformFeaturesLayer object at 0x7f99cc243fd0> and <keras.engine.input_layer.InputLayer object at 0x7f99ac7102b0>).
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. For example, in the saved checkpoint object, `model.layer.weight` and `model.layer_copy.weight` reference the same variable, while in the current object these are two different variables. The referenced variables are:(<keras.saving.legacy.saved_model.load.TensorFlowTransform>TransformFeaturesLayer object at 0x7f98d93418e0> and <keras.engine.input_layer.InputLayer object at 0x7f98d93832e0>).
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. For example, in the saved checkpoint object, `model.layer.weight` and `model.layer_copy.weight` reference the same variable, while in the current object these are two different variables. The referenced variables are:(<keras.saving.legacy.saved_model.load.TensorFlowTransform>TransformFeaturesLayer object at 0x7f94ec111e80> and <keras.engine.input_layer.InputLayer object at 0x7f98d939d610>).
INFO:absl:Evaluation complete. Results written to /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/Evaluator/evaluation/8.
INFO:absl:Checking validation results.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py:110: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
INFO:absl:Blessing result True written to /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/Evaluator/blessing/8.
INFO:absl:Running publisher for Evaluator
INFO:absl:MetadataStore with DB connection initialized

Now let's examine the output artifacts of Evaluator.

evaluator.outputs
{'evaluation': OutputChannel(artifact_type=ModelEvaluation, producer_component_id=Evaluator, output_key=evaluation, additional_properties={}, additional_custom_properties={}, _input_trigger=None,
 'blessing': OutputChannel(artifact_type=ModelBlessing, producer_component_id=Evaluator, output_key=blessing, additional_properties={}, additional_custom_properties={}, _input_trigger=None}

Using the evaluation output we can show the default visualization of global metrics on the entire evaluation set.

context.show(evaluator.outputs['evaluation'])
SlicingMetricsViewer(config={'weightedExamplesColumn': 'example_count'}, data=[{'slice': 'Overall', 'metrics':…

To see the visualization for sliced evaluation metrics, we can directly call the TensorFlow Model Analysis library.

import tensorflow_model_analysis as tfma

# Get the TFMA output result path and load the result.
PATH_TO_RESULT = evaluator.outputs['evaluation'].get()[0].uri
tfma_result = tfma.load_eval_result(PATH_TO_RESULT)

# Show data sliced along feature column trip_start_hour.
tfma.view.render_slicing_metrics(
    tfma_result, slicing_column='trip_start_hour')
SlicingMetricsViewer(config={'weightedExamplesColumn': 'example_count'}, data=[{'slice': 'trip_start_hour:19',…

This visualization shows the same metrics, but computed at every feature value of trip_start_hour instead of on the entire evaluation set.

TensorFlow Model Analysis supports many other visualizations, such as Fairness Indicators and plotting a time series of model performance. To learn more, see the tutorial.

Since we added thresholds to our config, validation output is also available. The precence of a blessing artifact indicates that our model passed validation. Since this is the first validation being performed the candidate is automatically blessed.

blessing_uri = evaluator.outputs['blessing'].get()[0].uri
!ls -l {blessing_uri}
total 0
-rw-rw-r-- 1 kbuilder kbuilder 0 Jul 28 09:33 BLESSED

Now can also verify the success by loading the validation result record:

PATH_TO_RESULT = evaluator.outputs['evaluation'].get()[0].uri
print(tfma.load_validation_result(PATH_TO_RESULT))
validation_ok: true
validation_details {
  slicing_details {
    slicing_spec {
    }
    num_matching_slices: 25
  }
}

Pusher

The Pusher component is usually at the end of a TFX pipeline. It checks whether a model has passed validation, and if so, exports the model to _serving_model_dir.

pusher = tfx.components.Pusher(
    model=trainer.outputs['model'],
    model_blessing=evaluator.outputs['blessing'],
    push_destination=tfx.proto.PushDestination(
        filesystem=tfx.proto.PushDestination.Filesystem(
            base_directory=_serving_model_dir)))
context.run(pusher, enable_cache=True)
INFO:absl:Running driver for Pusher
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Pusher
INFO:absl:Model version: 1690536819
INFO:absl:Model written to serving path /tmpfs/tmp/tmppujsjo7w/serving_model/taxi_simple/1690536819.
INFO:absl:Model pushed to /tmpfs/tmp/tfx-interactive-2023-07-28T09_30_08.401614-fwsxqb86/Pusher/pushed_model/9.
INFO:absl:Running publisher for Pusher
INFO:absl:MetadataStore with DB connection initialized

Let's examine the output artifacts of Pusher.

pusher.outputs
{'pushed_model': OutputChannel(artifact_type=PushedModel, producer_component_id=Pusher, output_key=pushed_model, additional_properties={}, additional_custom_properties={}, _input_trigger=None}

In particular, the Pusher will export your model in the SavedModel format, which looks like this:

push_uri = pusher.outputs['pushed_model'].get()[0].uri
model = tf.saved_model.load(push_uri)

for item in model.signatures.items():
  pp.pprint(item)
('serving_default',
 <ConcreteFunction signature_wrapper(*, examples) at 0x7F9BD24D7130>)
('transform_features',
 <ConcreteFunction signature_wrapper(*, examples) at 0x7F94E657B580>)

We're finished our tour of built-in TFX components!