Use TFF optimizers in custom iterative process

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This is an alternative to the Build Your Own Federated Learning Algorithm tutorial and the simple_fedavg example to build a custom iterative process for the federated averaging algorithm. This tutorial will use TFF optimizers instead of Keras optimizers. The TFF optimizer abstraction is desgined to be state-in-state-out to be easier to be incorporated in a TFF iterative process. The tff.learning APIs also accept TFF optimizers as input argument.

Before we start

Before we start, please run the following to make sure that your environment is correctly setup. If you don't see a greeting, please refer to the Installation guide for instructions.

pip install --quiet --upgrade tensorflow-federated
from typing import Any
import functools

import attrs
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

Preparing data and model

The EMNIST data processing and model are very similar to the simple_fedavg example.

only_digits=True

# Load dataset.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits)

# Define preprocessing functions.
def preprocess_fn(dataset, batch_size=16):

  def batch_format_fn(element):
    return (tf.expand_dims(element['pixels'], -1), element['label'])

  return dataset.batch(batch_size).map(batch_format_fn)

# Preprocess and sample clients for prototyping.
train_client_ids = sorted(emnist_train.client_ids)
train_data = emnist_train.preprocess(preprocess_fn)
central_test_data = preprocess_fn(
    emnist_train.create_tf_dataset_for_client(train_client_ids[0]))

