TensorFlow Model Analysis

An Example of a Key Component of TensorFlow Extended (TFX)

TensorFlow Model Analysis (TFMA) is a library for performing model evaluation across different slices of data. TFMA performs its computations in a distributed manner over large amounts of data using Apache Beam.

This example colab notebook illustrates how TFMA can be used to investigate and visualize the performance of a model with respect to characteristics of the dataset. We'll use a model that we trained previously, and now you get to play with the results! The model we trained was for the Chicago Taxi Example, which uses the Taxi Trips dataset released by the City of Chicago. Explore the full dataset in the BigQuery UI.

As a modeler and developer, think about how this data is used and the potential benefits and harm a model's predictions can cause. A model like this could reinforce societal biases and disparities. Is a feature relevant to the problem you want to solve or will it introduce bias? For more information, read about ML fairness.

The columns in the dataset are:

pickup_community_areafaretrip_start_month
trip_start_hourtrip_start_daytrip_start_timestamp
pickup_latitudepickup_longitudedropoff_latitude
dropoff_longitudetrip_milespickup_census_tract
dropoff_census_tractpayment_typecompany
trip_secondsdropoff_community_areatips

Install Jupyter Extensions

jupyter nbextension enable --py widgetsnbextension --sys-prefix 
jupyter nbextension install --py --symlink tensorflow_model_analysis --sys-prefix 
jupyter nbextension enable --py tensorflow_model_analysis --sys-prefix 

Install TensorFlow Model Analysis (TFMA)

This will pull in all the dependencies, and will take a minute.

# Upgrade pip to the latest, and install TFMA.
pip install -U pip
pip install tensorflow-model-analysis

Now you must restart the runtime before running the cells below.

# This setup was tested with TF 2.10 and TFMA 0.41 (using colab), but it should
# also work with the latest release.
import sys

# Confirm that we're using Python 3
assert sys.version_info.major==3, 'This notebook must be run using Python 3.'

import tensorflow as tf
print('TF version: {}'.format(tf.__version__))
import apache_beam as beam
print('Beam version: {}'.format(beam.__version__))
import tensorflow_model_analysis as tfma
print('TFMA version: {}'.format(tfma.__version__))
TF version: 2.12.1
Beam version: 2.48.0
TFMA version: 0.44.0

Load The Files

We'll download a tar file that has everything we need. That includes:

  • Training and evaluation datasets
  • Data schema
  • Training and serving saved models (keras and estimator) and eval saved models (estimator).
# Download the tar file from GCP and extract it
import io, os, tempfile
TAR_NAME = 'saved_models-2.2'
BASE_DIR = tempfile.mkdtemp()
DATA_DIR = os.path.join(BASE_DIR, TAR_NAME, 'data')
MODELS_DIR = os.path.join(BASE_DIR, TAR_NAME, 'models')
SCHEMA = os.path.join(BASE_DIR, TAR_NAME, 'schema.pbtxt')
OUTPUT_DIR = os.path.join(BASE_DIR, 'output')

!curl -O https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/{TAR_NAME}.tar
!tar xf {TAR_NAME}.tar
!mv {TAR_NAME} {BASE_DIR}
!rm {TAR_NAME}.tar

print("Here's what we downloaded:")
!ls -R {BASE_DIR}
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 6800k  100 6800k    0     0  18.9M      0 --:--:-- --:--:-- --:--:-- 18.9M
Here's what we downloaded:
/tmpfs/tmp/tmpi1xr1bb5:
saved_models-2.2

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2:
data  models  schema.pbtxt

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/data:
eval  train

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/data/eval:
data.csv

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/data/train:
data.csv

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models:
estimator  keras

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/estimator:
eval_model_dir  serving_model_dir

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/estimator/eval_model_dir:
1591221811

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/estimator/eval_model_dir/1591221811:
saved_model.pb  tmp.pbtxt  variables

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/estimator/eval_model_dir/1591221811/variables:
variables.data-00000-of-00001  variables.index

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/estimator/serving_model_dir:
checkpoint
eval_chicago-taxi-eval
events.out.tfevents.1591221780.my-pipeline-b57vp-237544850
export
graph.pbtxt
model.ckpt-100.data-00000-of-00001
model.ckpt-100.index
model.ckpt-100.meta

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/estimator/serving_model_dir/eval_chicago-taxi-eval:
events.out.tfevents.1591221799.my-pipeline-b57vp-237544850

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/estimator/serving_model_dir/export:
chicago-taxi

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/estimator/serving_model_dir/export/chicago-taxi:
1591221801

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/estimator/serving_model_dir/export/chicago-taxi/1591221801:
saved_model.pb  variables

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/estimator/serving_model_dir/export/chicago-taxi/1591221801/variables:
variables.data-00000-of-00001  variables.index

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/keras:
0  1  2

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/keras/0:
saved_model.pb  variables

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/keras/0/variables:
variables.data-00000-of-00001  variables.index

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/keras/1:
saved_model.pb  variables

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/keras/1/variables:
variables.data-00000-of-00001  variables.index

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/keras/2:
saved_model.pb  variables

/tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/keras/2/variables:
variables.data-00000-of-00001  variables.index

Parse the Schema

Among the things we downloaded was a schema for our data that was created by TensorFlow Data Validation. Let's parse that now so that we can use it with TFMA.

import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.lib.io import file_io
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow.core.example import example_pb2

schema = schema_pb2.Schema()
contents = file_io.read_file_to_string(SCHEMA)
schema = text_format.Parse(contents, schema)

Use the Schema to Create TFRecords

We need to give TFMA access to our dataset, so let's create a TFRecords file. We can use our schema to create it, since it gives us the correct type for each feature.

import csv

datafile = os.path.join(DATA_DIR, 'eval', 'data.csv')
reader = csv.DictReader(open(datafile, 'r'))
examples = []
for line in reader:
  example = example_pb2.Example()
  for feature in schema.feature:
    key = feature.name
    if feature.type == schema_pb2.FLOAT:
      example.features.feature[key].float_list.value[:] = (
          [float(line[key])] if len(line[key]) > 0 else [])
    elif feature.type == schema_pb2.INT:
      example.features.feature[key].int64_list.value[:] = (
          [int(line[key])] if len(line[key]) > 0 else [])
    elif feature.type == schema_pb2.BYTES:
      example.features.feature[key].bytes_list.value[:] = (
          [line[key].encode('utf8')] if len(line[key]) > 0 else [])
  # Add a new column 'big_tipper' that indicates if tips was > 20% of the fare. 
  # TODO(b/157064428): Remove after label transformation is supported for Keras.
  big_tipper = float(line['tips']) > float(line['fare']) * 0.2
  example.features.feature['big_tipper'].float_list.value[:] = [big_tipper]
  examples.append(example)

tfrecord_file = os.path.join(BASE_DIR, 'train_data.rio')
with tf.io.TFRecordWriter(tfrecord_file) as writer:
  for example in examples:
    writer.write(example.SerializeToString())

!ls {tfrecord_file}
/tmpfs/tmp/tmpi1xr1bb5/train_data.rio

Setup and Run TFMA

TFMA supports a number of different model types including TF keras models, models based on generic TF2 signature APIs, as well TF estimator based models. The get_started guide has the full list of model types supported and any restrictions. For this example we are going to show how to configure a keras based model as well as an estimator based model that was saved as an EvalSavedModel. See the FAQ for examples of other configurations.

TFMA provides support for calculating metrics that were used at training time (i.e. built-in metrics) as well metrics defined after the model was saved as part of the TFMA configuration settings. For our keras setup we will demonstrate adding our metrics and plots manually as part of our configuration (see the metrics guide for information on the metrics and plots that are supported). For the estimator setup we will use the built-in metrics that were saved with the model. Our setups also include a number of slicing specs which are discussed in more detail in the following sections.

After creating a tfma.EvalConfig and tfma.EvalSharedModel we can then run TFMA using tfma.run_model_analysis. This will create a tfma.EvalResult which we can use later for rendering our metrics and plots.

Keras

import tensorflow_model_analysis as tfma

