Working with tff's ClientData.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

The notion of a dataset keyed by clients (e.g. users) is essential to federated computation as modeled in TFF. TFF provides the interface tff.simulation.datasets.ClientData to abstract over this concept, and the datasets which TFF hosts (stackoverflow, shakespeare, emnist, cifar100, and gldv2) all implement this interface.

If you are working on federated learning with your own dataset, TFF strongly encourages you to either implement the ClientData interface or use one of TFF's helper functions to generate a ClientData which represents your data on disk, e.g. tff.simulation.datasets.ClientData.from_clients_and_fn.

As most of TFF's end-to-end examples start with ClientData objects, implementing the ClientData interface with your custom dataset will make it easier to spelunk through existing code written with TFF. Further, the tf.data.Datasets which ClientData constructs can be iterated over directly to yield structures of numpy arrays, so ClientData objects can be used with any Python-based ML framework before moving to TFF.

There are several patterns with which you can make your life easier if you intend to scale up your simulations to many machines or deploy them. Below we will walk through a few of the ways we can use ClientData and TFF to make our small-scale iteration-to large-scale experimentation-to production deployment experience as smooth as possible.

Which pattern should I use to pass ClientData into TFF?

We will discuss two usages of TFF's ClientData in depth; if you fit in either of the two categories below, you will clearly prefer one over the other. If not, you may need a more detailed understanding of the pros and cons of each to make a more nuanced choice.

  • I want to iterate as quickly as possible on a local machine; I don't need to be able to easily take advantage of TFF's distributed runtime.

    • You want to pass tf.data.Datasets in to TFF directly.
    • This allows you to program imperatively with tf.data.Dataset objects, and process them arbitrarily.
    • It provides more flexibility than the option below; pushing logic to the clients requires that this logic be serializable.
  • I want to run my federated computation in TFF's remote runtime, or I plan to do so soon.

    • In this case you want to map dataset construction and preprocessing to clients.
    • This results in you passing simply a list of client_ids directly to your federated computation.
    • Pushing dataset construction and preprocessing to the clients avoids bottlenecks in serialization, and significantly increases performance with hundreds-to-thousands of clients.

Set up open-source environment

Import packages

Manipulating a ClientData object

Let's begin by loading and exploring TFF's EMNIST ClientData:

client_data, _ = tff.simulation.datasets.emnist.load_data()

Inspecting the first dataset can tell us what type of examples are in the ClientData.

first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

Note that the dataset yields collections.OrderedDict objects that have pixels and label keys, where pixels is a tensor with shape [28, 28]. Suppose we wish to flatten our inputs out to shape [784]. One possible way we can do this would be to apply a pre-processing function to our ClientData object.

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)


preprocessed_client_data = client_data.preprocess(preprocess_dataset)

# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

We may want in addition to perform some more complex (and possibly stateful) preprocessing, for example shuffling.

def preprocess_and_shuffle(dataset):
  """Applies `preprocess_dataset` above and shuffles the result."""
  preprocessed = preprocess_dataset(dataset)
  return preprocessed.shuffle(buffer_size=5)

preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)

# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

Interfacing with a federated_language.Computation

Now that we can perform some basic manipulations with ClientData objects, we are ready to feed data to a federated_language.Computation. We define a tff.templates.IterativeProcess which implements Federated Averaging, and explore different methods of passing it data.

keras_model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=(784,)),
    tf.keras.layers.Dense(10, kernel_initializer='zeros'),
])
tff_model = tff.learning.models.functional_model_from_keras(
    keras_model,
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # Note: input spec is the _batched_ shape, and includes the
    # label tensor which will be passed to the loss function. This model is
    # therefore configured to accept data _after_ it has been preprocessed.
    input_spec=collections.OrderedDict(
        x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
        y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64),
    ),
    metrics_constructor=collections.OrderedDict(
        loss=lambda: tf.keras.metrics.SparseCategoricalCrossentropy(
            from_logits=True
        ),
        accuracy=tf.keras.metrics.SparseCategoricalAccuracy,
    ),
)

trainer = tff.learning.algorithms.build_weighted_fed_avg(
    tff_model,
    client_optimizer_fn=tff.learning.optimizers.build_sgdm(learning_rate=0.01),
)

Before we begin working with this IterativeProcess, one comment on the semantics of ClientData is in order. A ClientData object represents the entirety of the population available for federated training, which in general is not available to the execution environment of a production FL system and is specific to simulation. ClientData indeed gives the user the capacity to bypass federated computing entirely and simply train a server-side model as usual via ClientData.create_tf_dataset_from_all_clients.

TFF's simulation environment puts the researcher in complete control of the outer loop. In particular this implies considerations of client availability, client dropout, etc, must be addressed by the user or Python driver script. One could for example model client dropout by adjusting the sampling distribution over your ClientData's client_ids such that users with more data (and correspondingly longer-running local computations) would be selected with lower probability.

In a real federated system, however, clients cannot be selected explicitly by the model trainer; the selection of clients is delegated to the system which is executing the federated computation.

Passing tf.data.Datasets directly to TFF

One option we have for interfacing between a ClientData and an IterativeProcess is that of constructing tf.data.Datasets in Python, and passing these datasets to TFF.

Notice that if we use our preprocessed ClientData the datasets we yield are of the appropriate type expected by our model defined above.

selected_client_ids = preprocessed_and_shuffled.client_ids[:10]

