View source on GitHub |
Runs experiment with Orbit training loop.
tfm.core.train_lib.OrbitExperimentRunner(
distribution_strategy: tf.distribute.Strategy,
task: tfm.core.base_task.Task
,
mode: str,
params: tfm.core.base_trainer.ExperimentConfig
,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[tfm.core.base_trainer.Trainer
] = None,
controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
enable_async_checkpointing: bool = False
)
The default experiment runner for model garden experiments. User can customize the experiment pipeline by subclassing this class and replacing components or functions.
For example, an experiment runner with customized checkpoint manager:
class MyExpRunnerWithExporter(OrbitExperimentRunner):
def _maybe_build_checkpoint_manager(sefl):
# Replaces the default CheckpointManger with a customized one.
return MyCheckpointManager(*args)
# In user code, instead of the orginal
# `OrbitExperimentRunner(..).run(mode)`, now user can do:
MyExpRunnerWithExporter(**needed_kwargs).run(mode)
Similar override can be done to other components.
Methods
run
run() -> Tuple[tf.keras.Model, Mapping[str, Any]]
Run experiments by mode.
Returns | |
---|---|
A 2-tuple of (model, eval_logs).
model: tf.keras.Model instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
|