View source on GitHub |
Runs specified rounds of training and optionally evaluates the model.
tff.learning.programs.train_model(
*,
train_process,
initial_train_state=None,
train_data_source,
train_per_round_clients,
train_total_rounds,
program_state_manager,
model_output_manager,
train_metrics_manager=None,
evaluation_manager,
evaluation_periodicity
) -> bool
This method will create an initial training state and repeatedly call
train_process.next
, advancing the state of the training process. Depending
on the configuration of evaluation_manager
, asynchronous evaluation loops
will be spawned and executed in parallel.
This method will save the initial state (result of train_process.initialize
or passed via initial_train_state
) using program_state_manager
. If the
state manager is configured to keep the first version (e.g.
tff.program.FileStateProgramManager
's keep_first
parameter), then round
zero (the initialization) will be retained so that future experiments can use
the same starting point.
If the initial_train_state
is not None, its type signature should be the
same as the type_signature of the result of train_process.initialize
.
Args | |
---|---|
train_process
|
A tff.learning.templates.LearningProcess to run for
training. The state type of the train_process should be a
tff.learning.templates.LearningAlgorithmState , and the initial train
state can be provided using the initial_train_state argument.
|
initial_train_state
|
(Optional) A
tff.learning.templates.LearningAlgorithmState of the initial state of
the train process. Its type signature should match the type_signature of
the result of train_process.initialize . If not specified, use the
retsult of train_process.initialize .
|
train_data_source
|
A tff.program.FederatedDataSource which returns client
data used during training.
|
train_per_round_clients
|
The number of clients per round of training. |
train_total_rounds
|
Total number of rounds of training. |
program_state_manager
|
A tff.program.ProgramStateManager used to save
program state for fault tolerance.
|
model_output_manager
|
A tff.program.ReleaseManager to release the model,
the results can be used for building inference models after training, or
warm-starting future training loops.
|
train_metrics_manager
|
A tff.program.ReleaseManager to release metrics of
training. Use tff.program.GroupingReleaseManager to supply multiple
release managers.
|
evaluation_manager
|
An EvaluationManager used to create a state manager
for each evaluation loop that is forked off from the training loop.
|
evaluation_periodicity
|
Either a integer number of rounds or
datetime.timedelta to await before sending a new training checkpoint to
evaluation_manager.start_evaluation . Note that the last training round
will always be evaluated even if it does not satisfy the periodicity.
|
Raises | |
---|---|
ValueError
|
If the train state is None. |