Loading Remote Data in TFF


In a real-world application of federated learning, the raw training data is typically distributed across many devices or data silos -- requiring special preprocessing and loading before it's usable.

This tutorial describes how to load examples stored in those remote locations with TFF's DataBackend and DataExecutor interfaces, and use them to train a model using federated learning. We'll demonstrate the use of data loading APIs by using a training dataset stored locally and simulate the sampling of examples as if the dataset were partitioned over distinct remote clients. As you adapt this tutorial to your use case, you will simply swap that dataset with your own distributed data.

If you're new to federated learning or TFF, consider reading Federated Learning for Image Classification for a primer.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Before we start

Before we start, please run the following to make sure that your environment is correctly setup. Refer to the Installation guide for more information.

Set up open-source environment

Import packages

Preparing the input data

Let's begin by loading TFF's federated version of the EMNIST dataset from the built-in repository:

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

And construct a preprocessing function to transform the raw examples in the EMNIST dataset.

NUM_EPOCHS = 5
SHUFFLE_BUFFER = 100


def preprocess(dataset):

  def map_fn(element):
    # Rename the features from `pixels` and `label`, to `x` and `y` for use with
    # Keras.
    return collections.OrderedDict(
        # Transform each `28x28` image into a `784`-element array.
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  # Shuffle the individual examples and `repeat` over several epochs.
  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).map(map_fn)

Let's verify this works:

# The local dataset corresponding to a single client as tf.data.Dataset.
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

preprocessed_example_dataset = preprocess(example_dataset)
print(preprocessed_example_dataset)
<MapDataset element_spec=OrderedDict([('x', TensorSpec(shape=(1, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))])>

Next, we'll construct an implementation of DataBackend that will load and preprocess the local examples from clients in the EMNIST dataset, which is crucial for fetching trainable examples during federated learning.

Defining how to fetch client data

We need an instance of DataBackend to instruct the TFF workers how to load and tranform the local data.

TFF workers are the processes that run on edge machines and perform the work for a single or multiple logical clients. In this example, the EMNIST dataset we'll use for training is already partitioned by logical clients and all the workers are going to be running in the same local environment. So our DataBackend can reference the data corresponding to any client. But in a non-experimental setting, the TFF workers will be distributed over individual remote machines, each mapping to a distinct set of clients, and you need to ensure that the DataBackend can correctly resolve data references according to its local context.

# A `DataBackend` is a programmatic construct that resolves symbolic references,
# represented as application-specific URIs, to materialized examples that
# TFF operations can process.
class MyDataBackend(tff.framework.DataBackend):

  async def materialize(self, data, type_spec):
    # In this example, the URI contains the Id of a client.
    client_id = int(data.uri[-1])
    # The client Id is used to retrieve the corresponding local data.
    client_dataset = emnist_train.create_tf_dataset_for_client(
        emnist_train.client_ids[client_id])
    # We process the client dataset before returning so its compatible with our
    # model definitions.
    return preprocess(client_dataset)

Setting up the runtime environment

TFF computations are invoked by an ExecutionContext and in order for data URIs defined in TFF computations to be understood at runtime, a custom context must be defined for TFF workers that includes a pointer to the DataBackend we just created, so URIs can be properly resolved.

def ex_fn(device: tf.config.LogicalDevice) -> tff.framework.DataExecutor:
  # A `DataBackend` object is wrapped by a `DataExecutor`, which queries the
  # backend when a TFF worker encounters an operation requires fetching local
  # data.
  return tff.framework.DataExecutor(
      tff.framework.EagerTFExecutor(device), data_backend=MyDataBackend())


# In a distributed setting, this needs to run in the TFF worker as a service
# connecting to some port. The top-level controller feeding TFF computations
# would then connect to this port.
factory = tff.framework.local_executor_factory(leaf_executor_fn=ex_fn)
ctx = tff.framework.SyncExecutionContext(executor_fn=factory)
tff.framework.set_default_context(ctx)

Training the model

Now we are ready to train a model using federated learning. Lets define a Keras model:

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])


def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

