This tutorial describes graph regularization from the Neural Structured Learning framework and demonstrates an end-to-end workflow for sentiment classification in a TFX pipeline.
Overview
This notebook classifies movie reviews as positive or negative using the text of the review. This is an example of binary classification, an important and widely applicable kind of machine learning problem.
We will demonstrate the use of graph regularization in this notebook by building a graph from the given input. The general recipe for building a graph-regularized model using the Neural Structured Learning (NSL) framework when the input does not contain an explicit graph is as follows:
- Create embeddings for each text sample in the input. This can be done using pre-trained models such as word2vec, Swivel, BERT etc.
- Build a graph based on these embeddings by using a similarity metric such as the 'L2' distance, 'cosine' distance, etc. Nodes in the graph correspond to samples and edges in the graph correspond to similarity between pairs of samples.
- Generate training data from the above synthesized graph and sample features. The resulting training data will contain neighbor features in addition to the original node features.
- Create a neural network as a base model using Estimators.
- Wrap the base model with the
add_graph_regularization
wrapper function, which is provided by the NSL framework, to create a new graph Estimator model. This new model will include a graph regularization loss as the regularization term in its training objective. - Train and evaluate the graph Estimator model.
In this tutorial, we integrate the above workflow in a TFX pipeline using several custom TFX components as well as a custom graph-regularized trainer component.
Below is the schematic for our TFX pipeline. Orange boxes represent off-the-shelf TFX components and pink boxes represent custom TFX components.
Upgrade Pip
To avoid upgrading Pip in a system when running locally, check to make sure that we're running in Colab. Local systems can of course be upgraded separately.
try:
import colab
!pip install --upgrade pip
except:
pass
Install Required Packages
!pip install -q -U \
tfx==0.23.0 \
neural-structured-learning \
tensorflow-hub \
tensorflow-datasets
ERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts. We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default. tensorflow-metadata 0.24.0 requires absl-py<0.11,>=0.9, but you'll have absl-py 0.8.1 which is incompatible. apache-beam 2.24.0 requires dill<0.3.2,>=0.3.1.1, but you'll have dill 0.3.2 which is incompatible. google-api-python-client 1.12.3 requires httplib2<1dev,>=0.15.0, but you'll have httplib2 0.9.2 which is incompatible. tfx-bsl 0.23.0 requires tensorflow-metadata<0.24,>=0.23, but you'll have tensorflow-metadata 0.24.0 which is incompatible. tensorflow-transform 0.23.0 requires tensorflow-metadata<0.24,>=0.23, but you'll have tensorflow-metadata 0.24.0 which is incompatible. tensorflow-model-analysis 0.23.0 requires tensorflow-metadata<0.24,>=0.23, but you'll have tensorflow-metadata 0.24.0 which is incompatible. tensorflow-data-validation 0.23.1 requires tensorflow-metadata<0.24,>=0.23, but you'll have tensorflow-metadata 0.24.0 which is incompatible.
Did you restart the runtime?
If you are using Google Colab, the first time that you run the cell above, you must restart the runtime (Runtime > Restart runtime ...). This is because of the way that Colab loads packages.
Dependencies and imports
import apache_beam as beam
import gzip as gzip_lib
import numpy as np
import os
import pprint
import shutil
import tempfile
import urllib
import uuid
pp = pprint.PrettyPrinter()
import tensorflow as tf
import neural_structured_learning as nsl
import tfx
from tfx.components.evaluator.component import Evaluator
from tfx.components.example_gen.import_example_gen.component import ImportExampleGen
from tfx.components.example_validator.component import ExampleValidator
from tfx.components.model_validator.component import ModelValidator
from tfx.components.pusher.component import Pusher
from tfx.components.schema_gen.component import SchemaGen
from tfx.components.statistics_gen.component import StatisticsGen
from tfx.components.trainer.component import Trainer
from tfx.components.transform.component import Transform
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.proto import evaluator_pb2
from tfx.proto import example_gen_pb2
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.utils.dsl_utils import external_input
from tfx.types import artifact
from tfx.types import artifact_utils
from tfx.types import channel
from tfx.types import standard_artifacts
from tfx.types.standard_artifacts import Examples
from tfx.dsl.component.experimental.annotations import InputArtifact
from tfx.dsl.component.experimental.annotations import OutputArtifact
from tfx.dsl.component.experimental.annotations import Parameter
from tfx.dsl.component.experimental.decorators import component
from tensorflow_metadata.proto.v0 import anomalies_pb2
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_metadata.proto.v0 import statistics_pb2
import tensorflow_data_validation as tfdv
import tensorflow_transform as tft
import tensorflow_model_analysis as tfma
import tensorflow_hub as hub
import tensorflow_datasets as tfds
print("TF Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
"GPU is",
"available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
print("NSL Version: ", nsl.__version__)
print("TFX Version: ", tfx.__version__)
print("TFDV version: ", tfdv.__version__)
print("TFT version: ", tft.__version__)
print("TFMA version: ", tfma.__version__)
print("Hub version: ", hub.__version__)
print("Beam version: ", beam.__version__)
TF Version: 2.3.1 Eager mode: True GPU is available NSL Version: 1.3.1 TFX Version: 0.23.0 TFDV version: 0.23.1 TFT version: 0.23.0 TFMA version: 0.23.0 Hub version: 0.9.0 Beam version: 2.24.0
IMDB dataset
The IMDB dataset contains the text of 50,000 movie reviews from the Internet Movie Database. These are split into 25,000 reviews for training and 25,000 reviews for testing. The training and testing sets are balanced, meaning they contain an equal number of positive and negative reviews. Moreover, there are 50,000 additional unlabeled movie reviews.
Download preprocessed IMDB dataset
The following code downloads the IMDB dataset (or uses a cached copy if it has already been downloaded) using TFDS. To speed up this notebook we will use only 10,000 labeled reviews and 10,000 unlabeled reviews for training, and 10,000 test reviews for evaluation.
train_set, eval_set = tfds.load(
"imdb_reviews:1.0.0",
split=["train[:10000]+unsupervised[:10000]", "test[:10000]"],
shuffle_files=False)
Downloading and preparing dataset imdb_reviews/plain_text/1.0.0 (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /home/kbuilder/tensorflow_datasets/imdb_reviews/plain_text/1.0.0... Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteIPYLOW/imdb_reviews-train.tfrecord Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteIPYLOW/imdb_reviews-test.tfrecord Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteIPYLOW/imdb_reviews-unsupervised.tfrecord Warning:absl:Dataset is using deprecated text encoder API which will be removed soon. Please use the plain_text version of the dataset and migrate to `tensorflow_text`. Dataset imdb_reviews downloaded and prepared to /home/kbuilder/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
Let's look at a few reviews from the training set:
for tfrecord in train_set.take(4):
print("Review: {}".format(tfrecord["text"].numpy().decode("utf-8")[:300]))
print("Label: {}\n".format(tfrecord["label"].numpy()))
Review: This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda pi Label: 0 Review: I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However on this occasion I fell asleep because the film was rubbish. The plot development was constant. Cons Label: 0 Review: Mann photographs the Alberta Rocky Mountains in a superb fashion, and Jimmy Stewart and Walter Brennan give enjoyable performances as they always seem to do. <br /><br />But come on Hollywood - a Mountie telling the people of Dawson City, Yukon to elect themselves a marshal (yes a marshal!) and to e Label: 0 Review: This is the kind of film for a snowy Sunday afternoon when the rest of the world can go ahead with its own business as you descend into a big arm-chair and mellow for a couple of hours. Wonderful performances from Cher and Nicolas Cage (as always) gently row the plot along. There are no rapids to cr Label: 1
def _dict_to_example(instance):
"""Decoded CSV to tf example."""
feature = {}
for key, value in instance.items():
if value is None:
feature[key] = tf.train.Feature()
elif value.dtype == np.integer:
feature[key] = tf.train.Feature(
int64_list=tf.train.Int64List(value=value.tolist()))
elif value.dtype == np.float32:
feature[key] = tf.train.Feature(
float_list=tf.train.FloatList(value=value.tolist()))
else:
feature[key] = tf.train.Feature(
bytes_list=tf.train.BytesList(value=value.tolist()))
return tf.train.Example(features=tf.train.Features(feature=feature))
examples_path = tempfile.mkdtemp(prefix="tfx-data")
train_path = os.path.join(examples_path, "train.tfrecord")
eval_path = os.path.join(examples_path, "eval.tfrecord")
for path, dataset in [(train_path, train_set), (eval_path, eval_set)]:
with tf.io.TFRecordWriter(path) as writer:
for example in dataset:
writer.write(
_dict_to_example({
"label": np.array([example["label"].numpy()]),
"text": np.array([example["text"].numpy()]),
}).SerializeToString())
Run TFX Components Interactively
In the cells that follow you will construct TFX components and run each one interactively within the InteractiveContext to obtain ExecutionResult
objects. This mirrors the process of an orchestrator running components in a TFX DAG based on when the dependencies for each component are met.
context = InteractiveContext()
WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac as root for pipeline outputs. WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/metadata.sqlite.
The ExampleGen Component
In any ML development process the first step when starting code development is to ingest the training and test datasets. The ExampleGen
component brings data into the TFX pipeline.
Create an ExampleGen component and run it.
input_data = external_input(examples_path)
input_config = example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(name='train', pattern='train.tfrecord'),
example_gen_pb2.Input.Split(name='eval', pattern='eval.tfrecord')
])
example_gen = ImportExampleGen(input=input_data, input_config=input_config)
context.run(example_gen, enable_cache=True)
WARNING:tensorflow:From <ipython-input-1-6617f383c251>:1: external_input (from tfx.utils.dsl_utils) is deprecated and will be removed in a future version. Instructions for updating: external_input is deprecated, directly pass the uri to ExampleGen. Warning:absl:The "input" argument to the ImportExampleGen component has been deprecated by "input_base". Please update your usage as support for this argument will be removed soon. WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features. Warning:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
for artifact in example_gen.outputs['examples'].get():
print(artifact)
print('\nexample_gen.outputs is a {}'.format(type(example_gen.outputs)))
print(example_gen.outputs)
print(example_gen.outputs['examples'].get()[0].split_names)
Artifact(artifact: id: 1 type_id: 5 uri: "/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/ImportExampleGen/examples/1" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "input_fingerprint" value { string_value: "split:train,num_files:1,total_bytes:27706811,xor_checksum:1602753958,sum_checksum:1602753958\nsplit:eval,num_files:1,total_bytes:13374744,xor_checksum:1602753960,sum_checksum:1602753960" } } custom_properties { key: "name" value { string_value: "examples" } } custom_properties { key: "payload_format" value { string_value: "FORMAT_TF_EXAMPLE" } } custom_properties { key: "pipeline_name" value { string_value: "interactive-2020-10-15T09_26_00.686186" } } custom_properties { key: "producer_component" value { string_value: "ImportExampleGen" } } custom_properties { key: "span" value { string_value: "0" } } custom_properties { key: "state" value { string_value: "published" } } , artifact_type: id: 5 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } ) example_gen.outputs is a <class 'tfx.types.node_common._PropertyDictWrapper'> {'examples': Channel( type_name: Examples artifacts: [Artifact(artifact: id: 1 type_id: 5 uri: "/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/ImportExampleGen/examples/1" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "input_fingerprint" value { string_value: "split:train,num_files:1,total_bytes:27706811,xor_checksum:1602753958,sum_checksum:1602753958\nsplit:eval,num_files:1,total_bytes:13374744,xor_checksum:1602753960,sum_checksum:1602753960" } } custom_properties { key: "name" value { string_value: "examples" } } custom_properties { key: "payload_format" value { string_value: "FORMAT_TF_EXAMPLE" } } custom_properties { key: "pipeline_name" value { string_value: "interactive-2020-10-15T09_26_00.686186" } } custom_properties { key: "producer_component" value { string_value: "ImportExampleGen" } } custom_properties { key: "span" value { string_value: "0" } } custom_properties { key: "state" value { string_value: "published" } } , artifact_type: id: 5 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } )] )} ["train", "eval"]
The component's outputs include 2 artifacts:
- the training examples (10,000 labeled reviews + 10,000 unlabeled reviews)
- the eval examples (10,000 labeled reviews)
The IdentifyExamples Custom Component
To use NSL, we will need each instance to have a unique ID. We create a custom component that adds such a unique ID to all instances across all splits. We leverage Apache Beam to be able to easily scale to large datasets if needed.
def make_example_with_unique_id(example, id_feature_name):
"""Adds a unique ID to the given `tf.train.Example` proto.
This function uses Python's 'uuid' module to generate a universally unique
identifier for each example.
Args:
example: An instance of a `tf.train.Example` proto.
id_feature_name: The name of the feature in the resulting `tf.train.Example`
that will contain the unique identifier.
Returns:
A new `tf.train.Example` proto that includes a unique identifier as an
additional feature.
"""
result = tf.train.Example()
result.CopyFrom(example)
unique_id = uuid.uuid4()
result.features.feature.get_or_create(
id_feature_name).bytes_list.MergeFrom(
tf.train.BytesList(value=[str(unique_id).encode('utf-8')]))
return result
@component
def IdentifyExamples(orig_examples: InputArtifact[Examples],
identified_examples: OutputArtifact[Examples],
id_feature_name: Parameter[str],
component_name: Parameter[str]) -> None:
# Get a list of the splits in input_data
splits_list = artifact_utils.decode_split_names(
split_names=orig_examples.split_names)
for split in splits_list:
input_dir = os.path.join(orig_examples.uri, split)
output_dir = os.path.join(identified_examples.uri, split)
os.mkdir(output_dir)
with beam.Pipeline() as pipeline:
(pipeline
| 'ReadExamples' >> beam.io.ReadFromTFRecord(
os.path.join(input_dir, '*'),
coder=beam.coders.coders.ProtoCoder(tf.train.Example))
| 'AddUniqueId' >> beam.Map(make_example_with_unique_id, id_feature_name)
| 'WriteIdentifiedExamples' >> beam.io.WriteToTFRecord(
file_path_prefix=os.path.join(output_dir, 'data_tfrecord'),
coder=beam.coders.coders.ProtoCoder(tf.train.Example),
file_name_suffix='.gz'))
# For completeness, encode the splits names and payload_format.
# We could also just use input_data.split_names.
identified_examples.split_names = artifact_utils.encode_split_names(
splits=splits_list)
# TODO(b/168616829): Remove populating payload_format after tfx 0.25.0.
identified_examples.set_string_custom_property(
"payload_format",
orig_examples.get_string_custom_property("payload_format"))
return
identify_examples = IdentifyExamples(
orig_examples=example_gen.outputs['examples'],
component_name=u'IdentifyExamples',
id_feature_name=u'id')
context.run(identify_examples, enable_cache=False)
The StatisticsGen Component
The StatisticsGen
component computes descriptive statistics for your dataset. The statistics that it generates can be visualized for review, and are used for example validation and to infer a schema.
Create a StatisticsGen component and run it.
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(
examples=identify_examples.outputs["identified_examples"])
context.run(statistics_gen, enable_cache=True)
The SchemaGen Component
The SchemaGen
component generates a schema for your data based on the statistics from StatisticsGen. It tries to infer the data types of each of your features, and the ranges of legal values for categorical features.
Create a SchemaGen component and run it.
# Generates schema based on statistics files.
schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'])
context.run(schema_gen, enable_cache=True)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_data_validation/utils/stats_util.py:229: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version. Instructions for updating: Use eager execution and: `tf.data.TFRecordDataset(path)`
The generated artifact is just a schema.pbtxt
containing a text representation of a schema_pb2.Schema
protobuf:
train_uri = schema_gen.outputs['schema'].get()[0].uri
schema_filename = os.path.join(train_uri, 'schema.pbtxt')
schema = tfx.utils.io_utils.parse_pbtxt_file(
file_name=schema_filename, message=schema_pb2.Schema())
It can be visualized using tfdv.display_schema()
(we will look at this in more detail in a subsequent lab):
tfdv.display_schema(schema)
The ExampleValidator Component
The ExampleValidator
performs anomaly detection, based on the statistics from StatisticsGen and the schema from SchemaGen. It looks for problems such as missing values, values of the wrong type, or categorical values outside of the domain of acceptable values.
Create an ExampleValidator component and run it.
# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema'])
context.run(validate_stats, enable_cache=False)
The SynthesizeGraph Component
Graph construction involves creating embeddings for text samples and then using a similarity function to compare the embeddings.
We will use pretrained Swivel embeddings to create embeddings in the
tf.train.Example
format for each sample in the input. We will store the
resulting embeddings in the TFRecord
format along with the sample's ID.
This is important and will allow us match sample embeddings with corresponding
nodes in the graph later.
Once we have the sample embeddings, we will use them to build a similarity graph, i.e, nodes in this graph will correspond to samples and edges in this graph will correspond to similarity between pairs of nodes.
Neural Structured Learning provides a graph building library to build a graph based on sample embeddings. It uses cosine similarity as the similarity measure to compare embeddings and build edges between them. It also allows us to specify a similarity threshold, which can be used to discard dissimilar edges from the final graph. In the following example, using 0.99 as the similarity threshold, we end up with a graph that has 115,368 bi-directional edges.
swivel_url = 'https://hub.tensorflow.google.cn/google/tf2-preview/gnews-swivel-20dim/1'
hub_layer = hub.KerasLayer(swivel_url, input_shape=[], dtype=tf.string)
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def create_embedding_example(example):
"""Create tf.Example containing the sample's embedding and its ID."""
sentence_embedding = hub_layer(tf.sparse.to_dense(example['text']))
# Flatten the sentence embedding back to 1-D.
sentence_embedding = tf.reshape(sentence_embedding, shape=[-1])
feature_dict = {
'id': _bytes_feature(tf.sparse.to_dense(example['id']).numpy()),
'embedding': _float_feature(sentence_embedding.numpy().tolist())
}
return tf.train.Example(features=tf.train.Features(feature=feature_dict))
def create_dataset(uri):
tfrecord_filenames = [os.path.join(uri, name) for name in os.listdir(uri)]
return tf.data.TFRecordDataset(tfrecord_filenames, compression_type='GZIP')
def create_embeddings(train_path, output_path):
dataset = create_dataset(train_path)
embeddings_path = os.path.join(output_path, 'embeddings.tfr')
feature_map = {
'label': tf.io.FixedLenFeature([], tf.int64),
'id': tf.io.VarLenFeature(tf.string),
'text': tf.io.VarLenFeature(tf.string)
}
with tf.io.TFRecordWriter(embeddings_path) as writer:
for tfrecord in dataset:
tensor_dict = tf.io.parse_single_example(tfrecord, feature_map)
embedding_example = create_embedding_example(tensor_dict)
writer.write(embedding_example.SerializeToString())
def build_graph(output_path, similarity_threshold):
embeddings_path = os.path.join(output_path, 'embeddings.tfr')
graph_path = os.path.join(output_path, 'graph.tfv')
nsl.tools.build_graph([embeddings_path], graph_path, similarity_threshold)
"""Custom Artifact type"""
class SynthesizedGraph(tfx.types.artifact.Artifact):
"""Output artifact of the SynthesizeGraph component"""
TYPE_NAME = 'SynthesizedGraphPath'
PROPERTIES = {
'span': standard_artifacts.SPAN_PROPERTY,
'split_names': standard_artifacts.SPLIT_NAMES_PROPERTY,
}
@component
def SynthesizeGraph(identified_examples: InputArtifact[Examples],
synthesized_graph: OutputArtifact[SynthesizedGraph],
similarity_threshold: Parameter[float],
component_name: Parameter[str]) -> None:
# Get a list of the splits in input_data
splits_list = artifact_utils.decode_split_names(
split_names=identified_examples.split_names)
# We build a graph only based on the 'train' split which includes both
# labeled and unlabeled examples.
train_input_examples_uri = os.path.join(identified_examples.uri, 'train')
output_graph_uri = os.path.join(synthesized_graph.uri, 'train')
os.mkdir(output_graph_uri)
print('Creating embeddings...')
create_embeddings(train_input_examples_uri, output_graph_uri)
print('Synthesizing graph...')
build_graph(output_graph_uri, similarity_threshold)
synthesized_graph.split_names = artifact_utils.encode_split_names(
splits=['train'])
return
synthesize_graph = SynthesizeGraph(
identified_examples=identify_examples.outputs['identified_examples'],
component_name=u'SynthesizeGraph',
similarity_threshold=0.99)
context.run(synthesize_graph, enable_cache=False)
Creating embeddings... Synthesizing graph...
train_uri = synthesize_graph.outputs["synthesized_graph"].get()[0].uri
os.listdir(train_uri)
['train']
graph_path = os.path.join(train_uri, "train", "graph.tfv")
print("node 1\t\t\t\t\tnode 2\t\t\t\t\tsimilarity")
!head {graph_path}
print("...")
!tail {graph_path}
node 1 node 2 similarity c54d7b6d-5522-4c7f-80e8-63aefb40518d 48dc5b8a-2941-4de3-a92c-9a6829821632 0.991918 48dc5b8a-2941-4de3-a92c-9a6829821632 c54d7b6d-5522-4c7f-80e8-63aefb40518d 0.991918 4be77993-5b51-40fc-9ebd-ea4185243e0f 352566d1-7ecc-4299-8226-7ce88160661d 0.991171 352566d1-7ecc-4299-8226-7ce88160661d 4be77993-5b51-40fc-9ebd-ea4185243e0f 0.991171 4be77993-5b51-40fc-9ebd-ea4185243e0f f57a5e51-2960-493e-980d-395826c35ee0 0.992568 f57a5e51-2960-493e-980d-395826c35ee0 4be77993-5b51-40fc-9ebd-ea4185243e0f 0.992568 3630bfa5-2c97-47c4-acfd-bec08a96bc4a 00dc9419-28f2-4852-8ed1-604384254f8c 0.993089 00dc9419-28f2-4852-8ed1-604384254f8c 3630bfa5-2c97-47c4-acfd-bec08a96bc4a 0.993089 3630bfa5-2c97-47c4-acfd-bec08a96bc4a 21e41556-c9ad-4c5e-a580-e5f2772c1ba4 0.991987 21e41556-c9ad-4c5e-a580-e5f2772c1ba4 3630bfa5-2c97-47c4-acfd-bec08a96bc4a 0.991987 ... 2f63416d-12e9-40d1-970b-ab978a6d1e93 c4d07e3b-b991-42ba-9848-57a489080bab 0.993670 c4d07e3b-b991-42ba-9848-57a489080bab 2f63416d-12e9-40d1-970b-ab978a6d1e93 0.993670 829c875d-66ef-43d2-ab35-51fc7448b61d f8fb2876-afc4-4e4b-af80-2c432377191b 0.990820 f8fb2876-afc4-4e4b-af80-2c432377191b 829c875d-66ef-43d2-ab35-51fc7448b61d 0.990820 bf272722-d225-4640-9edd-ec7674fc5734 97b231c3-7952-4826-bb01-a2935991a4e7 0.991107 97b231c3-7952-4826-bb01-a2935991a4e7 bf272722-d225-4640-9edd-ec7674fc5734 0.991107 f8fb2876-afc4-4e4b-af80-2c432377191b 42c5d827-4c77-4519-b128-48543104770f 0.990005 42c5d827-4c77-4519-b128-48543104770f f8fb2876-afc4-4e4b-af80-2c432377191b 0.990005 9bbbebb5-eb68-42d5-a0a3-700a6ac33a78 e74d91e7-170a-46a9-abf6-d176ef810e54 0.993868 e74d91e7-170a-46a9-abf6-d176ef810e54 9bbbebb5-eb68-42d5-a0a3-700a6ac33a78 0.993868
wc -l {graph_path}
230736 /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/SynthesizeGraph/synthesized_graph/6/train/graph.tfv
The Transform Component
The Transform
component performs data transformations and feature engineering. The results include an input TensorFlow graph which is used during both training and serving to preprocess the data before training or inference. This graph becomes part of the SavedModel that is the result of model training. Since the same input graph is used for both training and serving, the preprocessing will always be the same, and only needs to be written once.
The Transform component requires more code than many other components because of the arbitrary complexity of the feature engineering that you may need for the data and/or model that you're working with. It requires code files to be available which define the processing needed.
Each sample will include the following three features:
- id: The node ID of the sample.
- text_xf: An int64 list containing word IDs.
- label_xf: A singleton int64 identifying the target class of the review: 0=negative, 1=positive.
Let's define a module containing the preprocessing_fn()
function that we will pass to the Transform
component:
_transform_module_file = 'imdb_transform.py'
%%writefile {_transform_module_file}
import tensorflow as tf
import tensorflow_transform as tft
SEQUENCE_LENGTH = 100
VOCAB_SIZE = 10000
OOV_SIZE = 100
def tokenize_reviews(reviews, sequence_length=SEQUENCE_LENGTH):
reviews = tf.strings.lower(reviews)
reviews = tf.strings.regex_replace(reviews, r" '| '|^'|'$", " ")
reviews = tf.strings.regex_replace(reviews, "[^a-z' ]", " ")
tokens = tf.strings.split(reviews)[:, :sequence_length]
start_tokens = tf.fill([tf.shape(reviews)[0], 1], "<START>")
end_tokens = tf.fill([tf.shape(reviews)[0], 1], "<END>")
tokens = tf.concat([start_tokens, tokens, end_tokens], axis=1)
tokens = tokens[:, :sequence_length]
tokens = tokens.to_tensor(default_value="<PAD>")
pad = sequence_length - tf.shape(tokens)[1]
tokens = tf.pad(tokens, [[0, 0], [0, pad]], constant_values="<PAD>")
return tf.reshape(tokens, [-1, sequence_length])
def preprocessing_fn(inputs):
"""tf.transform's callback function for preprocessing inputs.
Args:
inputs: map from feature keys to raw not-yet-transformed features.
Returns:
Map from string feature key to transformed feature operations.
"""
outputs = {}
outputs["id"] = inputs["id"]
tokens = tokenize_reviews(_fill_in_missing(inputs["text"], ''))
outputs["text_xf"] = tft.compute_and_apply_vocabulary(
tokens,
top_k=VOCAB_SIZE,
num_oov_buckets=OOV_SIZE)
outputs["label_xf"] = _fill_in_missing(inputs["label"], -1)
return outputs
def _fill_in_missing(x, default_value):
"""Replace missing values in a SparseTensor.
Fills in missing values of `x` with the default_value.
Args:
x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1
in the second dimension.
default_value: the value with which to replace the missing values.
Returns:
A rank 1 tensor where missing values of `x` have been filled in.
"""
return tf.squeeze(
tf.sparse.to_dense(
tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
default_value),
axis=1)
Writing imdb_transform.py
Create and run the Transform
component, referring to the files that were created above.
# Performs transformations and feature engineering in training and serving.
transform = Transform(
examples=identify_examples.outputs['identified_examples'],
schema=schema_gen.outputs['schema'],
# TODO(b/169218106): Remove transformed_examples kwargs after bugfix is released.
transformed_examples=channel.Channel(
type=standard_artifacts.Examples,
artifacts=[standard_artifacts.Examples()]),
module_file=_transform_module_file)
context.run(transform)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tfx/components/transform/executor.py:485: Schema (from tensorflow_transform.tf_metadata.dataset_schema) is deprecated and will be removed in a future version. Instructions for updating: Schema is a deprecated, use schema_utils.schema_from_feature_spec to create a `Schema` WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_transform/tf_utils.py:218: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Use ref() instead. Warning:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: Tuple[Dict[str, Union[NoneType, _Dataset]], Union[Dict[str, Dict[str, PCollection]], NoneType]] instead. WARNING:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: Tuple[Dict[str, Union[NoneType, _Dataset]], Union[Dict[str, Dict[str, PCollection]], NoneType]] instead. Warning:tensorflow:Tensorflow version (2.3.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:201: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. WARNING:tensorflow:Issue encountered when serializing tft_mapper_use. Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore. 'Counter' object has no attribute 'name' INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transform_graph/7/.temp_path/tftransform_tmp/11b6d4f9f3844b359227a3c768c5608d/saved_model.pb INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. WARNING:tensorflow:Issue encountered when serializing tft_mapper_use. Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore. 'Counter' object has no attribute 'name' INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transform_graph/7/.temp_path/tftransform_tmp/8f11b64eb9504bd2bd71067216fee1db/saved_model.pb WARNING:tensorflow:Tensorflow version (2.3.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. Warning:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'> Warning:tensorflow:Tensorflow version (2.3.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. Warning:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'> INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Assets added to graph. INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transform_graph/7/.temp_path/tftransform_tmp/3e8a5a5dc9af40df94c4c20167ed200f/assets INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transform_graph/7/.temp_path/tftransform_tmp/3e8a5a5dc9af40df94c4c20167ed200f/saved_model.pb WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_1:0\022-vocab_compute_and_apply_vocabulary_vocabulary" INFO:tensorflow:Saver not created because there are no variables in the graph to restore WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_1:0\022-vocab_compute_and_apply_vocabulary_vocabulary" INFO:tensorflow:Saver not created because there are no variables in the graph to restore WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_1:0\022-vocab_compute_and_apply_vocabulary_vocabulary" INFO:tensorflow:Saver not created because there are no variables in the graph to restore
The Transform
component has 2 types of outputs:
transform_graph
is the graph that can perform the preprocessing operations (this graph will be included in the serving and evaluation models).transformed_examples
represents the preprocessed training and evaluation data.
transform.outputs
{'transform_graph': Channel( type_name: TransformGraph artifacts: [Artifact(artifact: id: 7 type_id: 16 uri: "/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transform_graph/7" custom_properties { key: "name" value { string_value: "transform_graph" } } custom_properties { key: "pipeline_name" value { string_value: "interactive-2020-10-15T09_26_00.686186" } } custom_properties { key: "producer_component" value { string_value: "Transform" } } custom_properties { key: "state" value { string_value: "published" } } , artifact_type: id: 16 name: "TransformGraph" )] ), 'transformed_examples': Channel( type_name: Examples artifacts: [Artifact(artifact: id: 8 type_id: 5 uri: "/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transformed_examples/7" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "name" value { string_value: "transformed_examples" } } custom_properties { key: "pipeline_name" value { string_value: "interactive-2020-10-15T09_26_00.686186" } } custom_properties { key: "producer_component" value { string_value: "Transform" } } custom_properties { key: "state" value { string_value: "published" } } , artifact_type: id: 5 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } )] )}
Take a peek at the transform_graph
artifact: it points to a directory containing 3 subdirectories:
train_uri = transform.outputs['transform_graph'].get()[0].uri
os.listdir(train_uri)
['transform_fn', 'transformed_metadata', 'metadata']
The transform_fn
subdirectory contains the actual preprocessing graph. The metadata
subdirectory contains the schema of the original data. The transformed_metadata
subdirectory contains the schema of the preprocessed data.
Take a look at some of the transformed examples and check that they are indeed processed as intended.
def pprint_examples(artifact, n_examples=3):
print("artifact:", artifact)
uri = os.path.join(artifact.uri, "train")
print("uri:", uri)
tfrecord_filenames = [os.path.join(uri, name) for name in os.listdir(uri)]
print("tfrecord_filenames:", tfrecord_filenames)
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
for tfrecord in dataset.take(n_examples):
serialized_example = tfrecord.numpy()
example = tf.train.Example.FromString(serialized_example)
pp.pprint(example)
pprint_examples(transform.outputs['transformed_examples'].get()[0])
artifact: Artifact(artifact: id: 8 type_id: 5 uri: "/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transformed_examples/7" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "name" value { string_value: "transformed_examples" } } custom_properties { key: "pipeline_name" value { string_value: "interactive-2020-10-15T09_26_00.686186" } } custom_properties { key: "producer_component" value { string_value: "Transform" } } custom_properties { key: "state" value { string_value: "published" } } , artifact_type: id: 5 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } ) uri: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transformed_examples/7/train tfrecord_filenames: ['/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Transform/transformed_examples/7/train/transformed_examples-00000-of-00001.gz'] features { feature { key: "id" value { bytes_list { value: "08903146-1233-49d7-ac8e-ac126c0a8b14" } } } feature { key: "label_xf" value { int64_list { value: 0 } } } feature { key: "text_xf" value { int64_list { value: 13 value: 8 value: 14 value: 32 value: 338 value: 310 value: 15 value: 95 value: 27 value: 10001 value: 9 value: 31 value: 1173 value: 3153 value: 43 value: 495 value: 10060 value: 214 value: 26 value: 71 value: 142 value: 19 value: 8 value: 204 value: 339 value: 27 value: 74 value: 181 value: 238 value: 9 value: 440 value: 67 value: 74 value: 71 value: 94 value: 100 value: 22 value: 5442 value: 8 value: 1573 value: 607 value: 530 value: 8 value: 15 value: 6 value: 32 value: 378 value: 6292 value: 207 value: 2276 value: 388 value: 0 value: 84 value: 1023 value: 154 value: 65 value: 155 value: 52 value: 0 value: 10080 value: 7871 value: 65 value: 250 value: 74 value: 3202 value: 20 value: 10000 value: 3720 value: 10020 value: 10008 value: 1282 value: 3862 value: 3 value: 53 value: 3952 value: 110 value: 1879 value: 17 value: 3153 value: 14 value: 166 value: 19 value: 2 value: 1023 value: 1007 value: 9405 value: 9 value: 2 value: 15 value: 12 value: 14 value: 4504 value: 4 value: 109 value: 158 value: 1202 value: 7 value: 174 value: 505 value: 12 } } } } features { feature { key: "id" value { bytes_list { value: "71e3f765-3bfd-4754-92fb-c258c43f78dc" } } } feature { key: "label_xf" value { int64_list { value: 0 } } } feature { key: "text_xf" value { int64_list { value: 13 value: 7 value: 23 value: 75 value: 494 value: 5 value: 748 value: 2155 value: 307 value: 91 value: 19 value: 8 value: 6 value: 499 value: 763 value: 5 value: 2 value: 1690 value: 4 value: 200 value: 593 value: 57 value: 1244 value: 120 value: 2364 value: 3 value: 4407 value: 21 value: 0 value: 10081 value: 3 value: 263 value: 42 value: 6947 value: 2 value: 169 value: 185 value: 21 value: 8 value: 5143 value: 7 value: 1339 value: 2155 value: 81 value: 0 value: 18 value: 14 value: 1468 value: 0 value: 86 value: 986 value: 14 value: 2259 value: 1790 value: 562 value: 3 value: 284 value: 200 value: 401 value: 5 value: 668 value: 19 value: 17 value: 58 value: 1934 value: 4 value: 45 value: 14 value: 4212 value: 113 value: 43 value: 135 value: 7 value: 753 value: 7 value: 224 value: 23 value: 1155 value: 179 value: 4 value: 0 value: 18 value: 19 value: 7 value: 191 value: 0 value: 2047 value: 4 value: 10 value: 3 value: 283 value: 42 value: 401 value: 5 value: 668 value: 4 value: 90 value: 234 value: 10023 value: 227 } } } } features { feature { key: "id" value { bytes_list { value: "eaad5638-befe-4556-8ef8-1b5061aaab34" } } } feature { key: "label_xf" value { int64_list { value: 0 } } } feature { key: "text_xf" value { int64_list { value: 13 value: 4577 value: 7158 value: 0 value: 10047 value: 3778 value: 3346 value: 9 value: 2 value: 758 value: 1915 value: 3 value: 2280 value: 1511 value: 3 value: 2003 value: 10020 value: 225 value: 786 value: 382 value: 16 value: 39 value: 203 value: 361 value: 5 value: 93 value: 11 value: 11 value: 19 value: 220 value: 21 value: 341 value: 2 value: 10000 value: 966 value: 0 value: 77 value: 4 value: 6677 value: 464 value: 10071 value: 5 value: 10042 value: 630 value: 2 value: 10044 value: 404 value: 2 value: 10044 value: 3 value: 5 value: 10008 value: 0 value: 1259 value: 630 value: 106 value: 10042 value: 6721 value: 10 value: 49 value: 21 value: 0 value: 2071 value: 20 value: 1292 value: 4 value: 0 value: 431 value: 11 value: 11 value: 166 value: 67 value: 2342 value: 5815 value: 12 value: 575 value: 21 value: 0 value: 1691 value: 537 value: 4 value: 0 value: 3605 value: 307 value: 0 value: 10054 value: 1563 value: 3115 value: 467 value: 4577 value: 3 value: 1069 value: 1158 value: 5 value: 23 value: 4279 value: 6677 value: 464 value: 20 value: 10004 } } } }
The GraphAugmentation Component
Since we have the sample features and the synthesized graph, we can generate the augmented training data for Neural Structured Learning. The NSL framework provides a library to combine the graph and the sample features to produce the final training data for graph regularization. The resulting training data will include original sample features as well as features of their corresponding neighbors.
In this tutorial, we consider undirected edges and use a maximum of 3 neighbors per sample to augment training data with graph neighbors.
def split_train_and_unsup(input_uri):
'Separate the labeled and unlabeled instances.'
tmp_dir = tempfile.mkdtemp(prefix='tfx-data')
tfrecord_filenames = [
os.path.join(input_uri, filename) for filename in os.listdir(input_uri)
]
train_path = os.path.join(tmp_dir, 'train.tfrecord')
unsup_path = os.path.join(tmp_dir, 'unsup.tfrecord')
with tf.io.TFRecordWriter(train_path) as train_writer, \
tf.io.TFRecordWriter(unsup_path) as unsup_writer:
for tfrecord in tf.data.TFRecordDataset(
tfrecord_filenames, compression_type='GZIP'):
example = tf.train.Example()
example.ParseFromString(tfrecord.numpy())
if ('label_xf' not in example.features.feature or
example.features.feature['label_xf'].int64_list.value[0] == -1):
writer = unsup_writer
else:
writer = train_writer
writer.write(tfrecord.numpy())
return train_path, unsup_path
def gzip(filepath):
with open(filepath, 'rb') as f_in:
with gzip_lib.open(filepath + '.gz', 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(filepath)
def copy_tfrecords(input_uri, output_uri):
for filename in os.listdir(input_uri):
input_filename = os.path.join(input_uri, filename)
output_filename = os.path.join(output_uri, filename)
shutil.copyfile(input_filename, output_filename)
@component
def GraphAugmentation(identified_examples: InputArtifact[Examples],
synthesized_graph: InputArtifact[SynthesizedGraph],
augmented_examples: OutputArtifact[Examples],
num_neighbors: Parameter[int],
component_name: Parameter[str]) -> None:
# Get a list of the splits in input_data
splits_list = artifact_utils.decode_split_names(
split_names=identified_examples.split_names)
train_input_uri = os.path.join(identified_examples.uri, 'train')
eval_input_uri = os.path.join(identified_examples.uri, 'eval')
train_graph_uri = os.path.join(synthesized_graph.uri, 'train')
train_output_uri = os.path.join(augmented_examples.uri, 'train')
eval_output_uri = os.path.join(augmented_examples.uri, 'eval')
os.mkdir(train_output_uri)
os.mkdir(eval_output_uri)
# Separate out the labeled and unlabeled examples from the 'train' split.
train_path, unsup_path = split_train_and_unsup(train_input_uri)
output_path = os.path.join(train_output_uri, 'nsl_train_data.tfr')
pack_nbrs_args = dict(
labeled_examples_path=train_path,
unlabeled_examples_path=unsup_path,
graph_path=os.path.join(train_graph_uri, 'graph.tfv'),
output_training_data_path=output_path,
add_undirected_edges=True,
max_nbrs=num_neighbors)
print('nsl.tools.pack_nbrs arguments:', pack_nbrs_args)
nsl.tools.pack_nbrs(**pack_nbrs_args)
# Downstream components expect gzip'ed TFRecords.
gzip(output_path)
# The test examples are left untouched and are simply copied over.
copy_tfrecords(eval_input_uri, eval_output_uri)
augmented_examples.split_names = identified_examples.split_names
return
# Augments training data with graph neighbors.
graph_augmentation = GraphAugmentation(
identified_examples=transform.outputs['transformed_examples'],
synthesized_graph=synthesize_graph.outputs['synthesized_graph'],
component_name=u'GraphAugmentation',
num_neighbors=3)
context.run(graph_augmentation, enable_cache=False)
nsl.tools.pack_nbrs arguments: {'labeled_examples_path': '/tmp/tfx-datajre7hdjd/train.tfrecord', 'unlabeled_examples_path': '/tmp/tfx-datajre7hdjd/unsup.tfrecord', 'graph_path': '/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/SynthesizeGraph/synthesized_graph/6/train/graph.tfv', 'output_training_data_path': '/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/GraphAugmentation/augmented_examples/8/train/nsl_train_data.tfr', 'add_undirected_edges': True, 'max_nbrs': 3}
pprint_examples(graph_augmentation.outputs['augmented_examples'].get()[0], 6)
artifact: Artifact(artifact: id: 9 type_id: 5 uri: "/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/GraphAugmentation/augmented_examples/8" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "name" value { string_value: "augmented_examples" } } custom_properties { key: "pipeline_name" value { string_value: "interactive-2020-10-15T09_26_00.686186" } } custom_properties { key: "producer_component" value { string_value: "GraphAugmentation" } } custom_properties { key: "state" value { string_value: "published" } } , artifact_type: id: 5 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } ) uri: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/GraphAugmentation/augmented_examples/8/train tfrecord_filenames: ['/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/GraphAugmentation/augmented_examples/8/train/nsl_train_data.tfr.gz'] features { feature { key: "NL_num_nbrs" value { int64_list { value: 0 } } } feature { key: "id" value { bytes_list { value: "08903146-1233-49d7-ac8e-ac126c0a8b14" } } } feature { key: "label_xf" value { int64_list { value: 0 } } } feature { key: "text_xf" value { int64_list { value: 13 value: 8 value: 14 value: 32 value: 338 value: 310 value: 15 value: 95 value: 27 value: 10001 value: 9 value: 31 value: 1173 value: 3153 value: 43 value: 495 value: 10060 value: 214 value: 26 value: 71 value: 142 value: 19 value: 8 value: 204 value: 339 value: 27 value: 74 value: 181 value: 238 value: 9 value: 440 value: 67 value: 74 value: 71 value: 94 value: 100 value: 22 value: 5442 value: 8 value: 1573 value: 607 value: 530 value: 8 value: 15 value: 6 value: 32 value: 378 value: 6292 value: 207 value: 2276 value: 388 value: 0 value: 84 value: 1023 value: 154 value: 65 value: 155 value: 52 value: 0 value: 10080 value: 7871 value: 65 value: 250 value: 74 value: 3202 value: 20 value: 10000 value: 3720 value: 10020 value: 10008 value: 1282 value: 3862 value: 3 value: 53 value: 3952 value: 110 value: 1879 value: 17 value: 3153 value: 14 value: 166 value: 19 value: 2 value: 1023 value: 1007 value: 9405 value: 9 value: 2 value: 15 value: 12 value: 14 value: 4504 value: 4 value: 109 value: 158 value: 1202 value: 7 value: 174 value: 505 value: 12 } } } } features { feature { key: "NL_num_nbrs" value { int64_list { value: 0 } } } feature { key: "id" value { bytes_list { value: "71e3f765-3bfd-4754-92fb-c258c43f78dc" } } } feature { key: "label_xf" value { int64_list { value: 0 } } } feature { key: "text_xf" value { int64_list { value: 13 value: 7 value: 23 value: 75 value: 494 value: 5 value: 748 value: 2155 value: 307 value: 91 value: 19 value: 8 value: 6 value: 499 value: 763 value: 5 value: 2 value: 1690 value: 4 value: 200 value: 593 value: 57 value: 1244 value: 120 value: 2364 value: 3 value: 4407 value: 21 value: 0 value: 10081 value: 3 value: 263 value: 42 value: 6947 value: 2 value: 169 value: 185 value: 21 value: 8 value: 5143 value: 7 value: 1339 value: 2155 value: 81 value: 0 value: 18 value: 14 value: 1468 value: 0 value: 86 value: 986 value: 14 value: 2259 value: 1790 value: 562 value: 3 value: 284 value: 200 value: 401 value: 5 value: 668 value: 19 value: 17 value: 58 value: 1934 value: 4 value: 45 value: 14 value: 4212 value: 113 value: 43 value: 135 value: 7 value: 753 value: 7 value: 224 value: 23 value: 1155 value: 179 value: 4 value: 0 value: 18 value: 19 value: 7 value: 191 value: 0 value: 2047 value: 4 value: 10 value: 3 value: 283 value: 42 value: 401 value: 5 value: 668 value: 4 value: 90 value: 234 value: 10023 value: 227 } } } } features { feature { key: "NL_num_nbrs" value { int64_list { value: 0 } } } feature { key: "id" value { bytes_list { value: "eaad5638-befe-4556-8ef8-1b5061aaab34" } } } feature { key: "label_xf" value { int64_list { value: 0 } } } feature { key: "text_xf" value { int64_list { value: 13 value: 4577 value: 7158 value: 0 value: 10047 value: 3778 value: 3346 value: 9 value: 2 value: 758 value: 1915 value: 3 value: 2280 value: 1511 value: 3 value: 2003 value: 10020 value: 225 value: 786 value: 382 value: 16 value: 39 value: 203 value: 361 value: 5 value: 93 value: 11 value: 11 value: 19 value: 220 value: 21 value: 341 value: 2 value: 10000 value: 966 value: 0 value: 77 value: 4 value: 6677 value: 464 value: 10071 value: 5 value: 10042 value: 630 value: 2 value: 10044 value: 404 value: 2 value: 10044 value: 3 value: 5 value: 10008 value: 0 value: 1259 value: 630 value: 106 value: 10042 value: 6721 value: 10 value: 49 value: 21 value: 0 value: 2071 value: 20 value: 1292 value: 4 value: 0 value: 431 value: 11 value: 11 value: 166 value: 67 value: 2342 value: 5815 value: 12 value: 575 value: 21 value: 0 value: 1691 value: 537 value: 4 value: 0 value: 3605 value: 307 value: 0 value: 10054 value: 1563 value: 3115 value: 467 value: 4577 value: 3 value: 1069 value: 1158 value: 5 value: 23 value: 4279 value: 6677 value: 464 value: 20 value: 10004 } } } } features { feature { key: "NL_num_nbrs" value { int64_list { value: 0 } } } feature { key: "id" value { bytes_list { value: "11ff10a2-1ea4-4b10-ba91-2ba633b8abd4" } } } feature { key: "label_xf" value { int64_list { value: 1 } } } feature { key: "text_xf" value { int64_list { value: 13 value: 8 value: 6 value: 0 value: 251 value: 4 value: 18 value: 20 value: 2 value: 6783 value: 2295 value: 2338 value: 52 value: 0 value: 468 value: 4 value: 0 value: 189 value: 73 value: 153 value: 1294 value: 17 value: 90 value: 234 value: 935 value: 16 value: 25 value: 10024 value: 92 value: 2 value: 192 value: 4218 value: 3317 value: 3 value: 10098 value: 20 value: 2 value: 356 value: 4 value: 565 value: 334 value: 382 value: 36 value: 6989 value: 3 value: 6065 value: 2510 value: 16 value: 203 value: 7264 value: 2849 value: 0 value: 86 value: 346 value: 50 value: 26 value: 58 value: 10020 value: 5 value: 1464 value: 58 value: 2081 value: 2969 value: 42 value: 2 value: 2364 value: 3 value: 1402 value: 10062 value: 138 value: 147 value: 614 value: 115 value: 29 value: 90 value: 105 value: 2 value: 223 value: 18 value: 9 value: 160 value: 324 value: 3 value: 24 value: 12 value: 1252 value: 0 value: 2142 value: 10 value: 1832 value: 111 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 } } } } features { feature { key: "NL_num_nbrs" value { int64_list { value: 0 } } } feature { key: "id" value { bytes_list { value: "ed3db659-5524-4410-a5d5-d2bbd550a01f" } } } feature { key: "label_xf" value { int64_list { value: 1 } } } feature { key: "text_xf" value { int64_list { value: 13 value: 16 value: 423 value: 23 value: 1367 value: 30 value: 0 value: 363 value: 12 value: 153 value: 3174 value: 9 value: 8 value: 18 value: 26 value: 667 value: 338 value: 1372 value: 0 value: 86 value: 46 value: 9200 value: 282 value: 0 value: 10091 value: 4 value: 0 value: 694 value: 10028 value: 52 value: 362 value: 26 value: 202 value: 39 value: 216 value: 5 value: 27 value: 5822 value: 19 value: 52 value: 58 value: 362 value: 26 value: 202 value: 39 value: 474 value: 0 value: 10029 value: 4 value: 2 value: 243 value: 143 value: 386 value: 3 value: 0 value: 386 value: 579 value: 2 value: 132 value: 57 value: 725 value: 88 value: 140 value: 30 value: 27 value: 33 value: 1359 value: 29 value: 8 value: 567 value: 35 value: 106 value: 230 value: 60 value: 0 value: 3041 value: 5 value: 7879 value: 28 value: 281 value: 110 value: 111 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 value: 1 } } } } features { feature { key: "NL_nbr_0_id" value { bytes_list { value: "daf1f061-ef48-4476-a047-b9022c372d4e" } } } feature { key: "NL_nbr_0_label_xf" value { int64_list { value: -1 } } } feature { key: "NL_nbr_0_text_xf" value { int64_list { value: 13 value: 7 value: 174 value: 2 value: 1525 value: 4 value: 440 value: 3 value: 1260 value: 91 value: 108 value: 19 value: 10095 value: 10004 value: 40 value: 2 value: 169 value: 4 value: 4594 value: 84 value: 4 value: 30 value: 8 value: 15 value: 1063 value: 9 value: 54 value: 966 value: 31 value: 926 value: 757 value: 104 value: 3 value: 757 value: 86 value: 986 value: 0 value: 68 value: 4769 value: 9 value: 69 value: 8 value: 18 value: 1252 value: 0 value: 375 value: 31 value: 103 value: 1558 value: 9 value: 9 value: 640 value: 876 value: 3 value: 2551 value: 24 value: 1946 value: 1097 value: 8 value: 15 value: 5 value: 2 value: 2351 value: 1779 value: 19 value: 7 value: 95 value: 118 value: 4 value: 109 value: 2351 value: 9899 value: 12 value: 23 value: 4876 value: 16 value: 63 value: 16 value: 8 value: 24 value: 0 value: 68 value: 104 value: 12 value: 361 value: 5 value: 2257 value: 9 value: 2 value: 1092 value: 97 value: 26 value: 0 value: 2114 value: 10044 value: 10025 value: 3 value: 28 value: 343 value: 6595 } } } feature { key: "NL_nbr_0_weight" value { float_list { value: 0.9909949898719788 } } } feature { key: "NL_num_nbrs" value { int64_list { value: 1 } } } feature { key: "id" value { bytes_list { value: "c6c89c93-e2e9-4c4a-9f52-1221b4467499" } } } feature { key: "label_xf" value { int64_list { value: 1 } } } feature { key: "text_xf" value { int64_list { value: 13 value: 8 value: 6 value: 2 value: 18 value: 69 value: 140 value: 27 value: 83 value: 31 value: 1877 value: 905 value: 9 value: 10057 value: 31 value: 43 value: 2115 value: 36 value: 32 value: 2057 value: 6133 value: 10 value: 6 value: 32 value: 2474 value: 1614 value: 3 value: 2707 value: 990 value: 4 value: 10067 value: 9 value: 2 value: 1532 value: 242 value: 90 value: 3757 value: 3 value: 90 value: 10026 value: 0 value: 242 value: 6 value: 260 value: 31 value: 24 value: 4 value: 0 value: 84 value: 497 value: 177 value: 1151 value: 777 value: 9 value: 397 value: 552 value: 7726 value: 10051 value: 34 value: 14 value: 379 value: 33 value: 1829 value: 9 value: 123 value: 0 value: 916 value: 10028 value: 7 value: 64 value: 571 value: 12 value: 8 value: 18 value: 27 value: 687 value: 9 value: 30 value: 5609 value: 16 value: 25 value: 99 value: 117 value: 66 value: 2 value: 130 value: 21 value: 8 value: 842 value: 7726 value: 10051 value: 6 value: 338 value: 1107 value: 3 value: 24 value: 10020 value: 29 value: 53 value: 1476 } } } }
The Trainer Component
The Trainer
component trains models using TensorFlow.
Create a Python module containing a trainer_fn
function, which must return an estimator. If you prefer creating a Keras model, you can do so and then convert it to an estimator using keras.model_to_estimator()
.
# Setup paths.
_trainer_module_file = 'imdb_trainer.py'
%%writefile {_trainer_module_file}
import neural_structured_learning as nsl
import tensorflow as tf
import tensorflow_model_analysis as tfma
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
LABEL_KEY = 'label'
ID_FEATURE_KEY = 'id'
def _transformed_name(key):
return key + '_xf'
def _transformed_names(keys):
return [_transformed_name(key) for key in keys]
# Hyperparameters:
#
# We will use an instance of `HParams` to inclue various hyperparameters and
# constants used for training and evaluation. We briefly describe each of them
# below:
#
# - max_seq_length: This is the maximum number of words considered from each
# movie review in this example.
# - vocab_size: This is the size of the vocabulary considered for this
# example.
# - oov_size: This is the out-of-vocabulary size considered for this example.
# - distance_type: This is the distance metric used to regularize the sample
# with its neighbors.
# - graph_regularization_multiplier: This controls the relative weight of the
# graph regularization term in the overall
# loss function.
# - num_neighbors: The number of neighbors used for graph regularization. This
# value has to be less than or equal to the `num_neighbors`
# argument used above in the GraphAugmentation component when
# invoking `nsl.tools.pack_nbrs`.
# - num_fc_units: The number of units in the fully connected layer of the
# neural network.
class HParams(object):
"""Hyperparameters used for training."""
def __init__(self):
### dataset parameters
# The following 3 values should match those defined in the Transform
# Component.
self.max_seq_length = 100
self.vocab_size = 10000
self.oov_size = 100
### Neural Graph Learning parameters
self.distance_type = nsl.configs.DistanceType.L2
self.graph_regularization_multiplier = 0.1
# The following value has to be at most the value of 'num_neighbors' used
# in the GraphAugmentation component.
self.num_neighbors = 1
### Model Architecture
self.num_embedding_dims = 16
self.num_fc_units = 64
HPARAMS = HParams()
def optimizer_fn():
"""Returns an instance of `tf.Optimizer`."""
return tf.compat.v1.train.RMSPropOptimizer(
learning_rate=0.0001, decay=1e-6)
def build_train_op(loss, global_step):
"""Builds a train op to optimize the given loss using gradient descent."""
with tf.name_scope('train'):
optimizer = optimizer_fn()
train_op = optimizer.minimize(loss=loss, global_step=global_step)
return train_op
# Building the model:
#
# A neural network is created by stacking layers—this requires two main
# architectural decisions:
# * How many layers to use in the model?
# * How many *hidden units* to use for each layer?
#
# In this example, the input data consists of an array of word-indices. The
# labels to predict are either 0 or 1. We will use a feed-forward neural network
# as our base model in this tutorial.
def feed_forward_model(features, is_training, reuse=tf.compat.v1.AUTO_REUSE):
"""Builds a simple 2 layer feed forward neural network.
The layers are effectively stacked sequentially to build the classifier. The
first layer is an Embedding layer, which takes the integer-encoded vocabulary
and looks up the embedding vector for each word-index. These vectors are
learned as the model trains. The vectors add a dimension to the output array.
The resulting dimensions are: (batch, sequence, embedding). Next is a global
average pooling 1D layer, which reduces the dimensionality of its inputs from
3D to 2D. This fixed-length output vector is piped through a fully-connected
(Dense) layer with 16 hidden units. The last layer is densely connected with a
single output node. Using the sigmoid activation function, this value is a
float between 0 and 1, representing a probability, or confidence level.
Args:
features: A dictionary containing batch features returned from the
`input_fn`, that include sample features, corresponding neighbor features,
and neighbor weights.
is_training: a Python Boolean value or a Boolean scalar Tensor, indicating
whether to apply dropout.
reuse: a Python Boolean value for reusing variable scope.
Returns:
logits: Tensor of shape [batch_size, 1].
representations: Tensor of shape [batch_size, _] for graph regularization.
This is the representation of each example at the graph regularization
layer.
"""
with tf.compat.v1.variable_scope('ff', reuse=reuse):
inputs = features[_transformed_name('text')]
embeddings = tf.compat.v1.get_variable(
'embeddings',
shape=[
HPARAMS.vocab_size + HPARAMS.oov_size, HPARAMS.num_embedding_dims
])
embedding_layer = tf.nn.embedding_lookup(embeddings, inputs)
pooling_layer = tf.compat.v1.layers.AveragePooling1D(
pool_size=HPARAMS.max_seq_length, strides=HPARAMS.max_seq_length)(
embedding_layer)
# Shape of pooling_layer is now [batch_size, 1, HPARAMS.num_embedding_dims]
pooling_layer = tf.reshape(pooling_layer, [-1, HPARAMS.num_embedding_dims])
dense_layer = tf.compat.v1.layers.Dense(
16, activation='relu')(
pooling_layer)
output_layer = tf.compat.v1.layers.Dense(
1, activation='sigmoid')(
dense_layer)
# Graph regularization will be done on the penultimate (dense) layer
# because the output layer is a single floating point number.
return output_layer, dense_layer
# A note on hidden units:
#
# The above model has two intermediate or "hidden" layers, between the input and
# output, and excluding the Embedding layer. The number of outputs (units,
# nodes, or neurons) is the dimension of the representational space for the
# layer. In other words, the amount of freedom the network is allowed when
# learning an internal representation. If a model has more hidden units
# (a higher-dimensional representation space), and/or more layers, then the
# network can learn more complex representations. However, it makes the network
# more computationally expensive and may lead to learning unwanted
# patterns—patterns that improve performance on training data but not on the
# test data. This is called overfitting.
# This function will be used to generate the embeddings for samples and their
# corresponding neighbors, which will then be used for graph regularization.
def embedding_fn(features, mode):
"""Returns the embedding corresponding to the given features.
Args:
features: A dictionary containing batch features returned from the
`input_fn`, that include sample features, corresponding neighbor features,
and neighbor weights.
mode: Specifies if this is training, evaluation, or prediction. See
tf.estimator.ModeKeys.
Returns:
The embedding that will be used for graph regularization.
"""
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
_, embedding = feed_forward_model(features, is_training)
return embedding
def feed_forward_model_fn(features, labels, mode, params, config):
"""Implementation of the model_fn for the base feed-forward model.
Args:
features: This is the first item returned from the `input_fn` passed to
`train`, `evaluate`, and `predict`. This should be a single `Tensor` or
`dict` of same.
labels: This is the second item returned from the `input_fn` passed to
`train`, `evaluate`, and `predict`. This should be a single `Tensor` or
`dict` of same (for multi-head models). If mode is `ModeKeys.PREDICT`,
`labels=None` will be passed. If the `model_fn`'s signature does not
accept `mode`, the `model_fn` must still be able to handle `labels=None`.
mode: Optional. Specifies if this training, evaluation or prediction. See
`ModeKeys`.
params: An HParams instance as returned by get_hyper_parameters().
config: Optional configuration object. Will receive what is passed to
Estimator in `config` parameter, or the default `config`. Allows updating
things in your model_fn based on configuration such as `num_ps_replicas`,
or `model_dir`. Unused currently.
Returns:
A `tf.estimator.EstimatorSpec` for the base feed-forward model. This does
not include graph-based regularization.
"""
is_training = mode == tf.estimator.ModeKeys.TRAIN
# Build the computation graph.
probabilities, _ = feed_forward_model(features, is_training)
predictions = tf.round(probabilities)
if mode == tf.estimator.ModeKeys.PREDICT:
# labels will be None, and no loss to compute.
cross_entropy_loss = None
eval_metric_ops = None
else:
# Loss is required in train and eval modes.
# Flatten 'probabilities' to 1-D.
probabilities = tf.reshape(probabilities, shape=[-1])
cross_entropy_loss = tf.compat.v1.keras.losses.binary_crossentropy(
labels, probabilities)
eval_metric_ops = {
'accuracy': tf.compat.v1.metrics.accuracy(labels, predictions)
}
if is_training:
global_step = tf.compat.v1.train.get_or_create_global_step()
train_op = build_train_op(cross_entropy_loss, global_step)
else:
train_op = None
return tf.estimator.EstimatorSpec(
mode=mode,
predictions={
'probabilities': probabilities,
'predictions': predictions
},
loss=cross_entropy_loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)
# Tf.Transform considers these features as "raw"
def _get_raw_feature_spec(schema):
return schema_utils.schema_as_feature_spec(schema).feature_spec
def _gzip_reader_fn(filenames):
"""Small utility returning a record reader that can read gzip'ed files."""
return tf.data.TFRecordDataset(
filenames,
compression_type='GZIP')
def _example_serving_receiver_fn(tf_transform_output, schema):
"""Build the serving in inputs.
Args:
tf_transform_output: A TFTransformOutput.
schema: the schema of the input data.
Returns:
Tensorflow graph which parses examples, applying tf-transform to them.
"""
raw_feature_spec = _get_raw_feature_spec(schema)
raw_feature_spec.pop(LABEL_KEY)
# We don't need the ID feature for serving.
raw_feature_spec.pop(ID_FEATURE_KEY)
raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
raw_feature_spec, default_batch_size=None)
serving_input_receiver = raw_input_fn()
transformed_features = tf_transform_output.transform_raw_features(
serving_input_receiver.features)
# Even though, LABEL_KEY was removed from 'raw_feature_spec', the transform
# operation would have injected the transformed LABEL_KEY feature with a
# default value.
transformed_features.pop(_transformed_name(LABEL_KEY))
return tf.estimator.export.ServingInputReceiver(
transformed_features, serving_input_receiver.receiver_tensors)
def _eval_input_receiver_fn(tf_transform_output, schema):
"""Build everything needed for the tf-model-analysis to run the model.
Args:
tf_transform_output: A TFTransformOutput.
schema: the schema of the input data.
Returns:
EvalInputReceiver function, which contains:
- Tensorflow graph which parses raw untransformed features, applies the
tf-transform preprocessing operators.
- Set of raw, untransformed features.
- Label against which predictions will be compared.
"""
# Notice that the inputs are raw features, not transformed features here.
raw_feature_spec = _get_raw_feature_spec(schema)
# We don't need the ID feature for TFMA.
raw_feature_spec.pop(ID_FEATURE_KEY)
raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
raw_feature_spec, default_batch_size=None)
serving_input_receiver = raw_input_fn()
transformed_features = tf_transform_output.transform_raw_features(
serving_input_receiver.features)
labels = transformed_features.pop(_transformed_name(LABEL_KEY))
return tfma.export.EvalInputReceiver(
features=transformed_features,
receiver_tensors=serving_input_receiver.receiver_tensors,
labels=labels)
def _augment_feature_spec(feature_spec, num_neighbors):
"""Augments `feature_spec` to include neighbor features.
Args:
feature_spec: Dictionary of feature keys mapping to TF feature types.
num_neighbors: Number of neighbors to use for feature key augmentation.
Returns:
An augmented `feature_spec` that includes neighbor feature keys.
"""
for i in range(num_neighbors):
feature_spec['{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'id')] = \
tf.io.VarLenFeature(dtype=tf.string)
# We don't care about the neighbor features corresponding to
# _transformed_name(LABEL_KEY) because the LABEL_KEY feature will be
# removed from the feature spec during training/evaluation.
feature_spec['{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'text_xf')] = \
tf.io.FixedLenFeature(shape=[HPARAMS.max_seq_length], dtype=tf.int64,
default_value=tf.constant(0, dtype=tf.int64,
shape=[HPARAMS.max_seq_length]))
# The 'NL_num_nbrs' features is currently not used.
# Set the neighbor weight feature keys.
for i in range(num_neighbors):
feature_spec['{}{}{}'.format(NBR_FEATURE_PREFIX, i, NBR_WEIGHT_SUFFIX)] = \
tf.io.FixedLenFeature(shape=[1], dtype=tf.float32, default_value=[0.0])
return feature_spec
def _input_fn(filenames, tf_transform_output, is_training, batch_size=200):
"""Generates features and labels for training or evaluation.
Args:
filenames: [str] list of CSV files to read data from.
tf_transform_output: A TFTransformOutput.
is_training: Boolean indicating if we are in training mode.
batch_size: int First dimension size of the Tensors returned by input_fn
Returns:
A (features, indices) tuple where features is a dictionary of
Tensors, and indices is a single Tensor of label indices.
"""
transformed_feature_spec = (
tf_transform_output.transformed_feature_spec().copy())
# During training, NSL uses augmented training data (which includes features
# from graph neighbors). So, update the feature spec accordingly. This needs
# to be done because we are using different schemas for NSL training and eval,
# but the Trainer Component only accepts a single schema.
if is_training:
transformed_feature_spec =_augment_feature_spec(transformed_feature_spec,
HPARAMS.num_neighbors)
dataset = tf.data.experimental.make_batched_features_dataset(
filenames, batch_size, transformed_feature_spec, reader=_gzip_reader_fn)
transformed_features = tf.compat.v1.data.make_one_shot_iterator(
dataset).get_next()
# We pop the label because we do not want to use it as a feature while we're
# training.
return transformed_features, transformed_features.pop(
_transformed_name(LABEL_KEY))
# TFX will call this function
def trainer_fn(hparams, schema):
"""Build the estimator using the high level API.
Args:
hparams: Holds hyperparameters used to train the model as name/value pairs.
schema: Holds the schema of the training examples.
Returns:
A dict of the following:
- estimator: The estimator that will be used for training and eval.
- train_spec: Spec for training.
- eval_spec: Spec for eval.
- eval_input_receiver_fn: Input function for eval.
"""
train_batch_size = 40
eval_batch_size = 40
tf_transform_output = tft.TFTransformOutput(hparams.transform_output)
train_input_fn = lambda: _input_fn(
hparams.train_files,
tf_transform_output,
is_training=True,
batch_size=train_batch_size)
eval_input_fn = lambda: _input_fn(
hparams.eval_files,
tf_transform_output,
is_training=False,
batch_size=eval_batch_size)
train_spec = tf.estimator.TrainSpec(
train_input_fn,
max_steps=hparams.train_steps)
serving_receiver_fn = lambda: _example_serving_receiver_fn(
tf_transform_output, schema)
exporter = tf.estimator.FinalExporter('imdb', serving_receiver_fn)
eval_spec = tf.estimator.EvalSpec(
eval_input_fn,
steps=hparams.eval_steps,
exporters=[exporter],
name='imdb-eval')
run_config = tf.estimator.RunConfig(
save_checkpoints_steps=999, keep_checkpoint_max=1)
run_config = run_config.replace(model_dir=hparams.serving_model_dir)
estimator = tf.estimator.Estimator(
model_fn=feed_forward_model_fn, config=run_config, params=HPARAMS)
# Create a graph regularization config.
graph_reg_config = nsl.configs.make_graph_reg_config(
max_neighbors=HPARAMS.num_neighbors,
multiplier=HPARAMS.graph_regularization_multiplier,
distance_type=HPARAMS.distance_type,
sum_over_axis=-1)
# Invoke the Graph Regularization Estimator wrapper to incorporate
# graph-based regularization for training.
graph_nsl_estimator = nsl.estimator.add_graph_regularization(
estimator,
embedding_fn,
optimizer_fn=optimizer_fn,
graph_reg_config=graph_reg_config)
# Create an input receiver for TFMA processing
receiver_fn = lambda: _eval_input_receiver_fn(
tf_transform_output, schema)
return {
'estimator': graph_nsl_estimator,
'train_spec': train_spec,
'eval_spec': eval_spec,
'eval_input_receiver_fn': receiver_fn
}
Writing imdb_trainer.py
Create and run the Trainer
component, passing it the file that we created above.
# Uses user-provided Python function that implements a model using TensorFlow's
# Estimators API.
trainer = Trainer(
module_file=_trainer_module_file,
transformed_examples=graph_augmentation.outputs['augmented_examples'],
schema=schema_gen.outputs['schema'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
context.run(trainer)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 999, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 999 or save_checkpoints_secs None. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/rmsprop.py:123: calling Ones.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.69318736, step = 0 INFO:tensorflow:global_step/sec: 222.546 INFO:tensorflow:loss = 0.6928638, step = 100 (0.450 sec) INFO:tensorflow:global_step/sec: 291.344 INFO:tensorflow:loss = 0.69281894, step = 200 (0.343 sec) INFO:tensorflow:global_step/sec: 296.443 INFO:tensorflow:loss = 0.6927313, step = 300 (0.337 sec) INFO:tensorflow:global_step/sec: 291.965 INFO:tensorflow:loss = 0.6917414, step = 400 (0.342 sec) INFO:tensorflow:global_step/sec: 298.269 INFO:tensorflow:loss = 0.6905616, step = 500 (0.335 sec) INFO:tensorflow:global_step/sec: 292.315 INFO:tensorflow:loss = 0.6894297, step = 600 (0.342 sec) INFO:tensorflow:global_step/sec: 295.769 INFO:tensorflow:loss = 0.6896509, step = 700 (0.338 sec) INFO:tensorflow:global_step/sec: 296.858 INFO:tensorflow:loss = 0.68861306, step = 800 (0.337 sec) INFO:tensorflow:global_step/sec: 292.735 INFO:tensorflow:loss = 0.68658316, step = 900 (0.342 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 999... INFO:tensorflow:Saving checkpoints for 999 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/saver.py:971: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 999... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2020-10-15T09:32:00Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt-999 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [500/5000] INFO:tensorflow:Evaluation [1000/5000] INFO:tensorflow:Evaluation [1500/5000] INFO:tensorflow:Evaluation [2000/5000] INFO:tensorflow:Evaluation [2500/5000] INFO:tensorflow:Evaluation [3000/5000] INFO:tensorflow:Evaluation [3500/5000] INFO:tensorflow:Evaluation [4000/5000] INFO:tensorflow:Evaluation [4500/5000] INFO:tensorflow:Evaluation [5000/5000] INFO:tensorflow:Inference Time : 5.29909s INFO:tensorflow:Finished evaluation at 2020-10-15-09:32:05 INFO:tensorflow:Saving dict for global step 999: accuracy = 0.7035, global_step = 999, loss = 0.68670774 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 999: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt-999 INFO:tensorflow:global_step/sec: 17.0767 INFO:tensorflow:loss = 0.68894106, step = 1000 (5.855 sec) INFO:tensorflow:global_step/sec: 299.602 INFO:tensorflow:loss = 0.6814944, step = 1100 (0.334 sec) INFO:tensorflow:global_step/sec: 300.889 INFO:tensorflow:loss = 0.6839364, step = 1200 (0.333 sec) INFO:tensorflow:global_step/sec: 302.256 INFO:tensorflow:loss = 0.6763433, step = 1300 (0.331 sec) INFO:tensorflow:global_step/sec: 299.199 INFO:tensorflow:loss = 0.6769841, step = 1400 (0.334 sec) INFO:tensorflow:global_step/sec: 299.279 INFO:tensorflow:loss = 0.67444175, step = 1500 (0.334 sec) INFO:tensorflow:global_step/sec: 307.62 INFO:tensorflow:loss = 0.67098206, step = 1600 (0.325 sec) INFO:tensorflow:global_step/sec: 304.262 INFO:tensorflow:loss = 0.665629, step = 1700 (0.329 sec) INFO:tensorflow:global_step/sec: 297.873 INFO:tensorflow:loss = 0.6719124, step = 1800 (0.336 sec) INFO:tensorflow:global_step/sec: 306.605 INFO:tensorflow:loss = 0.65660954, step = 1900 (0.326 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1998... INFO:tensorflow:Saving checkpoints for 1998 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1998... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 254.265 INFO:tensorflow:loss = 0.6726355, step = 2000 (0.393 sec) INFO:tensorflow:global_step/sec: 290.351 INFO:tensorflow:loss = 0.6551316, step = 2100 (0.345 sec) INFO:tensorflow:global_step/sec: 298.852 INFO:tensorflow:loss = 0.67447, step = 2200 (0.335 sec) INFO:tensorflow:global_step/sec: 295.696 INFO:tensorflow:loss = 0.64570725, step = 2300 (0.338 sec) INFO:tensorflow:global_step/sec: 301.494 INFO:tensorflow:loss = 0.6464771, step = 2400 (0.332 sec) INFO:tensorflow:global_step/sec: 304.472 INFO:tensorflow:loss = 0.6501285, step = 2500 (0.329 sec) INFO:tensorflow:global_step/sec: 302.118 INFO:tensorflow:loss = 0.6361262, step = 2600 (0.331 sec) INFO:tensorflow:global_step/sec: 307.043 INFO:tensorflow:loss = 0.64034796, step = 2700 (0.325 sec) WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 2748 vs previous value: 2748. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize. INFO:tensorflow:global_step/sec: 298.63 INFO:tensorflow:loss = 0.62189335, step = 2800 (0.335 sec) INFO:tensorflow:global_step/sec: 293.917 INFO:tensorflow:loss = 0.6147873, step = 2900 (0.340 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2997... INFO:tensorflow:Saving checkpoints for 2997 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2997... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 254.499 INFO:tensorflow:loss = 0.61259216, step = 3000 (0.393 sec) INFO:tensorflow:global_step/sec: 298.886 INFO:tensorflow:loss = 0.6229025, step = 3100 (0.335 sec) INFO:tensorflow:global_step/sec: 305.197 INFO:tensorflow:loss = 0.60436034, step = 3200 (0.328 sec) INFO:tensorflow:global_step/sec: 299.399 INFO:tensorflow:loss = 0.62933403, step = 3300 (0.334 sec) INFO:tensorflow:global_step/sec: 301.028 INFO:tensorflow:loss = 0.60902774, step = 3400 (0.332 sec) INFO:tensorflow:global_step/sec: 300.191 INFO:tensorflow:loss = 0.64181244, step = 3500 (0.333 sec) INFO:tensorflow:global_step/sec: 290.434 INFO:tensorflow:loss = 0.57052743, step = 3600 (0.344 sec) INFO:tensorflow:global_step/sec: 299.378 INFO:tensorflow:loss = 0.60267526, step = 3700 (0.334 sec) INFO:tensorflow:global_step/sec: 307.013 INFO:tensorflow:loss = 0.6107319, step = 3800 (0.326 sec) INFO:tensorflow:global_step/sec: 304.692 INFO:tensorflow:loss = 0.56591743, step = 3900 (0.328 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3996... INFO:tensorflow:Saving checkpoints for 3996 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3996... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 255.208 INFO:tensorflow:loss = 0.56774515, step = 4000 (0.392 sec) INFO:tensorflow:global_step/sec: 309.924 INFO:tensorflow:loss = 0.59160006, step = 4100 (0.323 sec) INFO:tensorflow:global_step/sec: 306.066 INFO:tensorflow:loss = 0.5484713, step = 4200 (0.327 sec) INFO:tensorflow:global_step/sec: 301.846 INFO:tensorflow:loss = 0.63335776, step = 4300 (0.332 sec) INFO:tensorflow:global_step/sec: 299.014 INFO:tensorflow:loss = 0.5656133, step = 4400 (0.334 sec) INFO:tensorflow:global_step/sec: 306.259 INFO:tensorflow:loss = 0.5533817, step = 4500 (0.326 sec) INFO:tensorflow:global_step/sec: 300.019 INFO:tensorflow:loss = 0.56391084, step = 4600 (0.333 sec) INFO:tensorflow:global_step/sec: 304.165 INFO:tensorflow:loss = 0.5910115, step = 4700 (0.329 sec) INFO:tensorflow:global_step/sec: 295.489 INFO:tensorflow:loss = 0.5945301, step = 4800 (0.338 sec) INFO:tensorflow:global_step/sec: 297.313 INFO:tensorflow:loss = 0.61218303, step = 4900 (0.336 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4995... INFO:tensorflow:Saving checkpoints for 4995 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4995... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 260.352 INFO:tensorflow:loss = 0.5332743, step = 5000 (0.385 sec) INFO:tensorflow:global_step/sec: 304.608 INFO:tensorflow:loss = 0.56679493, step = 5100 (0.328 sec) INFO:tensorflow:global_step/sec: 311.855 INFO:tensorflow:loss = 0.54229665, step = 5200 (0.321 sec) INFO:tensorflow:global_step/sec: 305.253 INFO:tensorflow:loss = 0.52315617, step = 5300 (0.328 sec) INFO:tensorflow:global_step/sec: 299.658 INFO:tensorflow:loss = 0.5793217, step = 5400 (0.334 sec) INFO:tensorflow:global_step/sec: 304.107 INFO:tensorflow:loss = 0.5486561, step = 5500 (0.329 sec) INFO:tensorflow:global_step/sec: 308.079 INFO:tensorflow:loss = 0.49263632, step = 5600 (0.325 sec) INFO:tensorflow:global_step/sec: 313.378 INFO:tensorflow:loss = 0.5385544, step = 5700 (0.319 sec) INFO:tensorflow:global_step/sec: 302.781 INFO:tensorflow:loss = 0.5010498, step = 5800 (0.330 sec) INFO:tensorflow:global_step/sec: 296.805 INFO:tensorflow:loss = 0.47667298, step = 5900 (0.337 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5994... INFO:tensorflow:Saving checkpoints for 5994 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5994... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 257.488 INFO:tensorflow:loss = 0.55798185, step = 6000 (0.388 sec) INFO:tensorflow:global_step/sec: 305.947 INFO:tensorflow:loss = 0.43396345, step = 6100 (0.327 sec) INFO:tensorflow:global_step/sec: 296.299 INFO:tensorflow:loss = 0.43670568, step = 6200 (0.338 sec) INFO:tensorflow:global_step/sec: 303.445 INFO:tensorflow:loss = 0.46067405, step = 6300 (0.330 sec) INFO:tensorflow:global_step/sec: 310.182 INFO:tensorflow:loss = 0.5060933, step = 6400 (0.322 sec) INFO:tensorflow:global_step/sec: 298.273 INFO:tensorflow:loss = 0.4996158, step = 6500 (0.335 sec) INFO:tensorflow:global_step/sec: 300.567 INFO:tensorflow:loss = 0.396133, step = 6600 (0.333 sec) INFO:tensorflow:global_step/sec: 297.986 INFO:tensorflow:loss = 0.42002386, step = 6700 (0.336 sec) INFO:tensorflow:global_step/sec: 304.359 INFO:tensorflow:loss = 0.4611571, step = 6800 (0.328 sec) INFO:tensorflow:global_step/sec: 302.25 INFO:tensorflow:loss = 0.44177708, step = 6900 (0.331 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6993... INFO:tensorflow:Saving checkpoints for 6993 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6993... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 256.018 INFO:tensorflow:loss = 0.46849436, step = 7000 (0.390 sec) INFO:tensorflow:global_step/sec: 291.076 INFO:tensorflow:loss = 0.41983128, step = 7100 (0.344 sec) INFO:tensorflow:global_step/sec: 296.444 INFO:tensorflow:loss = 0.35345578, step = 7200 (0.337 sec) INFO:tensorflow:global_step/sec: 293.356 INFO:tensorflow:loss = 0.41871148, step = 7300 (0.341 sec) INFO:tensorflow:global_step/sec: 303.596 INFO:tensorflow:loss = 0.47682336, step = 7400 (0.329 sec) INFO:tensorflow:global_step/sec: 303.782 INFO:tensorflow:loss = 0.55223024, step = 7500 (0.329 sec) INFO:tensorflow:global_step/sec: 296.762 INFO:tensorflow:loss = 0.42545128, step = 7600 (0.337 sec) INFO:tensorflow:global_step/sec: 309.21 INFO:tensorflow:loss = 0.43023503, step = 7700 (0.323 sec) INFO:tensorflow:global_step/sec: 306.462 INFO:tensorflow:loss = 0.5604722, step = 7800 (0.326 sec) INFO:tensorflow:global_step/sec: 303.329 INFO:tensorflow:loss = 0.5337108, step = 7900 (0.330 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7992... INFO:tensorflow:Saving checkpoints for 7992 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7992... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 255.136 INFO:tensorflow:loss = 0.4013764, step = 8000 (0.392 sec) INFO:tensorflow:global_step/sec: 301.602 INFO:tensorflow:loss = 0.4093078, step = 8100 (0.332 sec) INFO:tensorflow:global_step/sec: 299.313 INFO:tensorflow:loss = 0.41223857, step = 8200 (0.334 sec) INFO:tensorflow:global_step/sec: 296.211 INFO:tensorflow:loss = 0.4117222, step = 8300 (0.338 sec) INFO:tensorflow:global_step/sec: 299.752 INFO:tensorflow:loss = 0.39056668, step = 8400 (0.334 sec) INFO:tensorflow:global_step/sec: 302.187 INFO:tensorflow:loss = 0.391355, step = 8500 (0.331 sec) INFO:tensorflow:global_step/sec: 295.599 INFO:tensorflow:loss = 0.46732607, step = 8600 (0.338 sec) INFO:tensorflow:global_step/sec: 297.524 INFO:tensorflow:loss = 0.44837368, step = 8700 (0.336 sec) INFO:tensorflow:global_step/sec: 298.751 INFO:tensorflow:loss = 0.5095719, step = 8800 (0.335 sec) INFO:tensorflow:global_step/sec: 300.476 INFO:tensorflow:loss = 0.3573585, step = 8900 (0.333 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8991... INFO:tensorflow:Saving checkpoints for 8991 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8991... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 259.517 INFO:tensorflow:loss = 0.38418576, step = 9000 (0.385 sec) INFO:tensorflow:global_step/sec: 300.71 INFO:tensorflow:loss = 0.3826803, step = 9100 (0.333 sec) INFO:tensorflow:global_step/sec: 301.991 INFO:tensorflow:loss = 0.36049247, step = 9200 (0.331 sec) INFO:tensorflow:global_step/sec: 298.252 INFO:tensorflow:loss = 0.31363297, step = 9300 (0.335 sec) INFO:tensorflow:global_step/sec: 297.207 INFO:tensorflow:loss = 0.3982248, step = 9400 (0.337 sec) INFO:tensorflow:global_step/sec: 301.999 INFO:tensorflow:loss = 0.34949106, step = 9500 (0.331 sec) INFO:tensorflow:global_step/sec: 301.815 INFO:tensorflow:loss = 0.40354735, step = 9600 (0.331 sec) INFO:tensorflow:global_step/sec: 300.948 INFO:tensorflow:loss = 0.47522005, step = 9700 (0.333 sec) INFO:tensorflow:global_step/sec: 299.78 INFO:tensorflow:loss = 0.4353662, step = 9800 (0.333 sec) INFO:tensorflow:global_step/sec: 300.752 INFO:tensorflow:loss = 0.45311904, step = 9900 (0.333 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9990... INFO:tensorflow:Saving checkpoints for 9990 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9990... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10000... INFO:tensorflow:Saving checkpoints for 10000 into /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10000... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2020-10-15T09:32:36Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt-10000 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [500/5000] INFO:tensorflow:Evaluation [1000/5000] INFO:tensorflow:Evaluation [1500/5000] INFO:tensorflow:Evaluation [2000/5000] INFO:tensorflow:Evaluation [2500/5000] INFO:tensorflow:Evaluation [3000/5000] INFO:tensorflow:Evaluation [3500/5000] INFO:tensorflow:Evaluation [4000/5000] INFO:tensorflow:Evaluation [4500/5000] INFO:tensorflow:Evaluation [5000/5000] INFO:tensorflow:Inference Time : 5.22927s INFO:tensorflow:Finished evaluation at 2020-10-15-09:32:41 INFO:tensorflow:Saving dict for global step 10000: accuracy = 0.8, global_step = 10000, loss = 0.4427957 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10000: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt-10000 INFO:tensorflow:Performing the final export in the end of training. WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_1:0\022-vocab_compute_and_apply_vocabulary_vocabulary" INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt-10000 INFO:tensorflow:Assets added to graph. INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/export/imdb/temp-1602754361/assets INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/export/imdb/temp-1602754361/saved_model.pb INFO:tensorflow:Loss for final step: 0.4515194. WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_1:0\022-vocab_compute_and_apply_vocabulary_vocabulary" INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Predict: None INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval'] WARNING:tensorflow:Export includes no default signature! INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/serving_model_dir/model.ckpt-10000 INFO:tensorflow:Assets added to graph. INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/eval_model_dir/temp-1602754361/assets INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-10-15T09_26_00.686186-pmz3k9ac/Trainer/model_run/9/eval_model_dir/temp-1602754361/saved_model.pb Warning:absl:Support for estimator-based executor and model export will be deprecated soon. Please use export structure <ModelExportPath>/serving_model_dir/saved_model.pb" WARNING:absl:Support for estimator-based executor and model export will be deprecated soon. Please use export structure <ModelExportPath>/eval_model_dir/saved_model.pb"
Take a peek at the trained model which was exported from Trainer
.
train_uri = trainer.outputs['model'].get()[0].uri
serving_model_path = os.path.join(train_uri, 'serving_model_dir')
exported_model = tf.saved_model.load(serving_model_path)
exported_model.graph.get_operations()[:10] + ["..."]
[<tf.Operation 'global_step/Initializer/zeros' type=Const>, <tf.Operation 'global_step' type=VarHandleOp>, <tf.Operation 'global_step/IsInitialized/VarIsInitializedOp' type=VarIsInitializedOp>, <tf.Operation 'global_step/Assign' type=AssignVariableOp>, <tf.Operation 'global_step/Read/ReadVariableOp' type=ReadVariableOp>, <tf.Operation 'input_example_tensor' type=Placeholder>, <tf.Operation 'ParseExample/ParseExampleV2/names' type=Const>, <tf.Operation 'ParseExample/ParseExampleV2/sparse_keys' type=Const>, <tf.Operation 'ParseExample/ParseExampleV2/dense_keys' type=Const>, <tf.Operation 'ParseExample/ParseExampleV2/ragged_keys' type=Const>, '...']
Let's visualize the model's metrics using Tensorboard.
# Get the URI of the output artifact representing the training logs,
# which is a directory
model_run_dir = trainer.outputs['model_run'].get()[0].uri
%load_ext tensorboard
%tensorboard --logdir {model_run_dir}
Model Serving
Graph regularization only affects the training workflow by adding a regularization term to the loss function. As a result, the model evaluation and serving workflows remain unchanged. It is for the same reason that we've also omitted downstream TFX components that typically come after the Trainer component like the Evaluator, Pusher, etc.
Conclusion
We have demonstrated the use of graph regularization using the Neural Structured Learning (NSL) framework in a TFX pipeline even when the input does not contain an explicit graph. We considered the task of sentiment classification of IMDB movie reviews for which we synthesized a similarity graph based on review embeddings. We encourage users to experiment further by using different embeddings for graph construction, varying hyperparameters, changing the amount of supervision, and by defining different model architectures.