View source on GitHub
|
Composes specialized measured processes into a learning process.
tff.learning.templates.compose_learning_process(
initial_model_weights_fn: tff.Computation,
model_weights_distributor: tff.learning.templates.DistributionProcess,
client_work: tff.learning.templates.ClientWorkProcess,
model_update_aggregator: tff.templates.AggregationProcess,
model_finalizer: tff.learning.templates.FinalizerProcess
) -> tff.learning.templates.LearningProcess
Used in the notebooks
| Used in the tutorials |
|---|
Given 4 specialized measured processes (described below) that make a learning
process, and a computation that returns initial model weights to be used for
training, this method validates that the processes fit together, and returns a
LearningProcess. Please see the tutorial at
https://www.tensorflow.org/federated/tutorials/composing_learning_algorithms
for more details on composing learning processes.
The main purpose of the 4 measured processes are:
model_weights_distributor: Make global model weights at server available as the starting point for learning work to be done at clients.client_work: Produce an update to the model received at clients.model_update_aggregator: Aggregates the model updates from clients to the server.model_finalizer: Updates the global model weights using the aggregated model update at server.
The next computation of the created learning process is composed from the
next computations of the 4 measured processes, in order as visualized below.
The type signatures of the processes must be such that this chaining is
possible. Each process also reports its own metrics.
┌─────────────────────────┐
│model_weights_distributor│
└△─┬─┬────────────────────┘
│ │┌▽──────────┐
│ ││client_work│
│ │└┬─────┬────┘
│┌▽─▽────┐│
││metrics││
│└△─△────┘│
│ │┌┴─────▽────────────────┐
│ ││model_update_aggregator│
│ │└┬──────────────────────┘
┌┴─┴─▽──────────┐
│model_finalizer│
└┬──────────────┘
┌▽─────┐
│result│
└──────┘
The get_hparams computation of the created learning process produces a
nested ordered dictionary containing the result of client_work.get_hparams
and finalizer.get_hparams. The set_hparams computation operates similarly,
by delegating to client_work.set_hparams and finalizer.set_hparams to set
the hyperparameters in their associated states.
Args | |
|---|---|
initial_model_weights_fn
|
A tff.Computation that returns (unplaced)
initial model weights.
|
model_weights_distributor
|
A tff.learning.templates.DistributionProcess.
|
client_work
|
A tff.learning.templates.ClientWorkProcess.
|
model_update_aggregator
|
A tff.templates.AggregationProcess.
|
model_finalizer
|
A tff.learning.templates.FinalizerProcess.
|
Returns | |
|---|---|
A tff.learning.templates.LearningProcess.
|
Raises | |
|---|---|
ClientSequenceTypeError
|
If the first arg of the next method of the
resulting LearningProcess is not a structure of sequences placed at
tff.CLIENTS.
|
View source on GitHub