tfm.core.train_lib.run_experiment
Runs train/eval configured by the experiment params.
tfm.core.train_lib.run_experiment(
distribution_strategy: tf.distribute.Strategy,
task: tfm.core.base_task.Task
,
mode: str,
params: tfm.core.base_trainer.ExperimentConfig
,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[tfm.core.base_trainer.Trainer
] = None,
controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
enable_async_checkpointing: bool = False
) -> Tuple[tf.keras.Model, Mapping[str, Any]]
Args |
distribution_strategy
|
A distribution distribution_strategy.
|
task
|
A Task instance.
|
mode
|
A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
|
params
|
ExperimentConfig instance.
|
model_dir
|
A 'str', a path to store model checkpoints and summaries.
|
run_post_eval
|
Whether to run post eval once after training, metrics logs
are returned.
|
save_summary
|
Whether to save train and validation summary.
|
train_actions
|
Optional list of Orbit train actions.
|
eval_actions
|
Optional list of Orbit eval actions.
|
trainer
|
the base_trainer.Trainer instance. It should be created within the
strategy.scope().
|
controller_cls
|
The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
|
summary_manager
|
Instance of the summary manager to override default summary
manager.
|
eval_summary_manager
|
Instance of the eval summary manager to override
default eval summary manager.
|
enable_async_checkpointing
|
Optional boolean indicating whether to enable
async checkpoint saving.
|
Returns |
A 2-tuple of (model, eval_logs).
model: tf.keras.Model instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-02-02 UTC.
[null,null,["Last updated 2024-02-02 UTC."],[],[]]