Better ML Engineering with ML Metadata

View on TensorFlow.org View source on GitHub Download notebook

Assume a scenario where you set up a production ML pipeline to classify pictures of iris flowers. The pipeline ingests your training data, trains and evaluates a model, and pushes it to production.

However, when you later try using this model with a larger dataset that contains images of different kinds of flowers, you observe that your model does not behave as expected and starts classifying roses and lilies as types of irises.

At this point, you are interested in knowing:

  • What is the most efficient way to debug the model when the only available artifact is the model in production?
  • Which training dataset was used to train the model?
  • Which training run led to this erroneous model?
  • Where are the model evaluation results?
  • Where to begin debugging?

ML Metadata (MLMD) is a library that leverages the metadata associated with ML models to help you answer these questions and more. A helpful analogy is to think of this metadata as the equivalent of logging in software development. MLMD enables you to reliably track the artifacts and lineage associated with the various components of your ML pipeline.

In this tutorial, you set up a TFX Pipeline to create a model that classifies Iris flowers into three species - Iris setosa, Iris virginica, and Iris versicolor based on the length and width measurements of their petals and sepals. You then use MLMD to track the lineage of pipeline components.

TFX Pipelines in Colab

Colab is a lightweight development environment which differs significantly from a production environment. In production, you may have various pipeline components like data ingestion, transformation, model training, run histories, etc. across multiple, distributed systems. For this tutorial, you should be aware that siginificant differences exist in Orchestration and Metadata storage - it is all handled locally within Colab. Learn more about TFX in Colab here.

Setup

Import all required libraries.

Install and import TFX

pip install --quiet tfx==0.23.0
WARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ProtocolError('Connection aborted.', ConnectionResetError(104, 'Connection reset by peer'))': /simple/google-resumable-media/
ERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.

We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default.

libcst 0.3.14 requires pyyaml>=5.2, but you'll have pyyaml 3.12 which is incompatible.
apache-beam 2.25.0 requires httplib2<0.18.0,>=0.8, but you'll have httplib2 0.18.1 which is incompatible.

You must restart the Colab runtime after installing TFX. Select Runtime > Restart runtime from the Colab menu.

Do not proceed with the rest of this tutorial without first restarting the runtime.

Import other libraries

import base64
import csv
import json
import os
import requests
import tempfile
import urllib
import pprint
import numpy as np
import pandas as pd

pp = pprint.PrettyPrinter()
import tensorflow as tf
import tfx

Import TFX component classes.

from tfx.components.evaluator.component import Evaluator
from tfx.components.example_gen.csv_example_gen.component import CsvExampleGen
from tfx.components.pusher.component import Pusher
from tfx.components.schema_gen.component import SchemaGen
from tfx.components.statistics_gen.component import StatisticsGen
from tfx.components.trainer.component import Trainer
from tfx.components.base import executor_spec
from tfx.components.trainer.executor import GenericExecutor
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.proto import evaluator_pb2
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.utils.dsl_utils import external_input

from tensorflow_metadata.proto.v0 import anomalies_pb2
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_metadata.proto.v0 import statistics_pb2

import tensorflow_model_analysis as tfma

from tfx.components import ResolverNode
from tfx.dsl.experimental import latest_blessed_model_resolver
from tfx.types import Channel
from tfx.types.standard_artifacts import Model
from tfx.types.standard_artifacts import ModelBlessing

Import the MLMD library.

import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2

Download the dataset

Download the Iris dataset dataset to use in this tutorial. The dataset contains data about the length and width measurements of sepals and petals for 150 Iris flowers. You use this data to classify irises into one of three species - Iris setosa, Iris virginica, and Iris versicolor.

DATA_PATH = 'https://raw.githubusercontent.com/tensorflow/tfx/v0.24.0/tfx/examples/iris/data/iris.csv'
_data_root = tempfile.mkdtemp(prefix='tfx-data')
_data_filepath = os.path.join(_data_root, "iris.csv")
urllib.request.urlretrieve(DATA_PATH, _data_filepath)
('/tmp/tfx-datakmvtalov/iris.csv', <http.client.HTTPMessage at 0x7fa643ded978>)

Create an InteractiveContext

To run TFX components interactively in this notebook, create an InteractiveContext. The InteractiveContext uses a temporary directory with an ephemeral MLMD database instance. Note that calls to InteractiveContext are no-ops outside the Colab environment.

In general, it is a good practice to group similar pipeline runs under a Context.

interactive_context = InteractiveContext()
WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmp/tfx-interactive-2020-11-23T10_16_06.265013-tu4qjfvx as root for pipeline outputs.
WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmp/tfx-interactive-2020-11-23T10_16_06.265013-tu4qjfvx/metadata.sqlite.

Construct the TFX Pipeline

