![]() |
A checkpoint manager backed by a file system.
tff.simulation.FileCheckpointManager(
root_dir: str,
prefix: str = 'ckpt_',
keep_total: int = 5,
keep_first: bool = True
)
This checkpoint manager is a utility to save and load checkpoints. While
the checkpoint manager is compatible with any nested structure supported by
tf.convert_to_tensor
, checkpoints may often represent the output of a
tff.templates.IterativeProcess
. For example, one possible use case would
be to save the ServerState
output of an iterative process created via
tff.learning
. This is comparable to periodically saving model weights and
optimizer states during non-federated training.
The implementation you find here is slightly different from
tf.train.CheckpointManager
. This implementation yields nested structures
that are immutable whereas tf.train.CheckpointManager
is used to manage
tf.train.Checkpoint
objects, which are mutable collections. Additionally,
this implementation allows retaining the initial checkpoint as part of the
total number of checkpoints that are kept.
The checkpoint manager is intended only for allowing simulations to be resumed after interruption. In particular, it is intended to only restart the same simulation, run with the same version of TensorFlow Federated.
Args | |
---|---|
root_dir
|
A path on the filesystem to store checkpoints. |
prefix
|
A string to use as the prefix for checkpoint names. |
keep_total
|
An integer representing the total number of checkpoints to keep. |
keep_first
|
A boolean indicating if the first checkpoint should be kept,
irrespective of whether it is in the last keep_total checkpoints. This
is desirable in settings where you would like to ensure full
reproducibility of the simulation, especially in settings where
model weights or optimizer states are initialized randomly. By loading
from the initial checkpoint, one can avoid re-initializing and obtaining
different results.
|
Methods
load_checkpoint
load_checkpoint(
structure: Any,
round_num: int
) -> Any
Returns the checkpointed state for the given round_num
.
Args | |
---|---|
structure
|
A nested structure which tf.convert_to_tensor supports to use
as a template when reconstructing the loaded template.
|
round_num
|
An integer representing the round to load from. |
load_latest_checkpoint
load_latest_checkpoint(
structure: Any
) -> Tuple[Any, Union[int, None]]
Loads the latest state and round number.
Args | |
---|---|
structure
|
A nested structure which tf.convert_to_tensor supports to use
as a template when reconstructing the loaded template.
|
Returns | |
---|---|
A tuple of (state, round_num) where state matches the Python
structure in structure , and round_num is an integer. If no checkpoints
have been previously saved, returns the tuple (None, None) .
|
load_latest_checkpoint_or_default
load_latest_checkpoint_or_default(
default: Any
) -> Tuple[Any, int]
Loads latest checkpoint, loading default
if no checkpoints exist.
Saves default
as the 0th checkpoint if no checkpoints exist.
Args | |
---|---|
default
|
A nested structure which tf.convert_to_tensor supports to use
as a template when reconstructing the loaded template. This structure
will be saved as the checkpoint for round number 0 and returned if there
are no pre-existing saved checkpoints.
|
Returns | |
---|---|
A tuple of (state, round_num) where state matches the Python
structure in structure , and round_num is an integer. If no
checkpoints have been written, returns (default, 0) .
|
save_checkpoint
save_checkpoint(
state: Any,
round_num: int
) -> None
Saves a new checkpointed state
for the given round_num
.
Args | |
---|---|
state
|
A nested structure which tf.convert_to_tensor supports.
|
round_num
|
An integer representing the current training round. |