View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Before you start
Before you start, please run the following to make sure that your environment is correctly setup. If you don't see a greeting, please refer to the Installation guide for instructions.
pip install --quite --upgrade federated_language
pip install --quiet --upgrade tensorflow-federated
import collections
import federated_language
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
Composing Learning Algorithms
The Building Your Own Federated Learning Algorithm Tutorial used TFF's federated core to directly implement a version of the Federated Averaging (FedAvg) algorithm.
In this tutorial, you will use federated learning components in TFF's API to build federated learning algorithms in a modular manner, without having to re-implement everything from scratch.
For the purposes of this tutorial, you will implement a variant of FedAvg that employs gradient clipping through local training.
Learning Algorithm Building Blocks
At a high level, many learning algorithms can be separated into 4 separate components, referred to as building blocks. These are as follows:
- Distributor (ie. server-to-client communication)
- Client work (ie. local client computation)
- Aggregator (ie. client-to-server communication)
- Finalizer (ie. server computation using aggregated client outputs)
While the Building Your Own Federated Learning Algorithm Tutorial implemented all of these building blocks from scratch, this is often unnecessary. Instead, you can re-use building blocks from similar algorithms.
In this case, to implement FedAvg with gradient clipping, you only need to modify the client work building block. The remaining blocks can be identical to what is used in "vanilla" FedAvg.
Implementing the Client Work
First, let's write TF logic that does local model training with gradient clipping. For simplicity, gradients will be clipped have norm at most 1.
TF Logic
@tf.function
def client_update(
model: tff.learning.models.FunctionalModel,
dataset: tf.data.Dataset,
initial_weights: tff.learning.models.ModelWeights,
client_optimizer: tff.learning.optimizers.Optimizer,
):
"""Performs training (using the initial server model weights) on the client's dataset."""
# Keep track of the number of examples.
num_examples = 0.0
# Use the client_optimizer to update the local model.
trainable_weights, non_trainable_weights = (
initial_weights.trainable,
initial_weights.non_trainable,
)
optimizer_state = client_optimizer.initialize(
tf.nest.map_structure(lambda x: tf.TensorSpec, trainable_weights)
)
for batch in dataset:
x, y = batch
with tf.GradientTape() as tape:
tape.watch(trainable_weights)
logits = model.predict_on_batch(
model_weights=(trainable_weights, non_trainable_weights),
x=x,
training=True,
)
num_examples += tf.cast(tf.shape(y)[0], tf.float32)
loss = model.loss(output=logits, label=y)
# Compute the corresponding gradient
grads = tape.gradient(loss, trainable_weights)
# Compute the gradient norm and clip
gradient_norm = tf.linalg.global_norm(grads)
if gradient_norm > 1:
grads = tf.nest.map_structure(lambda x: x / gradient_norm, grads)
# Apply the gradient using a client optimizer.
optimizer_state, trainable_weights = client_optimizer.next(
optimizer_state, trainable_weights, grads
)
# Compute the difference between the initial weights and the client weights
client_update = tf.nest.map_structure(
tf.subtract, trainable_weights, initial_weights[0]
)
return tff.learning.templates.ClientResult(
update=client_update, update_weight=num_examples
)
There are a few important points about the code above. First, it keeps track of the number of examples seen, as this will constitute the weight of the client update (when computing an average across clients).
Second, it uses tff.learning.templates.ClientResult
to package the output. This return type is used to standardize client work building blocks in tff.learning
.
Creating a ClientWorkProcess
While the TF logic above will do local training with clipping, it still needs to be wrapped in TFF code in order to create the necessary building block.
Specifically, the 4 building blocks are represented as a tff.templates.MeasuredProcess
. This means that all 4 blocks have both an initialize
and next
function used to instantiate and run the computation.
This allows each building block to keep track of its own state (stored at the server) as needed to perform its operations. While it will not be used in this tutorial, it can be used for things like tracking how many iterations have occurred, or keeping track of optimizer states.
Client work TF logic should generally be wrapped as a tff.learning.templates.ClientWorkProcess
, which codifies the expected types going into and out of the client's local training. It can be parameterized by a model and optimizer, as below.
def build_gradient_clipping_client_work(
model: tff.learning.models.FunctionalModel,
optimizer: tff.learning.optimizers.Optimizer,
) -> tff.learning.templates.ClientWorkProcess:
"""Creates a client work process that uses gradient clipping."""
data_type = federated_language.SequenceType(tff.tensorflow.to_type(model.input_spec))
model_weights_type = federated_language.to_type(
tf.nest.map_structure(
lambda arr: federated_language.TensorType(shape=arr.shape, dtype=arr.dtype),
tff.learning.models.ModelWeights(*model.initial_weights),
)
)
@tff.federated_computation
def initialize_fn():
return federated_language.federated_value((), federated_language.SERVER)
@tff.tensorflow.computation(model_weights_type, data_type)
def client_update_computation(model_weights, dataset):
return client_update(model, dataset, model_weights, optimizer)
@tff.federated_computation(
initialize_fn.type_signature.result,
federated_language.FederatedType(model_weights_type, federated_language.CLIENTS),
federated_language.FederatedType(data_type, federated_language.CLIENTS),
)
def next_fn(state, model_weights, client_dataset):
client_result = federated_language.federated_map(
client_update_computation, (model_weights, client_dataset)
)
# Return empty measurements, though a more complete algorithm might
# measure something here.
measurements = federated_language.federated_value((), federated_language.SERVER)
return tff.templates.MeasuredProcessOutput(
state, client_result, measurements
)
return tff.learning.templates.ClientWorkProcess(initialize_fn, next_fn)
Composing a Learning Algorithm
Let's put the client work above into a full-fledged algorithm. First, let's set up our data and model.
Preparing the input data
Load and preprocess the EMNIST dataset included in TFF. For more details, see the image classification tutorial.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
In order to feed the dataset into our model, the data is flattened and converted into tuples of the form (flattened_image_vector, label)
.
Let's select a small number of clients, and apply the preprocessing above to their datasets.
NUM_CLIENTS = 10
BATCH_SIZE = 20
def preprocess(dataset):
def batch_format_fn(element):
"""Flatten a batch of EMNIST data and return a (features, label) tuple."""
return (tf.reshape(element['pixels'], [-1, 784]),
tf.reshape(element['label'], [-1, 1]))
return dataset.batch(BATCH_SIZE).map(batch_format_fn)
client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
for x in client_ids
]
Preparing the model
This uses the same model as in the image classification tutorial. This model (implemented via tf.keras
) has a single hidden layer, followed by a softmax layer. In order to use this model in TFF, Keras model is wrapped as a tff.learning.models.FunctionalModel
. This allows us to perform the model's forward pass
aggregator_factory = tff.aggregators.MeanFactory()
aggregator = aggregator_factory.create(
model_weights_type.trainable, federated_language.TensorType(np.float32)
)
finalizer = tff.learning.templates.build_apply_optimizer_finalizer(
server_optimizer, model_weights_type
)
initializer = tf.keras.initializers.GlorotNormal(seed=0)
keras_model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer=initializer),
tf.keras.layers.Softmax(),
])
tff_model = tff.learning.models.functional_model_from_keras(
keras_model,
loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
input_spec=federated_train_data[0].element_spec,
metrics_constructor=collections.OrderedDict(
accuracy=tf.keras.metrics.SparseCategoricalAccuracy
),
)
Preparing the optimizers
Just as in tff.learning.algorithms.build_weighted_fed_avg
, there are two optimizers here: A client optimizer, and a server optimizer. For simplicity, the optimizers will be SGD with different learning rates.
client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01)
server_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=1.0)
Defining the building blocks
Now that the client work building block, data, model, and optimizers are set up, it remains to create building blocks for the distributor, the aggregator, and the finalizer. This can be done just by borrowing some defaults available in TFF and that are used by FedAvg.
@tff.tensorflow.computation
def initial_model_weights_fn():
return tff.learning.models.ModelWeights(*tff_model.initial_weights)
model_weights_type = initial_model_weights_fn.type_signature.result
distributor = tff.learning.templates.build_broadcast_process(model_weights_type)
client_work = build_gradient_clipping_client_work(tff_model, client_optimizer)
# TFF aggregators use a factory pattern, which create an aggregator
# based on the output type of the client work. This also uses a float (the number
# of examples) to govern the weight in the average being computed.)
aggregator_factory = tff.aggregators.MeanFactory()
aggregator = aggregator_factory.create(
model_weights_type.trainable, federated_language.TensorType(np.float32)
)
finalizer = tff.learning.templates.build_apply_optimizer_finalizer(
server_optimizer, model_weights_type
)
Composing the building blocks
Finally, you can use a built-in composer in TFF for putting the building blocks together. This one is a relatively simple composer, which takes the 4 building blocks above and wires their types together.
fed_avg_with_clipping = tff.learning.templates.compose_learning_process(
initial_model_weights_fn,
distributor,
client_work,
aggregator,
finalizer
)
Running the algorithm
Now that the algorithm is done, let's run it. First, initialize the algorithm. The state of this algorithm has a component for each building block, along with one for the global model weights.
state = fed_avg_with_clipping.initialize()
state.client_work
()
As expected, the client work has an empty state (remember the client work code above!). However, other building blocks may have non-empty state. For example, the finalizer keeps track of how many iterations have occurred. Since next
has not been run yet, it has a state of 0
.
state.finalizer
OrderedDict([('learning_rate', 1.0)])
Now run a training round.
learning_process_output = fed_avg_with_clipping.next(state, federated_train_data)
The output of this (tff.learning.templates.LearningProcessOutput
) has both a .state
and .metrics
output. Let's look at both.
learning_process_output.state.finalizer
OrderedDict([('learning_rate', 1.0)])
Clearly, the finalizer state has incremented by one, as one round of .next
has been run.
learning_process_output.metrics
OrderedDict([('distributor', ()), ('client_work', ()), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
While the metrics are empty, for more complex and practical algorithms they'll generally be full of useful information.
Conclusion
By using the building block/composers framework above, you can create entirely new learning algorithms, without having to re-do everything from scratch. However, this is only the starting point. This framework makes it much easier to express algorithms as simple modifications of FedAvg. For more algorithms, see tff.learning.algorithms
, which contains algorithms such as FedProx and FedAvg with client learning rate scheduling. These APIs can even aid implementations of entirely new algorithms, such as federated k-means clustering.