A TFX pipeline consists of several components that perform different aspects of the ML workflow. In this notebook, you create and run the ExampleGen, StatisticsGen, SchemaGen, and TrainerGen components and use the Evaluator and Pusher component to evaluate and push the trained model.

Refer to the components tutorial for more information on TFX pipeline components.

Instantiate and run the ExampleGen Component

input_data = external_input(_data_root)
example_gen = CsvExampleGen(input=input_data)

# Run the ExampleGen component using the InteractiveContext
interactive_context.run(example_gen)
WARNING:tensorflow:From <ipython-input-1-3f24e2b7621a>:1: external_input (from tfx.utils.dsl_utils) is deprecated and will be removed in a future version.
Instructions for updating:
external_input is deprecated, directly pass the uri to ExampleGen.

Warning:absl:The "input" argument to the CsvExampleGen component has been deprecated by "input_base". Please update your usage as support for this argument will be removed soon.
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.

Warning:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.

Instantiate and run the StatisticsGen Component

statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

# Run the StatisticsGen component using the InteractiveContext
interactive_context.run(statistics_gen)

Instantiate and run the SchemaGen Component

infer_schema = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                         infer_feature_shape = True)

# Run the SchemaGen component using the InteractiveContext
interactive_context.run(infer_schema)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_data_validation/utils/stats_util.py:229: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Instantiate and run the Trainer Component

# Define the module file for the Trainer component
trainer_module_file = 'iris_trainer.py'
%%writefile {trainer_module_file}

# Define the training algorithm for the Trainer module file
import os
from typing import List, Text

import tensorflow as tf
from tensorflow import keras

from tfx.components.trainer.executor import TrainerFnArgs
from tfx.components.trainer.fn_args_utils import FnArgs

# The iris dataset has 150 records, and is split into training and evaluation 
# datasets in a 2:1 split

_TRAIN_DATA_SIZE = 100
_EVAL_DATA_SIZE = 50
_TRAIN_BATCH_SIZE = 100
_EVAL_BATCH_SIZE = 50

# Features used for classification - sepal length and width, petal length and
# width, and variety (species of flower)

_FEATURES = {
    'sepal_length': tf.io.FixedLenFeature([], dtype=tf.float32, default_value=0),
    'sepal_width': tf.io.FixedLenFeature([], dtype=tf.float32, default_value=0),
    'petal_length': tf.io.FixedLenFeature([], dtype=tf.float32, default_value=0),
    'petal_width': tf.io.FixedLenFeature([], dtype=tf.float32, default_value=0),
    'variety': tf.io.FixedLenFeature([], dtype=tf.int64, default_value=0)
}

_LABEL_KEY = 'variety'

_FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']

def _gzip_reader_fn(filenames):
  return tf.data.TFRecordDataset(filenames, compression_type='GZIP')

def _input_fn(file_pattern: List[Text],
              batch_size: int = 200):
  dataset = tf.data.experimental.make_batched_features_dataset(
            file_pattern=file_pattern,
            batch_size=batch_size,
            features=_FEATURES,
            reader=_gzip_reader_fn,
            label_key=_LABEL_KEY)

  return dataset

def _build_keras_model():
  inputs = [keras.layers.Input(shape = (1,), name = f) for f in _FEATURE_KEYS]
  d = keras.layers.concatenate(inputs)
  d = keras.layers.Dense(8, activation = 'relu')(d)
  d = keras.layers.Dense(8, activation = 'relu')(d)
  outputs = keras.layers.Dense(3, activation = 'softmax')(d)
  model = keras.Model(inputs=inputs, outputs=outputs)
  model.compile(optimizer = 'adam',
                loss = 'sparse_categorical_crossentropy',
                metrics= [keras.metrics.SparseCategoricalAccuracy()])
  return model

def run_fn(fn_args: TrainerFnArgs):
  train_dataset = _input_fn(fn_args.train_files, batch_size=_TRAIN_BATCH_SIZE)
  eval_dataset = _input_fn(fn_args.eval_files, batch_size=_EVAL_BATCH_SIZE)

  model = _build_keras_model()

  steps_per_epoch = _TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE

  model.fit(train_dataset, 
            epochs=int(fn_args.train_steps / steps_per_epoch),
            steps_per_epoch=steps_per_epoch,
            validation_data=eval_dataset,
            validation_steps=fn_args.eval_steps)
  model.save(fn_args.serving_model_dir, save_format='tf')
Writing iris_trainer.py

Run the Trainer component.

trainer = Trainer(
    module_file=os.path.abspath(trainer_module_file),
    custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=100),
    eval_args=trainer_pb2.EvalArgs(num_steps=50))

