In a production environment, the raw data for a federated computation is typically distributed across machines and requires special preprocessing and loading before it's usable.
This tutorial describes how to load data stored in those remote locations with TFF's DataBackend
and DataExecutor
interfaces. But to keep the example simple, the dataset will exist entirely in memory and we'll smiulate the fetching as if the dataset was partitioned over a network.
![]() |
![]() |
![]() |
![]() |
Before we start
Before we start, please run the following to make sure that your environment is correctly setup. If you don't see a greeting, please refer to the Installation guide for instructions.
Set up open-source environment
!pip install --quiet --upgrade tensorflow-federated
!pip install --quiet --upgrade nest-asyncio
import nest_asyncio
nest_asyncio.apply()
Import packages
import collections
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
np.random.seed(0)
Preparing the input data
Let's begin by loading TFF's federated version of the EMNIST dataset from the built-in repository:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
We'll construct a preprocessing function to transform the raw examples in the EMNIST dataset from 28x28
images into 784
-element arrays. Additionally, the function will shuffle the individual examples, and rename the features from pixels
and label
, to x
and y
for use with Keras. We also throw in a repeat
over the data set to run several epochs.
NUM_CLIENTS = 10
NUM_EPOCHS = 5
SHUFFLE_BUFFER = 100
def preprocess(dataset):
def map_fn(element):
return collections.OrderedDict(
x=tf.reshape(element['pixels'], [-1, 784]),
y=tf.reshape(element['label'], [-1, 1]))
return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).map(map_fn)
Let's verify this works:
example_dataset = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[0])
preprocessed_example_dataset = preprocess(example_dataset)
We'll use the EMNIST dataset to train a model by loading and preprocessing individual clients (emulating distinct partitions) through an implementation of DataBackend
.
Defining a DataBackend
We need an instance of DataBackend
to instruct TFF workers (these are the processes that run the client-side of federated computations) how to load and tranform the local data stored in remote locations. A DataBackend
is a programmatic construct that resolves symbolic references, represented as application-specific URIs, to materialized payloads that downstream TFF operations can process. Specifically, a DataBackend
object is wrapped by a DataExecutor
, which queries the object when the TFF runtime encounters an operation that fetches the data.
In this example, an Id to a client is encoded in a URI, which is parsed by our DataBackend
definition to retrieve the corresponding client data, convert it to tf.Dataset
, and then apply our preprocess
function.
class TestDataBackend(tff.framework.DataBackend):
async def materialize(self, data, type_spec):
client_id = int(data.uri[-1])
client_dataset = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[client_id])
return preprocess(client_dataset)
Plugging the DataBackend
into the ExecutionContext
TFF computations are invoked by an ExecutionContext
and in order for data URIs defined in TFF computations to be understood at runtime, a custom context must be defined that includes a pointer to the DataBackend
we just created, so URIs can be properly resolved.
The DataBackend
works in tandem with DataExecutor
to supply the executor with operable data that the executor can relay to requesting executors in order to complete a TFF computation.
def ex_fn(
device: tf.config.LogicalDevice) -> tff.framework.DataExecutor:
return tff.framework.DataExecutor(
tff.framework.EagerTFExecutor(device),
data_backend=TestDataBackend())
factory = tff.framework.local_executor_factory(leaf_executor_fn=ex_fn)
ctx = tff.framework.ExecutionContext(executor_fn=factory)
tff.framework.set_default_context(ctx)
Training the model
Now we are ready to train a model in a federated fashion. Lets define a Keras model along with training hyperparameters:
def create_keras_model():
return tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer='zeros'),
tf.keras.layers.Softmax(),
])
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=preprocessed_example_dataset.element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
We can pass this TFF-wrapped definition of our model
to a Federated Averaging algorithm by invoking the helper
function tff.learning.algorithms.build_weighted_fed_avg
, as follows:
iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
state = iterative_process.initialize()
The initialize
computation returns the initial state of the
Federated Averaging process.
To run a round of training, we need to construct a sample of data by organizing a sample of URI references as follows:
element_type = tff.types.StructWithPythonType(
preprocessed_example_dataset.element_spec,
container_type=collections.OrderedDict)
dataset_type = tff.types.SequenceType(element_type)
data_uris = [f'uri://{i}' for i in range(5)]
data_handle = tff.framework.CreateDataDescriptor(arg_uris=data_uris, arg_type=dataset_type)
Now we can round a round of training:
result = iterative_process.next(state, data_handle)
state = result.state
metrics = result.metrics
print('round 1, metrics={}'.format(metrics))
round 1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.11625), ('loss', 12.682652), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
And we can run a few more rounds:
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
result = iterative_process.next(state, data_handle)
state = result.state
metrics = result.metrics
print('round {:2d}, metrics={}'.format(round_num, metrics))
round 2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12375), ('loss', 10.2836895), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.17916666), ('loss', 7.733705), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.29458332), ('loss', 5.6188993), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.38541666), ('loss', 4.4057455), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.47041667), ('loss', 3.512454), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.535), ('loss', 3.0268242), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.5729167), ('loss', 2.7468147), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.62416667), ('loss', 2.3982067), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.62333333), ('loss', 2.3998983), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
Conclusion
This concludes the tutorial. We encourage you to explore the other tutorials we've developed to learn about the many other features of the TFF framework.