TFX Pipeline for Fine-Tuning a Large Language Model (LLM)

# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

This codelab demonstrates how to leverage the power of Keras 3, KerasNLP and TFX pipelines to fine-tune a pre-trained GPT-2 model on the IMDb movie reviews dataset. The dataset that is used in this demo is IMDB Reviews dataset.

Why is this pipeline useful?

TFX pipelines provide a powerful and structured approach to building and managing machine learning workflows, particularly those involving large language models. They offer significant advantages over traditional Python code, including:

  1. Enhanced Reproducibility: TFX pipelines ensure consistent results by capturing all steps and dependencies, eliminating the inconsistencies often associated with manual workflows.

  2. Scalability and Modularity: TFX allows for breaking down complex workflows into manageable, reusable components, promoting code organization.

  3. Streamlined Fine-Tuning and Conversion: The pipeline structure streamlines the fine-tuning and conversion processes of large language models, significantly reducing manual effort and time.

  4. Comprehensive Lineage Tracking: Through metadata tracking, TFX pipelines provide a clear understanding of data and model provenance, making debugging, auditing, and performance analysis much easier and more efficient.

By leveraging the benefits of TFX pipelines, organizations can effectively manage the complexity of large language model development and deployment, achieving greater efficiency and control over their machine learning processes.

Note

GPT-2 is used here only to demonstrate the end-to-end process; the techniques and tooling introduced in this codelab are potentially transferrable to other generative language models such as Google T5.

Before You Begin

Colab offers different kinds of runtimes. Make sure to go to Runtime -> Change runtime type and choose the GPU Hardware Accelerator runtime since you will finetune the GPT-2 model.

This tutorial's interactive pipeline is designed to function seamlessly with free Colab GPUs. However, for users opting to run the pipeline using the LocalDagRunner orchestrator (code provided at the end of this tutorial), a more substantial amount of GPU memory is required. Therefore, Colab Pro or a local machine equipped with a higher-capacity GPU is recommended for this approach.

Set Up

We first install required python packages.

Upgrade Pip

To avoid upgrading Pip in a system when running locally, check to make sure that we are running in Colab. Local systems can of course be upgraded separately.

try:
  import colab
  !pip install --upgrade pip

except:
  pass

Install TFX, Keras 3, KerasNLP and required Libraries

pip install -q tfx tensorflow-text more_itertools tensorflow_datasets
pip install -q --upgrade keras-nlp
pip install -q --upgrade keras

Did you restart the runtime?

If you are using Google Colab, the first time that you run the cell above, you must restart the runtime by clicking above "RESTART SESSION" button or using "Runtime > Restart session" menu. This is because of the way that Colab loads packages.

Let's check the TensorFlow, Keras, Keras-nlp and TFX library versions.

import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
print('TensorFlow version: {}'.format(tf.__version__))
from tfx import v1 as tfx
print('TFX version: {}'.format(tfx.__version__))
import keras
print('Keras version: {}'.format(keras.__version__))
import keras_nlp
print('Keras NLP version: {}'.format(keras_nlp.__version__))

keras.mixed_precision.set_global_policy("mixed_float16")
2024-06-19 10:24:56.971153: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-19 10:24:56.971203: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-19 10:24:56.972902: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TensorFlow version: 2.15.1
TFX version: 1.15.1
Keras version: 3.3.3
Keras NLP version: 0.12.1

Using TFX Interactive Context

An interactive context is used to provide global context when running a TFX pipeline in a notebook without using a runner or orchestrator such as Apache Airflow or Kubeflow. This style of development is only useful when developing the code for a pipeline, and cannot currently be used to deploy a working pipeline to production.

from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
context = InteractiveContext()
WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6 as root for pipeline outputs.
WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/metadata.sqlite.

Pipeline Overview

Below are the components that this pipeline follows.

  • Custom Artifacts are artifacts that we have created for this pipeline. Artifacts are data that is produced by a component or consumed by a component. Artifacts are stored in a system for managing the storage and versioning of artifacts called MLMD.

  • Components are defined as the implementation of an ML task that you can use as a step in your pipeline

  • Aside from artifacts, Parameters are passed into the components to specify an argument.

ExampleGen

We create a custom ExampleGen component which we use to load a TensorFlow Datasets (TFDS) dataset. This uses a custom executor in a FileBasedExampleGen.

from typing import Any, Dict, List, Text
import tensorflow_datasets as tfds
import apache_beam as beam
import json
from tfx.components.example_gen.base_example_gen_executor import BaseExampleGenExecutor
from tfx.components.example_gen.component import FileBasedExampleGen
from tfx.components.example_gen import utils
from tfx.dsl.components.base import executor_spec
import os
import pprint
pp = pprint.PrettyPrinter()
@beam.ptransform_fn
@beam.typehints.with_input_types(beam.Pipeline)
@beam.typehints.with_output_types(tf.train.Example)
def _TFDatasetToExample(
    pipeline: beam.Pipeline,
    exec_properties: Dict[str, Any],
    split_pattern: str
    ) -> beam.pvalue.PCollection:
    """Read a TensorFlow Dataset and create tf.Examples"""
    custom_config = json.loads(exec_properties['custom_config'])
    dataset_name = custom_config['dataset']
    split_name = custom_config['split']

    builder = tfds.builder(dataset_name)
    builder.download_and_prepare()

    return (pipeline
            | 'MakeExamples' >> tfds.beam.ReadFromTFDS(builder, split=split_name)
            | 'AsNumpy' >> beam.Map(tfds.as_numpy)
            | 'ToDict' >> beam.Map(dict)
            | 'ToTFExample' >> beam.Map(utils.dict_to_example)
            )

class TFDSExecutor(BaseExampleGenExecutor):
  def GetInputSourceToExamplePTransform(self) -> beam.PTransform:
    """Returns PTransform for TF Dataset to TF examples."""
    return _TFDatasetToExample

For this demonstration, we're using a subset of the IMDb reviews dataset, representing 20% of the total data. This allows for a more manageable training process. You can modify the "custom_config" settings to experiment with larger amounts of data, up to the full dataset, depending on your computational resources.

example_gen = FileBasedExampleGen(
    input_base='dummy',
    custom_config={'dataset':'imdb_reviews', 'split':'train[:20%]'},
    custom_executor_spec=executor_spec.BeamExecutorSpec(TFDSExecutor))
context.run(example_gen, enable_cache=False)
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.

We've developed a handy utility for examining datasets composed of TFExamples. When used with the reviews dataset, this tool returns a clear dictionary containing both the text and the corresponding label.

def inspect_examples(component,
                     channel_name='examples',
                     split_name='train',
                     num_examples=1):
  # Get the URI of the output artifact, which is a directory
  full_split_name = 'Split-{}'.format(split_name)
  print('channel_name: {}, split_name: {} (\"{}\"), num_examples: {}\n'.format(
      channel_name, split_name, full_split_name, num_examples))
  train_uri = os.path.join(
      component.outputs[channel_name].get()[0].uri, full_split_name)
  print('train_uri: {}'.format(train_uri))

  # Get the list of files in this directory (all compressed TFRecord files)
  tfrecord_filenames = [os.path.join(train_uri, name)
                        for name in os.listdir(train_uri)]

  # Create a `TFRecordDataset` to read these files
  dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")

  # Iterate over the records and print them
  print()
  for tfrecord in dataset.take(num_examples):
    serialized_example = tfrecord.numpy()
    example = tf.train.Example()
    example.ParseFromString(serialized_example)
    pp.pprint(example)
inspect_examples(example_gen, num_examples=1, split_name='eval')
channel_name: examples, split_name: eval ("Split-eval"), num_examples: 1

train_uri: /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/FileBasedExampleGen/examples/1/Split-eval

features {
  feature {
    key: "label"
    value {
    }
  }
  feature {
    key: "text"
    value {
      bytes_list {
        value: "This was an absolutely terrible movie. Don\'t be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie\'s ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor\'s like Christopher Walken\'s good name. I could barely sit through it."
      }
    }
  }
}

StatisticsGen

StatisticsGen component computes statistics over your dataset for data analysis, such as the number of examples, the number of features, and the data types of the features. It uses the TensorFlow Data Validation library. StatisticsGen takes as input the dataset we just ingested using ExampleGen.

Note that the statistics generator is appropriate for tabular data, and therefore, text dataset for this LLM tutorial may not be the optimal dataset for the analysis with statistics generator.

from tfx.components import StatisticsGen
statistics_gen = tfx.components.StatisticsGen(
    examples=example_gen.outputs['examples'], exclude_splits=['eval']
)
context.run(statistics_gen, enable_cache=False)
context.show(statistics_gen.outputs['statistics'])

SchemaGen

The SchemaGen component generates a schema based on your data statistics. (A schema defines the expected bounds, types, and properties of the features in your dataset.) It also uses the TensorFlow Data Validation library.

SchemaGen will take as input the statistics that we generated with StatisticsGen, looking at the training split by default.

schema_gen = tfx.components.SchemaGen(
    statistics=statistics_gen.outputs['statistics'],
    infer_feature_shape=False,
    exclude_splits=['eval'],
)
context.run(schema_gen, enable_cache=False)
context.show(schema_gen.outputs['schema'])

ExampleValidator

The ExampleValidator component detects anomalies in your data, based on the expectations defined by the schema. It also uses the TensorFlow Data Validation library.

ExampleValidator will take as input the statistics from StatisticsGen, and the schema from SchemaGen.

example_validator = tfx.components.ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_gen.outputs['schema'],
    exclude_splits=['eval'],
)
context.run(example_validator, enable_cache=False)

After ExampleValidator finishes running, we can visualize the anomalies as a table.

context.show(example_validator.outputs['anomalies'])

Transform

For a structured and repeatable design of a TFX pipeline we will need a scalable approach to feature engineering. The Transform component performs feature engineering for both training and serving. It uses the TensorFlow Transform library.

The Transform component uses a module file to supply user code for the feature engineering what we want to do, so our first step is to create that module file. We will only be working with the summary field.

import os
if not os.path.exists("modules"):
  os.mkdir("modules")
_transform_module_file = 'modules/_transform_module.py'
%%writefile {_transform_module_file}

import tensorflow as tf

def _fill_in_missing(x, default_value):
  """Replace missing values in a SparseTensor.

  Fills in missing values of `x` with the default_value.

  Args:
    x: A `SparseTensor` of rank 2.  Its dense shape should have size at most 1
      in the second dimension.
    default_value: the value with which to replace the missing values.

  Returns:
    A rank 1 tensor where missing values of `x` have been filled in.
  """
  if not isinstance(x, tf.sparse.SparseTensor):
    return x
  return tf.squeeze(
      tf.sparse.to_dense(
          tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
          default_value),
      axis=1)

