Trains and tunes a federated model using Vizier.
tff.learning.programs.train_model_with_vizier(
*,
study,
total_trials,
num_parallel_trials=1,
update_hparams,
train_model_program_logic,
train_process_factory,
train_data_source,
total_rounds,
num_clients,
program_state_manager_factory,
model_output_manager_factory,
train_metrics_manager_factory=None,
evaluation_manager_factory,
evaluation_periodicity
)
Args |
study
|
The Vizier study to use to to tune train_model_program_logic .
|
total_trials
|
The number of Vizier trials.
|
num_parallel_trials
|
The number of Vizier trials to be evaluated in
parallel. Default is 1.
|
update_hparams
|
A tff.Computation to use to update the models hparams
using a trials parameters.
|
train_model_program_logic
|
The program logic to use for training and
evaluating the model.
|
train_process_factory
|
A factory for creating
tff.learning.templates.LearningProcess to run for training.
|
train_data_source
|
A tff.program.FederatedDataSource which returns client
data used during training.
|
total_rounds
|
The number of rounds of training.
|
num_clients
|
The number of clients per round of training.
|
program_state_manager_factory
|
A factory for creating
tff.program.ProgramStateManager s for each trail.
|
model_output_manager_factory
|
A factory for creating
tff.program.ReleaseManager s used to release the model.
|
train_metrics_manager_factory
|
A factory for creating
tff.program.ReleaseManager s used to release training metrics for each
trail.
|
evaluation_manager_factory
|
A factory for creating
tff.learning.programs.EvaluationManager s for each trail.
|
evaluation_periodicity
|
Either a integer number of rounds or
datetime.timedelta to await before sending a new training checkpoint to
evaluation_manager.start_evaluation .
|