tff.learning.programs.train_model
Stay organized with collections
Save and categorize content based on your preferences.
Runs specified rounds of training and optionally evaluates the model.
tff.learning.programs.train_model(
*,
train_process,
initial_train_state=None,
train_data_source,
train_per_round_clients,
train_total_rounds,
should_discard_round=None,
program_state_manager,
model_output_manager,
train_metrics_manager=None,
evaluation_manager,
evaluation_periodicity
) -> bool
This method will create an initial training state and repeatedly call
train_process.next
, advancing the state of the training process. Depending
on the configuration of evaluation_manager
, asynchronous evaluation loops
will be spawned and executed in parallel.
This method will save the initial state (result of train_process.initialize
or passed via initial_train_state
) using program_state_manager
. If the
state manager is configured to keep the first version (e.g.
tff.program.FileStateProgramManager
's keep_first
parameter), then round
zero (the initialization) will be retained so that future experiments can use
the same starting point. If there is a previous program state, this method
will load the latest state and resume from that state. In this case, the data
source iterator will be restored from that state instead of being created from
the input train_data_source
.
If the initial_train_state
is not None, its type signature should be the
same as the type_signature of the result of train_process.initialize
.
Args |
train_process
|
A tff.learning.templates.LearningProcess to run for
training. The state type of the train_process should be a
tff.learning.templates.LearningAlgorithmState , and the initial train
state can be provided using the initial_train_state argument.
|
initial_train_state
|
(Optional) A
tff.learning.templates.LearningAlgorithmState of the initial state of
the train process. Its type signature should match the type_signature of
the result of train_process.initialize . If not specified, use the
retsult of train_process.initialize .
|
train_data_source
|
A tff.program.FederatedDataSource which returns client
data used during training.
|
train_per_round_clients
|
The number of clients per round of training.
|
train_total_rounds
|
Total number of rounds of training.
|
should_discard_round
|
A Callable that takes the
tff.learning.templates.LearningProcessOutput returned by
training_process.next and returns whether the round should be discarded.
If a round should be discarded, the program will roll back to the state of
the previous round and retry this round.
|
program_state_manager
|
A tff.program.ProgramStateManager used to save
program state for fault tolerance.
|
model_output_manager
|
A tff.program.ReleaseManager to release the model,
the results can be used for building inference models after training, or
warm-starting future training loops.
|
train_metrics_manager
|
A tff.program.ReleaseManager to release metrics of
training. Use tff.program.GroupingReleaseManager to supply multiple
release managers.
|
evaluation_manager
|
An EvaluationManager used to create a state manager
for each evaluation loop that is forked off from the training loop.
|
evaluation_periodicity
|
Either a integer number of rounds or
datetime.timedelta to await before sending a new training checkpoint to
evaluation_manager.start_evaluation . Note that the last training round
will always be evaluated even if it does not satisfy the periodicity.
|
Raises |
ValueError
|
If the train state is None.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.
[null,null,["Last updated 2024-09-20 UTC."],[],[],null,["# tff.learning.programs.train_model\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/federated/blob/v0.87.0 Version 2.0, January 2004 Licensed under the Apache License, Version 2.0 (the) |\n\nRuns specified rounds of training and optionally evaluates the model. \n\n tff.learning.programs.train_model(\n *,\n train_process,\n initial_train_state=None,\n train_data_source,\n train_per_round_clients,\n train_total_rounds,\n should_discard_round=None,\n program_state_manager,\n model_output_manager,\n train_metrics_manager=None,\n evaluation_manager,\n evaluation_periodicity\n ) -\u003e bool\n\nThis method will create an initial training state and repeatedly call\n`train_process.next`, advancing the state of the training process. Depending\non the configuration of `evaluation_manager`, asynchronous evaluation loops\nwill be spawned and executed in parallel.\n\nThis method will save the initial state (result of `train_process.initialize`\nor passed via `initial_train_state`) using `program_state_manager`. If the\nstate manager is configured to keep the first version (e.g.\n`tff.program.FileStateProgramManager`'s `keep_first` parameter), then round\nzero (the initialization) will be retained so that future experiments can use\nthe same starting point. If there is a previous program state, this method\nwill load the latest state and resume from that state. In this case, the data\nsource iterator will be restored from that state instead of being created from\nthe input `train_data_source`.\n\nIf the `initial_train_state` is not None, its type signature should be the\nsame as the type_signature of the result of `train_process.initialize`.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `train_process` | A [`tff.learning.templates.LearningProcess`](../../../tff/learning/templates/LearningProcess) to run for training. The state type of the `train_process` should be a [`tff.learning.templates.LearningAlgorithmState`](../../../tff/learning/templates/LearningAlgorithmState), and the initial train state can be provided using the `initial_train_state` argument. |\n| `initial_train_state` | (Optional) A [`tff.learning.templates.LearningAlgorithmState`](../../../tff/learning/templates/LearningAlgorithmState) of the initial state of the train process. Its type signature should match the `type_signature` of the result of `train_process.initialize`. If not specified, use the retsult of `train_process.initialize`. |\n| `train_data_source` | A [`tff.program.FederatedDataSource`](../../../tff/program/FederatedDataSource) which returns client data used during training. |\n| `train_per_round_clients` | The number of clients per round of training. |\n| `train_total_rounds` | Total number of rounds of training. |\n| `should_discard_round` | A Callable that takes the [`tff.learning.templates.LearningProcessOutput`](../../../tff/learning/templates/LearningProcessOutput) returned by `training_process.next` and returns whether the round should be discarded. If a round should be discarded, the program will roll back to the state of the previous round and retry this round. |\n| `program_state_manager` | A [`tff.program.ProgramStateManager`](../../../tff/program/ProgramStateManager) used to save program state for fault tolerance. |\n| `model_output_manager` | A [`tff.program.ReleaseManager`](../../../tff/program/ReleaseManager) to release the model, the results can be used for building inference models after training, or warm-starting future training loops. |\n| `train_metrics_manager` | A [`tff.program.ReleaseManager`](../../../tff/program/ReleaseManager) to release metrics of training. Use [`tff.program.GroupingReleaseManager`](../../../tff/program/GroupingReleaseManager) to supply multiple release managers. |\n| `evaluation_manager` | An `EvaluationManager` used to create a state manager for each evaluation loop that is forked off from the training loop. |\n| `evaluation_periodicity` | Either a integer number of rounds or `datetime.timedelta` to await before sending a new training checkpoint to `evaluation_manager.start_evaluation`. Note that the last training round will always be evaluated even if it does not satisfy the periodicity. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|-----------------------------|\n| `ValueError` | If the train state is None. |\n\n\u003cbr /\u003e"]]