View source on GitHub |
Runs specified rounds of training and optionally evaluates the model.
tff.learning.programs.train_model(
*,
train_process,
initial_train_state=None,
train_data_source,
train_per_round_clients,
train_total_rounds,
should_discard_round=None,
program_state_manager,
model_output_manager,
train_metrics_manager=None,
evaluation_manager,
evaluation_periodicity
) -> bool
This method will create an initial training state and repeatedly call
train_process.next
, advancing the state of the training process. Depending
on the configuration of evaluation_manager
, asynchronous evaluation loops
will be spawned and executed in parallel.
This method will save the initial state (result of train_process.initialize
or passed via initial_train_state
) using program_state_manager
. If the
state manager is configured to keep the first version (e.g.
tff.program.FileStateProgramManager
's keep_first
parameter), then round
zero (the initialization) will be retained so that future experiments can use
the same starting point. If there is a previous program state, this method
will load the latest state and resume from that state. In this case, the data
source iterator will be restored from that state instead of being created from
the input train_data_source
.
If the initial_train_state
is not None, its type signature should be the
same as the type_signature of the result of train_process.initialize
.
Args | |
---|---|
train_process
|
A tff.learning.templates.LearningProcess to run for
training. The state type of the train_process should be a
tff.learning.templates.LearningAlgorithmState , and the initial train
state can be provided using the initial_train_state argument.
|
initial_train_state
|
(Optional) A
tff.learning.templates.LearningAlgorithmState of the initial state of
the train process. Its type signature should match the type_signature of
the result of train_process.initialize . If not specified, use the
retsult of train_process.initialize .
|
train_data_source
|
A tff.program.FederatedDataSource which returns client
data used during training.
|
train_per_round_clients
|
The number of clients per round of training. |
train_total_rounds
|
Total number of rounds of training. |
should_discard_round
|
A Callable that takes the
tff.learning.templates.LearningProcessOutput returned by
training_process.next and returns whether the round should be discarded.
If a round should be discarded, the program will roll back to the state of
the previous round and retry this round.
|
program_state_manager
|
A tff.program.ProgramStateManager used to save
program state for fault tolerance.
|
model_output_manager
|
A tff.program.ReleaseManager to release the model,
the results can be used for building inference models after training, or
warm-starting future training loops.
|
train_metrics_manager
|
A tff.program.ReleaseManager to release metrics of
training. Use tff.program.GroupingReleaseManager to supply multiple
release managers.
|
evaluation_manager
|
An EvaluationManager used to create a state manager
for each evaluation loop that is forked off from the training loop.
|
evaluation_periodicity
|
Either a integer number of rounds or
datetime.timedelta to await before sending a new training checkpoint to
evaluation_manager.start_evaluation . Note that the last training round
will always be evaluated even if it does not satisfy the periodicity.
|
Raises | |
---|---|
ValueError
|
If the train state is None. |