def preprocessing_fn(inputs):
  outputs = {}
  # outputs["summary"] = _fill_in_missing(inputs["summary"],"")
  outputs["summary"] = _fill_in_missing(inputs["text"],"")
  return outputs
Writing modules/_transform_module.py
preprocessor = tfx.components.Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file=os.path.abspath(_transform_module_file))
context.run(preprocessor, enable_cache=False)
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying _transform_module.py -> build/lib
installing to /tmpfs/tmp/tmpng5n_dum
running install
running install_lib
copying build/lib/_transform_module.py -> /tmpfs/tmp/tmpng5n_dum
running install_egg_info
running egg_info
creating tfx_user_code_Transform.egg-info
writing tfx_user_code_Transform.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Transform.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Transform.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
reading manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
Copying tfx_user_code_Transform.egg-info to /tmpfs/tmp/tmpng5n_dum/tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmpng5n_dum/tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29.dist-info/WHEEL
creating '/tmpfs/tmp/tmpyps6sws4/tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29-py3-none-any.whl' and adding '/tmpfs/tmp/tmpng5n_dum' to it
adding '_transform_module.py'
adding 'tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29.dist-info/METADATA'
adding 'tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29.dist-info/WHEEL'
adding 'tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29.dist-info/top_level.txt'
adding 'tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29.dist-info/RECORD'
removing /tmpfs/tmp/tmpng5n_dum
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
Processing /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/_wheels/tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29-py3-none-any.whl
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29
Processing /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/_wheels/tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29-py3-none-any.whl
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29
Processing /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/_wheels/tfx_user_code_Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29-py3-none-any.whl
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+1c2de159578a848ee3ef5ef0cff9839b74e136b49dc22db018288c0f9abdee29
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/Transform/transform_graph/5/.temp_path/tftransform_tmp/16aec2c799b44aacabe0e367f06d0a6e/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/Transform/transform_graph/5/.temp_path/tftransform_tmp/16aec2c799b44aacabe0e367f06d0a6e/assets
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.

Let's take a look at some of the transformed examples and check that they are indeed processed as intended.

def pprint_examples(artifact, n_examples=2):
  print("artifact:", artifact, "\n")
  uri = os.path.join(artifact.uri, "Split-eval")
  print("uri:", uri, "\n")
  tfrecord_filenames = [os.path.join(uri, name) for name in os.listdir(uri)]
  print("tfrecord_filenames:", tfrecord_filenames, "\n")
  dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
  for tfrecord in dataset.take(n_examples):
    serialized_example = tfrecord.numpy()
    example = tf.train.Example.FromString(serialized_example)
    pp.pprint(example)
pprint_examples(preprocessor.outputs['transformed_examples'].get()[0])
artifact: Artifact(artifact: id: 6
type_id: 14
uri: "/tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/Transform/transformed_examples/5"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "name"
  value {
    string_value: "transformed_examples:2024-06-19T10:25:14.991872"
  }
}
custom_properties {
  key: "producer_component"
  value {
    string_value: "Transform"
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
name: "transformed_examples:2024-06-19T10:25:14.991872"
, artifact_type: id: 14
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
) 

uri: /tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/Transform/transformed_examples/5/Split-eval 

tfrecord_filenames: ['/tmpfs/tmp/tfx-interactive-2024-06-19T10_25_02.997883-vluhbwr6/Transform/transformed_examples/5/Split-eval/transformed_examples-00000-of-00001.gz'] 

features {
  feature {
    key: "summary"
    value {
      bytes_list {
        value: "This was an absolutely terrible movie. Don\'t be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie\'s ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor\'s like Christopher Walken\'s good name. I could barely sit through it."
      }
    }
  }
}

features {
  feature {
    key: "summary"
    value {
      bytes_list {
        value: "This is the kind of film for a snowy Sunday afternoon when the rest of the world can go ahead with its own business as you descend into a big arm-chair and mellow for a couple of hours. Wonderful performances from Cher and Nicolas Cage (as always) gently row the plot along. There are no rapids to cross, no dangerous waters, just a warm and witty paddle through New York life at its best. A family film in every sense and one that deserves the praise it received."
      }
    }
  }
}

Trainer

Trainer component trains an ML model, and it requires a model definition code from users.

The run_fn function in TFX's Trainer component is the entry point for training a machine learning model. It is a user-supplied function that takes in a set of arguments and returns a model artifact.

The run_fn function is responsible for:

  • Building the machine learning model.
  • Training the model on the training data.
  • Saving the trained model to the serving model directory.

Write model training code

We will create a very simple fine-tuned model, with the preprocessing GPT-2 model. First, we need to create a module that contains the run_fn function for TFX Trainer because TFX Trainer expects the run_fn function to be defined in a module.

model_file = "modules/model.py"
model_fn = "modules.model.run_fn"

Now, we write the run_fn function:

This run_fn function first gets the training data from the fn_args.examples argument. It then gets the schema of the training data from the fn_args.schema argument. Next, it loads finetuned GPT-2 model along with its preprocessor. The model is then trained on the training data using the model.train() method. Finally, the trained model weights are saved to the fn_args.serving_model_dir argument.

Now, we are going to work with Keras NLP's GPT-2 Model! You can learn about the full GPT-2 model implementation in KerasNLP on GitHub or can read and interactively test the model on Google IO2023 colab notebook.

import keras_nlp
import keras
import tensorflow as tf
%%writefile {model_file}

import os
import time
from absl import logging
import keras_nlp
import more_itertools
import pandas as pd
import tensorflow as tf
import keras
import tfx
import tfx.components.trainer.fn_args_utils
import gc


_EPOCH = 1
_BATCH_SIZE = 20
_INITIAL_LEARNING_RATE = 5e-5
_END_LEARNING_RATE = 0.0
_SEQUENCE_LENGTH = 128 # default value is 256

def _input_fn(file_pattern: str) -> list:
  """Retrieves training data and returns a list of articles for training.

  For each row in the TFRecordDataset, generated in the previous ExampleGen
  component, create a new tf.train.Example object and parse the TFRecord into
  the example object. Articles, which are initially in bytes objects, are
  decoded into a string.

  Args:
    file_pattern: Path to the TFRecord file of the training dataset.

  Returns:
    A list of training articles.

  Raises:
    FileNotFoundError: If TFRecord dataset is not found in the file_pattern
    directory.
  """

  if os.path.basename(file_pattern) == '*':
    file_loc = os.path.dirname(file_pattern)

  else:
    raise FileNotFoundError(
        f"There is no file in the current directory: '{file_pattern}."
    )

  file_paths = [os.path.join(file_loc, name) for name in os.listdir(file_loc)]
  train_articles = []
  parsed_dataset = tf.data.TFRecordDataset(file_paths, compression_type="GZIP")
  for raw_record in parsed_dataset:
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    train_articles.append(
        example.features.feature["summary"].bytes_list.value[0].decode('utf-8')
    )
  return train_articles

def run_fn(fn_args: tfx.components.trainer.fn_args_utils.FnArgs) -> None:
  """Trains the model and outputs the trained model to a the desired location given by FnArgs.

  Args:
    FnArgs :  Args to pass to user defined training/tuning function(s)
  """

  train_articles =  pd.Series(_input_fn(
          fn_args.train_files[0],
      ))
  tf_train_ds = tf.data.Dataset.from_tensor_slices(train_articles)

  gpt2_preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
      'gpt2_base_en',
      sequence_length=_SEQUENCE_LENGTH,
      add_end_token=True,
  )
  gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
      'gpt2_base_en', preprocessor=gpt2_preprocessor
  )

  processed_ds = (
      tf_train_ds
      .batch(_BATCH_SIZE)
      .cache()
      .prefetch(tf.data.AUTOTUNE)
  )

  gpt2_lm.include_preprocessing = False

  lr = tf.keras.optimizers.schedules.PolynomialDecay(
      5e-5,
      decay_steps=processed_ds.cardinality() * _EPOCH,
      end_learning_rate=0.0,
  )
  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

  gpt2_lm.compile(
      optimizer=keras.optimizers.Adam(lr),
      loss=loss,
      weighted_metrics=['accuracy'],
  )

  gpt2_lm.fit(processed_ds, epochs=_EPOCH)
  if os.path.exists(fn_args.serving_model_dir):
    os.rmdir(fn_args.serving_model_dir)
  os.mkdir(fn_args.serving_model_dir)
  gpt2_lm.save_weights(
      filepath=os.path.join(fn_args.serving_model_dir, "model_weights.weights.h5")
  )
  del gpt2_lm, gpt2_preprocessor, processed_ds, tf_train_ds
  gc.collect()