# Define model.
def create_keras_model():
  """The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format = 'channels_last'
  input_shape = [28, 28, 1]

  max_pool = functools.partial(
      tf.keras.layers.MaxPooling2D,
      pool_size=(2, 2),
      padding='same',
      data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      conv2d(filters=32, input_shape=input_shape),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10 if only_digits else 62),
  ])

  return model

# Wrap as `tff.learning.models.VariableModel`.
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.models.from_keras_model(
      keras_model,
      input_spec=central_test_data.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

Custom iterative process

In many cases, federated algorithms have 4 main components:

  1. A server-to-client broadcast step.
  2. A local client update step.
  3. A client-to-server upload step.
  4. A server update step.

In TFF, we generally represent federated algorithms as a tff.templates.IterativeProcess (which we refer to as just an IterativeProcess throughout). This is a class that contains initialize and next functions. Here, initialize is used to initialize the server, and next will perform one communication round of the federated algorithm.

We will introduce different components to build the federated averaging (FedAvg) algorithm, which will use an optimizer in the client update step, and another optimizer in the server update step. The core logics of client and server updates can be expressed as pure TF blocks.

TF blocks: client and server update

On each client, a local client_optimizer is initialized and used to update the client model weights. On the server, server_optimizer will use the state from the previous round, and update the state for the next round.

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs local training on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)
  # Initialize the client optimizer.
  trainable_tensor_specs = tf.nest.map_structure(
          lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
  optimizer_state = client_optimizer.initialize(trainable_tensor_specs)
  # Use the client_optimizer to update the local model.
  for batch in iter(dataset):
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data.
      outputs = model.forward_pass(batch)
    # Compute the corresponding gradient.
    grads = tape.gradient(outputs.loss, client_weights)
    # Apply the gradient using a client optimizer.
    optimizer_state, updated_weights = client_optimizer.next(
        optimizer_state, client_weights, grads)
    tf.nest.map_structure(lambda a, b: a.assign(b),
                          client_weights, updated_weights)
  # Return model deltas.
  return tf.nest.map_structure(tf.subtract, client_weights, server_weights)
@attrs.define(eq=False, frozen=True)
class ServerState(object):
  trainable_weights: Any
  optimizer_state: Any

@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
  """Updates the server model weights."""
  # Use aggregated negative model delta as pseudo gradient.
  negative_weights_delta = tf.nest.map_structure(
      lambda w: -1.0 * w, mean_model_delta)
  new_optimizer_state, updated_weights = server_optimizer.next(
      server_state.optimizer_state, server_state.trainable_weights,
      negative_weights_delta)
  return tff.structure.update_struct(
      server_state,
      trainable_weights=updated_weights,
      optimizer_state=new_optimizer_state)

TFF blocks: tff.tf_computation and tff.federated_computation

We now use TFF for orchestration and build the iterative process for FedAvg. We have to wrap the TF blocks defined above with tff.tf_computation, and use TFF methods tff.federated_broadcast, tff.federated_map, tff.federated_mean in a tff.federated_computation function. It is easy to use the tff.learning.optimizers.Optimizer APIs with initialize and next functions when defining a custom iterative process.

# 1. Server and client optimizer to be used.
server_optimizer = tff.learning.optimizers.build_sgdm(
    learning_rate=0.05, momentum=0.9)
client_optimizer = tff.learning.optimizers.build_sgdm(
    learning_rate=0.01)

# 2. Functions return initial state on server.
@tff.tf_computation
def server_init():
  model = model_fn()
  trainable_tensor_specs = tf.nest.map_structure(
        lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
  optimizer_state = server_optimizer.initialize(trainable_tensor_specs)
  return ServerState(
      trainable_weights=model.trainable_variables,
      optimizer_state=optimizer_state)

@tff.federated_computation
def server_init_tff():
  return tff.federated_value(server_init(), tff.SERVER)

# 3. One round of computation and communication.
server_state_type = server_init.type_signature.result
print('server_state_type:\n',
      server_state_type.formatted_representation())
trainable_weights_type = server_state_type.trainable_weights
print('trainable_weights_type:\n',
      trainable_weights_type.formatted_representation())

# 3-1. Wrap server and client TF blocks with `tff.tf_computation`.
@tff.tf_computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
  return server_update(server_state, model_delta, server_optimizer)

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(tff.types.tensorflow_to_type(whimsy_model.input_spec))
print('tf_dataset_type:\n',
      tf_dataset_type.formatted_representation())
@tff.tf_computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
  model = model_fn()
  return client_update(model, dataset, server_weights, client_optimizer)

# 3-2. Orchestration with `tff.federated_computation`.
federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
  # Server-to-client broadcast.
  server_weights_at_client = tff.federated_broadcast(
      server_state.trainable_weights)
  # Local client update.
  model_deltas = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))
  # Client-to-server upload and aggregation.
  mean_model_delta = tff.federated_mean(model_deltas)
  # Server update.
  server_state = tff.federated_map(
      server_update_fn, (server_state, mean_model_delta))
  return server_state

# 4. Build the iterative process for FedAvg.
fedavg_process = tff.templates.IterativeProcess(
    initialize_fn=server_init_tff, next_fn=run_one_round)
print('type signature of `initialize`:\n',
      fedavg_process.initialize.type_signature.formatted_representation())
print('type signature of `next`:\n',
      fedavg_process.next.type_signature.formatted_representation())
server_state_type:
 <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    learning_rate=float32,
    momentum=float32,
    accumulator=<
      float32[5,5,1,32],
      float32[32],
      float32[5,5,32,64],
      float32[64],
      float32[3136,512],
      float32[512],
      float32[512,10],
      float32[10]
    >
  >
>
trainable_weights_type:
 <
  float32[5,5,1,32],
  float32[32],
  float32[5,5,32,64],
  float32[64],
  float32[3136,512],
  float32[512],
  float32[512,10],
  float32[10]
>
tf_dataset_type:
 <
  float32[?,28,28,1],
  int32[?]
>*
type signature of `initialize`:
 ( -> <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    learning_rate=float32,
    momentum=float32,
    accumulator=<
      float32[5,5,1,32],
      float32[32],
      float32[5,5,32,64],
      float32[64],
      float32[3136,512],
      float32[512],
      float32[512,10],
      float32[10]
    >
  >
>@SERVER)
type signature of `next`:
 (<
  server_state=<
    trainable_weights=<
      float32[5,5,1,32],
      float32[32],
      float32[5,5,32,64],
      float32[64],
      float32[3136,512],
      float32[512],
      float32[512,10],
      float32[10]
    >,
    optimizer_state=<
      learning_rate=float32,
      momentum=float32,
      accumulator=<
        float32[5,5,1,32],
        float32[32],
        float32[5,5,32,64],
        float32[64],
        float32[3136,512],
        float32[512],
        float32[512,10],
        float32[10]
      >
    >
  >@SERVER,
  federated_dataset={<
    float32[?,28,28,1],
    int32[?]
  >*}@CLIENTS
> -> <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    learning_rate=float32,
    momentum=float32,
    accumulator=<
      float32[5,5,1,32],
      float32[32],
      float32[5,5,32,64],
      float32[64],
      float32[3136,512],
      float32[512],
      float32[512,10],
      float32[10]
    >
  >
>@SERVER)

Evaluating the algorithm

We evaluate the performance on a centralized evaluation dataset.

def evaluate(server_state):
  keras_model = create_keras_model()
  tf.nest.map_structure(
      lambda var, t: var.assign(t),
      keras_model.trainable_weights, server_state.trainable_weights)
  metric = tf.keras.metrics.SparseCategoricalAccuracy()
  for batch in iter(central_test_data):
    preds = keras_model(batch[0], training=False)
    metric.update_state(y_true=batch[1], y_pred=preds)
  return metric.result().numpy()
server_state = fedavg_process.initialize()
acc = evaluate(server_state)
print('Initial test accuracy', acc)

# Evaluate after a few rounds
CLIENTS_PER_ROUND=2
sampled_clients = train_client_ids[:CLIENTS_PER_ROUND]
sampled_train_data = [
    train_data.create_tf_dataset_for_client(client)
    for client in sampled_clients]
for round in range(20):
  server_state = fedavg_process.next(server_state, sampled_train_data)
acc = evaluate(server_state)
print('Test accuracy', acc)
Initial test accuracy 0.06451613
Test accuracy 0.086021505