View source on GitHub |
Eval action to save checkpoint with average weights when EMA is used.
tfm.core.actions.EMACheckpointing(
export_dir: str,
optimizer: tf.keras.optimizers.Optimizer,
checkpoint: tf.train.Checkpoint,
max_to_keep: int = 1
)
This action swaps the weights of the model with the average weights, then it saves the checkpoint under export_dir/ema_checkpoints. Checkpointing is expensive for large models, so doing this action in eval is more efficient than training.
Args | |
---|---|
export_dir
|
str for the export directory of the EMA average weights.
|
optimizer
|
tf.keras.optimizers.Optimizer optimizer instance used for
training. This will be used to swap the model weights with the average
weigths.
|
checkpoint
|
tf.train.Checkpoint instance.
|
max_to_keep
|
int for max checkpoints to keep in ema_checkpoints subdir.
|
Methods
__call__
__call__(
output: orbit.runner.Output
)
Swaps model weights, and saves the checkpoint.
Args | |
---|---|
output
|
The train or eval output. |