tff.learning.programs.train_model

Runs specified rounds of training and optionally evaluates the model.

This method will create an initial training state and repeatedly call train_process.next, advancing the state of the training process. Depending on the configuration of evaluation_manager, asynchronous evaluation loops will be spawned and executed in parallel.

This method will save the initial state (result of train_process.initialize or passed via initial_train_state) using program_state_manager. If the state manager is configured to keep the first version (e.g. tff.program.FileStateProgramManager's keep_first parameter), then round zero (the initialization) will be retained so that future experiments can use the same starting point. If there is a previous program state, this method will load the latest state and resume from that state. In this case, the data source iterator will be restored from that state instead of being created from the input train_data_source.

If the initial_train_state is not None, its type signature should be the same as the type_signature of the result of train_process.initialize.

train_process A tff.learning.templates.LearningProcess to run for training. The state type of the train_process should be a tff.learning.templates.LearningAlgorithmState, and the initial train state can be provided using the initial_train_state argument.
initial_train_state (Optional) A tff.learning.templates.LearningAlgorithmState of the initial state of the train process. Its type signature should match the type_signature of the result of train_process.initialize. If not specified, use the retsult of train_process.initialize.
train_data_source A tff.program.FederatedDataSource which returns client data used during training.
train_per_round_clients The number of clients per round of training.
train_total_rounds Total number of rounds of training.
should_discard_round A Callable that takes the tff.learning.templates.LearningProcessOutput returned by training_process.next and returns whether the round should be discarded. If a round should be discarded, the program will roll back to the state of the previous round and retry this round.
program_state_manager A tff.program.ProgramStateManager used to save program state for fault tolerance.
model_output_manager A tff.program.ReleaseManager to release the model, the results can be used for building inference models after training, or warm-starting future training loops.
train_metrics_manager A tff.program.ReleaseManager to release metrics of training. Use tff.program.GroupingReleaseManager to supply multiple release managers.
evaluation_manager An EvaluationManager used to create a state manager for each evaluation loop that is forked off from the training loop.
evaluation_periodicity Either a integer number of rounds or datetime.timedelta to await before sending a new training checkpoint to evaluation_manager.start_evaluation. Note that the last training round will always be evaluated even if it does not satisfy the periodicity.

ValueError If the train state is None.