We can pass this TFF-wrapped definition of our model to a Federated Averaging algorithm by invoking the helper function tff.learning.algorithms.build_weighted_fed_avg, as follows:

iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

state = iterative_process.initialize()

The initialize computation returns the initial state of the Federated Averaging process.

To run a round of training, we need to construct a sample of data by collecting a sample of URI references as follows:

NUM_CLIENTS = 10

element_type = tff.types.StructWithPythonType(
    preprocessed_example_dataset.element_spec,
    container_type=collections.OrderedDict)
dataset_type = tff.types.SequenceType(element_type)

round_data_uris = [f'uri://{i}' for i in range(NUM_CLIENTS)]
round_train_data = tff.framework.CreateDataDescriptor(
    arg_uris=round_data_uris, arg_type=dataset_type)

Now we can round a round of training:

result = iterative_process.next(state, round_train_data)
state = result.state
metrics = result.metrics
print('round 1, metrics={}'.format(metrics))
round 1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.11234568), ('loss', 11.965633), ('num_examples', 4860), ('num_batches', 4860)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])

Training over multiple rounds

We can define a FederatedDataSource container for selecting clients and assembling the inputs for retrieving local data. This makes it convenient to loop over multiple rounds of training, and can be reused across multiple training jobs.

class MyFederatedDataSourceIterator(tff.program.FederatedDataSourceIterator):

  def __init__(self, client_ids: Sequence[str],
               federated_type: tff.FederatedType):
    self._client_ids = client_ids
    self._federated_type = federated_type

  @property
  def federated_type(self) -> tff.FederatedType:
    return self._federated_type

  def select(self, num_clients: Optional[int] = None) -> Any:
    client_ids_sample = random.sample(self._client_ids, num_clients)
    data_uris = [f'uri://{i}' for i in client_ids_sample]
    return tff.framework.CreateDataDescriptor(
        arg_uris=data_uris, arg_type=self._federated_type)


class MyFederatedDataSource(tff.program.FederatedDataSource):

  def __init__(self, client_ids: Sequence[str],
               federated_type: tff.FederatedType):
    self._client_ids = client_ids
    self._federated_type = federated_type
    self._capabilities = [tff.program.Capability.RANDOM_UNIFORM]

  @property
  def federated_type(self) -> tff.FederatedType:
    return self._federated_type

  @property
  def capabilities(self) -> List[tff.program.Capability]:
    return self._capabilities

  def iterator(self) -> tff.program.FederatedDataSourceIterator:
    return MyFederatedDataSourceIterator(self._client_ids, self._federated_type)


train_data_source = MyFederatedDataSource(
    client_ids=emnist_train.client_ids, federated_type=dataset_type)
train_data_iterator = train_data_source.iterator()

Now we can run our federated learning training loop like so:

NUM_ROUNDS = 10

for round_num in range(2, NUM_ROUNDS + 1):
  round_train_data = train_data_iterator.select(NUM_CLIENTS)
  result = iterative_process.next(state, round_train_data)
  state = result.state
  metrics = result.metrics
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12357217), ('loss', 9.161968), ('num_examples', 4815), ('num_batches', 4815)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.20563674), ('loss', 7.0862083), ('num_examples', 4790), ('num_batches', 4790)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.30241227), ('loss', 5.6945825), ('num_examples', 4560), ('num_batches', 4560)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.3867347), ('loss', 4.7210026), ('num_examples', 4900), ('num_batches', 4900)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.42311886), ('loss', 4.205554), ('num_examples', 4585), ('num_batches', 4585)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.4501548), ('loss', 4.1297464), ('num_examples', 4845), ('num_batches', 4845)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.56590474), ('loss', 2.8927681), ('num_examples', 5250), ('num_batches', 5250)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.59917355), ('loss', 2.7431731), ('num_examples', 4840), ('num_batches', 4840)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.5717234), ('loss', 2.9738288), ('num_examples', 4845), ('num_batches', 4845)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])

Conclusion

This concludes the tutorial. We encourage you to explore the other tutorials we've developed to learn about the many other features of the TFF framework.