View source on GitHub |
Implements the common trainer shared for TensorFlow models.
tfm.core.base_trainer.Trainer(
config: tfm.core.base_trainer.ExperimentConfig
,
task: tfm.core.base_task.Task
,
model: tf.keras.Model,
optimizer: tf.optimizers.Optimizer,
train: bool = True,
evaluate: bool = True,
train_dataset: Optional[Union[tf.data.Dataset, tf.distribute.DistributedDataset]] = None,
validation_dataset: Optional[Union[tf.data.Dataset, tf.distribute.DistributedDataset]] = None,
checkpoint_exporter=None
)
Methods
coordinator_for_async
coordinator_for_async() -> tf.distribute.experimental.coordinator.ClusterCoordinator
create_eval_loop_fn
create_eval_loop_fn(
has_state: bool
)
Creates a training loop from the given step function and options.
create_train_loop_fn
create_train_loop_fn()
Creates a eval loop from the given step function and options.
distribute_dataset
distribute_dataset(
dataset_or_fn, *args, **kwargs
)
A utility function to help create a tf.distribute.DistributedDataset
.
Args | |
---|---|
dataset_or_fn
|
A instance of tf.data.Dataset , or a "dataset function"
returning a tf.data.Dataset . If it is a function, it may optionally
have an argument named input_context which will be passed a
tf.distribute.InputContext instance.
|
*args
|
Any positional arguments to pass through to dataset_or_fn .
|
**kwargs
|
Any keyword arguments to pass through to dataset_or_fn .
|
Returns | |
---|---|
A distributed Dataset. |
eval_begin
eval_begin()
Sets up metrics.
eval_end
eval_end(
aggregated_logs=None
)
Processes evaluation results.
eval_reduce
eval_reduce(
state=None, step_outputs=None
)
A function to perform per-step reduction on the evaluation outputs.
This is useful for passing state throughout evaluation, especially in cases
where maintaining or accumulating state is hard to accomplish using
tf.metrics.Metric
or other tf.Variable
-based approaches. For instance,
it can be used to easily accumulate all per-example losses from the full
evaluation for subsequent processing in eval_end()
.
Args | |
---|---|
state
|
A state being maintained throughout the evaluation. |
step_outputs
|
Outputs from the current evaluation step. |
Returns | |
---|---|
An output which is passed as the state argument to this function for the
next step. After evaluation is finished, the output from last step will be
passed to eval_end .
|
eval_step
eval_step(
iterator
)
See base class.
evaluate
evaluate(
num_steps: tf.Tensor
) -> Optional[runner.Output]
Implements num_steps
steps of evaluation.
Args | |
---|---|
num_steps
|
The number of evaluation steps to run. When this is -1,
evaluation proceeds until a call to eval_step raises a StopIteration
or tf.errors.OutOfRangeError .
|
Returns | |
---|---|
The output of self.eval_end() .
|
Raises | |
---|---|
ValueError
|
If options.use_tf_while_loop is True and num_steps is
unspecified.
|
init_async
init_async()
Initializes the Async Trainer base class.
initialize
initialize()
A callback function.
This function will be called when no checkpoint found for the model. If there is a checkpoint, the checkpoint will be loaded and this function will not be called. Tasks may use this callback function to load a pretrained checkpoint, saved under a directory other than the model_dir.
join
join()
Join all async steps. Only useful in aysnc training.
next_eval_inputs
next_eval_inputs(
iterator
)
Fetches the next inputs for the model during eval.
This method consumes the input iterator and returns the next inputs for the
model and an additional logs dict. The output dict remains in the host (not
sent to GPUs/TPUs) and is merged with the model outputs which will be
processed later in aggregate_logs
. This is useful for sending extra logs
downstream that are not compatible with the accelerators.
Args | |
---|---|
iterator
|
Dataset iterator to generate the next inputs from. |
Returns | |
---|---|
The inputs to the model, and an additional logs dictionnary. The logs are not passed to the model, instead they are merged with model output logs. |
next_train_inputs
next_train_inputs(
iterator
)
Fetches the next inputs for the model during train.
This method consumes the input iterator and returns the next inputs for the model.
This method provides a way to control how to fetch the next model input, and what data to send to the model.
Args | |
---|---|
iterator
|
Dataset iterator to generate the next inputs from. |
Returns | |
---|---|
The inputs to the model. |
train
train(
num_steps: tf.Tensor
) -> Optional[runner.Output]
Implements num_steps
steps of training.
Args | |
---|---|
num_steps
|
The number of training steps to run. This corresponds directly
to the number of calls made to train_step .
|
Returns | |
---|---|
The output of train_loop_end .
|
train_loop_begin
train_loop_begin()
Called once at the beginning of the training loop.
This method is always called in eager mode, and is a good place to reset metrics that accumulate values over multiple steps of training.
Note that this method is called before dataset iterator creation.
train_loop_end
train_loop_end()
See base class.
train_step
train_step(
iterator
)
See base class.