tfm.core.base_trainer.TrainerConfig

Configuration for trainer.

Inherits From: Config, ParamsDict

optimizer_config optimizer config, it includes optimizer, learning rate, and warmup schedule configs.
train_tf_while_loop whether or not to use tf while loop.
train_tf_function whether or not to use tf_function for training loop.
eval_tf_function whether or not to use tf_function for eval.
eval_tf_while_loop whether or not to use tf while loop for eval.
allow_tpu_summary Whether to allow summary happen inside the XLA program runs on TPU through automatic outside compilation.
steps_per_loop number of steps per loop to report training metrics. This can also be used to reduce host worker communication in a TPU setup.
summary_interval number of steps between each summary.
checkpoint_interval number of steps between checkpoints.
max_to_keep max checkpoints to keep.
continuous_eval_timeout maximum number of seconds to wait between checkpoints, if set to None, continuous eval will wait indefinitely. This is only used continuous_train_and_eval and continuous_eval modes. Default value is 1 hrs.
train_steps number of train steps.
validation_steps number of eval steps. If -1, the entire eval dataset is used.
validation_interval number of training steps to run between evaluations.
best_checkpoint_export_subdir if set, the trainer will keep track of the best evaluation metric, and export the corresponding best checkpoint under model_dir/best_checkpoint_export_subdir. Note that this only works if mode contains eval (such as train_and_eval, continuous_eval, and continuous_train_and_eval).
best_checkpoint_eval_metric for exporting the best checkpoint, which evaluation metric the trainer should monitor. This can be any evaluation metric appears on tensorboard.
best_checkpoint_metric_comp for exporting the best checkpoint, how the trainer should compare the evaluation metrics. This can be either higher (higher the better) or lower (lower the better).
validation_summary_subdir A 'str', sub directory for saving eval summary.
preemption_on_demand_checkpoint whether or not to save on-demand checkpoints after a preemption.
BUILDER

default_params Dataclass field
restrictions Dataclass field
loss_upper_bound Dataclass field
recovery_begin_steps Dataclass field
recovery_max_trials Dataclass field

Methods

as_dict

View source

Returns a dict representation of params_dict.ParamsDict.

For the nested params_dict.ParamsDict, a nested dict will be returned.

from_args

View source

Builds a config from the given list of arguments.

from_json

View source

Wrapper for from_yaml.

from_yaml

View source

get

View source

Accesses through built-in dictionary get method.

lock

View source

Makes the ParamsDict immutable.

override

View source

Override the ParamsDict with a set of given params.

Args
override_params a dict or a ParamsDict specifying the parameters to be overridden.
is_strict a boolean specifying whether override is strict or not. If True, keys in override_params must be present in the ParamsDict. If False, keys in override_params can be different from what is currently defined in the ParamsDict. In this case, the ParamsDict will be extended to include the new keys.

replace

View source

Overrides/returns a unlocked copy with the current config unchanged.

validate

View source

Validate the parameters consistency based on the restrictions.

This method validates the internal consistency using the pre-defined list of restrictions. A restriction is defined as a string which specifies a binary operation. The supported binary operations are {'==', '!=', '<', '<=', '>', '>='}. Note that the meaning of these operators are consistent with the underlying Python immplementation. Users should make sure the define restrictions on their type make sense.

For example, for a ParamsDict like the following

a:
  a1: 1
  a2: 2
b:
  bb:
    bb1: 10
    bb2: 20
  ccc:
    a1: 1
    a3: 3

one can define two restrictions like this ['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']

What it enforces are

  • a.a1 = 1 == b.ccc.a1 = 1
  • a.a2 = 2 <= b.bb.bb2 = 20

Raises
KeyError if any of the following happens (1) any of parameters in any of restrictions is not defined in ParamsDict, (2) any inconsistency violating the restriction is found.
ValueError if the restriction defined in the string is not supported.

__contains__

View source

Implements the membership test operator.

__eq__

IMMUTABLE_TYPES (<class 'str'>, <class 'int'>, <class 'float'>, <class 'bool'>, <class 'NoneType'>)
RESERVED_ATTR ['_locked', '_restrictions']
SEQUENCE_TYPES (<class 'list'>, <class 'tuple'>)
allow_tpu_summary False
best_checkpoint_eval_metric ''
best_checkpoint_export_subdir ''
best_checkpoint_metric_comp 'higher'
checkpoint_interval 1000
continuous_eval_timeout 3600
default_params None
eval_tf_function True
eval_tf_while_loop False
loss_upper_bound 1000000.0
max_to_keep 5
preemption_on_demand_checkpoint True
recovery_begin_steps 0
recovery_max_trials 0
restrictions None
steps_per_loop 1000
summary_interval 1000
train_steps 0
train_tf_function True
train_tf_while_loop True
validation_interval 1000
validation_steps -1
validation_summary_subdir 'validation'