interactive_context.run(trainer)
Epoch 1/100
1/1 [==============================] - 0s 213ms/step - loss: 2.1738 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.9316 - val_sparse_categorical_accuracy: 0.3564
Epoch 2/100
1/1 [==============================] - 0s 107ms/step - loss: 2.1393 - sparse_categorical_accuracy: 0.3200 - val_loss: 1.9030 - val_sparse_categorical_accuracy: 0.3556
Epoch 3/100
1/1 [==============================] - 0s 107ms/step - loss: 2.1919 - sparse_categorical_accuracy: 0.3000 - val_loss: 1.8707 - val_sparse_categorical_accuracy: 0.3556
Epoch 4/100
1/1 [==============================] - 0s 103ms/step - loss: 1.9875 - sparse_categorical_accuracy: 0.3600 - val_loss: 1.8387 - val_sparse_categorical_accuracy: 0.3548
Epoch 5/100
1/1 [==============================] - 0s 107ms/step - loss: 2.1086 - sparse_categorical_accuracy: 0.2800 - val_loss: 1.8052 - val_sparse_categorical_accuracy: 0.3564
Epoch 6/100
1/1 [==============================] - 0s 107ms/step - loss: 1.9383 - sparse_categorical_accuracy: 0.3500 - val_loss: 1.7790 - val_sparse_categorical_accuracy: 0.3544
Epoch 7/100
1/1 [==============================] - 0s 106ms/step - loss: 2.0362 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.7467 - val_sparse_categorical_accuracy: 0.3560
Epoch 8/100
1/1 [==============================] - 0s 107ms/step - loss: 1.9771 - sparse_categorical_accuracy: 0.3000 - val_loss: 1.7210 - val_sparse_categorical_accuracy: 0.3548
Epoch 9/100
1/1 [==============================] - 0s 107ms/step - loss: 1.8635 - sparse_categorical_accuracy: 0.3500 - val_loss: 1.6906 - val_sparse_categorical_accuracy: 0.3552
Epoch 10/100
1/1 [==============================] - 0s 104ms/step - loss: 1.9302 - sparse_categorical_accuracy: 0.2600 - val_loss: 1.6573 - val_sparse_categorical_accuracy: 0.3560
Epoch 11/100
1/1 [==============================] - 0s 104ms/step - loss: 1.7571 - sparse_categorical_accuracy: 0.3800 - val_loss: 1.6326 - val_sparse_categorical_accuracy: 0.3552
Epoch 12/100
1/1 [==============================] - 0s 103ms/step - loss: 1.7127 - sparse_categorical_accuracy: 0.3600 - val_loss: 1.6035 - val_sparse_categorical_accuracy: 0.3552
Epoch 13/100
1/1 [==============================] - 0s 105ms/step - loss: 1.7665 - sparse_categorical_accuracy: 0.2900 - val_loss: 1.5761 - val_sparse_categorical_accuracy: 0.3552
Epoch 14/100
1/1 [==============================] - 0s 104ms/step - loss: 1.6161 - sparse_categorical_accuracy: 0.3700 - val_loss: 1.5508 - val_sparse_categorical_accuracy: 0.3556
Epoch 15/100
1/1 [==============================] - 0s 105ms/step - loss: 1.9868 - sparse_categorical_accuracy: 0.2500 - val_loss: 1.5230 - val_sparse_categorical_accuracy: 0.3564
Epoch 16/100
1/1 [==============================] - 0s 106ms/step - loss: 1.4698 - sparse_categorical_accuracy: 0.3500 - val_loss: 1.4969 - val_sparse_categorical_accuracy: 0.3568
Epoch 17/100
1/1 [==============================] - 0s 106ms/step - loss: 1.7664 - sparse_categorical_accuracy: 0.3000 - val_loss: 1.4762 - val_sparse_categorical_accuracy: 0.3548
Epoch 18/100
1/1 [==============================] - 0s 105ms/step - loss: 1.5376 - sparse_categorical_accuracy: 0.3600 - val_loss: 1.4480 - val_sparse_categorical_accuracy: 0.3564
Epoch 19/100
1/1 [==============================] - 0s 105ms/step - loss: 1.6296 - sparse_categorical_accuracy: 0.3000 - val_loss: 1.4250 - val_sparse_categorical_accuracy: 0.3556
Epoch 20/100
1/1 [==============================] - 0s 107ms/step - loss: 1.5085 - sparse_categorical_accuracy: 0.3400 - val_loss: 1.4034 - val_sparse_categorical_accuracy: 0.3548
Epoch 21/100
1/1 [==============================] - 0s 107ms/step - loss: 1.5454 - sparse_categorical_accuracy: 0.3200 - val_loss: 1.3796 - val_sparse_categorical_accuracy: 0.3556
Epoch 22/100
1/1 [==============================] - 0s 106ms/step - loss: 1.5426 - sparse_categorical_accuracy: 0.3100 - val_loss: 1.3577 - val_sparse_categorical_accuracy: 0.3556
Epoch 23/100
1/1 [==============================] - 0s 106ms/step - loss: 1.4540 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.3329 - val_sparse_categorical_accuracy: 0.3564
Epoch 24/100
1/1 [==============================] - 0s 104ms/step - loss: 1.4245 - sparse_categorical_accuracy: 0.3500 - val_loss: 1.3119 - val_sparse_categorical_accuracy: 0.3564
Epoch 25/100
1/1 [==============================] - 0s 105ms/step - loss: 1.4109 - sparse_categorical_accuracy: 0.2900 - val_loss: 1.2928 - val_sparse_categorical_accuracy: 0.3556
Epoch 26/100
1/1 [==============================] - 0s 105ms/step - loss: 1.4113 - sparse_categorical_accuracy: 0.3500 - val_loss: 1.2726 - val_sparse_categorical_accuracy: 0.3548
Epoch 27/100
1/1 [==============================] - 0s 107ms/step - loss: 1.3752 - sparse_categorical_accuracy: 0.3200 - val_loss: 1.2543 - val_sparse_categorical_accuracy: 0.3548
Epoch 28/100
1/1 [==============================] - 0s 104ms/step - loss: 1.3221 - sparse_categorical_accuracy: 0.3200 - val_loss: 1.2318 - val_sparse_categorical_accuracy: 0.3564
Epoch 29/100
1/1 [==============================] - 0s 103ms/step - loss: 1.4254 - sparse_categorical_accuracy: 0.2900 - val_loss: 1.2143 - val_sparse_categorical_accuracy: 0.3548
Epoch 30/100
1/1 [==============================] - 0s 105ms/step - loss: 1.2473 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1940 - val_sparse_categorical_accuracy: 0.3552
Epoch 31/100
1/1 [==============================] - 0s 102ms/step - loss: 1.2610 - sparse_categorical_accuracy: 0.3100 - val_loss: 1.1765 - val_sparse_categorical_accuracy: 0.3556
Epoch 32/100
1/1 [==============================] - 0s 104ms/step - loss: 1.3022 - sparse_categorical_accuracy: 0.3500 - val_loss: 1.1588 - val_sparse_categorical_accuracy: 0.3556
Epoch 33/100
1/1 [==============================] - 0s 103ms/step - loss: 1.1578 - sparse_categorical_accuracy: 0.3700 - val_loss: 1.1417 - val_sparse_categorical_accuracy: 0.3556
Epoch 34/100
1/1 [==============================] - 0s 104ms/step - loss: 1.2511 - sparse_categorical_accuracy: 0.2800 - val_loss: 1.1236 - val_sparse_categorical_accuracy: 0.3564
Epoch 35/100
1/1 [==============================] - 0s 106ms/step - loss: 1.1921 - sparse_categorical_accuracy: 0.3500 - val_loss: 1.1089 - val_sparse_categorical_accuracy: 0.3556
Epoch 36/100
1/1 [==============================] - 0s 107ms/step - loss: 1.2322 - sparse_categorical_accuracy: 0.2900 - val_loss: 1.0947 - val_sparse_categorical_accuracy: 0.3548
Epoch 37/100
1/1 [==============================] - 0s 108ms/step - loss: 1.1761 - sparse_categorical_accuracy: 0.3200 - val_loss: 1.0772 - val_sparse_categorical_accuracy: 0.3564
Epoch 38/100
1/1 [==============================] - 0s 110ms/step - loss: 1.0737 - sparse_categorical_accuracy: 0.3600 - val_loss: 1.0656 - val_sparse_categorical_accuracy: 0.3552
Epoch 39/100
1/1 [==============================] - 0s 106ms/step - loss: 1.1380 - sparse_categorical_accuracy: 0.3000 - val_loss: 1.0512 - val_sparse_categorical_accuracy: 0.3560
Epoch 40/100
1/1 [==============================] - 0s 104ms/step - loss: 1.1456 - sparse_categorical_accuracy: 0.3200 - val_loss: 1.0395 - val_sparse_categorical_accuracy: 0.3564
Epoch 41/100
1/1 [==============================] - 0s 108ms/step - loss: 1.0861 - sparse_categorical_accuracy: 0.3400 - val_loss: 1.0300 - val_sparse_categorical_accuracy: 0.3560
Epoch 42/100
1/1 [==============================] - 0s 104ms/step - loss: 1.0927 - sparse_categorical_accuracy: 0.3200 - val_loss: 1.0202 - val_sparse_categorical_accuracy: 0.3552
Epoch 43/100
1/1 [==============================] - 0s 108ms/step - loss: 1.0975 - sparse_categorical_accuracy: 0.3100 - val_loss: 1.0114 - val_sparse_categorical_accuracy: 0.3556
Epoch 44/100
1/1 [==============================] - 0s 110ms/step - loss: 1.0389 - sparse_categorical_accuracy: 0.3500 - val_loss: 1.0035 - val_sparse_categorical_accuracy: 0.3556
Epoch 45/100
1/1 [==============================] - 0s 108ms/step - loss: 1.0658 - sparse_categorical_accuracy: 0.3100 - val_loss: 0.9955 - val_sparse_categorical_accuracy: 0.3564
Epoch 46/100
1/1 [==============================] - 0s 105ms/step - loss: 1.0533 - sparse_categorical_accuracy: 0.3100 - val_loss: 0.9904 - val_sparse_categorical_accuracy: 0.3548
Epoch 47/100
1/1 [==============================] - 0s 109ms/step - loss: 1.0256 - sparse_categorical_accuracy: 0.3700 - val_loss: 0.9832 - val_sparse_categorical_accuracy: 0.3560
Epoch 48/100
1/1 [==============================] - 0s 106ms/step - loss: 1.0665 - sparse_categorical_accuracy: 0.2500 - val_loss: 0.9779 - val_sparse_categorical_accuracy: 0.3560
Epoch 49/100
1/1 [==============================] - 0s 101ms/step - loss: 1.0182 - sparse_categorical_accuracy: 0.3400 - val_loss: 0.9746 - val_sparse_categorical_accuracy: 0.3548
Epoch 50/100
1/1 [==============================] - 0s 103ms/step - loss: 1.0190 - sparse_categorical_accuracy: 0.3600 - val_loss: 0.9688 - val_sparse_categorical_accuracy: 0.3564
Epoch 51/100
1/1 [==============================] - 0s 106ms/step - loss: 1.0155 - sparse_categorical_accuracy: 0.2900 - val_loss: 0.9659 - val_sparse_categorical_accuracy: 0.3544
Epoch 52/100
1/1 [==============================] - 0s 105ms/step - loss: 0.9900 - sparse_categorical_accuracy: 0.3800 - val_loss: 0.9606 - val_sparse_categorical_accuracy: 0.3564
Epoch 53/100
1/1 [==============================] - 0s 105ms/step - loss: 1.0392 - sparse_categorical_accuracy: 0.2700 - val_loss: 0.9570 - val_sparse_categorical_accuracy: 0.3556
Epoch 54/100
1/1 [==============================] - 0s 106ms/step - loss: 0.9828 - sparse_categorical_accuracy: 0.3300 - val_loss: 0.9540 - val_sparse_categorical_accuracy: 0.3548
Epoch 55/100
1/1 [==============================] - 0s 105ms/step - loss: 0.9919 - sparse_categorical_accuracy: 0.3100 - val_loss: 0.9507 - val_sparse_categorical_accuracy: 0.3552
Epoch 56/100
1/1 [==============================] - 0s 104ms/step - loss: 0.9920 - sparse_categorical_accuracy: 0.3200 - val_loss: 0.9467 - val_sparse_categorical_accuracy: 0.3560
Epoch 57/100
1/1 [==============================] - 0s 105ms/step - loss: 0.9652 - sparse_categorical_accuracy: 0.3400 - val_loss: 0.9441 - val_sparse_categorical_accuracy: 0.3540
Epoch 58/100
1/1 [==============================] - 0s 105ms/step - loss: 0.9720 - sparse_categorical_accuracy: 0.3400 - val_loss: 0.9399 - val_sparse_categorical_accuracy: 0.3564
Epoch 59/100
1/1 [==============================] - 0s 104ms/step - loss: 0.9785 - sparse_categorical_accuracy: 0.3300 - val_loss: 0.9372 - val_sparse_categorical_accuracy: 0.3560
Epoch 60/100
1/1 [==============================] - 0s 104ms/step - loss: 0.9731 - sparse_categorical_accuracy: 0.3100 - val_loss: 0.9353 - val_sparse_categorical_accuracy: 0.3544
Epoch 61/100
1/1 [==============================] - 0s 105ms/step - loss: 0.9682 - sparse_categorical_accuracy: 0.3200 - val_loss: 0.9325 - val_sparse_categorical_accuracy: 0.3552
Epoch 62/100
1/1 [==============================] - 0s 105ms/step - loss: 0.9528 - sparse_categorical_accuracy: 0.3400 - val_loss: 0.9292 - val_sparse_categorical_accuracy: 0.3560
Epoch 63/100
1/1 [==============================] - 0s 105ms/step - loss: 0.9567 - sparse_categorical_accuracy: 0.3200 - val_loss: 0.9271 - val_sparse_categorical_accuracy: 0.3560
Epoch 64/100
1/1 [==============================] - 0s 102ms/step - loss: 0.9486 - sparse_categorical_accuracy: 0.3300 - val_loss: 0.9252 - val_sparse_categorical_accuracy: 0.3544
Epoch 65/100
1/1 [==============================] - 0s 102ms/step - loss: 0.9634 - sparse_categorical_accuracy: 0.3100 - val_loss: 0.9224 - val_sparse_categorical_accuracy: 0.3556
Epoch 66/100
1/1 [==============================] - 0s 105ms/step - loss: 0.9481 - sparse_categorical_accuracy: 0.3100 - val_loss: 0.9204 - val_sparse_categorical_accuracy: 0.3556
Epoch 67/100
1/1 [==============================] - 0s 101ms/step - loss: 0.9308 - sparse_categorical_accuracy: 0.3600 - val_loss: 0.9188 - val_sparse_categorical_accuracy: 0.3540
Epoch 68/100
1/1 [==============================] - 0s 101ms/step - loss: 0.9433 - sparse_categorical_accuracy: 0.3200 - val_loss: 0.9162 - val_sparse_categorical_accuracy: 0.3552
Epoch 69/100
1/1 [==============================] - 0s 106ms/step - loss: 0.9513 - sparse_categorical_accuracy: 0.2900 - val_loss: 0.9143 - val_sparse_categorical_accuracy: 0.3552
Epoch 70/100
1/1 [==============================] - 0s 104ms/step - loss: 0.9305 - sparse_categorical_accuracy: 0.3400 - val_loss: 0.9121 - val_sparse_categorical_accuracy: 0.3560
Epoch 71/100
1/1 [==============================] - 0s 103ms/step - loss: 0.9328 - sparse_categorical_accuracy: 0.3400 - val_loss: 0.9108 - val_sparse_categorical_accuracy: 0.3540
Epoch 72/100
1/1 [==============================] - 0s 104ms/step - loss: 0.9423 - sparse_categorical_accuracy: 0.2700 - val_loss: 0.9084 - val_sparse_categorical_accuracy: 0.3564
Epoch 73/100
1/1 [==============================] - 0s 104ms/step - loss: 0.9189 - sparse_categorical_accuracy: 0.3600 - val_loss: 0.9067 - val_sparse_categorical_accuracy: 0.3560
Epoch 74/100
1/1 [==============================] - 0s 104ms/step - loss: 0.9308 - sparse_categorical_accuracy: 0.3700 - val_loss: 0.9051 - val_sparse_categorical_accuracy: 0.3552
Epoch 75/100
1/1 [==============================] - 0s 105ms/step - loss: 0.9156 - sparse_categorical_accuracy: 0.2900 - val_loss: 0.9034 - val_sparse_categorical_accuracy: 0.3556
Epoch 76/100
1/1 [==============================] - 0s 103ms/step - loss: 0.9283 - sparse_categorical_accuracy: 0.3200 - val_loss: 0.9016 - val_sparse_categorical_accuracy: 0.3564
Epoch 77/100
1/1 [==============================] - 0s 101ms/step - loss: 0.9228 - sparse_categorical_accuracy: 0.3400 - val_loss: 0.9001 - val_sparse_categorical_accuracy: 0.3560
Epoch 78/100
1/1 [==============================] - 0s 106ms/step - loss: 0.9098 - sparse_categorical_accuracy: 0.3400 - val_loss: 0.8986 - val_sparse_categorical_accuracy: 0.3556
Epoch 79/100
1/1 [==============================] - 0s 104ms/step - loss: 0.9190 - sparse_categorical_accuracy: 0.3000 - val_loss: 0.8970 - val_sparse_categorical_accuracy: 0.3556
Epoch 80/100
1/1 [==============================] - 0s 104ms/step - loss: 0.9310 - sparse_categorical_accuracy: 0.2400 - val_loss: 0.8956 - val_sparse_categorical_accuracy: 0.3552
Epoch 81/100
1/1 [==============================] - 0s 102ms/step - loss: 0.8997 - sparse_categorical_accuracy: 0.3300 - val_loss: 0.8940 - val_sparse_categorical_accuracy: 0.3556
Epoch 82/100
1/1 [==============================] - 0s 104ms/step - loss: 0.9114 - sparse_categorical_accuracy: 0.2700 - val_loss: 0.8925 - val_sparse_categorical_accuracy: 0.3332
Epoch 83/100
1/1 [==============================] - 0s 106ms/step - loss: 0.9068 - sparse_categorical_accuracy: 0.2400 - val_loss: 0.8908 - val_sparse_categorical_accuracy: 0.2880
Epoch 84/100
1/1 [==============================] - 0s 104ms/step - loss: 0.9088 - sparse_categorical_accuracy: 0.2300 - val_loss: 0.8892 - val_sparse_categorical_accuracy: 0.2900
Epoch 85/100
1/1 [==============================] - 0s 106ms/step - loss: 0.9048 - sparse_categorical_accuracy: 0.2100 - val_loss: 0.8879 - val_sparse_categorical_accuracy: 0.2876
Epoch 86/100
1/1 [==============================] - 0s 108ms/step - loss: 0.9029 - sparse_categorical_accuracy: 0.2100 - val_loss: 0.8865 - val_sparse_categorical_accuracy: 0.2660
Epoch 87/100
1/1 [==============================] - 0s 105ms/step - loss: 0.8985 - sparse_categorical_accuracy: 0.2200 - val_loss: 0.8849 - val_sparse_categorical_accuracy: 0.3332
Epoch 88/100
1/1 [==============================] - 0s 105ms/step - loss: 0.8969 - sparse_categorical_accuracy: 0.2700 - val_loss: 0.8835 - val_sparse_categorical_accuracy: 0.4220
Epoch 89/100
1/1 [==============================] - 0s 106ms/step - loss: 0.9025 - sparse_categorical_accuracy: 0.3900 - val_loss: 0.8821 - val_sparse_categorical_accuracy: 0.5112
Epoch 90/100
1/1 [==============================] - 0s 104ms/step - loss: 0.8964 - sparse_categorical_accuracy: 0.4000 - val_loss: 0.8804 - val_sparse_categorical_accuracy: 0.5112
Epoch 91/100
1/1 [==============================] - 0s 104ms/step - loss: 0.9025 - sparse_categorical_accuracy: 0.3900 - val_loss: 0.8792 - val_sparse_categorical_accuracy: 0.5124
Epoch 92/100
1/1 [==============================] - 0s 103ms/step - loss: 0.8880 - sparse_categorical_accuracy: 0.4200 - val_loss: 0.8777 - val_sparse_categorical_accuracy: 0.4892
Epoch 93/100
1/1 [==============================] - 0s 105ms/step - loss: 0.8965 - sparse_categorical_accuracy: 0.3800 - val_loss: 0.8765 - val_sparse_categorical_accuracy: 0.4656
Epoch 94/100
1/1 [==============================] - 0s 104ms/step - loss: 0.8821 - sparse_categorical_accuracy: 0.4000 - val_loss: 0.8751 - val_sparse_categorical_accuracy: 0.4652
Epoch 95/100
1/1 [==============================] - 0s 105ms/step - loss: 0.8922 - sparse_categorical_accuracy: 0.3700 - val_loss: 0.8736 - val_sparse_categorical_accuracy: 0.4440
Epoch 96/100
1/1 [==============================] - 0s 104ms/step - loss: 0.8857 - sparse_categorical_accuracy: 0.4300 - val_loss: 0.8720 - val_sparse_categorical_accuracy: 0.4228
Epoch 97/100
1/1 [==============================] - 0s 106ms/step - loss: 0.8851 - sparse_categorical_accuracy: 0.4000 - val_loss: 0.8706 - val_sparse_categorical_accuracy: 0.4436
Epoch 98/100
1/1 [==============================] - 0s 106ms/step - loss: 0.8812 - sparse_categorical_accuracy: 0.4100 - val_loss: 0.8691 - val_sparse_categorical_accuracy: 0.4892
Epoch 99/100
1/1 [==============================] - 0s 103ms/step - loss: 0.8887 - sparse_categorical_accuracy: 0.4000 - val_loss: 0.8676 - val_sparse_categorical_accuracy: 0.4664
Epoch 100/100
1/1 [==============================] - 0s 101ms/step - loss: 0.8756 - sparse_categorical_accuracy: 0.4600 - val_loss: 0.8662 - val_sparse_categorical_accuracy: 0.4664
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2020-11-23T10_16_06.265013-tu4qjfvx/Trainer/model/4/serving_model_dir/assets