# Setup tfma.EvalConfig settings
keras_eval_config = text_format.Parse("""
  ## Model information
  model_specs {
    # For keras (and serving models) we need to add a `label_key`.
    label_key: "big_tipper"
  }

  ## Post training metric information. These will be merged with any built-in
  ## metrics from training.
  metrics_specs {
    metrics { class_name: "ExampleCount" }
    metrics { class_name: "AUC" }
    metrics { class_name: "Precision" }
    metrics { class_name: "Recall" }
    metrics { class_name: "MeanPrediction" }
    metrics { class_name: "Calibration" }
    metrics { class_name: "CalibrationPlot" }
    metrics { class_name: "ConfusionMatrixPlot" }
    # ... add additional metrics and plots ...
  }

  ## Slicing information
  slicing_specs {}  # overall slice
  slicing_specs {
    feature_keys: ["trip_start_hour"]
  }
  slicing_specs {
    feature_keys: ["trip_start_day"]
  }
  slicing_specs {
    feature_values: {
      key: "trip_start_month"
      value: "1"
    }
  }
""", tfma.EvalConfig())

# Create a tfma.EvalSharedModel that points at our keras model.
keras_model_path = os.path.join(MODELS_DIR, 'keras', '2')
keras_eval_shared_model = tfma.default_eval_shared_model(
    eval_saved_model_path=keras_model_path,
    eval_config=keras_eval_config)

keras_output_path = os.path.join(OUTPUT_DIR, 'keras')

# Run TFMA
keras_eval_result = tfma.run_model_analysis(
    eval_shared_model=keras_eval_shared_model,
    eval_config=keras_eval_config,
    data_location=tfrecord_file,
    output_path=keras_output_path)
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:absl:Tensorflow version (2.12.1) found. Note that TFMA support for TF 2.0 is currently in beta
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py:110: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py:110: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Estimator

import tensorflow_model_analysis as tfma

# Setup tfma.EvalConfig settings
estimator_eval_config = text_format.Parse("""
  ## Model information
  model_specs {
    # To use EvalSavedModel set `signature_name` to "eval".
    signature_name: "eval"
  }

  ## Post training metric information. These will be merged with any built-in
  ## metrics from training.
  metrics_specs {
    metrics { class_name: "ConfusionMatrixPlot" }
    # ... add additional metrics and plots ...
  }

  ## Slicing information
  slicing_specs {}  # overall slice
  slicing_specs {
    feature_keys: ["trip_start_hour"]
  }
  slicing_specs {
    feature_keys: ["trip_start_day"]
  }
  slicing_specs {
    feature_values: {
      key: "trip_start_month"
      value: "1"
    }
  }
""", tfma.EvalConfig())

# Create a tfma.EvalSharedModel that points at our eval saved model.
estimator_base_model_path = os.path.join(
    MODELS_DIR, 'estimator', 'eval_model_dir')
estimator_model_path = os.path.join(
    estimator_base_model_path, os.listdir(estimator_base_model_path)[0])
estimator_eval_shared_model = tfma.default_eval_shared_model(
    eval_saved_model_path=estimator_model_path,
    eval_config=estimator_eval_config)

estimator_output_path = os.path.join(OUTPUT_DIR, 'estimator')

# Run TFMA
estimator_eval_result = tfma.run_model_analysis(
    eval_shared_model=estimator_eval_shared_model,
    eval_config=estimator_eval_config,
    data_location=tfrecord_file,
    output_path=estimator_output_path)
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:absl:Tensorflow version (2.12.1) found. Note that TFMA support for TF 2.0 is currently in beta
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/eval_saved_model/load.py:163: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/eval_saved_model/load.py:163: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/estimator/eval_model_dir/1591221811/variables/variables
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpi1xr1bb5/saved_models-2.2/models/estimator/eval_model_dir/1591221811/variables/variables
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/eval_saved_model/graph_ref.py:183: get_tensor_from_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
2023-07-28 10:48:35.821825: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT64
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}

    while inferring type of node 'dnn/zero_fraction/cond/output/_9'
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/eval_saved_model/graph_ref.py:183: get_tensor_from_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
2023-07-28 10:48:36.196438: W tensorflow/c/c_api.cc:300] Operation '{name:'head/metrics/true_positives_1/Assign' id:674 op device:{requested: '', assigned: ''} def:{ { {node head/metrics/true_positives_1/Assign} } = AssignVariableOp[_has_manual_control_dependencies=true, dtype=DT_FLOAT, validate_shape=false](head/metrics/true_positives_1, head/metrics/true_positives_1/Initializer/zeros)} }' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.
2023-07-28 10:48:36.350067: W tensorflow/c/c_api.cc:300] Operation '{name:'head/metrics/true_positives_1/Assign' id:674 op device:{requested: '', assigned: ''} def:{ { {node head/metrics/true_positives_1/Assign} } = AssignVariableOp[_has_manual_control_dependencies=true, dtype=DT_FLOAT, validate_shape=false](head/metrics/true_positives_1, head/metrics/true_positives_1/Initializer/zeros)} }' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.