preprocessed_data_for_clients = [
    preprocessed_and_shuffled.create_tf_dataset_for_client(
        selected_client_ids[i]
    )
    for i in range(10)
]

state = trainer.initialize()
for _ in range(5):
  t1 = time.time()
  result = trainer.next(state, preprocessed_data_for_clients)
  state = result.state
  train_metrics = result.metrics['client_work']['train']
  t2 = time.time()
  print(f'loss {train_metrics["loss"]:.2f}, round time {t2 - t1:.2f} seconds')
loss 2.89, round time 2.35 seconds
loss 3.05, round time 2.26 seconds
loss 2.80, round time 0.63 seconds
loss 2.94, round time 3.18 seconds
loss 3.17, round time 2.44 seconds

If we take this route, however, we will be unable to trivially move to multimachine simulation. The datasets we construct in the local TensorFlow runtime can capture state from the surrounding python environment, and fail in serialization or deserialization when they attempt to reference state which is no longer available to them. This can manifest for example in the inscrutable error from TensorFlow's tensor_util.cc:

Check failed: DT_VARIANT == input.dtype() (21 vs. 20)

Mapping construction and preprocessing over the clients

To avoid this issue, TFF recommends its users to consider dataset instantiation and preprocessing as something that happens locally on each client, and to use TFF's helpers or federated_map to explicitly run this preprocessing code at each client.

Conceptually, the reason for preferring this is clear: in TFF's local runtime, the clients only "accidentally" have access to the global Python environment due to the fact that the entire federated orchestration is happening on a single machine. It is worthwhile noting at this point that similar thinking gives rise to TFF's cross-platform, always-serializable, functional philosophy.

TFF makes such a change simple via ClientData's attribute dataset_computation, a federated_language.Computation which takes a client_id and returns the associated tf.data.Dataset.

Note that preprocess simply works with dataset_computation; the dataset_computation attribute of the preprocessed ClientData incorporates the entire preprocessing pipeline we just defined:

print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing:
(str -> <label=int32,pixels=float32[28,28]>*)


dataset computation with preprocessing:
(str -> <x=float32[?,784],y=int64[?,1]>*)

We could invoke dataset_computation and receive an eager dataset in the Python runtime, but the real power of this approach is exercised when we compose with an iterative process or another computation to avoid materializing these datasets in the global eager runtime at all. TFF provides a helper function tff.simulation.compose_dataset_computation_with_iterative_process which can be used to do exactly this.

trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
    preprocessed_and_shuffled.dataset_computation, trainer)

Both this tff.templates.IterativeProcesses and the one above run the same way; but former accepts preprocessed client datasets, and the latter accepts strings representing client ids, handling both dataset construction and preprocessing in its body--in fact state can be passed between the two.

for _ in range(5):
  t1 = time.time()
  result = trainer_accepting_ids.next(state, selected_client_ids)
  state = result.state
  train_metrics = result.metrics['client_work']['train']
  t2 = time.time()
  print(f'loss {train_metrics["loss"]:.2f}, round time {t2 - t1:.2f} seconds')

Scaling to large numbers of clients

trainer_accepting_ids can immediately be used in TFF's multimachine runtime, and avoids materializing tf.data.Datasets and the controller (and therefore serializing them and sending them out to the workers).

This significantly speeds up distributed simulations, especially with a large number of clients, and enables intermediate aggregation to avoid similar serialization/deserialization overhead.

Optional deepdive: manually composing preprocessing logic in TFF

TFF is designed for compositionality from the ground up; the kind of composition just performed by TFF's helper is fully within our control as users. We could have manually compose the preprocessing computation we just defined with the trainer's own next quite simply:

selected_clients_type = federated_language.FederatedType(
    preprocessed_and_shuffled.dataset_computation.type_signature.parameter,
    federated_language.CLIENTS,
)


@tff.federated_computation(
    trainer.next.type_signature.parameter[0], selected_clients_type
)
def new_next(server_state, selected_clients):
  preprocessed_data = federated_language.federated_map(
      preprocessed_and_shuffled.dataset_computation, selected_clients
  )
  return trainer.next(server_state, preprocessed_data)


manual_trainer_with_preprocessing = tff.templates.IterativeProcess(
    initialize_fn=trainer.initialize, next_fn=new_next
)

In fact, this is effectively what the helper we used is doing under the hood (plus performing appropriate type checking and manipulation). We could even have expressed the same logic slightly differently, by serializing preprocess_and_shuffle into a federated_language.Computation, and decomposing the federated_map into one step which constructs un-preprocessed datasets and another which runs preprocess_and_shuffle at each client.

We can verify that this more-manual path results in computations with the same type signature as TFF's helper (modulo parameter names):

print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,client_data={str}@CLIENTS> -> <state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,metrics=<distributor=<>,client_work=<train=<loss=float32,accuracy=float32>>,aggregator=<mean_value=<>,mean_weight=<>>,finalizer=<update_non_finite=int32>>@SERVER>)
(<server_state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,selected_clients={str}@CLIENTS> -> <state=<global_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,distributor=<>,client_work=<>,aggregator=<value_sum_process=<>,weight_sum_process=<>>,finalizer=<learning_rate=float32>>@SERVER,metrics=<distributor=<>,client_work=<train=<loss=float32,accuracy=float32>>,aggregator=<mean_value=<>,mean_weight=<>>,finalizer=<update_non_finite=int32>>@SERVER>)