Writing modules/model.py
trainer = tfx.components.Trainer(
    run_fn=model_fn,
    examples=preprocessor.outputs['transformed_examples'],
    train_args=tfx.proto.TrainArgs(splits=['train']),
    eval_args=tfx.proto.EvalArgs(splits=['train']),
    schema=schema_gen.outputs['schema'],
)
context.run(trainer, enable_cache=False)
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/metadata.json...
100%|██████████| 141/141 [00:00<00:00, 263kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/preprocessor.json...
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/tokenizer.json...
100%|██████████| 448/448 [00:00<00:00, 822kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/assets/tokenizer/vocabulary.json...
100%|██████████| 0.99M/0.99M [00:00<00:00, 5.48MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/assets/tokenizer/merges.txt...
100%|██████████| 446k/446k [00:00<00:00, 2.91MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/task.json...
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/config.json...
100%|██████████| 484/484 [00:00<00:00, 703kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/model.weights.h5...
100%|██████████| 475M/475M [00:11<00:00, 44.3MB/s]
2024-06-19 10:26:26.872190: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1718792868.731678  163724 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
165/166 ━━━━━━━━━━━━━━━━━━━━ 0s 316ms/step - accuracy: 0.3164 - loss: 3.6999
2024-06-19 10:28:42.110354: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert
166/166 ━━━━━━━━━━━━━━━━━━━━ 251s 804ms/step - accuracy: 0.3164 - loss: 3.6994
W0000 00:00:1718793001.686917  163732 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update

Inference and Evaluation

With our model fine-tuned, let's evaluate its performance by generating inferences. To capture and preserve these results, we'll create an EvaluationMetric artifact.

from tfx.types import artifact
from tfx import types

Property = artifact.Property
PropertyType = artifact.PropertyType

DURATION_PROPERTY = Property(type=PropertyType.FLOAT)
EVAL_OUTPUT_PROPERTY = Property(type=PropertyType.STRING)

class EvaluationMetric(types.Artifact):
  """Artifact that contains metrics for a model.

  * Properties:

     - 'model_prediction_time' : time it took for the model to make predictions
     based on the input text.
     - 'model_evaluation_output_path' : saves the path to the CSV file that
     contains the model's prediction based on the testing inputs.
  """
  TYPE_NAME = 'Evaluation_Metric'
  PROPERTIES = {
      'model_prediction_time': DURATION_PROPERTY,
      'model_evaluation_output_path': EVAL_OUTPUT_PROPERTY,
  }

These helper functions contribute to the evaluation of a language model (LLM) by providing tools for calculating perplexity, a key metric reflecting the model's ability to predict the next word in a sequence, and by facilitating the extraction, preparation, and processing of evaluation data. The input_fn function retrieves training data from a specified TFRecord file, while the trim_sentence function ensures consistency by limiting sentence length. A lower perplexity score indicates higher prediction confidence and generally better model performance, making these functions essential for comprehensive evaluation within the LLM pipeline.

"""This is an evaluation component for the LLM pipeline takes in a
standard trainer artifact and outputs a custom evaluation artifact.
It displays the evaluation output in the colab notebook.
"""
import os
import time
import keras_nlp
import numpy as np
import pandas as pd
import tensorflow as tf
import tfx.v1 as tfx

def input_fn(file_pattern: str) -> list:
  """Retrieves training data and returns a list of articles for training.

  Args:
    file_pattern: Path to the TFRecord file of the training dataset.

  Returns:
    A list of test articles

  Raises:
    FileNotFoundError: If the file path does not exist.
  """
  if os.path.exists(file_pattern):
    file_paths = [os.path.join(file_pattern, name) for name in os.listdir(file_pattern)]
    test_articles = []
    parsed_dataset = tf.data.TFRecordDataset(file_paths, compression_type="GZIP")
    for raw_record in parsed_dataset:
      example = tf.train.Example()
      example.ParseFromString(raw_record.numpy())
      test_articles.append(
          example.features.feature["summary"].bytes_list.value[0].decode('utf-8')
      )
    return test_articles
  else:
    raise FileNotFoundError(f'File path "{file_pattern}" does not exist.')

def trim_sentence(sentence: str, max_words: int = 20):
  """Trims the sentence to include up to the given number of words.

  Args:
    sentence: The sentence to trim.
    max_words: The maximum number of words to include in the trimmed sentence.

  Returns:
    The trimmed sentence.
  """
  words = sentence.split(' ')
  if len(words) <= max_words:
    return sentence
  return ' '.join(words[:max_words])

perplexity.png

One of the useful metrics for evaluating a Large Language Model is Perplexity. Perplexity is a measure of how well a language model predicts the next token in a sequence. It is calculated by taking the exponentiation of the average negative log-likelihood of the next token. A lower perplexity score indicates that the language model is better at predicting the next token.

This is the formula for calculating perplexity.

\(\text{Perplexity} = \exp(-1 * \) Average Negative Log Likelihood $) = \exp\left(-\frac{1}{T} \sum_{t=1}^T \log p(wt | w{<t})\right)$.

In this colab notebook, we calculate perplexity using keras_nlp's perplexity.

Computing Perplexity for Base GPT-2 Model and Finetuned Model

The code below is the function which will be used later in the notebook for computing perplexity for the base GPT-2 model and the finetuned model.

def calculate_perplexity(gpt2_model, gpt2_tokenizer, sentence) -> int:
  """Calculates perplexity of a model given a sentence.

  Args:
    gpt2_model: GPT-2 Language Model
    gpt2_tokenizer: A GPT-2 tokenizer using Byte-Pair Encoding subword segmentation.
    sentence: Sentence that the model's perplexity is calculated upon.

  Returns:
    A perplexity score.
  """
  # gpt2_tokenizer([sentence])[0] produces a tensor containing an array of tokens that form the sentence.
  tokens = gpt2_tokenizer([sentence])[0].numpy()
  # decoded_sentences is an array containing sentences that increase by one token in size.
  # e.g. if tokens for a sentence "I love dogs" are ["I", "love", "dogs"], then decoded_sentences = ["I love", "I love dogs"]
  decoded_sentences = [gpt2_tokenizer.detokenize([tokens[:i]])[0].numpy() for i in range(1, len(tokens))]
  predictions = gpt2_model.predict(decoded_sentences)
  logits = [predictions[i - 1][i] for i in range(1, len(tokens))]
  target = tokens[1:].reshape(len(tokens) - 1, 1)
  perplexity = keras_nlp.metrics.Perplexity(from_logits=True)
  perplexity.update_state(target, logits)
  result = perplexity.result()
  return result.numpy()

def average_perplexity(gpt2_model, gpt2_tokenizer, sentences):
  perplexity_lst = [calculate_perplexity(gpt2_model, gpt2_tokenizer, sent) for sent in sentences]
  return np.mean(perplexity_lst)

Evaluator

Having established the necessary helper functions for evaluation, we proceed to define the Evaluator component. This component facilitates model inference using both base and fine-tuned models, computes perplexity scores for all models, and measures inference time. The Evaluator's output provides comprehensive insights for a thorough comparison and assessment of each model's performance.

@tfx.dsl.components.component
def Evaluator(
    examples: tfx.dsl.components.InputArtifact[
        tfx.types.standard_artifacts.Examples
    ],
    trained_model: tfx.dsl.components.InputArtifact[
        tfx.types.standard_artifacts.Model
    ],
    max_length: tfx.dsl.components.Parameter[int],
    evaluation: tfx.dsl.components.OutputArtifact[EvaluationMetric],
) -> None:
  """Makes inferences with base model, finetuned model, TFlite model, and quantized model.

  Args:
    examples: Standard TFX examples artifacts for retreiving test dataset.
    trained_model: Standard TFX trained model artifact finetuned with imdb-reviews
      dataset.
    tflite_model: Unquantized TFLite model.
    quantized_model: Quantized TFLite model.
    max_length: Length of the text that the model generates given custom input
      statements.
    evaluation: An evaluation artifact that saves predicted outcomes of custom
      inputs in a csv document and inference speed of the model.
  """
  _TEST_SIZE = 10
  _INPUT_LENGTH = 10
  _SEQUENCE_LENGTH = 128

  path = os.path.join(examples.uri, 'Split-eval')
  test_data = input_fn(path)
  evaluation_inputs = [
      trim_sentence(article, max_words=_INPUT_LENGTH)
      for article in test_data[:_TEST_SIZE]
  ]
  true_test = [
      trim_sentence(article, max_words=max_length)
      for article in test_data[:_TEST_SIZE]
  ]

  # Loading base model, making inference, and calculating perplexity on the base model.
  gpt2_preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
      'gpt2_base_en',
      sequence_length=_SEQUENCE_LENGTH,
      add_end_token=True,
  )
  gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
      'gpt2_base_en', preprocessor=gpt2_preprocessor
  )
  gpt2_tokenizer = keras_nlp.models.GPT2Tokenizer.from_preset('gpt2_base_en')

  base_average_perplexity = average_perplexity(
      gpt2_lm, gpt2_tokenizer, true_test
  )

  start_base_model = time.time()
  base_evaluation = [
      gpt2_lm.generate(input, max_length)
      for input in evaluation_inputs
  ]
  end_base_model = time.time()

  # Loading finetuned model and making inferences with the finetuned model.
  model_weights_path = os.path.join(
      trained_model.uri, "Format-Serving", "model_weights.weights.h5"
  )
  gpt2_lm.load_weights(model_weights_path)

  trained_model_average_perplexity = average_perplexity(
      gpt2_lm, gpt2_tokenizer, true_test
  )

  start_trained = time.time()
  trained_evaluation = [
      gpt2_lm.generate(input, max_length)
      for input in evaluation_inputs
  ]
  end_trained = time.time()

  # Building an inference table.
  inference_data = {
      'input': evaluation_inputs,
      'actual_test_output': true_test,
      'base_model_prediction': base_evaluation,
      'trained_model_prediction': trained_evaluation,
  }

  models = [
      'Base Model',
      'Finetuned Model',
  ]
  inference_time = [
      (end_base_model - start_base_model),
      (end_trained - start_trained),
  ]
  average_inference_time = [time / _TEST_SIZE for time in inference_time]
  average_perplexity_lst = [
      base_average_perplexity,
      trained_model_average_perplexity,
  ]
  evaluation_data = {
      'Model': models,
      'Average Inference Time (sec)': average_inference_time,
      'Average Perplexity': average_perplexity_lst,
  }

  # creating directory in examples artifact to save metric dataframes
  metrics_path = os.path.join(evaluation.uri, 'metrics')
  if not os.path.exists(metrics_path):
      os.mkdir(metrics_path)

  evaluation_df = pd.DataFrame(evaluation_data).set_index('Model').transpose()
  evaluation_path = os.path.join(metrics_path, 'evaluation_output.csv')
  evaluation_df.to_csv(evaluation_path)

  inference_df = pd.DataFrame(inference_data)
  inference_path = os.path.join(metrics_path, 'inference_output.csv')
  inference_df.to_csv(inference_path)
  evaluation.model_evaluation_output_path = inference_path
evaluator = Evaluator(examples = preprocessor.outputs['transformed_examples'],
                      trained_model = trainer.outputs['model'],
                      max_length = 50)
context.run(evaluator, enable_cache = False)
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/preprocessor.json...
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/task.json...
W0000 00:00:1718793025.659453  163715 triton_autotuner.cc:656] Slow kernel for triton_gemm_dot.72 took: 1.337449829s. config: block_m: 16 block_n: 16 block_k: 512 split_k: 1 num_stages: 1 num_warps: 4
1/2 ━━━━━━━━━━━━━━━━━━━━ 16s 17s/step
W0000 00:00:1718793030.476132  163715 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
W0000 00:00:1718793040.725059  163713 triton_autotuner.cc:656] Slow kernel for triton_gemm_dot.72 took: 1.18693603525s. config: block_m: 16 block_n: 16 block_k: 512 split_k: 1 num_stages: 1 num_warps: 4
2/2 ━━━━━━━━━━━━━━━━━━━━ 32s 15s/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 12s 11s/step
1/2 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step
W0000 00:00:1718793078.531160  163728 triton_autotuner.cc:656] Slow kernel for triton_gemm_dot.72 took: 1.10094506825s. config: block_m: 16 block_n: 16 block_k: 512 split_k: 1 num_stages: 1 num_warps: 4
2/2 ━━━━━━━━━━━━━━━━━━━━ 13s 12s/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 13s 11s/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 175ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 10s 5s/step
W0000 00:00:1718793124.131074  163718 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
3/3 ━━━━━━━━━━━━━━━━━━━━ 9s 4s/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 8s 4s/step
W0000 00:00:1718793152.631304  163718 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
3/3 ━━━━━━━━━━━━━━━━━━━━ 11s 5s/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 1s 166ms/step
W0000 00:00:1718793202.224897  163525 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/saving/saving_lib.py:415: UserWarning: Skipping variable loading for optimizer 'loss_scale_optimizer', because it has 4 variables whereas the saved optimizer has 397 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/saving/saving_lib.py:415: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 393 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 168ms/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 167ms/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 177ms/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 180ms/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 174ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 1s 162ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 1s 163ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 1s 151ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 1s 160ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 1s 159ms/step

