Attend the Women in ML Symposium on December 7 Register now

tff.learning.templates.ClientWorkProcess

A stateful process capturing work at clients during learning.

Inherits From: MeasuredProcess, IterativeProcess

Used in the notebooks

Used in the tutorials

Client work encapsulates the main work performed by clients as part of a federated learning algorithm, such as several steps of gradient descent based on the client data, and returning a update to the initial model weights.

A ClientWorkProcess is a tff.templates.MeasuredProcess that formalizes the type signature of initialize and next for the core work performed by clients in a learning process.

initialize_fn A tff.Computation matching the criteria above.
next_fn A tff.Computation matching the criteria above.
get_hparams_fn An optional tff.Computation matching the criteria above. If not provided, this defaults to a computation that returns an empty ordred dictionary, regardless of the contents of the state.
set_hparams_fn An optional tff.Computation matching the criteria above. If not provided, this defaults to a pass-through computation, that returns the input state regardless of the hparams passed in.

TemplateNotFederatedError If any of the federated computations provided do not return a federated type.
TemplateNextFnNumArgsError If the next_fn has an incorrect number of arguments.
TemplatePlacementError If any of the federated computations have an incorrect placement.
ClientDataTypeError If the third input of next_fn is not a sequence type placed at CLIENTS.
ClientResultTypeError If the second output of next_fn does not meet the criteria outlined above.
GetHparamsTypeError If the type signature of get_hparams_fn does not meet the criteria above.
SetHparamsTypeError If the type signature of set_hparams_fn does not meet the criteria above.

get_hparams

initialize A no-arg tff.Computation that returns the initial state.
next A tff.Computation that runs one iteration of the process.

Its first argument should always be the current state (originally produced by tff.templates.MeasuredProcess.initialize), and the return type must be a tff.templates.MeasuredProcessOutput.

set_hparams

state_type The tff.Type of the state of the process.