View source on GitHub |
Composes specialized measured processes into a learning process.
tff.learning.templates.compose_learning_process(
initial_model_weights_fn: tff.Computation
,
model_weights_distributor: tff.learning.templates.DistributionProcess
,
client_work: tff.learning.templates.ClientWorkProcess
,
model_update_aggregator: tff.templates.AggregationProcess
,
model_finalizer: tff.learning.templates.FinalizerProcess
) -> tff.learning.templates.LearningProcess
Used in the notebooks
Used in the tutorials |
---|
Given 4 specialized measured processes (described below) that make a learning
process, and a computation that returns initial model weights to be used for
training, this method validates that the processes fit together, and returns a
LearningProcess
. Please see the tutorial at
https://www.tensorflow.org/federated/tutorials/composing_learning_algorithms
for more details on composing learning processes.
The main purpose of the 4 measured processes are:
model_weights_distributor
: Make global model weights at server available as the starting point for learning work to be done at clients.client_work
: Produce an update to the model received at clients.model_update_aggregator
: Aggregates the model updates from clients to the server.model_finalizer
: Updates the global model weights using the aggregated model update at server.
The next
computation of the created learning process is composed from the
next
computations of the 4 measured processes, in order as visualized below.
The type signatures of the processes must be such that this chaining is
possible. Each process also reports its own metrics.
┌─────────────────────────┐
│model_weights_distributor│
└△─┬─┬────────────────────┘
│ │┌▽──────────┐
│ ││client_work│
│ │└┬─────┬────┘
│┌▽─▽────┐│
││metrics││
│└△─△────┘│
│ │┌┴─────▽────────────────┐
│ ││model_update_aggregator│
│ │└┬──────────────────────┘
┌┴─┴─▽──────────┐
│model_finalizer│
└┬──────────────┘
┌▽─────┐
│result│
└──────┘
The get_hparams
computation of the created learning process produces a
nested ordered dictionary containing the result of client_work.get_hparams
and finalizer.get_hparams
. The set_hparams
computation operates similarly,
by delegating to client_work.set_hparams
and finalizer.set_hparams
to set
the hyperparameters in their associated states.
Args | |
---|---|
initial_model_weights_fn
|
A tff.Computation that returns (unplaced)
initial model weights.
|
model_weights_distributor
|
A tff.learning.templates.DistributionProcess .
|
client_work
|
A tff.learning.templates.ClientWorkProcess .
|
model_update_aggregator
|
A tff.templates.AggregationProcess .
|
model_finalizer
|
A tff.learning.templates.FinalizerProcess .
|
Returns | |
---|---|
A tff.learning.templates.LearningProcess .
|
Raises | |
---|---|
ClientSequenceTypeError
|
If the first arg of the next method of the
resulting LearningProcess is not a structure of sequences placed at
tff.CLIENTS .
|