View source on GitHub
|
A stateful process for learning tasks that produces metrics.
Inherits From: IterativeProcess
tff.learning.templates.LearningProcess(
initialize_fn: tff.Computation,
next_fn: tff.Computation,
get_model_weights: tff.Computation,
set_model_weights: tff.Computation,
*,
get_hparams_fn: Optional[tff.Computation] = None,
set_hparams_fn: Optional[tff.Computation] = None
)
This class inherits the constraints documented by
tff.templates.IterativeProcess, including an initialize and next
attribute. The LearningProcess also contains additional attributes,
including get_model_weights and get_hparams. The former can be used to
get out structures suitable for evaluation purposes, while the latter can
be used to extract hyperparameters from the process. There are also
corresponding set_model_weights and set_hparams attributes that can set
these structures in a given state.
For example, given a LearningProcess process and client data data, we
could call the following to initialize, optionally load other model weights,
update the state three times, and extract the model weights of the state:
state = process.initialize()# Optional: state = process.set_model_weights(state, other_weights)for _ in range(3):state, metrics = process.next(state, data)model_weights = process.get_model_weights(state)
Args | |
|---|---|
initialize_fn
|
A no-arg tff.Computation that creates the initial state
of the learning process.
|
next_fn
|
A tff.Computation that defines an iterated function. Given that
initialize_fn returns a type S@SERVER, the next_fn must return a
LearningProcessOutput where the state attribute is assignable from
values with type S@SERVER, and accepts two arguments with types
assignable from values with type S@SERVER and {D*}@CLIENTS.
|
get_model_weights
|
A tff.Computation that accepts an input S whose
type is assignable from the result of init_fn. This computation is
used to create a representation of the state that can be used for
downstream tasks without requiring access to the entire server state.
For example, get_model_weights could be used to extract model weights
suitable for computing evaluation metrics on held-out data.
|
set_model_weights
|
A tff.Computation that accepts two inputs S and M
where the type of S is assignable from values with the type returned
by init_fn and M is a representation of the model weights stored in
S. This updates the model weights representation within the state with
the incoming value and returns a new value of type S.
|
get_hparams_fn
|
An optional tff.Computation accepting the state S and
returning the hyperparameters H. If not provided, this defaults to a
computation that returns an empty ordered dictionary, regardless of the
contents of the state.
|
set_hparams_fn
|
An optional tff.Computation accepting the state S and
hyperparameters H (matching the output of get_hparams_fn) and
returning an updated state S. If not provided, this defaults to a
pass-through computation that returns the input state regardless of the
hparams passed in.
|
Raises | |
|---|---|
TypeError
|
If initialize_fn and next_fn are not instances of
tff.Computation.
|
TemplateInitFnParamNotEmptyError
|
If initialize_fn has any input
arguments.
|
TemplateStateNotAssignableError
|
If the state returned by either
initialize_fn or next_fn is not assignable to the first input
argument of next_fn.
|
TemplateNextFnNumArgsError
|
If next_fn does not have at exactly two
input arguments.
|
LearningProcessPlacementError
|
If the placements of initialize_fn and
next_fn do not match the expected type placements.
|
LearningProcessOutputError
|
If next_fn does not return a
LearningProcessOutput.
|
GetModelWeightsTypeSignatureError
|
If the input type of get_model_weights does not match the process state type. |
SetModelWeightsTypeSignatureError
|
If the type of the first input or the type of the output of set_model_weights does not match the process state type. |
Attributes | |
|---|---|
get_hparams
|
A tff.Computation returning the hyperparameters of a server state.
This computation accepts an unplaced state of the process (originally
produced by the |
get_model_weights
|
A tff.Computation returning the model weights of a server state.
This computation accepts an unplaced state of the process (originally
produced by the |
initialize
|
A tff.Computation that initializes the process.
This computation must have no input arguments, and its output must be the
initial state of the learning process, placed at |
next
|
A tff.Computation that runs one iteration of the process.
The first argument of this computation should always be the current state
(originally produced by the |
set_hparams
|
A tff.Computation that sets the hyperparamters of a server state.
This computation accepts two arguments: an unplaced state of the process
(originally produced by the |
set_model_weights
|
A tff.Computation that sets the model weights of a server state.
This computation accepts two arguments: an unplaced state of the process
(originally produced by the |
state_type
|
The tff.Type of the state of the process.
|
View source on GitHub