Visualizing Metrics and Plots

Now that we've run the evaluation, let's take a look at our visualizations using TFMA. For the following examples, we will visualize the results from running the evaluation on the keras model. To view the estimator based model update the eval_result_path to point at our estimator_output_path variable.

eval_result_path = keras_output_path
# eval_result_path = estimator_output_path

eval_result = keras_eval_result
# eval_result = estimator_eval_result

Rendering Metrics

TFMA provides dataframe APIs in tfma.experimental.dataframe to load the materalized output as Pandas DataFrames. To view metrics you can use metrics_as_dataframes(tfma.load_metrics(eval_path)), which returns an object which potentially contains several DataFrames, one for each metric value type (double_value, confusion_matrix_at_thresholds, bytes_value, and array_value). The specific DataFrames populated depends on the eval result. Here, we show the double_value DataFrame as an example.

import tensorflow_model_analysis.experimental.dataframe as tfma_dataframe
dfs = tfma_dataframe.metrics_as_dataframes(
  tfma.load_metrics(eval_result_path))

display(dfs.double_value.head())

Each of the DataFrames has a column multi-index with the top-level columns: slices, metric_keys, and metric_values. The exact columns of each group can change according to the payload. we can use DataFrame.columns API to inspect all the multi-index columns. For example, the slices columns are 'Overall', 'trip_start_day', 'trip_start_hour', and 'trip_start_month', which is configured by the slicing_specs in the eval_config.

print(dfs.double_value.columns)
MultiIndex([(       'slices',  'trip_start_hour'),
            (       'slices',   'trip_start_day'),
            (       'slices',          'Overall'),
            (       'slices', 'trip_start_month'),
            (  'metric_keys',             'name'),
            (  'metric_keys',       'model_name'),
            (  'metric_keys',      'output_name'),
            (  'metric_keys', 'example_weighted'),
            (  'metric_keys',          'is_diff'),
            ('metric_values',     'double_value')],
           )

Auto pivoting

The DataFrame is verbose by design so that there is no loss of information from the payload. However, sometimes, for direct consumption, we might want to organize the information in a more concise but lossy form: slices as rows and metrics as columns. TFMA provides an auto_pivot API for this purpose. The util pivots on all of the non-unique columns inside metric_keys, and condenses all the slices into one stringified_slices column by default.

tfma_dataframe.auto_pivot(dfs.double_value).head()

Filtering slices

Since the outputs are DataFrames, any native DataFrame APIs can be used to slice and dice the DataFrame. For example, if we are only interested in trip_start_hour of 1, 3, 5, 7 and not in trip_start_day, we can use DataFrame's .loc filtering logic. Again, we use the auto_pivot function to re-organize the DataFrame in the slice vs. metrics view after the filtering is performed.

df_double = dfs.double_value
df_filtered = (df_double
  .loc[df_double.slices.trip_start_hour.isin([1,3,5,7])]
)
display(tfma_dataframe.auto_pivot(df_filtered))

Sorting by metric values

We can also sort slices by metrics value. As an example, we show how to sort slices in the above DataFrame by ascending AUC, so that we can find poorly performing slices. This involves two steps: auto-pivoting so that slices are represented as rows and columns are metrics, and then sorting the pivoted DataFrame by the AUC column.

# Pivoted table sorted by AUC in ascending order.
df_sorted = (
    tfma_dataframe.auto_pivot(df_double)
    .sort_values(by='auc', ascending=True)
    )
display(df_sorted.head())

