Custom Python function components

Python function-based component definition makes it easier for you to create TFX custom components, by saving you the effort of defining a component specification class, executor class, and component interface class. In this component definition style, you write a function that is annotated with type hints. The type hints describe the input artifacts, output artifacts, and parameters of your component.

Writing your custom component in this style is very straightforward, as in the following example.

class MyOutput(TypedDict):
  accuracy: float

@component
def MyValidationComponent(
    model: InputArtifact[Model],
    blessing: OutputArtifact[Model],
    accuracy_threshold: Parameter[int] = 10,
) -> MyOutput:
  '''My simple custom model validation component.'''

  accuracy = evaluate_model(model)
  if accuracy >= accuracy_threshold:
    write_output_blessing(blessing)

  return {
    'accuracy': accuracy
  }

Under the hood, this defines a custom component that is a subclass of BaseComponent and its Spec and Executor classes.

If you want to define a subclass of BaseBeamComponent such that you could use a beam pipeline with TFX-pipeline-wise shared configuration, i.e., beam_pipeline_args when compiling the pipeline (Chicago Taxi Pipeline Example) you could set use_beam=True in the decorator and add another BeamComponentParameter with default value None in your function as the following example:

@component(use_beam=True)
def MyDataProcessor(
    examples: InputArtifact[Example],
    processed_examples: OutputArtifact[Example],
    beam_pipeline: BeamComponentParameter[beam.Pipeline] = None,
    ) -> None:
  '''My simple custom model validation component.'''

  with beam_pipeline as p:
    # data pipeline definition with beam_pipeline begins
    ...
    # data pipeline definition with beam_pipeline ends

If you are new to TFX pipelines, learn more about the core concepts of TFX pipelines.

Inputs, outputs, and parameters

In TFX, inputs and outputs are tracked as Artifact objects which describe the location of and metadata properties associated with the underlying data; this information is stored in ML Metadata. Artifacts can describe complex data types or simple data types, such as: int, float, bytes, or unicode strings.

A parameter is an argument (int, float, bytes, or unicode string) to a component known at pipeline construction time. Parameters are useful for specifying arguments and hyperparameters like training iteration count, dropout rate, and other configuration to your component. Parameters are stored as properties of component executions when tracked in ML Metadata.

Definition

To create a custom component, write a function that implements your custom logic and decorate it with the @component decorator from the tfx.dsl.component.experimental.decorators module. To define your component’s input and output schema, annotate your function’s arguments and return value using annotations from the tfx.dsl.component.experimental.annotations module:

  • For each artifact input, apply the InputArtifact[ArtifactType] type hint annotation. Replace ArtifactType with the artifact’s type, which is a subclass of tfx.types.Artifact. These inputs can be optional arguments.

  • For each output artifact, apply the OutputArtifact[ArtifactType] type hint annotation. Replace ArtifactType with the artifact’s type, which is a subclass of tfx.types.Artifact. Component output artifacts should be passed as input arguments of the function, so that your component can write outputs to a system-managed location and set appropriate artifact metadata properties. This argument can be optional or this argument can be defined with a default value.

  • For each parameter, use the type hint annotation Parameter[T]. Replace T with the type of the parameter. We currently only support primitive python types: bool, int, float, str, or bytes.

  • For beam pipeline, use the type hint annotation BeamComponentParameter[beam.Pipeline]. Set the default value to be None. The value None will be replaced by an instantiated beam pipeline created by _make_beam_pipeline() of BaseBeamExecutor

  • For each simple data type input (int, float, str or bytes) not known at pipeline construction time, use the type hint T. Note that in the TFX 0.22 release, concrete values cannot be passed at pipeline construction time for this type of input (use the Parameter annotation instead, as described in the previous section). This argument can be optional or this argument can be defined with a default value. If your component has simple data type outputs (int, float, str or bytes), you can return these outputs by using a TypedDict as a return type annotation, and returning an appropriate dict object.

In the body of your function, input and output artifacts are passed as tfx.types.Artifact objects; you can inspect its .uri to get its system-managed location and read/set any properties. Input parameters and simple data type inputs are passed as objects of the specified type. Simple data type outputs should be returned as a dictionary, where the keys are the appropriate output names and the values are the desired return values.

The completed function component can look like this:

from typing import TypedDict
import tfx.v1 as tfx
from tfx.dsl.component.experimental.decorators import component

class MyOutput(TypedDict):
  loss: float
  accuracy: float

@component
def MyTrainerComponent(
    training_data: tfx.dsl.components.InputArtifact[tfx.types.standard_artifacts.Examples],
    model: tfx.dsl.components.OutputArtifact[tfx.types.standard_artifacts.Model],
    dropout_hyperparameter: float,
    num_iterations: tfx.dsl.components.Parameter[int] = 10
) -> MyOutput:
  '''My simple trainer component.'''

  records = read_examples(training_data.uri)
  model_obj = train_model(records, num_iterations, dropout_hyperparameter)
  model_obj.write_to(model.uri)

  return {
    'loss': model_obj.loss,
    'accuracy': model_obj.accuracy
  }

# Example usage in a pipeline graph definition:
# ...
trainer = MyTrainerComponent(
    examples=example_gen.outputs['examples'],
    dropout_hyperparameter=other_component.outputs['dropout'],
    num_iterations=1000)
pusher = Pusher(model=trainer.outputs['model'])
# ...

The preceding example defines MyTrainerComponent as a Python function-based custom component. This component consumes an examples artifact as its input, and produces a model artifact as its output. The component uses the artifact_instance.uri to read or write the artifact at its system-managed location. The component takes a num_iterations input parameter and a dropout_hyperparameter simple data type value, and the component outputs loss and accuracy metrics as simple data type output values. The output model artifact is then used by the Pusher component.