tff.learning.templates.LearningProcess

A stateful process for learning tasks that produces metrics.

Inherits From: IterativeProcess

This class inherits the constraints documented by tff.templates.IterativeProcess, including an initialize and next attribute. The LearningProcess also contains an additional get_model_weights attribute.

All of initialize, next and get_model_weights must be tff.Computations, with the following type signatures:

  • initialize: ( -> S@SERVER)
  • next: (<S@SERVER, {D*}@CLIENTS> -> <state=S@SERVER, metrics=M@SERVER>)
  • get_model_weights: (S -> M) where {D*}@CLIENTS represents the sequence of data at a client, with D denoting the type of a single member of that sequence, and M representing the (unplaced) output type of the get_model_weights function.

Note that here, "model weights" is a loosely-defined term intended to refer to some kind of "representation" of the model being learned. This is typically some nested structure of tensors, and is often suitable for evaluation purposes.

For example, given a LearningProcess process and client data data, we could call the following to initialize, optionally load other model weights, update the state three times, and extract the model weights of the state:

state = process.initialize()
# Optional: state = process.set_model_weights(state, other_weights)
for _ in range(3):
 state, metrics = process.next(state, data)
model_weights = process.get_model_weights(state)

initialize_fn A no-arg tff.Computation that creates the initial state of the learning process.
next_fn A tff.Computation that defines an iterated function. Given that initialize_fn returns a type S@SERVER, the next_fn must return a LearningProcessOutput where the state attribute is assignable from values with type S@SERVER, and accepts two arguments with types assignable from values with type S@SERVER and {D*}@CLIENTS.
get_model_weights A tff.Computation that accepts an input S whose type is assignable from the result of init_fn. This computation is used to create a representation of the state that can be used for downstream tasks without requiring access to the entire server state. For example, get_model_weights could be used to extract model weights suitable for computing evaluation metrics on held-out data.
set_model_weights A tff.Computation that accepts two inputs S and M where the type of S is assignable from values with the type returned by init_fn and M is a representation of the model weights stored in S. This updates the model weights representation within the state with the incoming value and returns a new value of type S.

TypeError If initialize_fn and next_fn are not instances of tff.Computation.
TemplateInitFnParamNotEmptyError If initialize_fn has any input arguments.
TemplateStateNotAssignableError If the state returned by either initialize_fn or next_fn is not assignable to the first input argument of next_fn.
TemplateNextFnNumArgsError If next_fn does not have at exactly two input arguments.
LearningProcessPlacementError If the placements of initialize_fn and next_fn do not match the expected type placements.
LearningProcessOutputError If next_fn does not return a LearningProcessOutput.
LearningProcessSequenceTypeError If the second argument to next_fn is not a sequence type.

get_model_weights A tff.Computation returning the model weights of a server state.

This computation accepts an unplaced state of the process (originally produced by the initialize attribute), and returns an unplaced representation of the model weights of the state. Note that this representation need not take the form of a tff.learning.ModelWeights object, and may depend on the specific LearningProcess in question.

initialize A tff.Computation that initializes the process.

This computation must have no input arguments, and its output must be the initial state of the iterative process, placed at SERVER.

next A tff.Computation that runs one iteration of the process.

The first argument of this computation should always be the current state (originally produced by the initialize attribute), the second argument must be a tff.SequenceType placed at CLIENTS. The return type must be a LearningProcessOutput, with each field placed at SERVER.

set_model_weights A tff.Computation that sets the model weights of a server state.

This computation accepts two arguments: an unplaced state of the process (originally produced by the initialize attribute) and a new structure of tensors representing the model weights, and returns new unplaced state with the updated model weights. Note that the model weights representation need not take the form of a tff.learning.ModelWeights object, and may depend on the specific LearningProcess in question.

state_type The tff.Type of the state of the process.