Building Your Own Federated Learning Algorithm

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Before you start

Before you start, please run the following to make sure that your environment is correctly setup. If you don't see a greeting, please refer to the Installation guide for instructions.

pip install --quiet --upgrade tensorflow-federated
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

In the image classification and text generation tutorials, you learned how to set up model and data pipelines for Federated Learning (FL), and performed federated training via the tff.learning API layer of TFF.

This is only the tip of the iceberg when it comes to FL research. This tutorial discusses how to implement federated learning algorithms without deferring to the tff.learning API. In this tutorial, you will accomplish the following:

Goals:

  • Understand the general structure of federated learning algorithms.
  • Explore the Federated Core of TFF.
  • Use the Federated Core to implement Federated Averaging directly.

While this tutorial is self-contained, it may be useful to first check out the image classification and text generation tutorials.

Preparing the input data

First load and preprocess the EMNIST dataset included in TFF. For more details, see the image classification tutorial.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In order to feed the dataset into our model, the data is flattened, and each example is converted into a tuple of the form (flattened_image_vector, label).

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]),
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

Now, select a small number of clients, and apply the preprocessing above to their datasets.

client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

Preparing the model

This uses the same model as in the image classification tutorial. This model (implemented via tf.keras) has a single hidden layer, followed by a softmax layer.

def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

In order to use this model in TFF, wrap the Keras model as a tff.learning.models.FunctionalModel. This allows one to perform the model's forward pass within TFF, and extract model outputs. For more details, also see the image classification tutorial.

keras_model = create_keras_model()
tff_model = tff.learning.models.functional_model_from_keras(
    keras_model,
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
    input_spec=federated_train_data[0].element_spec,
    metrics_constructor=collections.OrderedDict(
        accuracy=tf.keras.metrics.SparseCategoricalAccuracy
    ),
)

While the above used tf.keras to create a tff.learning.models.FunctionalModel, TFF supports much more general models. These models have the following relevant attributes capturing the model weights:

  • trainable_variables: An iterable of the tensors corresponding to trainable layers.
  • non_trainable_variables: An iterable of the tensors corresponding to non-trainable layers.

In this tutorial, only the trainable_variables will be used (as the model only has those!).

Building your own Federated Learning algorithm

While the tff.learning API allows one to create many variants of Federated Averaging, there are other federated algorithms that do not fit neatly into this framework. For example, you may want to add regularization, clipping, or more complicated algorithms such as federated GAN training. You may also be instead be interested in federated analytics.

For these more advanced algorithms, you'll have to write our own custom algorithm using TFF. In many cases, federated algorithms have 4 main components:

  1. A server-to-client broadcast step.
  2. A local client update step.
  3. A client-to-server upload step.
  4. A server update step.

In TFF, a federated algorithm is typically represented as a tff.templates.IterativeProcess (which will be referred to as just an IterativeProcess throughout). This is a class that contains initialize and next functions. Here, initialize is used to initialize the server, and next will perform one communication round of the federated algorithm. Let's write a skeleton of what our iterative process for FedAvg should look like.

First, there is an initialize function that simply creates a tff.learning.models.FunctionalModel, and returns its trainable weights.

def initialize_fn():
  trainable_weights, _ =  tff_model.initial_weights
  return trainable_weights

This function looks good, but as you will see later, you will need to make a small modification to make it a "TFF computation".

Next, let's write a sketch of the next_fn.

def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

Let's focus on implementing these four components separately. First, let's focus on the parts that can be implemented in pure TensorFlow, namely the client and server update steps.

TensorFlow Blocks

Client update

The tff.learning.models.FunctionalModel can be used to do client training in essentially the same way you would train a TensorFlow model. In particular, one can use tf.GradientTape to compute the gradient on batches of data, then apply these gradient using a client_optimizer. This will only involve the trainable weights.