Evaluate and push the model

Use the Evaluator component to evaluate and 'bless' the model before using the Pusher component to push the model to a serving directory.

_serving_model_dir = os.path.join(tempfile.mkdtemp(), 'serving_model/iris_classification')
eval_config = tfma.EvalConfig(model_specs=[tfma.ModelSpec(label_key ='variety')],
                              metrics_specs =[tfma.MetricsSpec(metrics = 
                                                               [tfma.MetricConfig(class_name='ExampleCount'),
                                                               tfma.MetricConfig(class_name='BinaryAccuracy',
                                                                  threshold=tfma.MetricThreshold(
                                                                      value_threshold=tfma.GenericValueThreshold(
                                                                          lower_bound={'value': 0.5}),
                                                                      change_threshold=tfma.GenericChangeThreshold(
                                                                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                                                                          absolute={'value': -1e-10})))])],
                              slicing_specs = [tfma.SlicingSpec(),
                                               tfma.SlicingSpec(feature_keys=['sepal_length'])])
model_resolver = ResolverNode(
      instance_name='latest_blessed_model_resolver',
      resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
      model=Channel(type=Model),
      model_blessing=Channel(type=ModelBlessing))
interactive_context.run(model_resolver)

evaluator = Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'],
    baseline_model=model_resolver.outputs['model'],
    eval_config=eval_config)
