ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tff.simulation.run_training_process

Runs a federated training_process.

The following tff.Computation types signaures are required:

  • training_process.initialize: ( -> state).
  • training_process.next: <state, client_data> -> <state, metrics>
  • evaulation_fn: <state, client_data> -> metrics

This function performs up to total_rounds updates to the state of the given training_process. At each training round, this update occurs by invoking training_process.next with state and the output of training_selection_fn. Depending on rounds_per_evaluation and rounds_per_saving_program_state, each training round may be followed by an invocation of the evaluation_fn and by saving the program state.

In addition to the training metrics and evaluation metrics, this function adds the following performance metrics (key and descriptions):

  • tff.simulation.ROUND_NUMBER_KEY: The round number.
  • tff.simulation.TRAINING_TIME_KEY: The amount of time (in seconds) it takes to run one round of training.
  • tff.simulation.EVALUATION_TIME_KEY: The amount of time (in seconds) it takes to run one round of evaluation.

training_process A tff.templates.IterativeProcess to run for training.
training_selection_fn A Callable accepting an integer round number, and returning a list of client data to use for trainig in that round.
total_rounds The number of training rounds to run.
evaluation_fn An optional tff.Computation to run for evaluation.
evaluation_selection_fn A optional Callable accepting an integer round number, and returning a list of client data to use for evaluation in that round.
rounds_per_evaluation The number of training rounds to run between each invocation of evaluation_fn.
program_state_manager An optional tff.program.ProgramStateManager to use to save program state for fault tolerance.
rounds_per_saving_program_state The number of training rounds to run between saving program state.
metrics_managers An optional list of tff.program.ReleaseManagerss to use to save metrics.

The state of the training process after training.