Cloud AI Platform Trainer component.
Inherits From: Trainer
, BaseComponent
, BaseNode
tfx.v1.extensions.google_cloud_ai_platform.Trainer(
examples: Optional[tfx.v1.dsl.Channel
] = None,
transformed_examples: Optional[tfx.v1.dsl.Channel
] = None,
transform_graph: Optional[tfx.v1.dsl.Channel
] = None,
schema: Optional[tfx.v1.dsl.Channel
] = None,
base_model: Optional[tfx.v1.dsl.Channel
] = None,
hyperparameters: Optional[tfx.v1.dsl.Channel
] = None,
module_file: Optional[Union[str, tfx.v1.dsl.experimental.RuntimeParameter
]] = None,
run_fn: Optional[Union[str, tfx.v1.dsl.experimental.RuntimeParameter
]] = None,
trainer_fn: Optional[Union[str, tfx.v1.dsl.experimental.RuntimeParameter
]] = None,
train_args: Optional[Union[tfx.v1.proto.TrainArgs
, tfx.v1.dsl.experimental.RuntimeParameter
]] = None,
eval_args: Optional[Union[tfx.v1.proto.EvalArgs
, tfx.v1.dsl.experimental.RuntimeParameter
]] = None,
custom_config: Optional[Dict[str, Any]] = None
)
Used in the notebooks
Args |
examples
|
A Channel of type standard_artifacts.Examples , serving as the
source of examples used in training (required). May be raw or
transformed.
|
transformed_examples
|
Deprecated field. Please set examples instead.
|
transform_graph
|
An optional Channel of type
standard_artifacts.TransformGraph , serving as the input transform
graph if present.
|
schema
|
An optional Channel of type standard_artifacts.Schema , serving
as the schema of training and eval data. Schema is optional when 1)
transform_graph is provided which contains schema. 2) user module
bypasses the usage of schema, e.g., hardcoded.
|
base_model
|
A Channel of type Model , containing model that will be used
for training. This can be used for warmstart, transfer learning or model
ensembling.
|
hyperparameters
|
A Channel of type standard_artifacts.HyperParameters ,
serving as the hyperparameters for training module. Tuner's output best
hyperparameters can be feed into this.
|
module_file
|
A path to python module file containing UDF model definition.
The module_file must implement a function named run_fn at its top
level with function signature: def
run_fn(trainer.fn_args_utils.FnArgs) , and the trained model must be
saved to FnArgs.serving_model_dir when this function is executed. For
Estimator based Executor, The module_file must implement a function
named trainer_fn at its top level. The function must have the
following signature. def trainer_fn(trainer.fn_args_utils.FnArgs,
tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ...
where the returned Dict has the following key-values.
'estimator': an instance of tf.estimator.Estimator
'train_spec': an instance of tf.estimator.TrainSpec
'eval_spec': an instance of tf.estimator.EvalSpec
'eval_input_receiver_fn': an instance of tfma EvalInputReceiver.
|
run_fn
|
A python path to UDF model definition function for generic
trainer. See 'module_file' for details. Exactly one of 'module_file' or
'run_fn' must be supplied if Trainer uses GenericExecutor (default).
|
trainer_fn
|
A python path to UDF model definition function for estimator
based trainer. See 'module_file' for the required signature of the UDF.
Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer
uses Estimator based Executor
|
train_args
|
A proto.TrainArgs instance, containing args used for training
Currently only splits and num_steps are available. Default behavior
(when splits is empty) is train on train split.
|
eval_args
|
A proto.EvalArgs instance, containing args used for evaluation.
Currently only splits and num_steps are available. Default behavior
(when splits is empty) is evaluate on eval split.
|
custom_config
|
A dict which contains addtional training job parameters
that will be passed into user module.
|
Attributes |
outputs
|
Component's output channel dict.
|
Methods
with_node_execution_options
with_node_execution_options(
node_execution_options: utils.NodeExecutionOptions
) -> typing_extensions.Self
Class Variables |
POST_EXECUTABLE_SPEC
|
None
|
PRE_EXECUTABLE_SPEC
|
None
|