interactive_context.run(evaluator)
pusher = Pusher(
    model=trainer.outputs['model'],
    model_blessing=evaluator.outputs['blessing'],
    push_destination=pusher_pb2.PushDestination(
        filesystem=pusher_pb2.PushDestination.Filesystem(
            base_directory=_serving_model_dir)))
interactive_context.run(pusher)

Running the TFX pipeline populates the MLMD Database. In the next section, you use the MLMD API to query this database for metadata information.

Query the MLMD Database

The MLMD database stores three types of metadata:

  • Metadata about the pipeline and lineage information associated with the pipeline components
  • Metadata about artifacts that were generated during the pipeline run
  • Metadata about the executions of the pipeline

A typical production environment pipeline serves multiple models as new data arrives. When you encounter erroneous results in served models, you can query the MLMD database to isolate the erroneous models. You can then trace the lineage of the pipeline components that correspond to these models to debug your models

Set up the metadata (MD) store with the InteractiveContext defined previously to query the MLMD database.

#md_store = mlmd.MetadataStore(interactive_context.metadata_connection_config)
store = mlmd.MetadataStore(interactive_context.metadata_connection_config)

# All TFX artifacts are stored in the base directory
base_dir = interactive_context.metadata_connection_config.sqlite.filename_uri.split('metadata.sqlite')[0]

