![]() |
A TFX component for model hyperparameter tuning.
Inherits From: BaseComponent
, BaseNode
tfx.components.Tuner(
examples: tfx.types.Channel
= None,
schema: Optional[tfx.types.Channel
] = None,
transform_graph: Optional[tfx.types.Channel
] = None,
module_file: Optional[Text] = None,
tuner_fn: Optional[Text] = None,
train_args: trainer_pb2.TrainArgs = None,
eval_args: trainer_pb2.EvalArgs = None,
tune_args: Optional[tuner_pb2.TuneArgs] = None,
custom_config: Optional[Dict[Text, Any]] = None,
best_hyperparameters: Optional[tfx.types.Channel
] = None,
instance_name: Optional[Text] = None
)
Args | |
---|---|
examples
|
A Channel of type standard_artifacts.Examples , serving as the
source of examples that are used in tuning (required).
|
schema
|
An optional Channel of type standard_artifacts.Schema , serving
as the schema of training and eval data. This is used when raw examples
are provided.
|
transform_graph
|
An optional Channel of type
standard_artifacts.TransformGraph , serving as the input transform
graph if present. This is used when transformed examples are provided.
|
module_file
|
A path to python module file containing UDF tuner definition.
The module_file must implement a function named tuner_fn at its top
level. The function must have the following signature.
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
Exactly one of 'module_file' or 'tuner_fn' must be supplied.
|
tuner_fn
|
A python path to UDF model definition function. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'tuner_fn' must be supplied. |
train_args
|
A trainer_pb2.TrainArgs instance, containing args used for
training. Currently only splits and num_steps are available. Default
behavior (when splits is empty) is train on train split.
|
eval_args
|
A trainer_pb2.EvalArgs instance, containing args used for eval.
Currently only splits and num_steps are available. Default behavior
(when splits is empty) is evaluate on eval split.
|
tune_args
|
A tuner_pb2.TuneArgs instance, containing args used for tuning. Currently only num_parallel_trials is available. |
custom_config
|
A dict which contains addtional training job parameters that will be passed into user module. |
best_hyperparameters
|
Optional Channel of type
standard_artifacts.HyperParameters for result of the best hparams.
|
instance_name
|
Optional unique instance name. Necessary if multiple Tuner components are declared in the same pipeline. |
Attributes | |
---|---|
component_id
|
|
component_type
|
|
downstream_nodes
|
|
exec_properties
|
|
id
|
Node id, unique across all TFX nodes in a pipeline.
If |
inputs
|
|
outputs
|
|
type
|
|
upstream_nodes
|
Child Classes
Methods
add_downstream_node
add_downstream_node(
downstream_node
)
Experimental: Add another component that must run after this one.
This method enables task-based dependencies by enforcing execution order for synchronous pipelines on supported platforms. Currently, the supported platforms are Airflow, Beam, and Kubeflow Pipelines.
Note that this API call should be considered experimental, and may not work with asynchronous pipelines, sub-pipelines and pipelines with conditional nodes. We also recommend relying on data for capturing dependencies where possible to ensure data lineage is fully captured within MLMD.
It is symmetric with add_upstream_node
.
Args | |
---|---|
downstream_node
|
a component that must run after this node. |
add_upstream_node
add_upstream_node(
upstream_node
)
Experimental: Add another component that must run before this one.
This method enables task-based dependencies by enforcing execution order for synchronous pipelines on supported platforms. Currently, the supported platforms are Airflow, Beam, and Kubeflow Pipelines.
Note that this API call should be considered experimental, and may not work with asynchronous pipelines, sub-pipelines and pipelines with conditional nodes. We also recommend relying on data for capturing dependencies where possible to ensure data lineage is fully captured within MLMD.
It is symmetric with add_downstream_node
.
Args | |
---|---|
upstream_node
|
a component that must run before this node. |
from_json_dict
@classmethod
from_json_dict( dict_data: Dict[Text, Any] ) -> Any
Convert from dictionary data to an object.
get_id
@classmethod
get_id( instance_name: Optional[Text] = None )
Gets the id of a node.
This can be used during pipeline authoring time. For example: from tfx.components import Trainer
resolver = ResolverNode(..., model=Channel( type=Model, producer_component_id=Trainer.get_id('my_trainer')))
Args | |
---|---|
instance_name
|
(Optional) instance name of a node. If given, the instance name will be taken into consideration when generating the id. |
Returns | |
---|---|
an id for the node. |
to_json_dict
to_json_dict() -> Dict[Text, Any]
Convert from an object to a JSON serializable dictionary.
with_id
with_id(
id: Text
) -> "BaseNode"
with_platform_config
with_platform_config(
config: message.Message
) -> "BaseComponent"
Attaches a proto-form platform config to a component.
The config will be a per-node platform-specific config.
Args | |
---|---|
config
|
platform config to attach to the component. |
Returns | |
---|---|
the same component itself. |
Class Variables | |
---|---|
EXECUTOR_SPEC |
tfx.dsl.components.base.executor_spec.ExecutorClassSpec
|