@tf.function
def client_update(model, dataset, initial_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights and the optimizer
  # state.
  client_weights = initial_weights.trainable
  optimizer_state = client_optimizer.initialize(
      tf.nest.map_structure(tf.TensorSpec.from_tensor, client_weights)
  )

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    x, y = batch
    with tf.GradientTape() as tape:
      tape.watch(client_weights)
      # Compute a forward pass on the batch of data
      outputs = model.predict_on_batch(
          model_weights=(client_weights, ()), x=x, training=True
      )
      loss = model.loss(output=outputs, label=y)

    # Compute the corresponding gradient
    grads = tape.gradient(loss, client_weights)

    # Apply the gradient using a client optimizer.
    optimizer_state, client_weights = client_optimizer.next(
        optimizer_state, weights=client_weights, gradients=grads
    )

  return tff.learning.models.ModelWeights(client_weights, non_trainable=())

Server Update

The server update for FedAvg is simpler than the client update. This tutorial will implement "vanilla" federated averaging, in which the server model weights are replaced by the average of the client model weights. Again, this only uses the trainable weights.

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  del model  # Unused, just take the mean_client_weights.
  return mean_client_weights

The snippet could be simplified by simply returning the mean_client_weights. However, more advanced implementations of Federated Averaging use mean_client_weights with more sophisticated techniques, such as momentum or adaptivity.

Challenge: Implement a version of server_update that updates the server weights to be the midpoint of model_weights and mean_client_weights. (Note: This kind of "midpoint" approach is analogous to recent work on the Lookahead optimizer!).

So far, this has only involved TensorFlow code. This is by design, as TFF allows you to use much of the TensorFlow code you're already familiar with. Next you will have to specify the orchestration logic, that is, the logic that dictates what the server broadcasts to the client, and what the client uploads to the server.

This will require the Federated Core of TFF.

Introduction to the Federated Core

The Federated Core (FC) is a set of lower-level interfaces that serve as the foundation for the tff.learning API. However, these interfaces are not limited to learning. In fact, they can be used for analytics and many other computations over distributed data.

At a high-level, the federated core is a development environment that enables compactly expressed program logic to combine TensorFlow code with distributed communication operators (such as distributed sums and broadcasts). The goal is to give researchers and practitioners explicit control over the distributed communication in their systems, without requiring system implementation details (such as specifying point-to-point network message exchanges).

One key point is that TFF is designed for privacy-preservation. Therefore, it allows explicit control over where data resides, to prevent unwanted accumulation of data at the centralized server location.

Federated data

A key concept in TFF is "federated data", which refers to a collection of data items hosted across a group of devices in a distributed system (eg. client datasets, or the server model weights). The entire collection of values across all devices is represented as a single federated value.

For example, suppose there are client devices that each have a float representing the temperature of a sensor. These floats can be represented as a federated float by

federated_float_on_clients = tff.FederatedType(np.float32, tff.CLIENTS)

Federated types are specified by a type T of its member constituents (eg. np.float32) and a group G of devices. Typically, G is either tff.CLIENTS or tff.SERVER. Such a federated type is represented as {T}@G, as shown below.

str(federated_float_on_clients)
'{float32}@CLIENTS'

Why does TFF care so much about placements? A key goal of TFF is to enable writing code that could be deployed on a real distributed system. This means that it is vital to reason about which subsets of devices execute which code, and where different pieces of data reside.

TFF focuses on three things: data, where the data is placed, and how the data is being transformed. The first two are encapsulated in federated types, while the last is encapsulated in federated computations.

Federated computations

TFF is a strongly-typed functional programming environment whose basic units are federated computations. These are pieces of logic that accept federated values as input, and return federated values as output.

For example, suppose you wanted to average the temperatures on our client sensors. You could define the following (using our federated float):

@tff.federated_computation(tff.FederatedType(np.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

You might ask, how is this different from the tf.function decorator in TensorFlow? The key answer is that the code generated by tff.federated_computation is neither TensorFlow nor Python code; It is a specification of a distributed system in an internal platform-independent glue language.

While this may sound complicated, you can think of TFF computations as functions with well-defined type signatures. These type signatures can be directly queried.

str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'

This tff.federated_computation accepts arguments of federated type {float32}@CLIENTS, and returns values of federated type {float32}@SERVER. Federated computations may also go from server to client, from client to client, or from server to server. Federated computations can also be composed like normal functions, as long as their type signatures match up.

To support development, TFF allows you to invoke a tff.federated_computation as a Python function. For example, you can call

get_average_temperature([68.5, 70.3, 69.8])
69.53333

Non-eager computations and TensorFlow

There are two key restrictions to be aware of. First, when the Python interpreter encounters a tff.federated_computation decorator, the function is traced once and serialized for future use. Due to the decentralized nature of Federated Learning, this future usage may occur elsewhere, such as a remote execution environment. Therefore, TFF computations are fundamentally non-eager. This behavior is somewhat analogous to that of the tf.function decorator in TensorFlow.

Second, a federated computation can only consist of federated operators (such as tff.federated_mean), they cannot contain TensorFlow operations. TensorFlow code must be confined to blocks decorated with tff.tf_computation. Most ordinary TensorFlow code can be directly decorated, such as the following function that takes a number and adds 0.5 to it.

@tff.tensorflow.computation(np.float32)
def add_half(x):
  return tf.add(x, 0.5)

These also have type signatures, but without placements. For example, you can call

str(add_half.type_signature)
'(float32 -> float32)'

This showcases an important difference between tff.federated_computation and tff.tf_computation. The former has explicit placements, while the latter does not.

You can use tff.tf_computation blocks in federated computations by specifying placements. Let's create a function that adds half, but only to federated floats at the clients. You can do this by using tff.federated_map, which applies a given tff.tf_computation, while preserving the placement.

@tff.federated_computation(tff.FederatedType(np.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

This function is almost identical to add_half, except that it only accepts values with placement at tff.CLIENTS, and returns values with the same placement. This can be seen in its type signature:

str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'

In summary:

  • TFF operates on federated values.
  • Each federated value has a federated type, with a type (eg. np.float32) and a placement (eg. tff.CLIENTS).
  • Federated values can be transformed using federated computations, which must be decorated with tff.federated_computation and a federated type signature.
  • TensorFlow code must be contained in blocks with tff.tf_computation decorators.
  • These blocks can then be incorporated into federated computations.

Building your own Federated Learning algorithm, revisited

Now that you've gotten a glimpse of the Federated Core, you can build our own federated learning algorithm. Remember that above, you defined an initialize_fn and next_fn for our algorithm. The next_fn will make use of the client_update and server_update you defined using pure TensorFlow code.

However, in order to make our algorithm a federated computation, you will need both the next_fn and initialize_fn to each be a tff.federated_computation.

TensorFlow Federated blocks

Creating the initialization computation

The initialize function will be quite simple: You will create a model using model_fn. However, remember that you must separate out our TensorFlow code using tff.tf_computation.

@tff.tensorflow.computation
def server_init():
  return tff.learning.models.ModelWeights(*tff_model.initial_weights)

You can then pass this directly into a federated computation using tff.federated_value.

@tff.federated_computation
def initialize_fn():
  return tff.federated_eval(server_init, tff.SERVER)

Creating the next_fn

The client and server update code can now be used to write the actual algorithm. First, you will turn the client_update into a tff.tensorflow.computation that accepts a client datasets and server weights, and outputs an updated client weights tensor.

You will need the corresponding types to properly decorate our function. Luckily, the type of the server weights can be extracted directly from our model.

tf_dataset_type = tff.SequenceType(
    tff.types.tensorflow_to_type(tff_model.input_spec)
)

Let's look at the dataset type signature. Remember that you took 28 by 28 images (with integer labels) and flattened them.

str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'

You can also extract the model weights type by using our server_init function above.

model_weights_type = server_init.type_signature.result

Examining the type signature, you'll be able to see the architecture of our model!

str(model_weights_type)
'<trainable=<float32[784,10],float32[10]>,non_trainable=<>>'

You can now create our tff.tf_computation for the client update.

@tff.tensorflow.computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01)
  return client_update(tff_model, tf_dataset, server_weights, client_optimizer)

The tff.tf_computation version of the server update can be defined in a similar way, using types you've already extracted.

@tff.tensorflow.computation(model_weights_type)
def server_update_fn(mean_client_weights):
  return server_update(tff_model, mean_client_weights)

Last, but not least, you need to create the tff.federated_computation that brings this all together. This function will accept two federated values, one corresponding to the server weights (with placement tff.SERVER), and the other corresponding to the client datasets (with placement tff.CLIENTS).

Note that both these types were defined above! You simply need to give them the proper placement using tff.FederatedType.

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

Remember the 4 elements of an FL algorithm?

  1. A server-to-client broadcast step.
  2. A local client update step.
  3. A client-to-server upload step.
  4. A server update step.

Now that you've built up the above, each part can be compactly represented as a single line of TFF code. This simplicity is why you had to take extra care to specify things such as federated types!

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client)
  )

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

You now have a tff.federated_computation for both the algorithm initialization, and for running one step of the algorithm. To finish our algorithm, you pass these into tff.templates.IterativeProcess.

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

Let's look at the type signature of the initialize and next functions of our iterative process.

str(federated_algorithm.initialize.type_signature)
'( -> <trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER)'

This reflects the fact that federated_algorithm.initialize is a no-arg function that returns a single-layer model (with a 784-by-10 weight matrix, and 10 bias units).

str(federated_algorithm.next.type_signature)
'(<server_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER)'

Here, one can see that federated_algorithm.next accepts a server model and client data, and returns an updated server model.

Evaluating the algorithm

Let's run a few rounds, and see how the loss changes. First, you will define an evaluation function using the centralized approach discussed in the second tutorial.

You will first create a centralized evaluation dataset, and then apply the same preprocessing you used for the training data.

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

Next, you will write a function that accepts a server state, and uses Keras to evaluate on the test dataset. If you're familiar with tf.Keras, this will all look familiar, though note the use of set_weights!

def evaluate(model_weights):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
  )
  model_weights.assign_weights_to(keras_model)
  keras_model.evaluate(central_emnist_test)

Now, let's initialize our algorithm and evaluate on the test set.

server_state = federated_algorithm.initialize()
evaluate(server_state)
2042/2042 [==============================] - 26s 10ms/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027

Let's train for a few rounds and see if anything changes.

for _ in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
2042/2042 [==============================] - 4s 1ms/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980

There is a slight decrease in the loss function. While the jump is small, you've only performed 15 training rounds, and on a small subset of clients. To see better results, you may have to do hundreds if not thousands of rounds.

Modifying our algorithm

At this point, let's stop and think about what you've accomplished. You've implemented Federated Averaging directly by combining pure TensorFlow code (for the client and server updates) with federated computations from the Federated Core of TFF.

To perform more sophisticted learning, you can simply alter what you have above. In particular, by editing the pure TF code above, you can change how the client performs training, or how the server updates its model.

Challenge: Add gradient clipping to the client_update function.

If you wanted to make larger changes, you could also have the server store and broadcast more data. For example, the server could also store the client learning rate, and make it decay over time! Note that this will require changes to the type signatures used in the tff.tf_computation calls above.

Harder Challenge: Implement Federated Averaging with learning rate decay on the clients.

At this point, you may begin to realize how much flexibility there is in what you can implement in this framework. For ideas (including the answer to the harder challenge above) you can see the source-code for tff.learning.algorithms.build_weighted_fed_avg, or check out various research projects using TFF.