Rendering Plots

Any plots that were added to the tfma.EvalConfig as post training metric_specs can be displayed using tfma.view.render_plot.

As with metrics, plots can be viewed by slice. Unlike metrics, only plots for a particular slice value can be displayed so the tfma.SlicingSpec must be used and it must specify both a slice feature name and value. If no slice is provided then the plots for the Overall slice is used.

In the example below we are displaying the CalibrationPlot and ConfusionMatrixPlot plots that were computed for the trip_start_hour:1 slice.

tfma.view.render_plot(
    eval_result,
    tfma.SlicingSpec(feature_values={'trip_start_hour': '1'}))
PlotViewer(config={'sliceName': 'trip_start_hour:1', 'metricKeys': {'calibrationPlot': {'metricName': 'calibra…

Tracking Model Performance Over Time

Your training dataset will be used for training your model, and will hopefully be representative of your test dataset and the data that will be sent to your model in production. However, while the data in inference requests may remain the same as your training data, in many cases it will start to change enough so that the performance of your model will change.

That means that you need to monitor and measure your model's performance on an ongoing basis, so that you can be aware of and react to changes. Let's take a look at how TFMA can help.

Let's load 3 different model runs and use TFMA to see how they compare using render_time_series.

# Note this re-uses the EvalConfig from the keras setup.

# Run eval on each saved model
output_paths = []
for i in range(3):
  # Create a tfma.EvalSharedModel that points at our saved model.
  eval_shared_model = tfma.default_eval_shared_model(
      eval_saved_model_path=os.path.join(MODELS_DIR, 'keras', str(i)),
      eval_config=keras_eval_config)

  output_path = os.path.join(OUTPUT_DIR, 'time_series', str(i))
  output_paths.append(output_path)

  # Run TFMA
  tfma.run_model_analysis(eval_shared_model=eval_shared_model,
                          eval_config=keras_eval_config,
                          data_location=tfrecord_file,
                          output_path=output_path)
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:absl:Tensorflow version (2.12.1) found. Note that TFMA support for TF 2.0 is currently in beta
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:absl:Tensorflow version (2.12.1) found. Note that TFMA support for TF 2.0 is currently in beta
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:absl:Tensorflow version (2.12.1) found. Note that TFMA support for TF 2.0 is currently in beta
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.

First, we'll imagine that we've trained and deployed our model yesterday, and now we want to see how it's doing on the new data coming in today. The visualization will start by displaying AUC. From the UI you can:

  • Add other metrics using the "Add metric series" menu.
  • Close unwanted graphs by clicking on x
  • Hover over data points (the ends of line segments in the graph) to get more details
eval_results_from_disk = tfma.load_eval_results(output_paths[:2])

tfma.view.render_time_series(eval_results_from_disk)
TimeSeriesViewer(config={'isModelCentric': True}, data=[{'metrics': {'': {'': {'example_count': {'doubleValue'…

Now we'll imagine that another day has passed and we want to see how it's doing on the new data coming in today, compared to the previous two days:

eval_results_from_disk = tfma.load_eval_results(output_paths)

tfma.view.render_time_series(eval_results_from_disk)
TimeSeriesViewer(config={'isModelCentric': True}, data=[{'metrics': {'': {'': {'example_count': {'doubleValue'…

Model Validation

TFMA can be configured to evaluate multiple models at the same time. Typically this is done to compare a new model against a baseline (such as the currently serving model) to determine what the performance differences in metrics (e.g. AUC, etc) are relative to the baseline. When thresholds are configured, TFMA will produce a tfma.ValidationResult record indicating whether the performance matches expecations.

Let's re-configure our keras evaluation to compare two models: a candidate and a baseline. We will also validate the candidate's performance against the baseline by setting a tmfa.MetricThreshold on the AUC metric.

# Setup tfma.EvalConfig setting
eval_config_with_thresholds = text_format.Parse("""
  ## Model information
  model_specs {
    name: "candidate"
    # For keras we need to add a `label_key`.
    label_key: "big_tipper"
  }
  model_specs {
    name: "baseline"
    # For keras we need to add a `label_key`.
    label_key: "big_tipper"
    is_baseline: true
  }

  ## Post training metric information
  metrics_specs {
    metrics { class_name: "ExampleCount" }
    metrics { class_name: "BinaryAccuracy" }
    metrics { class_name: "BinaryCrossentropy" }
    metrics {
      class_name: "AUC"
      threshold {
        # Ensure that AUC is always > 0.9
        value_threshold {
          lower_bound { value: 0.9 }
        }
        # Ensure that AUC does not drop by more than a small epsilon
        # e.g. (candidate - baseline) > -1e-10 or candidate > baseline - 1e-10
        change_threshold {
          direction: HIGHER_IS_BETTER
          absolute { value: -1e-10 }
        }
      }
    }
    metrics { class_name: "AUCPrecisionRecall" }
    metrics { class_name: "Precision" }
    metrics { class_name: "Recall" }
    metrics { class_name: "MeanLabel" }
    metrics { class_name: "MeanPrediction" }
    metrics { class_name: "Calibration" }
    metrics { class_name: "CalibrationPlot" }
    metrics { class_name: "ConfusionMatrixPlot" }
    # ... add additional metrics and plots ...
  }

  ## Slicing information
  slicing_specs {}  # overall slice
  slicing_specs {
    feature_keys: ["trip_start_hour"]
  }
  slicing_specs {
    feature_keys: ["trip_start_day"]
  }
  slicing_specs {
    feature_keys: ["trip_start_month"]
  }
  slicing_specs {
    feature_keys: ["trip_start_hour", "trip_start_day"]
  }
""", tfma.EvalConfig())

# Create tfma.EvalSharedModels that point at our keras models.
candidate_model_path = os.path.join(MODELS_DIR, 'keras', '2')
baseline_model_path = os.path.join(MODELS_DIR, 'keras', '1')
eval_shared_models = [
  tfma.default_eval_shared_model(
      model_name=tfma.CANDIDATE_KEY,
      eval_saved_model_path=candidate_model_path,
      eval_config=eval_config_with_thresholds),
  tfma.default_eval_shared_model(
      model_name=tfma.BASELINE_KEY,
      eval_saved_model_path=baseline_model_path,
      eval_config=eval_config_with_thresholds),
]

validation_output_path = os.path.join(OUTPUT_DIR, 'validation')

# Run TFMA
eval_result_with_validation = tfma.run_model_analysis(
    eval_shared_models,
    eval_config=eval_config_with_thresholds,
    data_location=tfrecord_file,
    output_path=validation_output_path)
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:absl:Tensorflow version (2.12.1) found. Note that TFMA support for TF 2.0 is currently in beta
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/metrics/confusion_matrix_metrics.py:520: RuntimeWarning: invalid value encountered in true_divide
  prec_slope = dtp / np.maximum(dp, 0)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/metrics/confusion_matrix_metrics.py:524: RuntimeWarning: divide by zero encountered in true_divide
  p[:num_thresholds - 1] / np.maximum(p[1:], 0), np.ones_like(p[1:]))
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/metrics/confusion_matrix_metrics.py:524: RuntimeWarning: invalid value encountered in true_divide
  p[:num_thresholds - 1] / np.maximum(p[1:], 0), np.ones_like(p[1:]))
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/metrics/confusion_matrix_metrics.py:531: RuntimeWarning: invalid value encountered in true_divide
  recall = tp / (tp + fn)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/metrics/confusion_matrix_metrics.py:526: RuntimeWarning: invalid value encountered in true_divide
  prec_slope * (dtp + intercept * np.log(safe_p_ratio)) /

When running evaluations with one or more models against a baseline, TFMA automatically adds diff metrics for all of the metrics computed during the evaluation. These metrics are named after the corresponding metric but with _diff appended to the metric name.

Let's take a look at the metrics produced by our run:

tfma.view.render_time_series(eval_result_with_validation)
TimeSeriesViewer(config={'isModelCentric': True}, data=[{'metrics': {'': {'': {'binary_crossentropy': {'double…

Now let's look at the output from our validation checks. To view the validation results we use tfma.load_validator_result. For our example, the validation fails because AUC is below the threshold.

validation_result = tfma.load_validation_result(validation_output_path)
print(validation_result.validation_ok)
False