Evaluator Results

Once our evaluation component execution is completed, we will load the evaluation metrics from evaluator URI and display them.

Perplexity Calculation: Perplexity is only one of many ways to evaluate LLMs. LLM evaluation is an active research topic and a comprehensive treatment is beyond the scope of this notebook.

evaluation_path = os.path.join(evaluator.outputs['evaluation']._artifacts[0].uri, 'metrics')
inference_df = pd.read_csv(os.path.join(evaluation_path, 'inference_output.csv'), index_col=0)
evaluation_df = pd.read_csv(os.path.join(evaluation_path, 'evaluation_output.csv'), index_col=0)

The fine-tuned GPT-2 model exhibits a slight improvement in perplexity compared to the baseline model. Further training with more epochs or a larger dataset may yield more substantial perplexity reductions.

from IPython import display
display.display(display.HTML(inference_df.to_html()))
display.display(display.HTML(evaluation_df.to_html()))

Running the Entire Pipeline

TFX supports multiple orchestrators to run pipelines. In this tutorial we will use LocalDagRunner which is included in the TFX Python package and runs pipelines on local environment. We often call TFX pipelines "DAGs" which stands for directed acyclic graph.

LocalDagRunner provides fast iterations for development and debugging. TFX also supports other orchestrators including Kubeflow Pipelines and Apache Airflow which are suitable for production use cases. See TFX on Cloud AI Platform Pipelines or TFX Airflow Tutorial to learn more about other orchestration systems.

Now we create a LocalDagRunner and pass a Pipeline object created from the function we already defined. The pipeline runs directly and you can see logs for the progress of the pipeline including ML model training.

import urllib.request
import tempfile
import os

PIPELINE_NAME = "tfx-llm-imdb-reviews"
model_fn = "modules.model.run_fn"
_transform_module_file = "modules/_transform_module.py"

# Output directory to store artifacts generated from the pipeline.
PIPELINE_ROOT = os.path.join('pipelines', PIPELINE_NAME)
# Path to a SQLite DB file to use as an MLMD storage.
METADATA_PATH = os.path.join('metadata', PIPELINE_NAME, 'metadata.db')
# Output directory where created models from the pipeline will be exported.
SERVING_MODEL_DIR = os.path.join('serving_model', PIPELINE_NAME)

from absl import logging
logging.set_verbosity(logging.INFO)  # Set default logging level.
def _create_pipeline(
    pipeline_name: str,
    pipeline_root: str,
    model_fn: str,
    serving_model_dir: str,
    metadata_path: str,
) -> tfx.dsl.Pipeline:
  """Creates a Pipeline for Fine-Tuning and Converting an Large Language Model with TFX."""

  example_gen = FileBasedExampleGen(
    input_base='dummy',
    custom_config={'dataset':'imdb_reviews', 'split':'train[:5%]'},
    custom_executor_spec=executor_spec.BeamExecutorSpec(TFDSExecutor))

  statistics_gen = tfx.components.StatisticsGen(
      examples=example_gen.outputs['examples'], exclude_splits=['eval']
  )

  schema_gen = tfx.components.SchemaGen(
      statistics=statistics_gen.outputs['statistics'],
      infer_feature_shape=False,
      exclude_splits=['eval'],
  )

  example_validator = tfx.components.ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_gen.outputs['schema'],
      exclude_splits=['eval'],
  )

  preprocessor = tfx.components.Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file= _transform_module_file,
  )

  trainer = tfx.components.Trainer(
      run_fn=model_fn,
      examples=preprocessor.outputs['transformed_examples'],
      train_args=tfx.proto.TrainArgs(splits=['train']),
      eval_args=tfx.proto.EvalArgs(splits=['train']),
      schema=schema_gen.outputs['schema'],
  )


  evaluator = Evaluator(
      examples=preprocessor.outputs['transformed_examples'],
      trained_model=trainer.outputs['model'],
      max_length=50,
  )

  # Following 7 components will be included in the pipeline.
  components = [
      example_gen,
      statistics_gen,
      schema_gen,
      example_validator,
      preprocessor,
      trainer,
      evaluator,
  ]

  return tfx.dsl.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(
          metadata_path
      ),
      components=components,
  )
tfx.orchestration.LocalDagRunner().run(
    _create_pipeline(
        pipeline_name=PIPELINE_NAME,
        pipeline_root=PIPELINE_ROOT,
        model_fn=model_fn,
        serving_model_dir=SERVING_MODEL_DIR,
        metadata_path=METADATA_PATH,
    )
)
INFO:absl:Generating ephemeral wheel package for '/tmpfs/src/temp/docs/tutorials/tfx/modules/_transform_module.py' (including modules: ['model', '_transform_module']).
INFO:absl:User module package has hash fingerprint version a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '/tmpfs/tmp/tmpizd25738/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmpfs/tmp/tmpji_h09mu', '--dist-dir', '/tmpfs/tmp/tmp9w684xql']
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
INFO:absl:Successfully built user code wheel distribution at 'pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl'; target user module is '_transform_module'.
INFO:absl:Full user module path is '_transform_module@pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl'
INFO:absl:Using deployment config:
 executor_specs {
  key: "Evaluator"
  value {
    python_class_executable_spec {
      class_path: "__main__.Evaluator_Executor"
    }
  }
}
executor_specs {
  key: "ExampleValidator"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.example_validator.executor.Executor"
    }
  }
}
executor_specs {
  key: "FileBasedExampleGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "__main__.TFDSExecutor"
      }
    }
  }
}
executor_specs {
  key: "SchemaGen"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.schema_gen.executor.Executor"
    }
  }
}
executor_specs {
  key: "StatisticsGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.statistics_gen.executor.Executor"
      }
    }
  }
}
executor_specs {
  key: "Trainer"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.trainer.executor.GenericExecutor"
    }
  }
}
executor_specs {
  key: "Transform"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.transform.executor.Executor"
      }
    }
  }
}
custom_driver_specs {
  key: "FileBasedExampleGen"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.example_gen.driver.FileBasedDriver"
    }
  }
}
metadata_connection_config {
  database_connection_config {
    sqlite {
      filename_uri: "metadata/tfx-llm-imdb-reviews/metadata.db"
      connection_mode: READWRITE_OPENCREATE
    }
  }
}

INFO:absl:Using connection config:
 sqlite {
  filename_uri: "metadata/tfx-llm-imdb-reviews/metadata.db"
  connection_mode: READWRITE_OPENCREATE
}