Create some helper functions to view the data from the MD store.

def display_types(types):
  # Helper function to render dataframes for the artifact and execution types
  table = {'id': [], 'name': []}
  for a_type in types:
    table['id'].append(a_type.id)
    table['name'].append(a_type.name)
  return pd.DataFrame(data=table)
def display_artifacts(store, artifacts):
  # Helper function to render dataframes for the input artifacts
  table = {'artifact id': [], 'type': [], 'uri': []}
  for a in artifacts:
    table['artifact id'].append(a.id)
    artifact_type = store.get_artifact_types_by_id([a.type_id])[0]
    table['type'].append(artifact_type.name)
    table['uri'].append(a.uri.replace(base_dir, './'))
  return pd.DataFrame(data=table)
def display_properties(store, node):
  # Helper function to render dataframes for artifact and execution properties
  table = {'property': [], 'value': []}
  for k, v in node.properties.items():
    table['property'].append(k)
    table['value'].append(
        v.string_value if v.HasField('string_value') else v.int_value)
  for k, v in node.custom_properties.items():
    table['property'].append(k)
    table['value'].append(
        v.string_value if v.HasField('string_value') else v.int_value)
  return pd.DataFrame(data=table)

First, query the MD store for a list of all its stored ArtifactTypes.

display_types(store.get_artifact_types())

