View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
In this tutorial, we explain design principles behind the tff.aggregators
module and best practices for implementing custom aggregation of values from clients to server.
Prerequisites. This tutorial assumes you are already familiar with basic concepts of Federated Core such as placements (federated_language.SERVER
, federated_language.CLIENTS
), how TFF represents computations (tff.tensorflow.computation
, tff.federated_computation
) and their type signatures.
pip install --quite --upgrade federated_language
pip install --quiet --upgrade tensorflow-federated
Design summary
In TFF, "aggregation" refers to the movement of a set of values on federated_language.CLIENTS
to produce an aggregate value of the same type on federated_language.SERVER
. That is, each individual client value need not be available. For example in federated learning, client model updates are averaged to get an aggregate model update to apply to the global model on the server.
In addition to operators accomplishing this goal such as federated_language.federated_sum
, TFF provides tff.templates.AggregationProcess
(a stateful process) which formalizes the type signature for aggregation computation so it can generalize to more complex forms than a simple sum.
The main components of the tff.aggregators
module are factories for creation of the AggregationProcess
, which are designed to be generally useful and replacable building blocks of TFF in two aspects:
- Parameterized computations. Aggregation is an independent building block that can be plugged into other TFF modules designed to work with
tff.aggregators
to parameterize their necessary aggregation.
Example:
learning_process = tff.learning.algorithms.build_weighted_fed_avg(
...,
model_aggregator=tff.aggregators.MeanFactory())
- Aggregation composition. An aggregation building block can be composed with other aggregation building blocks to create more complex composite aggregations.
Example:
secure_mean = tff.aggregators.MeanFactory(
value_sum_factory=tff.aggregators.SecureSumFactory(...))
The rest of this tutorial explains how these two goals are achieved.
Aggregation process
We first summarize the tff.templates.AggregationProcess
, and follow with the factory pattern for its creation.
The tff.templates.AggregationProcess
is an tff.templates.MeasuredProcess
with type signatures specified for aggregation. In particular, the initialize
and next
functions have the following type signatures:
( -> state_type@SERVER)
(<state_type@SERVER, {value_type}@CLIENTS, *> -> <state_type@SERVER, value_type@SERVER, measurements_type@SERVER>)
The state (of type state_type
) must be placed at server. The next
function takes as input argument the state and a value to be aggregated (of type value_type
) placed at clients. The *
means optional other input arguments, for instance weights in a weighted mean. It returns an updated state object, the aggregated value of the same type placed at server, and some measurements.
Note that both the state to be passed between executions of the next
function, and the reported measurements intended to report any information depending on a specific execution of the next
function, may be empty. Nevertheless, they have to be explicitly specified for other parts of TFF to have a clear contract to follow.
Other TFF modules, for instance the model updates in tff.learning
, are expected to use the tff.templates.AggregationProcess
to parameterize how values are aggregated. However, what exactly are the values aggregated and what their type signatures are, depends on other details of the model being trained and the learning algorithm used to do it.
To make aggregation independent of the other aspects of computations, we use the factory pattern -- we create the appropriate tff.templates.AggregationProcess
once the relevant type signatures of objects to be aggregated are available, by invoking the create
method of the factory. Direct handling of the aggregation process is thus needed only for library authors, who are responsible for this creation.
Aggregation process factories
There are two abstract base factory classes for unweighted and weighted aggregation. Their create
method takes type signatures of value to be aggregated and returns a tff.templates.AggregationProcess
for aggregation of such values.
The process created by tff.aggregators.UnweightedAggregationFactory
takes two input arguments: (1) state at server and (2) value of specified type value_type
.
An example implementation is tff.aggregators.SumFactory
.
The process created by tff.aggregators.WeightedAggregationFactory
takes three input arguments: (1) state at server, (2) value of specified type value_type
and (3) weight of type weight_type
, as specified by the factory's user when invoking its create
method.
An example implementation is tff.aggregators.MeanFactory
which computes a weighted mean.
The factory pattern is how we achieve the first goal stated above; that aggregation is an independent building block. For example, when changing which model variables are trainable, a complex aggregation does not necessarily need to change; the factory representing it will be invoked with a different type signature when used by a method such as tff.learning.algorithms.build_weighted_fed_avg
.
Compositions
Recall that a general aggregation process can encapsulate (a) some preprocessing of the values at clients, (b) movement of values from client to server, and (c) some postprocessing of the aggregated value at the server. The second goal stated above, aggregation composition, is realized inside the tff.aggregators
module by structuring the implementation of the aggregation factories such that part (b) can be delegated to another aggregation factory.
Rather than implementing all necessary logic within a single factory class, the implementations are by default focused on a single aspect relevant for aggregation. When needed, this pattern then enables us to replace the building blocks one at a time.
An example is the weighted tff.aggregators.MeanFactory
. Its implementation multiplies provided values and weights at clients, then sums both weighted values and weights independently, and then divides the sum of weighted values by the sum of weights at the server. Instead of implementing the summations by directly using the federated_language.federated_sum
operator, the summation is delegated to two instances of tff.aggregators.SumFactory
.
Such structure makes it possible for the two default summations to be replaced by different factories, which realize the sum differently. For example, a tff.aggregators.SecureSumFactory
, or a custom implementation of the tff.aggregators.UnweightedAggregationFactory
. Conversely, time, tff.aggregators.MeanFactory
can itself be an inner aggregation of another factory such as tff.aggregators.clipping_factory
, if the values are to be clipped before averaging.
See the previous Tuning recommended aggregations for learning tutorial for receommended uses of the composition mechanism using existing factories in the tff.aggregators
module.
Best practices by example
We are going to illustrate the tff.aggregators
concepts in detail by implementing a simple example task, and make it progressively more general. Another way to learn is to look at the implementation of existing factories.
import collections
import federated_language
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
Instead of summing value
, the example task is to sum value * 2.0
and then divide the sum by 2.0
. The aggregation result is thus mathematically equivalent to directly summing the value
, and could be thought of as consisting of three parts: (1) scaling at clients (2) summing across clients (3) unscaling at server.
Following the design explained above, the logic will be implemented as a subclass of tff.aggregators.UnweightedAggregationFactory
, which creates appropriate tff.templates.AggregationProcess
when given a value_type
to aggregate:
Minimal implementation
For the example task, the computations necessary are always the same, so there is no need for using state. It is thus empty, and represented as federated_language.federated_value((), federated_language.SERVER)
. The same holds for measurements, for now.
The minimal implementation of the task is thus as follows:
class ExampleTaskFactory(tff.aggregators.UnweightedAggregationFactory):
def create(self, value_type):
@tff.federated_computation()
def initialize_fn():
return federated_language.federated_value((), federated_language.SERVER)
@tff.federated_computation(initialize_fn.type_signature.result,
federated_language.FederatedType(value_type, federated_language.CLIENTS))
def next_fn(state, value):
scaled_value = federated_language.federated_map(
tff.tensorflow.computation(lambda x: x * 2.0), value)
summed_value = federated_language.federated_sum(scaled_value)
unscaled_value = federated_language.federated_map(
tff.tensorflow.computation(lambda x: x / 2.0), summed_value)
measurements = federated_language.federated_value((), federated_language.SERVER)
return tff.templates.MeasuredProcessOutput(
state=state, result=unscaled_value, measurements=measurements)
return tff.templates.AggregationProcess(initialize_fn, next_fn)
Whether everything works as expected can be verified with the following code:
client_data = [1.0, 2.0, 5.0]
factory = ExampleTaskFactory()
aggregation_process = factory.create(federated_language.TensorType(np.float32))
print(f'Type signatures of the created aggregation process:\n'
f' - initialize: {aggregation_process.initialize.type_signature}\n'
f' - next: {aggregation_process.next.type_signature}\n')
state = aggregation_process.initialize()
output = aggregation_process.next(state, client_data)
print(f'Aggregation result: {output.result} (expected 8.0)')
Type signatures of the created aggregation process: - initialize: ( -> <>@SERVER) - next: (<state=<>@SERVER,value={float32}@CLIENTS> -> <state=<>@SERVER,result=float32@SERVER,measurements=<>@SERVER>) Aggregation result: 8.0 (expected 8.0)
Statefulness and measurements
Statefulness is broadly used in TFF to represent computations that are expected to be executed iteratively and change with each iteration. For example, the state of a learning computation contains the weights of the model being learned.
To illustrate how to use state in an aggregation computation, we modify the example task. Instead of multiplying value
by 2.0
, we multiply it by the iteration index - the number of times the aggregation has been executed.
To do so, we need a way to keep track of the iteration index, which is achieved through the concept of state. In the initialize_fn
, instead of creating an empty state, we initialize the state to be a scalar zero. Then, state can be used in the next_fn
in three steps: (1) increment by 1.0
, (2) use to multiply value
, and (3) return as the new updated state.
Once this is done, you may note: But exactly the same code as above can be used to verify all works as expected. How do I know something has actually changed?
Good question! This is where the concept of measurements becomes useful. In general, measurements can report any value relevant to a single execution of the next
function, which could be used for monitoring. In this case, it can be the summed_value
from the previous example. That is, the value before the "unscaling" step, which should depend on the iteration index. Again, this is not necessarily useful in practice, but illustrates the relevant mechanism.
The stateful answer to the task thus looks as follows:
class ExampleTaskFactory(tff.aggregators.UnweightedAggregationFactory):
def create(self, value_type):
@tff.federated_computation()
def initialize_fn():
return federated_language.federated_value(0.0, federated_language.SERVER)
@tff.federated_computation(initialize_fn.type_signature.result,
federated_language.FederatedType(value_type, federated_language.CLIENTS))
def next_fn(state, value):
new_state = federated_language.federated_map(
tff.tensorflow.computation(lambda x: x + 1.0), state)
state_at_clients = federated_language.federated_broadcast(new_state)
scaled_value = federated_language.federated_map(
tff.tensorflow.computation(lambda x, y: x * y), (value, state_at_clients))
summed_value = federated_language.federated_sum(scaled_value)
unscaled_value = federated_language.federated_map(
tff.tensorflow.computation(lambda x, y: x / y), (summed_value, new_state))
return tff.templates.MeasuredProcessOutput(
state=new_state, result=unscaled_value, measurements=summed_value)
return tff.templates.AggregationProcess(initialize_fn, next_fn)
Note that the state
that comes into next_fn
as input is placed at server. In order to use it at clients, it first needs to be communicated, which is achieved using the federated_language.federated_broadcast
operator.
To verify all works as expected, we can now look at the reported measurements
, which should be different with each round of execution, even if run with the same client_data
.
client_data = [1.0, 2.0, 5.0]
factory = ExampleTaskFactory()
aggregation_process = factory.create(federated_language.TensorType(np.float32))
print(f'Type signatures of the created aggregation process:\n'
f' - initialize: {aggregation_process.initialize.type_signature}\n'
f' - next: {aggregation_process.next.type_signature}\n')
state = aggregation_process.initialize()
output = aggregation_process.next(state, client_data)
print('| Round #1')
print(f'| Aggregation result: {output.result} (expected 8.0)')
print(f'| Aggregation measurements: {output.measurements} (expected 8.0 * 1)')
output = aggregation_process.next(output.state, client_data)
print('\n| Round #2')
print(f'| Aggregation result: {output.result} (expected 8.0)')
print(f'| Aggregation measurements: {output.measurements} (expected 8.0 * 2)')
output = aggregation_process.next(output.state, client_data)
print('\n| Round #3')
print(f'| Aggregation result: {output.result} (expected 8.0)')
print(f'| Aggregation measurements: {output.measurements} (expected 8.0 * 3)')
Type signatures of the created aggregation process: - initialize: ( -> float32@SERVER) - next: (<state=float32@SERVER,value={float32}@CLIENTS> -> <state=float32@SERVER,result=float32@SERVER,measurements=float32@SERVER>) | Round #1 | Aggregation result: 8.0 (expected 8.0) | Aggregation measurements: 8.0 (expected 8.0 * 1) | Round #2 | Aggregation result: 8.0 (expected 8.0) | Aggregation measurements: 16.0 (expected 8.0 * 2) | Round #3 | Aggregation result: 8.0 (expected 8.0) | Aggregation measurements: 24.0 (expected 8.0 * 3)
Structured types
The model weights of a model trained in federated learning are usually represented as a collection of tensors, rather than a single tensor. In TFF, this is represented as federated_language.StructType
and generally useful aggregation factories need to be able to accept the structured types.
However, in the above examples, we only worked with a federated_language.TensorType
object. If we try to use the previous factory to create the aggregation process with a federated_language.StructType([(np.float32, (2,)), (np.float32, (3,))])
, we get a strange error because TensorFlow will try to multiply a tf.Tensor
and a list
.
The problem is that instead of multiplying the structure of tensors by a constant, we need to multiply each tensor in the structure by a constant. The usual solution to this problem is to use the tf.nest
module inside of the created tff.tensorflow.computation
s.
The version of the previous ExampleTaskFactory
compatible with structured types thus looks as follows:
@tff.tensorflow.computation()
def scale(value, factor):
return tf.nest.map_structure(lambda x: x * factor, value)
@tff.tensorflow.computation()
def unscale(value, factor):
return tf.nest.map_structure(lambda x: x / factor, value)
@tff.tensorflow.computation()
def add_one(value):
return value + 1.0
class ExampleTaskFactory(tff.aggregators.UnweightedAggregationFactory):
def create(self, value_type):
@tff.federated_computation()
def initialize_fn():
return federated_language.federated_value(0.0, federated_language.SERVER)
@tff.federated_computation(initialize_fn.type_signature.result,
federated_language.FederatedType(value_type, federated_language.CLIENTS))
def next_fn(state, value):
new_state = federated_language.federated_map(add_one, state)
state_at_clients = federated_language.federated_broadcast(new_state)
scaled_value = federated_language.federated_map(scale, (value, state_at_clients))
summed_value = federated_language.federated_sum(scaled_value)
unscaled_value = federated_language.federated_map(unscale, (summed_value, new_state))
return tff.templates.MeasuredProcessOutput(
state=new_state, result=unscaled_value, measurements=summed_value)
return tff.templates.AggregationProcess(initialize_fn, next_fn)
This example highlights a pattern which may be useful to follow when structuring TFF code. When not dealing with very simple operations, the code becomes more legible when the tff.tensorflow.computation
s that will be used as building blocks inside a tff.federated_computation
are created in a separate place. Inside of the tff.federated_computation
, these building blocks are only connected using the intrinsic operators.
To verify it works as expected:
client_data = [[[1.0, 2.0], [3.0, 4.0, 5.0]],
[[1.0, 1.0], [3.0, 0.0, -5.0]]]
factory = ExampleTaskFactory()
aggregation_process = factory.create(
federated_language.to_type([(np.float32, (2,)), (np.float32, (3,))]))
print(f'Type signatures of the created aggregation process:\n'
f' - initialize: {aggregation_process.initialize.type_signature}\n'
f' - next: {aggregation_process.next.type_signature}\n')
state = aggregation_process.initialize()
output = aggregation_process.next(state, client_data)
print(f'Aggregation result: [{output.result[0]}, {output.result[1]}]\n'
f' Expected: [[2. 3.], [6. 4. 0.]]')
Type signatures of the created aggregation process: - initialize: ( -> float32@SERVER) - next: (<state=float32@SERVER,value={<float32[2],float32[3]>}@CLIENTS> -> <state=float32@SERVER,result=<float32[2],float32[3]>@SERVER,measurements=<float32[2],float32[3]>@SERVER>) Aggregation result: [[2. 3.], [6. 4. 0.]] Expected: [[2. 3.], [6. 4. 0.]]
Inner aggregations
The final step is to optionally enable delegation of the actual aggregation to other factories, in order to allow easy composition of different aggregation techniques.
This is achieved by creating an optional inner_factory
argument in the constructor of our ExampleTaskFactory
. If not specified, tff.aggregators.SumFactory
is used, which applies the federated_language.federated_sum
operator used directly in the previous section.
When create
is called, we can first call create
of the inner_factory
to create the inner aggregation process with the same value_type
.
The state of our process returned by initialize_fn
is a composition of two parts: the state created by "this" process, and the state of the just created inner process.
The implementation of the next_fn
differs in that the actual aggregation is delegated to the next
function of the inner process, and in how the final output is composed. The state is again composed of "this" and "inner" state, and measurements are composed in a similar manner as an OrderedDict
.
The following is an implementation of such pattern.
@tff.tensorflow.computation()
def scale(value, factor):
return tf.nest.map_structure(lambda x: x * factor, value)
@tff.tensorflow.computation()
def unscale(value, factor):
return tf.nest.map_structure(lambda x: x / factor, value)
@tff.tensorflow.computation()
def add_one(value):
return value + 1.0
class ExampleTaskFactory(tff.aggregators.UnweightedAggregationFactory):
def __init__(self, inner_factory=None):
if inner_factory is None:
inner_factory = tff.aggregators.SumFactory()
self._inner_factory = inner_factory
def create(self, value_type):
inner_process = self._inner_factory.create(value_type)
@tff.federated_computation()
def initialize_fn():
my_state = federated_language.federated_value(0.0, federated_language.SERVER)
inner_state = inner_process.initialize()
return federated_language.federated_zip((my_state, inner_state))
@tff.federated_computation(initialize_fn.type_signature.result,
federated_language.FederatedType(value_type, federated_language.CLIENTS))
def next_fn(state, value):
my_state, inner_state = state
my_new_state = federated_language.federated_map(add_one, my_state)
my_state_at_clients = federated_language.federated_broadcast(my_new_state)
scaled_value = federated_language.federated_map(scale, (value, my_state_at_clients))
# Delegation to an inner factory, returning values placed at SERVER.
inner_output = inner_process.next(inner_state, scaled_value)
unscaled_value = federated_language.federated_map(unscale, (inner_output.result, my_new_state))
new_state = federated_language.federated_zip((my_new_state, inner_output.state))
measurements = federated_language.federated_zip(
collections.OrderedDict(
scaled_value=inner_output.result,
example_task=inner_output.measurements))
return tff.templates.MeasuredProcessOutput(
state=new_state, result=unscaled_value, measurements=measurements)
return tff.templates.AggregationProcess(initialize_fn, next_fn)
When delegating to the inner_process.next
function, the return structure we get is a tff.templates.MeasuredProcessOutput
, with the same three fields - state
, result
and measurements
. When creating the overall return structure of the composed aggregation process, the state
and measurements
fields should be generally composed and returned together. In contrast, the result
field corresponds to the value being aggregated and instead "flows through" the composed aggregation.
The state
object should be seen as an implementation detail of the factory, and thus the composition could be of any structure. However, measurements
correspond to values to be reported to the user at some point. Therefore, we recommend to use OrderedDict
, with composed naming such that it would be clear where in an composition does a reported metric comes from.
Note also the use of the federated_language.federated_zip
operator. The state
object contolled by the created process should be a federated_language.FederatedType
. If we had instead returned (this_state, inner_state)
in the initialize_fn
, its return type signature would be a federated_language.StructType
containing a 2-tuple of federated_language.FederatedType
s. The use of federated_language.federated_zip
"lifts" the federated_language.FederatedType
to the top level. This is similarly used in the next_fn
when preparing the state and measurements to be returned.
Finally, we can see how this can be used with the default inner aggregation:
client_data = [1.0, 2.0, 5.0]
factory = ExampleTaskFactory()
aggregation_process = factory.create(federated_language.TensorType(np.float32))
state = aggregation_process.initialize()
output = aggregation_process.next(state, client_data)
print('| Round #1')
print(f'| Aggregation result: {output.result} (expected 8.0)')
print(f'| measurements[\'scaled_value\']: {output.measurements["scaled_value"]}')
print(f'| measurements[\'example_task\']: {output.measurements["example_task"]}')
output = aggregation_process.next(output.state, client_data)
print('\n| Round #2')
print(f'| Aggregation result: {output.result} (expected 8.0)')
print(f'| measurements[\'scaled_value\']: {output.measurements["scaled_value"]}')
print(f'| measurements[\'example_task\']: {output.measurements["example_task"]}')
| Round #1 | Aggregation result: 8.0 (expected 8.0) | measurements['scaled_value']: 8.0 | measurements['example_task']: () | Round #2 | Aggregation result: 8.0 (expected 8.0) | measurements['scaled_value']: 16.0 | measurements['example_task']: ()
... and with a different inner aggregation. For example, an ExampleTaskFactory
:
client_data = [1.0, 2.0, 5.0]
# Note the inner delegation can be to any UnweightedAggregaionFactory.
# In this case, each factory creates process that multiplies by the iteration
# index (1, 2, 3, ...), thus their combination multiplies by (1, 4, 9, ...).
factory = ExampleTaskFactory(ExampleTaskFactory())
aggregation_process = factory.create(federated_language.TensorType(np.float32))
state = aggregation_process.initialize()
output = aggregation_process.next(state, client_data)
print('| Round #1')
print(f'| Aggregation result: {output.result} (expected 8.0)')
print(f'| measurements[\'scaled_value\']: {output.measurements["scaled_value"]}')
print(f'| measurements[\'example_task\']: {output.measurements["example_task"]}')
output = aggregation_process.next(output.state, client_data)
print('\n| Round #2')
print(f'| Aggregation result: {output.result} (expected 8.0)')
print(f'| measurements[\'scaled_value\']: {output.measurements["scaled_value"]}')
print(f'| measurements[\'example_task\']: {output.measurements["example_task"]}')
| Round #1 | Aggregation result: 8.0 (expected 8.0) | measurements['scaled_value']: 8.0 | measurements['example_task']: OrderedDict([('scaled_value', 8.0), ('example_task', ())]) | Round #2 | Aggregation result: 8.0 (expected 8.0) | measurements['scaled_value']: 16.0 | measurements['example_task']: OrderedDict([('scaled_value', 32.0), ('example_task', ())])
Summary
In this tutorial, we explained the best practices to follow in order to create a general-purpose aggregation building block, represented as an aggregation factory. The generality comes through the design intent in two ways:
- Parameterized computations. Aggregation is an independent building block that can be plugged into other TFF modules designed to work with
tff.aggregators
to parameterize their necessary aggregation, such astff.learning.algorithms.build_weighted_fed_avg
. - Aggregation composition. An aggregation building block can be composed with other aggregation building blocks to create more complex composite aggregations.