INFO:absl:Component FileBasedExampleGen is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.components.example_gen.component.FileBasedExampleGen"
  }
  id: "FileBasedExampleGen"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.FileBasedExampleGen"
      }
    }
  }
}
outputs {
  outputs {
    key: "examples"
    value {
      artifact_spec {
        type {
          name: "Examples"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          properties {
            key: "version"
            value: INT
          }
          base_type: DATASET
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "custom_config"
    value {
      field_value {
        string_value: "{\"dataset\": \"imdb_reviews\", \"split\": \"train[:5%]\"}"
      }
    }
  }
  parameters {
    key: "input_base"
    value {
      field_value {
        string_value: "dummy"
      }
    }
  }
  parameters {
    key: "input_config"
    value {
      field_value {
        string_value: "{\n  \"splits\": [\n    {\n      \"name\": \"single_split\",\n      \"pattern\": \"*\"\n    }\n  ]\n}"
      }
    }
  }
  parameters {
    key: "output_config"
    value {
      field_value {
        string_value: "{\n  \"split_config\": {\n    \"splits\": [\n      {\n        \"hash_buckets\": 2,\n        \"name\": \"train\"\n      },\n      {\n        \"hash_buckets\": 1,\n        \"name\": \"eval\"\n      }\n    ]\n  }\n}"
      }
    }
  }
  parameters {
    key: "output_data_format"
    value {
      field_value {
        int_value: 6
      }
    }
  }
  parameters {
    key: "output_file_format"
    value {
      field_value {
        int_value: 5
      }
    }
  }
}
downstream_nodes: "StatisticsGen"
downstream_nodes: "Transform"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
INFO:absl:[FileBasedExampleGen] Resolved inputs: ({},)
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying model.py -> build/lib
copying _transform_module.py -> build/lib
installing to /tmpfs/tmp/tmpji_h09mu
running install
running install_lib
copying build/lib/model.py -> /tmpfs/tmp/tmpji_h09mu
copying build/lib/_transform_module.py -> /tmpfs/tmp/tmpji_h09mu
running install_egg_info
running egg_info
creating tfx_user_code_Transform.egg-info
writing tfx_user_code_Transform.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Transform.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Transform.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
reading manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
Copying tfx_user_code_Transform.egg-info to /tmpfs/tmp/tmpji_h09mu/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmpji_h09mu/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446.dist-info/WHEEL
creating '/tmpfs/tmp/tmp9w684xql/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl' and adding '/tmpfs/tmp/tmpji_h09mu' to it
adding '_transform_module.py'
adding 'model.py'
adding 'tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446.dist-info/METADATA'
adding 'tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446.dist-info/WHEEL'
adding 'tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446.dist-info/top_level.txt'
adding 'tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446.dist-info/RECORD'
removing /tmpfs/tmp/tmpji_h09mu
INFO:absl:select span and version = (0, None)
INFO:absl:latest span and version = (0, None)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 1
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=1, input_dict={}, output_dict=defaultdict(<class 'list'>, {'examples': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/FileBasedExampleGen/examples/1"
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:0,total_bytes:0,xor_checksum:0,sum_checksum:0"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
, artifact_type: name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)]}), exec_properties={'custom_config': '{"dataset": "imdb_reviews", "split": "train[:5%]"}', 'output_file_format': 5, 'input_config': '{\n  "splits": [\n    {\n      "name": "single_split",\n      "pattern": "*"\n    }\n  ]\n}', 'output_config': '{\n  "split_config": {\n    "splits": [\n      {\n        "hash_buckets": 2,\n        "name": "train"\n      },\n      {\n        "hash_buckets": 1,\n        "name": "eval"\n      }\n    ]\n  }\n}', 'output_data_format': 6, 'input_base': 'dummy', 'span': 0, 'version': None, 'input_fingerprint': 'split:single_split,num_files:0,total_bytes:0,xor_checksum:0,sum_checksum:0'}, execution_output_uri='pipelines/tfx-llm-imdb-reviews/FileBasedExampleGen/.system/executor_execution/1/executor_output.pb', stateful_working_dir='pipelines/tfx-llm-imdb-reviews/FileBasedExampleGen/.system/stateful_working_dir/5102096e-f3a1-4010-98c4-796b5e6dc395', tmp_dir='pipelines/tfx-llm-imdb-reviews/FileBasedExampleGen/.system/executor_execution/1/.temp/', pipeline_node=node_info {
  type {
    name: "tfx.components.example_gen.component.FileBasedExampleGen"
  }
  id: "FileBasedExampleGen"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.FileBasedExampleGen"
      }
    }
  }
}
outputs {
  outputs {
    key: "examples"
    value {
      artifact_spec {
        type {
          name: "Examples"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          properties {
            key: "version"
            value: INT
          }
          base_type: DATASET
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "custom_config"
    value {
      field_value {
        string_value: "{\"dataset\": \"imdb_reviews\", \"split\": \"train[:5%]\"}"
      }
    }
  }
  parameters {
    key: "input_base"
    value {
      field_value {
        string_value: "dummy"
      }
    }
  }
  parameters {
    key: "input_config"
    value {
      field_value {
        string_value: "{\n  \"splits\": [\n    {\n      \"name\": \"single_split\",\n      \"pattern\": \"*\"\n    }\n  ]\n}"
      }
    }
  }
  parameters {
    key: "output_config"
    value {
      field_value {
        string_value: "{\n  \"split_config\": {\n    \"splits\": [\n      {\n        \"hash_buckets\": 2,\n        \"name\": \"train\"\n      },\n      {\n        \"hash_buckets\": 1,\n        \"name\": \"eval\"\n      }\n    ]\n  }\n}"
      }
    }
  }
  parameters {
    key: "output_data_format"
    value {
      field_value {
        int_value: 6
      }
    }
  }
  parameters {
    key: "output_file_format"
    value {
      field_value {
        int_value: 5
      }
    }
  }
}
downstream_nodes: "StatisticsGen"
downstream_nodes: "Transform"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "tfx-llm-imdb-reviews"
, pipeline_run_id='2024-06-19T10:34:33.931011', top_level_pipeline_run_id=None, frontend_url=None)
INFO:absl:Generating examples.
INFO:absl:No config specified, defaulting to config: imdb_reviews/plain_text
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/imdb_reviews/plain_text/1.0.0
INFO:absl:For 'imdb_reviews/plain_text/1.0.0': fields info.[description, release_notes, config_name, config_description, citation, splits, supervised_keys, module_name] differ on disk and in the code. Keeping the one from code.
INFO:absl:Reusing dataset imdb_reviews (gs://tensorflow-datasets/datasets/imdb_reviews/plain_text/1.0.0)
INFO:absl:No config specified, defaulting to config: imdb_reviews/plain_text
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/imdb_reviews/plain_text/1.0.0
INFO:absl:For 'imdb_reviews/plain_text/1.0.0': fields info.[description, release_notes, config_name, config_description, citation, splits, supervised_keys, module_name] differ on disk and in the code. Keeping the one from code.
INFO:absl:No config specified, defaulting to config: imdb_reviews/plain_text
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/imdb_reviews/plain_text/1.0.0
INFO:absl:For 'imdb_reviews/plain_text/1.0.0': fields info.[description, release_notes, config_name, config_description, citation, splits, supervised_keys, module_name] differ on disk and in the code. Keeping the one from code.
INFO:absl:No config specified, defaulting to config: imdb_reviews/plain_text
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/imdb_reviews/plain_text/1.0.0
INFO:absl:For 'imdb_reviews/plain_text/1.0.0': fields info.[description, release_notes, config_name, config_description, citation, splits, supervised_keys, module_name] differ on disk and in the code. Keeping the one from code.
INFO:absl:No config specified, defaulting to config: imdb_reviews/plain_text
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/imdb_reviews/plain_text/1.0.0
INFO:absl:For 'imdb_reviews/plain_text/1.0.0': fields info.[description, release_notes, config_name, config_description, citation, splits, supervised_keys, module_name] differ on disk and in the code. Keeping the one from code.
INFO:absl:No config specified, defaulting to config: imdb_reviews/plain_text
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/imdb_reviews/plain_text/1.0.0
INFO:absl:For 'imdb_reviews/plain_text/1.0.0': fields info.[description, release_notes, config_name, config_description, citation, splits, supervised_keys, module_name] differ on disk and in the code. Keeping the one from code.
INFO:absl:Constructing tf.data.Dataset imdb_reviews for split train[0shard], from gs://tensorflow-datasets/datasets/imdb_reviews/plain_text/1.0.0
INFO:absl:Examples generated.
INFO:absl:Value type <class 'NoneType'> of key version in exec_properties is not supported, going to drop it
INFO:absl:Value type <class 'list'> of key _beam_pipeline_args in exec_properties is not supported, going to drop it
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 1 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Deleted stateful_working_dir pipelines/tfx-llm-imdb-reviews/FileBasedExampleGen/.system/stateful_working_dir/5102096e-f3a1-4010-98c4-796b5e6dc395
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'examples': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/FileBasedExampleGen/examples/1"
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:0,total_bytes:0,xor_checksum:0,sum_checksum:0"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
, artifact_type: name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)]}) for execution 1
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component FileBasedExampleGen is finished.
INFO:absl:Component StatisticsGen is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.components.statistics_gen.component.StatisticsGen"
    base_type: PROCESS
  }
  id: "StatisticsGen"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.StatisticsGen"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "FileBasedExampleGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.FileBasedExampleGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "examples"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "statistics"
    value {
      artifact_spec {
        type {
          name: "ExampleStatistics"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          base_type: STATISTICS
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "exclude_splits"
    value {
      field_value {
        string_value: "[\"eval\"]"
      }
    }
  }
}
upstream_nodes: "FileBasedExampleGen"
downstream_nodes: "ExampleValidator"
downstream_nodes: "SchemaGen"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
WARNING:absl:ArtifactQuery.property_predicate is not supported.
INFO:absl:[StatisticsGen] Resolved inputs: ({'examples': [Artifact(artifact: id: 1
type_id: 15
uri: "pipelines/tfx-llm-imdb-reviews/FileBasedExampleGen/examples/1"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "file_format"
  value {
    string_value: "tfrecords_gzip"
  }
}
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:0,total_bytes:0,xor_checksum:0,sum_checksum:0"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "payload_format"
  value {
    string_value: "FORMAT_TF_EXAMPLE"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1718793279187
last_update_time_since_epoch: 1718793279187
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)]},)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 2
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=2, input_dict={'examples': [Artifact(artifact: id: 1
type_id: 15
uri: "pipelines/tfx-llm-imdb-reviews/FileBasedExampleGen/examples/1"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "file_format"
  value {
    string_value: "tfrecords_gzip"
  }
}
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:0,total_bytes:0,xor_checksum:0,sum_checksum:0"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "payload_format"
  value {
    string_value: "FORMAT_TF_EXAMPLE"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1718793279187
last_update_time_since_epoch: 1718793279187
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)]}, output_dict=defaultdict(<class 'list'>, {'statistics': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/StatisticsGen/statistics/2"
, artifact_type: name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)]}), exec_properties={'exclude_splits': '["eval"]'}, execution_output_uri='pipelines/tfx-llm-imdb-reviews/StatisticsGen/.system/executor_execution/2/executor_output.pb', stateful_working_dir='pipelines/tfx-llm-imdb-reviews/StatisticsGen/.system/stateful_working_dir/ce022665-b836-4c64-b1c5-4cd0900f76b5', tmp_dir='pipelines/tfx-llm-imdb-reviews/StatisticsGen/.system/executor_execution/2/.temp/', pipeline_node=node_info {
  type {
    name: "tfx.components.statistics_gen.component.StatisticsGen"
    base_type: PROCESS
  }
  id: "StatisticsGen"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.StatisticsGen"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "FileBasedExampleGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.FileBasedExampleGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "examples"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "statistics"
    value {
      artifact_spec {
        type {
          name: "ExampleStatistics"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          base_type: STATISTICS
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "exclude_splits"
    value {
      field_value {
        string_value: "[\"eval\"]"
      }
    }
  }
}
upstream_nodes: "FileBasedExampleGen"
downstream_nodes: "ExampleValidator"
downstream_nodes: "SchemaGen"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "tfx-llm-imdb-reviews"
, pipeline_run_id='2024-06-19T10:34:33.931011', top_level_pipeline_run_id=None, frontend_url=None)
INFO:absl:Generating statistics for split train.
INFO:absl:Statistics for split train written to pipelines/tfx-llm-imdb-reviews/StatisticsGen/statistics/2/Split-train.
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 2 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Deleted stateful_working_dir pipelines/tfx-llm-imdb-reviews/StatisticsGen/.system/stateful_working_dir/ce022665-b836-4c64-b1c5-4cd0900f76b5
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'statistics': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/StatisticsGen/statistics/2"
, artifact_type: name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)]}) for execution 2
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component StatisticsGen is finished.
INFO:absl:Component SchemaGen is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.components.schema_gen.component.SchemaGen"
    base_type: PROCESS
  }
  id: "SchemaGen"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.SchemaGen"
      }
    }
  }
}
inputs {
  inputs {
    key: "statistics"
    value {
      channels {
        producer_node_query {
          id: "StatisticsGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.StatisticsGen"
            }
          }
        }
        artifact_query {
          type {
            name: "ExampleStatistics"
            base_type: STATISTICS
          }
        }
        output_key: "statistics"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "schema"
    value {
      artifact_spec {
        type {
          name: "Schema"
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "exclude_splits"
    value {
      field_value {
        string_value: "[\"eval\"]"
      }
    }
  }
  parameters {
    key: "infer_feature_shape"
    value {
      field_value {
        int_value: 0
      }
    }
  }
}
upstream_nodes: "StatisticsGen"
downstream_nodes: "ExampleValidator"
downstream_nodes: "Trainer"
downstream_nodes: "Transform"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
WARNING:absl:ArtifactQuery.property_predicate is not supported.
INFO:absl:[SchemaGen] Resolved inputs: ({'statistics': [Artifact(artifact: id: 2
type_id: 17
uri: "pipelines/tfx-llm-imdb-reviews/StatisticsGen/statistics/2"
properties {
  key: "span"
  value {
    int_value: 0
  }
}
properties {
  key: "split_names"
  value {
    string_value: "[\"train\"]"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "stats_dashboard_link"
  value {
    string_value: ""
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "ExampleStatistics"
create_time_since_epoch: 1718793280732
last_update_time_since_epoch: 1718793280732
, artifact_type: id: 17
name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)]},)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 3
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=3, input_dict={'statistics': [Artifact(artifact: id: 2
type_id: 17
uri: "pipelines/tfx-llm-imdb-reviews/StatisticsGen/statistics/2"
properties {
  key: "span"
  value {
    int_value: 0
  }
}
properties {
  key: "split_names"
  value {
    string_value: "[\"train\"]"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "stats_dashboard_link"
  value {
    string_value: ""
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "ExampleStatistics"
create_time_since_epoch: 1718793280732
last_update_time_since_epoch: 1718793280732
, artifact_type: id: 17
name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)]}, output_dict=defaultdict(<class 'list'>, {'schema': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/SchemaGen/schema/3"
, artifact_type: name: "Schema"
)]}), exec_properties={'infer_feature_shape': 0, 'exclude_splits': '["eval"]'}, execution_output_uri='pipelines/tfx-llm-imdb-reviews/SchemaGen/.system/executor_execution/3/executor_output.pb', stateful_working_dir='pipelines/tfx-llm-imdb-reviews/SchemaGen/.system/stateful_working_dir/ef5fb07e-4c41-429e-a3c2-c11de8cd4be0', tmp_dir='pipelines/tfx-llm-imdb-reviews/SchemaGen/.system/executor_execution/3/.temp/', pipeline_node=node_info {
  type {
    name: "tfx.components.schema_gen.component.SchemaGen"
    base_type: PROCESS
  }
  id: "SchemaGen"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.SchemaGen"
      }
    }
  }
}
inputs {
  inputs {
    key: "statistics"
    value {
      channels {
        producer_node_query {
          id: "StatisticsGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.StatisticsGen"
            }
          }
        }
        artifact_query {
          type {
            name: "ExampleStatistics"
            base_type: STATISTICS
          }
        }
        output_key: "statistics"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "schema"
    value {
      artifact_spec {
        type {
          name: "Schema"
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "exclude_splits"
    value {
      field_value {
        string_value: "[\"eval\"]"
      }
    }
  }
  parameters {
    key: "infer_feature_shape"
    value {
      field_value {
        int_value: 0
      }
    }
  }
}
upstream_nodes: "StatisticsGen"
downstream_nodes: "ExampleValidator"
downstream_nodes: "Trainer"
downstream_nodes: "Transform"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "tfx-llm-imdb-reviews"
, pipeline_run_id='2024-06-19T10:34:33.931011', top_level_pipeline_run_id=None, frontend_url=None)
INFO:absl:Processing schema from statistics for split train.
INFO:absl:Schema written to pipelines/tfx-llm-imdb-reviews/SchemaGen/schema/3/schema.pbtxt.
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 3 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Deleted stateful_working_dir pipelines/tfx-llm-imdb-reviews/SchemaGen/.system/stateful_working_dir/ef5fb07e-4c41-429e-a3c2-c11de8cd4be0
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'schema': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/SchemaGen/schema/3"
, artifact_type: name: "Schema"
)]}) for execution 3
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component SchemaGen is finished.
INFO:absl:Component ExampleValidator is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.components.example_validator.component.ExampleValidator"
  }
  id: "ExampleValidator"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.ExampleValidator"
      }
    }
  }
}
inputs {
  inputs {
    key: "schema"
    value {
      channels {
        producer_node_query {
          id: "SchemaGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.SchemaGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Schema"
          }
        }
        output_key: "schema"
      }
      min_count: 1
    }
  }
  inputs {
    key: "statistics"
    value {
      channels {
        producer_node_query {
          id: "StatisticsGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.StatisticsGen"
            }
          }
        }
        artifact_query {
          type {
            name: "ExampleStatistics"
            base_type: STATISTICS
          }
        }
        output_key: "statistics"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "anomalies"
    value {
      artifact_spec {
        type {
          name: "ExampleAnomalies"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "exclude_splits"
    value {
      field_value {
        string_value: "[\"eval\"]"
      }
    }
  }
}
upstream_nodes: "SchemaGen"
upstream_nodes: "StatisticsGen"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
WARNING:absl:ArtifactQuery.property_predicate is not supported.
WARNING:absl:ArtifactQuery.property_predicate is not supported.
INFO:absl:[ExampleValidator] Resolved inputs: ({'statistics': [Artifact(artifact: id: 2
type_id: 17
uri: "pipelines/tfx-llm-imdb-reviews/StatisticsGen/statistics/2"
properties {
  key: "span"
  value {
    int_value: 0
  }
}
properties {
  key: "split_names"
  value {
    string_value: "[\"train\"]"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "stats_dashboard_link"
  value {
    string_value: ""
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "ExampleStatistics"
create_time_since_epoch: 1718793280732
last_update_time_since_epoch: 1718793280732
, artifact_type: id: 17
name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)], 'schema': [Artifact(artifact: id: 3
type_id: 19
uri: "pipelines/tfx-llm-imdb-reviews/SchemaGen/schema/3"
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Schema"
create_time_since_epoch: 1718793280787
last_update_time_since_epoch: 1718793280787
, artifact_type: id: 19
name: "Schema"
)]},)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 4
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=4, input_dict={'statistics': [Artifact(artifact: id: 2
type_id: 17
uri: "pipelines/tfx-llm-imdb-reviews/StatisticsGen/statistics/2"
properties {
  key: "span"
  value {
    int_value: 0
  }
}
properties {
  key: "split_names"
  value {
    string_value: "[\"train\"]"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "stats_dashboard_link"
  value {
    string_value: ""
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "ExampleStatistics"
create_time_since_epoch: 1718793280732
last_update_time_since_epoch: 1718793280732
, artifact_type: id: 17
name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)], 'schema': [Artifact(artifact: id: 3
type_id: 19
uri: "pipelines/tfx-llm-imdb-reviews/SchemaGen/schema/3"
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Schema"
create_time_since_epoch: 1718793280787
last_update_time_since_epoch: 1718793280787
, artifact_type: id: 19
name: "Schema"
)]}, output_dict=defaultdict(<class 'list'>, {'anomalies': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/ExampleValidator/anomalies/4"
, artifact_type: name: "ExampleAnomalies"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
)]}), exec_properties={'exclude_splits': '["eval"]'}, execution_output_uri='pipelines/tfx-llm-imdb-reviews/ExampleValidator/.system/executor_execution/4/executor_output.pb', stateful_working_dir='pipelines/tfx-llm-imdb-reviews/ExampleValidator/.system/stateful_working_dir/bc03d741-7b17-4523-942a-afe07d163534', tmp_dir='pipelines/tfx-llm-imdb-reviews/ExampleValidator/.system/executor_execution/4/.temp/', pipeline_node=node_info {
  type {
    name: "tfx.components.example_validator.component.ExampleValidator"
  }
  id: "ExampleValidator"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.ExampleValidator"
      }
    }
  }
}
inputs {
  inputs {
    key: "schema"
    value {
      channels {
        producer_node_query {
          id: "SchemaGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.SchemaGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Schema"
          }
        }
        output_key: "schema"
      }
      min_count: 1
    }
  }
  inputs {
    key: "statistics"
    value {
      channels {
        producer_node_query {
          id: "StatisticsGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.StatisticsGen"
            }
          }
        }
        artifact_query {
          type {
            name: "ExampleStatistics"
            base_type: STATISTICS
          }
        }
        output_key: "statistics"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "anomalies"
    value {
      artifact_spec {
        type {
          name: "ExampleAnomalies"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "exclude_splits"
    value {
      field_value {
        string_value: "[\"eval\"]"
      }
    }
  }
}
upstream_nodes: "SchemaGen"
upstream_nodes: "StatisticsGen"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "tfx-llm-imdb-reviews"
, pipeline_run_id='2024-06-19T10:34:33.931011', top_level_pipeline_run_id=None, frontend_url=None)
INFO:absl:Validating schema against the computed statistics for split train.
INFO:absl:Anomalies alerts created for split train.
INFO:absl:Validation complete for split train. Anomalies written to pipelines/tfx-llm-imdb-reviews/ExampleValidator/anomalies/4/Split-train.
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 4 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Deleted stateful_working_dir pipelines/tfx-llm-imdb-reviews/ExampleValidator/.system/stateful_working_dir/bc03d741-7b17-4523-942a-afe07d163534
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'anomalies': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/ExampleValidator/anomalies/4"
, artifact_type: name: "ExampleAnomalies"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
)]}) for execution 4
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component ExampleValidator is finished.
INFO:absl:Component Transform is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.components.transform.component.Transform"
    base_type: TRANSFORM
  }
  id: "Transform"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.Transform"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "FileBasedExampleGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.FileBasedExampleGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "examples"
      }
      min_count: 1
    }
  }
  inputs {
    key: "schema"
    value {
      channels {
        producer_node_query {
          id: "SchemaGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.SchemaGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Schema"
          }
        }
        output_key: "schema"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "post_transform_anomalies"
    value {
      artifact_spec {
        type {
          name: "ExampleAnomalies"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
        }
      }
    }
  }
  outputs {
    key: "post_transform_schema"
    value {
      artifact_spec {
        type {
          name: "Schema"
        }
      }
    }
  }
  outputs {
    key: "post_transform_stats"
    value {
      artifact_spec {
        type {
          name: "ExampleStatistics"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          base_type: STATISTICS
        }
      }
    }
  }
  outputs {
    key: "pre_transform_schema"
    value {
      artifact_spec {
        type {
          name: "Schema"
        }
      }
    }
  }
  outputs {
    key: "pre_transform_stats"
    value {
      artifact_spec {
        type {
          name: "ExampleStatistics"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          base_type: STATISTICS
        }
      }
    }
  }
  outputs {
    key: "transform_graph"
    value {
      artifact_spec {
        type {
          name: "TransformGraph"
        }
      }
    }
  }
  outputs {
    key: "transformed_examples"
    value {
      artifact_spec {
        type {
          name: "Examples"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          properties {
            key: "version"
            value: INT
          }
          base_type: DATASET
        }
      }
    }
  }
  outputs {
    key: "updated_analyzer_cache"
    value {
      artifact_spec {
        type {
          name: "TransformCache"
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "custom_config"
    value {
      field_value {
        string_value: "null"
      }
    }
  }
  parameters {
    key: "disable_statistics"
    value {
      field_value {
        int_value: 0
      }
    }
  }
  parameters {
    key: "force_tf_compat_v1"
    value {
      field_value {
        int_value: 0
      }
    }
  }
  parameters {
    key: "module_path"
    value {
      field_value {
        string_value: "_transform_module@pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl"
      }
    }
  }
}
upstream_nodes: "FileBasedExampleGen"
upstream_nodes: "SchemaGen"
downstream_nodes: "Evaluator"
downstream_nodes: "Trainer"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
WARNING:absl:ArtifactQuery.property_predicate is not supported.
WARNING:absl:ArtifactQuery.property_predicate is not supported.
INFO:absl:[Transform] Resolved inputs: ({'schema': [Artifact(artifact: id: 3
type_id: 19
uri: "pipelines/tfx-llm-imdb-reviews/SchemaGen/schema/3"
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Schema"
create_time_since_epoch: 1718793280787
last_update_time_since_epoch: 1718793280787
, artifact_type: id: 19
name: "Schema"
)], 'examples': [Artifact(artifact: id: 1
type_id: 15
uri: "pipelines/tfx-llm-imdb-reviews/FileBasedExampleGen/examples/1"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "file_format"
  value {
    string_value: "tfrecords_gzip"
  }
}
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:0,total_bytes:0,xor_checksum:0,sum_checksum:0"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "payload_format"
  value {
    string_value: "FORMAT_TF_EXAMPLE"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1718793279187
last_update_time_since_epoch: 1718793279187
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)]},)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 5
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=5, input_dict={'schema': [Artifact(artifact: id: 3
type_id: 19
uri: "pipelines/tfx-llm-imdb-reviews/SchemaGen/schema/3"
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Schema"
create_time_since_epoch: 1718793280787
last_update_time_since_epoch: 1718793280787
, artifact_type: id: 19
name: "Schema"
)], 'examples': [Artifact(artifact: id: 1
type_id: 15
uri: "pipelines/tfx-llm-imdb-reviews/FileBasedExampleGen/examples/1"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "file_format"
  value {
    string_value: "tfrecords_gzip"
  }
}
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:0,total_bytes:0,xor_checksum:0,sum_checksum:0"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "payload_format"
  value {
    string_value: "FORMAT_TF_EXAMPLE"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1718793279187
last_update_time_since_epoch: 1718793279187
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)]}, output_dict=defaultdict(<class 'list'>, {'post_transform_stats': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/post_transform_stats/5"
, artifact_type: name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)], 'updated_analyzer_cache': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/updated_analyzer_cache/5"
, artifact_type: name: "TransformCache"
)], 'post_transform_schema': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/post_transform_schema/5"
, artifact_type: name: "Schema"
)], 'transform_graph': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/transform_graph/5"
, artifact_type: name: "TransformGraph"
)], 'pre_transform_stats': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/pre_transform_stats/5"
, artifact_type: name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)], 'transformed_examples': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/transformed_examples/5"
, artifact_type: name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)], 'pre_transform_schema': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/pre_transform_schema/5"
, artifact_type: name: "Schema"
)], 'post_transform_anomalies': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/post_transform_anomalies/5"
, artifact_type: name: "ExampleAnomalies"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
)]}), exec_properties={'force_tf_compat_v1': 0, 'module_path': '_transform_module@pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl', 'disable_statistics': 0, 'custom_config': 'null'}, execution_output_uri='pipelines/tfx-llm-imdb-reviews/Transform/.system/executor_execution/5/executor_output.pb', stateful_working_dir='pipelines/tfx-llm-imdb-reviews/Transform/.system/stateful_working_dir/7b93af11-bd3d-44bd-bef8-42527bfb0cf7', tmp_dir='pipelines/tfx-llm-imdb-reviews/Transform/.system/executor_execution/5/.temp/', pipeline_node=node_info {
  type {
    name: "tfx.components.transform.component.Transform"
    base_type: TRANSFORM
  }
  id: "Transform"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.Transform"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "FileBasedExampleGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.FileBasedExampleGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "examples"
      }
      min_count: 1
    }
  }
  inputs {
    key: "schema"
    value {
      channels {
        producer_node_query {
          id: "SchemaGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.SchemaGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Schema"
          }
        }
        output_key: "schema"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "post_transform_anomalies"
    value {
      artifact_spec {
        type {
          name: "ExampleAnomalies"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
        }
      }
    }
  }
  outputs {
    key: "post_transform_schema"
    value {
      artifact_spec {
        type {
          name: "Schema"
        }
      }
    }
  }
  outputs {
    key: "post_transform_stats"
    value {
      artifact_spec {
        type {
          name: "ExampleStatistics"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          base_type: STATISTICS
        }
      }
    }
  }
  outputs {
    key: "pre_transform_schema"
    value {
      artifact_spec {
        type {
          name: "Schema"
        }
      }
    }
  }
  outputs {
    key: "pre_transform_stats"
    value {
      artifact_spec {
        type {
          name: "ExampleStatistics"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          base_type: STATISTICS
        }
      }
    }
  }
  outputs {
    key: "transform_graph"
    value {
      artifact_spec {
        type {
          name: "TransformGraph"
        }
      }
    }
  }
  outputs {
    key: "transformed_examples"
    value {
      artifact_spec {
        type {
          name: "Examples"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          properties {
            key: "version"
            value: INT
          }
          base_type: DATASET
        }
      }
    }
  }
  outputs {
    key: "updated_analyzer_cache"
    value {
      artifact_spec {
        type {
          name: "TransformCache"
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "custom_config"
    value {
      field_value {
        string_value: "null"
      }
    }
  }
  parameters {
    key: "disable_statistics"
    value {
      field_value {
        int_value: 0
      }
    }
  }
  parameters {
    key: "force_tf_compat_v1"
    value {
      field_value {
        int_value: 0
      }
    }
  }
  parameters {
    key: "module_path"
    value {
      field_value {
        string_value: "_transform_module@pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl"
      }
    }
  }
}
upstream_nodes: "FileBasedExampleGen"
upstream_nodes: "SchemaGen"
downstream_nodes: "Evaluator"
downstream_nodes: "Trainer"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "tfx-llm-imdb-reviews"
, pipeline_run_id='2024-06-19T10:34:33.931011', top_level_pipeline_run_id=None, frontend_url=None)
INFO:absl:Analyze the 'train' split and transform all splits when splits_config is not set.
INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': '_transform_module@pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl', 'preprocessing_fn': None} 'preprocessing_fn'
INFO:absl:Installing 'pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmpxwi57wnw', 'pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl']
Processing ./pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl
INFO:absl:Successfully installed 'pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl'.
INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': '_transform_module@pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl', 'stats_options_updater_fn': None} 'stats_options_updater_fn'
INFO:absl:Installing 'pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmpeaoopvz4', 'pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl']
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446
Processing ./pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl
INFO:absl:Successfully installed 'pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl'.
INFO:absl:Installing 'pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmp4d_8hid2', 'pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl']
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446
Processing ./pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl
INFO:absl:Successfully installed 'pipelines/tfx-llm-imdb-reviews/_wheels/tfx_user_code_Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446-py3-none-any.whl'.
INFO:absl:Feature label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature text has no shape. Setting to varlen_sparse_tensor.
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+a4329fc2f5ffd915423ecc1d0904712abe0068f1f738e2f30e53175facb53446
INFO:absl:Feature text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature text has no shape. Setting to varlen_sparse_tensor.
INFO:tensorflow:Assets written to: pipelines/tfx-llm-imdb-reviews/Transform/transform_graph/5/.temp_path/tftransform_tmp/acd6cbb43ce74a41a356ac1b3e27fc3a/assets
INFO:tensorflow:Assets written to: pipelines/tfx-llm-imdb-reviews/Transform/transform_graph/5/.temp_path/tftransform_tmp/acd6cbb43ce74a41a356ac1b3e27fc3a/assets
INFO:absl:Writing fingerprint to pipelines/tfx-llm-imdb-reviews/Transform/transform_graph/5/.temp_path/tftransform_tmp/acd6cbb43ce74a41a356ac1b3e27fc3a/fingerprint.pb
INFO:absl:Feature summary has a shape . Setting to DenseTensor.
INFO:absl:Feature summary has a shape . Setting to DenseTensor.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 5 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Deleted stateful_working_dir pipelines/tfx-llm-imdb-reviews/Transform/.system/stateful_working_dir/7b93af11-bd3d-44bd-bef8-42527bfb0cf7
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'post_transform_stats': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/post_transform_stats/5"
, artifact_type: name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)], 'updated_analyzer_cache': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/updated_analyzer_cache/5"
, artifact_type: name: "TransformCache"
)], 'post_transform_schema': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/post_transform_schema/5"
, artifact_type: name: "Schema"
)], 'transform_graph': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/transform_graph/5"
, artifact_type: name: "TransformGraph"
)], 'pre_transform_stats': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/pre_transform_stats/5"
, artifact_type: name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)], 'transformed_examples': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/transformed_examples/5"
, artifact_type: name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)], 'pre_transform_schema': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/pre_transform_schema/5"
, artifact_type: name: "Schema"
)], 'post_transform_anomalies': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Transform/post_transform_anomalies/5"
, artifact_type: name: "ExampleAnomalies"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
)]}) for execution 5
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component Transform is finished.
INFO:absl:Component Trainer is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.components.trainer.component.Trainer"
    base_type: TRAIN
  }
  id: "Trainer"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.Trainer"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "Transform"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.Transform"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "transformed_examples"
      }
      min_count: 1
    }
  }
  inputs {
    key: "schema"
    value {
      channels {
        producer_node_query {
          id: "SchemaGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.SchemaGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Schema"
          }
        }
        output_key: "schema"
      }
    }
  }
}
outputs {
  outputs {
    key: "model"
    value {
      artifact_spec {
        type {
          name: "Model"
          base_type: MODEL
        }
      }
    }
  }
  outputs {
    key: "model_run"
    value {
      artifact_spec {
        type {
          name: "ModelRun"
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "custom_config"
    value {
      field_value {
        string_value: "null"
      }
    }
  }
  parameters {
    key: "eval_args"
    value {
      field_value {
        string_value: "{\n  \"splits\": [\n    \"train\"\n  ]\n}"
      }
    }
  }
  parameters {
    key: "run_fn"
    value {
      field_value {
        string_value: "modules.model.run_fn"
      }
    }
  }
  parameters {
    key: "train_args"
    value {
      field_value {
        string_value: "{\n  \"splits\": [\n    \"train\"\n  ]\n}"
      }
    }
  }
}
upstream_nodes: "SchemaGen"
upstream_nodes: "Transform"
downstream_nodes: "Evaluator"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
WARNING:absl:ArtifactQuery.property_predicate is not supported.
WARNING:absl:ArtifactQuery.property_predicate is not supported.
INFO:absl:[Trainer] Resolved inputs: ({'schema': [Artifact(artifact: id: 3
type_id: 19
uri: "pipelines/tfx-llm-imdb-reviews/SchemaGen/schema/3"
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Schema"
create_time_since_epoch: 1718793280787
last_update_time_since_epoch: 1718793280787
, artifact_type: id: 19
name: "Schema"
)], 'examples': [Artifact(artifact: id: 10
type_id: 15
uri: "pipelines/tfx-llm-imdb-reviews/Transform/transformed_examples/5"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1718793294917
last_update_time_since_epoch: 1718793294917
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)]},)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 6
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=6, input_dict={'schema': [Artifact(artifact: id: 3
type_id: 19
uri: "pipelines/tfx-llm-imdb-reviews/SchemaGen/schema/3"
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Schema"
create_time_since_epoch: 1718793280787
last_update_time_since_epoch: 1718793280787
, artifact_type: id: 19
name: "Schema"
)], 'examples': [Artifact(artifact: id: 10
type_id: 15
uri: "pipelines/tfx-llm-imdb-reviews/Transform/transformed_examples/5"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1718793294917
last_update_time_since_epoch: 1718793294917
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)]}, output_dict=defaultdict(<class 'list'>, {'model_run': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Trainer/model_run/6"
, artifact_type: name: "ModelRun"
)], 'model': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Trainer/model/6"
, artifact_type: name: "Model"
base_type: MODEL
)]}), exec_properties={'run_fn': 'modules.model.run_fn', 'train_args': '{\n  "splits": [\n    "train"\n  ]\n}', 'custom_config': 'null', 'eval_args': '{\n  "splits": [\n    "train"\n  ]\n}'}, execution_output_uri='pipelines/tfx-llm-imdb-reviews/Trainer/.system/executor_execution/6/executor_output.pb', stateful_working_dir='pipelines/tfx-llm-imdb-reviews/Trainer/.system/stateful_working_dir/b27f59b5-d7bf-4cc8-9779-978a38135751', tmp_dir='pipelines/tfx-llm-imdb-reviews/Trainer/.system/executor_execution/6/.temp/', pipeline_node=node_info {
  type {
    name: "tfx.components.trainer.component.Trainer"
    base_type: TRAIN
  }
  id: "Trainer"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.Trainer"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "Transform"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.Transform"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "transformed_examples"
      }
      min_count: 1
    }
  }
  inputs {
    key: "schema"
    value {
      channels {
        producer_node_query {
          id: "SchemaGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.SchemaGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Schema"
          }
        }
        output_key: "schema"
      }
    }
  }
}
outputs {
  outputs {
    key: "model"
    value {
      artifact_spec {
        type {
          name: "Model"
          base_type: MODEL
        }
      }
    }
  }
  outputs {
    key: "model_run"
    value {
      artifact_spec {
        type {
          name: "ModelRun"
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "custom_config"
    value {
      field_value {
        string_value: "null"
      }
    }
  }
  parameters {
    key: "eval_args"
    value {
      field_value {
        string_value: "{\n  \"splits\": [\n    \"train\"\n  ]\n}"
      }
    }
  }
  parameters {
    key: "run_fn"
    value {
      field_value {
        string_value: "modules.model.run_fn"
      }
    }
  }
  parameters {
    key: "train_args"
    value {
      field_value {
        string_value: "{\n  \"splits\": [\n    \"train\"\n  ]\n}"
      }
    }
  }
}
upstream_nodes: "SchemaGen"
upstream_nodes: "Transform"
downstream_nodes: "Evaluator"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "tfx-llm-imdb-reviews"
, pipeline_run_id='2024-06-19T10:34:33.931011', top_level_pipeline_run_id=None, frontend_url=None)
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
INFO:absl:udf_utils.get_fn {'run_fn': 'modules.model.run_fn', 'train_args': '{\n  "splits": [\n    "train"\n  ]\n}', 'custom_config': 'null', 'eval_args': '{\n  "splits": [\n    "train"\n  ]\n}'} 'run_fn'
INFO:absl:Training model.
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/preprocessor.json...
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/task.json...
2024-06-19 10:35:35.015847: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert
40/41 ━━━━━━━━━━━━━━━━━━━━ 0s 351ms/step - accuracy: 0.3096 - loss: 3.7683
2024-06-19 10:36:57.859417: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert
W0000 00:00:1718793493.998777  163710 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
41/41 ━━━━━━━━━━━━━━━━━━━━ 195s 2s/step - accuracy: 0.3098 - loss: 3.7670
INFO:absl:Training complete. Model written to pipelines/tfx-llm-imdb-reviews/Trainer/model/6/Format-Serving. ModelRun written to pipelines/tfx-llm-imdb-reviews/Trainer/model_run/6
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 6 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Deleted stateful_working_dir pipelines/tfx-llm-imdb-reviews/Trainer/.system/stateful_working_dir/b27f59b5-d7bf-4cc8-9779-978a38135751
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'model_run': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Trainer/model_run/6"
, artifact_type: name: "ModelRun"
)], 'model': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Trainer/model/6"
, artifact_type: name: "Model"
base_type: MODEL
)]}) for execution 6
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component Trainer is finished.
INFO:absl:Component Evaluator is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "__main__.Evaluator"
  }
  id: "Evaluator"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.Evaluator"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "Transform"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.Transform"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "transformed_examples"
      }
      min_count: 1
    }
  }
  inputs {
    key: "trained_model"
    value {
      channels {
        producer_node_query {
          id: "Trainer"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.Trainer"
            }
          }
        }
        artifact_query {
          type {
            name: "Model"
            base_type: MODEL
          }
        }
        output_key: "model"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "evaluation"
    value {
      artifact_spec {
        type {
          name: "Evaluation_Metric"
          properties {
            key: "model_evaluation_output_path"
            value: STRING
          }
          properties {
            key: "model_prediction_time"
            value: DOUBLE
          }
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "max_length"
    value {
      field_value {
        int_value: 50
      }
    }
  }
}
upstream_nodes: "Trainer"
upstream_nodes: "Transform"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
WARNING:absl:ArtifactQuery.property_predicate is not supported.
WARNING:absl:ArtifactQuery.property_predicate is not supported.
INFO:absl:[Evaluator] Resolved inputs: ({'examples': [Artifact(artifact: id: 10
type_id: 15
uri: "pipelines/tfx-llm-imdb-reviews/Transform/transformed_examples/5"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1718793294917
last_update_time_since_epoch: 1718793294917
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)], 'trained_model': [Artifact(artifact: id: 14
type_id: 27
uri: "pipelines/tfx-llm-imdb-reviews/Trainer/model/6"
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Model"
create_time_since_epoch: 1718793500184
last_update_time_since_epoch: 1718793500184
, artifact_type: id: 27
name: "Model"
base_type: MODEL
)]},)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 7
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=7, input_dict={'examples': [Artifact(artifact: id: 10
type_id: 15
uri: "pipelines/tfx-llm-imdb-reviews/Transform/transformed_examples/5"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1718793294917
last_update_time_since_epoch: 1718793294917
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)], 'trained_model': [Artifact(artifact: id: 14
type_id: 27
uri: "pipelines/tfx-llm-imdb-reviews/Trainer/model/6"
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.15.1"
  }
}
state: LIVE
type: "Model"
create_time_since_epoch: 1718793500184
last_update_time_since_epoch: 1718793500184
, artifact_type: id: 27
name: "Model"
base_type: MODEL
)]}, output_dict=defaultdict(<class 'list'>, {'evaluation': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Evaluator/evaluation/7"
, artifact_type: name: "Evaluation_Metric"
properties {
  key: "model_evaluation_output_path"
  value: STRING
}
properties {
  key: "model_prediction_time"
  value: DOUBLE
}
)]}), exec_properties={'max_length': 50}, execution_output_uri='pipelines/tfx-llm-imdb-reviews/Evaluator/.system/executor_execution/7/executor_output.pb', stateful_working_dir='pipelines/tfx-llm-imdb-reviews/Evaluator/.system/stateful_working_dir/671c8077-1414-4a58-8fe1-eaf3448d5527', tmp_dir='pipelines/tfx-llm-imdb-reviews/Evaluator/.system/executor_execution/7/.temp/', pipeline_node=node_info {
  type {
    name: "__main__.Evaluator"
  }
  id: "Evaluator"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2024-06-19T10:34:33.931011"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "tfx-llm-imdb-reviews.Evaluator"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "Transform"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.Transform"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "transformed_examples"
      }
      min_count: 1
    }
  }
  inputs {
    key: "trained_model"
    value {
      channels {
        producer_node_query {
          id: "Trainer"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2024-06-19T10:34:33.931011"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "tfx-llm-imdb-reviews.Trainer"
            }
          }
        }
        artifact_query {
          type {
            name: "Model"
            base_type: MODEL
          }
        }
        output_key: "model"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "evaluation"
    value {
      artifact_spec {
        type {
          name: "Evaluation_Metric"
          properties {
            key: "model_evaluation_output_path"
            value: STRING
          }
          properties {
            key: "model_prediction_time"
            value: DOUBLE
          }
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "max_length"
    value {
      field_value {
        int_value: 50
      }
    }
  }
}
upstream_nodes: "Trainer"
upstream_nodes: "Transform"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "tfx-llm-imdb-reviews"
, pipeline_run_id='2024-06-19T10:34:33.931011', top_level_pipeline_run_id=None, frontend_url=None)
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/preprocessor.json...
Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/2/download/task.json...
1/2 ━━━━━━━━━━━━━━━━━━━━ 4s 5s/step
W0000 00:00:1718793511.410155  163720 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
2/2 ━━━━━━━━━━━━━━━━━━━━ 9s 4s/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 3s 2s/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 3s 2s/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 3s 2s/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 165ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 3s 1s/step
W0000 00:00:1718793557.898255  163709 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
3/3 ━━━━━━━━━━━━━━━━━━━━ 4s 1s/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 4s 1s/step
W0000 00:00:1718793576.951148  163726 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
3/3 ━━━━━━━━━━━━━━━━━━━━ 3s 1s/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 1s 153ms/step
W0000 00:00:1718793610.720210  163525 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/saving/saving_lib.py:415: UserWarning: Skipping variable loading for optimizer 'loss_scale_optimizer', because it has 4 variables whereas the saved optimizer has 397 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/saving/saving_lib.py:415: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 393 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 165ms/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 175ms/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 178ms/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 176ms/step
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 173ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 1s 153ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 1s 158ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 1s 165ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 1s 163ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 1s 159ms/step
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 7 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Deleted stateful_working_dir pipelines/tfx-llm-imdb-reviews/Evaluator/.system/stateful_working_dir/671c8077-1414-4a58-8fe1-eaf3448d5527
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'evaluation': [Artifact(artifact: uri: "pipelines/tfx-llm-imdb-reviews/Evaluator/evaluation/7"
, artifact_type: name: "Evaluation_Metric"
properties {
  key: "model_evaluation_output_path"
  value: STRING
}
properties {
  key: "model_prediction_time"
  value: DOUBLE
}
)]}) for execution 7
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component Evaluator is finished.

You should see INFO:absl:Component Evaluator is finished." at the end of the logs if the pipeline finished successfully because evaluator component is the last component of the pipeline.