Next, query all PushedModel artifacts.

pushed_models = store.get_artifacts_by_type("PushedModel")
display_artifacts(store, pushed_models)

Query the MD store for the latest pushed model. This tutorial has only one pushed model.

pushed_model = pushed_models[-1]
display_properties(store, pushed_model)

One of the first steps in debugging a pushed model is to look at which trained model is pushed and to see which training data is used to train that model.

MLMD provides traversal APIs to walk through the provenance graph, which you can use to analyze the model provenance.

def get_one_hop_parent_artifacts(store, artifacts):
  # Get a list of artifacts within a 1-hop neighborhood of the artifacts of interest
  artifact_ids = [artifact.id for artifact in artifacts]
  executions_ids = set(
      event.execution_id
      for event in store.get_events_by_artifact_ids(artifact_ids)
      if event.type == metadata_store_pb2.Event.OUTPUT)
  artifacts_ids = set(
      event.artifact_id
      for event in store.get_events_by_execution_ids(executions_ids)
      if event.type == metadata_store_pb2.Event.INPUT) 
  return [artifact for artifact in store.get_artifacts_by_id(artifacts_ids)]

Query the parent artifacts for the pushed model.

parent_artifacts = get_one_hop_parent_artifacts(store, [pushed_model])
display_artifacts(store, parent_artifacts)

Query the properties for the model.

exported_model = parent_artifacts[0]
display_properties(store, exported_model)

Query the upstream artifacts for the model.

model_parents = get_one_hop_parent_artifacts(store, [exported_model])
display_artifacts(store, model_parents)

Get the training data the model trained with.

used_data = model_parents[0]
display_properties(store, used_data)

Now that you have the training data that the model trained with, query the database again to find the training step (execution). Query the MD store for a list of the registered execution types.

display_types(store.get_execution_types())

The training step is the ExecutionType named tfx.components.trainer.component.Trainer. Traverse the MD store to get the trainer run that corresponds to the pushed model.

def find_producer_execution(store, artifact):
  executions_ids = set(
    event.execution_id
    for event in store.get_events_by_artifact_ids([artifact.id])
      if event.type == metadata_store_pb2.Event.OUTPUT)  
  return store.get_executions_by_id(executions_ids)[0]

trainer = find_producer_execution(store, exported_model)
display_properties(store, trainer)

Summary

In this tutorial, you learned about how you can leverage MLMD to trace the lineage of your TFX pipeline components and resolve issues.

To learn more about how to use MLMD, check out these additional resources: