用于图像分类的 TensorFlow Federated

准备工作

要编辑 Colab 笔记本,请转到“File”->“Save a copy in Drive”并对您的副本进行任何编辑。

在开始之前,请运行以下示例来确保您的环境已正确设置。如果未看到问候语,请参阅安装指南查看说明。

Upgrade tensorflow_federated and load TensorBoard

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码

让我们在模拟中测试联合学习。在本教程中,我们使用经典的 MNIST 训练示例来介绍 TFF 的联合学习 (FL) API 层 tff.learning - 一组更高级的接口,可用于执行常见类型的联合学习任务,例如针对在 TensorFlow 中实现的用户提供模型进行联合训练。

教程大纲

我们将训练一个模型来使用经典的 MNIST 数据集执行图像分类,过程中会使用神经网络学习对图像中的数字分类。在这种情况下,我们将模拟训练数据分布在不同设备上的联合学习。

部分

  1. 加载 TFF 库。
  2. 探索/预处理联合 EMNIST 数据集。
  3. 创建一个模型。
  4. 为训练建立联合平均过程。
  5. 分析训练指标。
  6. 设置联合评估计算。
  7. 分析评估指标。

准备输入数据

我们从数据开始。联合学习需要一个联合数据集,即来自多个用户的数据集合。联合数据通常是非独立同分布数据,这会带来一系列独特的挑战。根据使用模式,用户通常具有不同的数据分布。

为了方便实验,我们在 TFF 仓库中植入了一些数据集。

下面是我们加载示例数据集的方式。

# Code for loading federated data from TFF repository
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

load_data() 返回的数据集是 tff.simulation.datasets.ClientData 的实例,后者是一个允许您枚举用户集、构造表示特定用户的数据的 tf.data.Dataset 以及查询各个元素的结构的接口。

我们来探索数据集。

len(emnist_train.client_ids)
# Let's look at the shape of our data
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_dataset.element_spec
# Let's select an example dataset from one of our simulated clients
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

# Your code to get an example element from one client:
example_element = next(iter(example_dataset))

example_element['label'].numpy()
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

探索非独立同分布数据

## Example MNIST digits for one client
f = plt.figure(figsize=(20,4))
j = 0

for e in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(e['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1
# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12,7))
f.suptitle("Label Counts for a Sample of Clients")
for i in range(6):
  ds = emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[i])
  k = collections.defaultdict(list)
  for e in ds:
    k[e['label'].numpy()].append(e['label'].numpy())
  plt.subplot(2, 3, i+1)
  plt.title("Client {}".format(i))
  for j in range(10):
    plt.hist(k[j], density=False, bins=[0,1,2,3,4,5,6,7,8,9,10])
# Let's play around with the emnist_train dataset.
# Let's explore the non-iid charateristic of the example data.

for i in range(5):
  ds = emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[i])
  k = collections.defaultdict(list)
  for e in ds:
    k[e['label'].numpy()].append(e['pixels'].numpy())
  f = plt.figure(i, figsize=(12,5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mn_img = np.mean(k[j],0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mn_img.reshape((28,28)))#,cmap='gray') 
    plt.axis('off')

# Each client has different mean images -- each client will be nudging the model
# in their own directions.

预处理数据

由于数据已经是一个 tf.data.Dataset,可以使用数据集转换来完成预处理。有关转换的详情,请参阅此处

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER=10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

我们验证一下这是否有效。

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch

下面是一个简单的辅助函数,它将从给定的用户集构造一个数据集列表,作为一轮训练或评估的输入。

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

现在,我们如何选择客户端?

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

# Your code to get the federated dataset here for the sampled clients:
federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))

使用 Keras 创建模型

如果您正在使用 Keras,可能已经拥有构造 Keras 模型的代码。下面是一个足以满足我们需求的简单模型示例。

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

使用 Keras 进行集中训练

## Centralized training with keras ---------------------------------------------

# This is separate from the TFF tutorial, and demonstrates how to train a
# Keras model in a centralized fashion (contrasting training in a federated env)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Preprocess the data (these are NumPy arrays)
x_train = x_train.reshape(60000, 784).astype("float32") / 255

y_train = y_train.astype("float32")

mod = create_keras_model()
mod.compile(
    optimizer=tf.keras.optimizers.RMSprop(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)
h = mod.fit(
    x_train,
    y_train,
    batch_size=64,
    epochs=2
)

# ------------------------------------------------------------------------------

使用 Keras 模型进行联合训练

要将任意模型与 TFF 一起使用,需要将该模型封装在 tff.learning.Model 接口的实例中。

可以在此处找到更多您可以添加的 Keras 指标。

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

基于联合数据训练模型

现在,我们有一个封装为 tff.learning.Model 的模型可与 TFF 一起使用,我们可以让 TFF 通过调用辅助函数 tff.learning.build_federated_averaging_process 来构造联合平均算法,具体如下所示。

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    # Add server optimizer here!
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

刚刚发生了什么?TFF 构造了一对联合计算并将它们打包到一个 tff.templates.IterativeProcess 中,其中这些计算可作为一对属性 initializenext 使用。

迭代过程通常由控制循环驱动,例如:

def initialize():
  ...

def next(state):
  ...

iterative_process = IterativeProcess(initialize, next)
state = iterative_process.initialize()
for round in range(num_rounds):
  state = iterative_process.next(state)

我们调用 initialize 计算来构造服务器状态。

state = iterative_process.initialize()

这对联合计算中的第二个属性 next 代表单轮联合平均,其中包括将服务器状态(包括模型参数)推送到客户端、基于其本地数据进行设备端训练、收集模型更新并计算平均值,以及在服务器端生成一个新的更新模型。

让我们运行一轮训练并呈现结果。我们可以将上面已经生成的联合数据用于用户样本。

# Run one single round of training.
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics['train']))

让我们再运行几个轮次。如前面所述,通常在这个时候,您将从新随机选择的用户样本中为每个轮次选择模拟数据的一个子集,以便模拟用户不断往返的真实部署,但在此交互式笔记本中,出于演示的目的,我们只重用相同的用户,以便系统能够快速收敛。

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics['train']))

每轮联合训练后训练损失都在减少,这表明模型正在收敛。但是,这些训练指标有一些重要的注意事项,请参阅本教程后面的评估部分。

在 TensorBoard 中显示模型指标
接下来,我们使用 Tensorboard 呈现来自这些联合计算的指标。

我们首先创建用于写入指标的目录和相应的摘要编写器。

import os
import shutil

logdir = "/tmp/logs/scalars/training/"
if os.path.exists(logdir):
  shutil.rmtree(logdir)

# Your code to create a summary writer:
summary_writer = tf.summary.create_file_writer(logdir)

state = iterative_process.initialize()

使用同一个摘要编写器绘制相关的标量指标。

with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics['train'].items():
      tf.summary.scalar(name, value, step=round_num)

使用上面指定的根日志目录启动 TensorBoard。加载数据可能需要几秒钟的时间。

%tensorboard --logdir /tmp/logs/scalars/ --port=0

要以相同的方式查看评估指标,您可以创建一个单独的 eval 文件夹,如“logs/scalars/eval”,以写入 TensorBoard。

评估

要对联合数据执行评估,您可以使用 tff.learning.build_federated_evaluation 函数构造另一个专为此目的而设计的联合计算,并将模型构造函数作为参数传递。

# Construct federated evaluation computation here:
evaluation = tff.learning.build_federated_evaluation(model_fn)

现在,我们编译一个联合数据的测试样本并对测试数据重新运行评估。数据将来自不同的用户样本,但来自一个独特的保留数据集。

import random
shuffled_ids = emnist_test.client_ids.copy()
random.shuffle(shuffled_ids)
sample_clients = shuffled_ids[0:NUM_CLIENTS]

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
# Run evaluation on the test data here, using the federated model produced from 
# training:
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)

本教程到此结束。我们鼓励您使用参数(例如,批次大小、用户数量、周期、学习率等)修改上面的代码以模拟每个轮次中用户随机样本的训练,并探索我们已经开发的其他教程。

构建您自己的 FL 算法

在之前的教程中,我们学习了如何设置模型和数据流水线,并使用它们通过 tff.learning API 执行联合训练。

当然,这只是 FL 研究的冰山一角。在本教程中,我们将探讨如何在依赖 tff.learning API 的情况下实现联合学习算法。我们要实现的目标如下:

目标:

  • 了解联合学习算法的一般结构。
  • 探索 TFF 的 Federated Core
  • 使用 Federated Core 直接实现联合平均。

准备输入数据

我们首先加载和预处理包含在 TFF 中的 EMNIST 数据集。我们基本上使用与第一个教程中相同的代码。

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)
client_ids = np.random.choice(emnist_train.client_ids, size=NUM_CLIENTS, replace=False)

federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

准备模型

我们使用与第一个教程相同的模型,它有一个隐藏层,后面是一个 softmax 层。

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

我们将此 Keras 模型封装为 tff.learning.Model

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

自定义 FL 算法

虽然 tff.learning API 包含联合平均的许多变体,但也有许多其他算法不适合此框架。例如,您可能想要添加正则化、裁剪或更复杂的算法,例如f联合 GAN 训练。另外,您可能还对联合分析感兴趣。

对于这些更高级的算法,我们必须编写自己的自定义 FL 算法。

一般而言,FL 算法包括 4 个主要组件:

  1. 服务器到客户端的广播步骤。
  2. 本地客户端更新步骤。
  3. 客户端到服务器的上传步骤。
  4. 服务器更新步骤。

在 TFF 中,我们通常将联合算法表示为 IterativeProcess。这只是一个包含 initialize_fnnext_fn 的类。initialize_fn 将用于初始化服务器,next_fn 将执行一个通信轮次的联合平均。我们为 FedAvg 的迭代过程编写一个框架。

首先,我们有一个初始化函数,它简单地创建一个 tff.learning.Model,并返回其可训练权重。

def initialize_fn():
  model = model_fn()
  return model.weights.trainable

此函数看起来不错,但正如我们稍后会看到的,我们需要做一点小小的修改,使其成为 TFF 计算。

我们还想绘制 next_fn

def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

我们将专注于分别实现这四个组件。我们将首先关注可以在纯 TensorFlow 中实现的部分,即客户端和服务器更新步骤。

TensorFlow 块

客户端更新

我们将使用我们的 tff.learning.Model 以与训练 TF 模型基本相同的方式进行客户端训练。特别是,我们将使用 tf.GradientTape 来计算数据批次的梯度,然后使用 client_optimizer 应用这些梯度。

请注意,每个 tff.learning.Model 实例都有一个 weights 属性和两个子属性:

  • trainable:与可训练层对应的张量列表。
  • non_trainable:与不可训练层对应的张量列表。

出于说明目的,我们将只使用可训练权重(因为我们的模型只有这些!)。

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.weights.trainable
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

服务器更新

服务器更新将需要更少的工作。我们将实现普通联合平均,其中只需要用客户端模型权重的平均值替换服务器模型权重。同样,我们将只关注可训练权重。

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.weights.trainable
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

请注意,上面的代码片段显然有些多余,因为我们可以简单地返回 mean_client_weights。但是,联合平均的更高级实现可以使用 mean_client_weights 和更复杂的技术,例如动量或自适应。

到目前为止,我们只编写了纯 TensorFlow 代码。我们有意这样设计,因为 TFF 允许您使用许多您已经熟悉的 TensorFlow 代码。但是,现在我们必须指定编排逻辑,即规定服务器向客户端广播哪些内容以及客户端向服务器上传哪些内容的逻辑。

这将需要 TFF 的“Federated Core”。

Federated Core 简介

Federated Core (FC) 是一组用作 tff.learning API 基础的低级接口。不过,这些接口不仅限于学习。事实上,它们可用于对分布式数据进行分析和许多其他计算。

在高层级上,Federated Core 是一个开发环境,可让简洁表达的程序逻辑能够将 TensorFlow 代码与分布式通信算子(例如分布式和与广播)相结合。目标是让研究员和从业者明确控制他们系统中的分布式通信,而不需要系统实现细节(例如指定点对点网络消息交换)。

一个关键点在于,TFF 是专为隐私保护而设计。因此,它允许显式控制数据驻留的位置,以防止在集中式服务器位置不必要地积累数据。

联合数据

与 TensorFlow 中作为基本概念之一的“张量”类似,TFF 中的一个关键概念是“联合数据”,它指的是分布式系统中跨一组设备托管的数据项的集合(例如客户端数据集,或服务器模型权重)。我们将跨所有设备的整个数据项集合建模为单个联合值

例如,假设我们有客户端设备,每个设备都有一个表示传感器温度的浮点数。我们可以通过以下代码将其表示为联合浮点数

federated_float_on_clients = tff.type_at_clients(tf.float32)

联合类型由其组成成员(例如 tf.float32)的类型 T 和一组 G 设备指定。我们将关注 Gtff.CLIENTStff.SERVER 的情况。此类联合类型表示为 {T}@G,具体如下所示。

str(federated_float_on_clients)

为什么我们如此关心布局?TFF 的一个关键目标是能够编写可以部署在真实分布式系统上的代码。这意味着推断哪些设备子集执行哪些代码以及不同的数据段驻留在何处至关重要。

TFF 关注三个信息:数据、数据放置的位置以及数据如何转换。前两个封装在联合类型中,而最后一个封装在联合计算中。

联合计算

TFF 是一种强类型函数式编程环境,其基本单元是联合计算。这些单元是接受联合值作为输入并返回联合值作为输出的逻辑片段。

例如,假设我们想要计算客户端传感器上温度的平均值。我们可以定义以下代码(使用我们的联合浮点数):

@tff.federated_computation(tff.type_at_clients(tf.float32))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

您可能会问,这和 TensorFlow 中的 tf.function 装饰器有什么不同?关键的答案是 tff.federated_computation 生成的代码既不是 TensorFlow 也不是 Python 代码;它是以独立于内部平台的胶水语言编写的分布式系统规范。

虽然这听起来很复杂,但您可以将 TFF 计算视为具有明确定义的类型签名的函数。可以直接查询这些类型签名。

str(get_average_temperature.type_signature)

tff.federated_computation 接受联合类型 <float>@CLIENTS 的参数,并返回联合类型 <float>@SERVER 的值。联合计算可以从服务器到客户端、从客户端到客户端或者从服务器到服务器。另外,联合计算的构成也可以像普通函数一样,只要它们的类型签名匹配即可。

为了支持开发,TFF 允许您调用 tff.federated_computation 作为 Python 函数。例如,我们可以调用

get_average_temperature([68.5, 70.3, 69.8])

非 Eager 计算和 TensorFlow

有两个关键限制需要注意。首先,当 Python 解释器遇到 tff.federated_computation 装饰器时,该函数会被跟踪一次并序列化以备将来使用。因此,TFF 计算从根本上来说是非 Eager 计算。这种行为有点类似于 TensorFlow 中的 tf.function 装饰器。

其次,联合计算只能由联合算子(例如 tff.federated_mean)组成,不能包含 TensorFlow 算子。TensorFlow 代码必须限制在用 tff.tf_computation 装饰的块中。大多数普通 TensorFlow 代码都可以直接进行装饰,例如下面的函数,它会取一个数字并加 0.5

@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)

这些也有类型签名,但没有安置。例如,我们可以调用

str(add_half.type_signature)

在这里,我们看到了 tff.federated_computationtff.tf_computation 之间的重要区别。前者有显式安置,而后者没有。

我们可以通过指定安置在联合计算中使用 tff.tf_computation 块。我们创建一个增加一半值的函数,但仅适用于客户端的联合浮点数。我们可以通过使用 tff.federated_map 来做到这一点,它会应用给定的 tff.tf_computation,同时保留安置。

@tff.federated_computation(tff.type_at_clients(tf.float32))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

此函数与 add_half 几乎相同,不同之处在于它只接受安置位于 tff.CLIENTS 中的值,并返回具有相同安置的值。我们可以在它的类型签名中看到这一点:

str(add_half_on_clients.type_signature)

总结:

  • TFF 对联合值进行运算。
  • 每个联合值都有一个联合类型,而联合类型包含类型(例如 tf.float32)和安置(例如 tff.CLIENTS)。
  • 联合值可以使用联合计算进行转换,联合计算必须使用 tff.federated_computation 和联合类型签名进行装饰。
  • TensorFlow 代码必须包含在带有 tff.tf_computation 装饰器的块中。
  • 随后可以将这些块合并到联合计算中。

构建您自己的 FL 算法(第 2 部分)

现在我们已经了解了 Federated Core,我们可以构建自己的联合学习算法。请记住上面,我们已为我们的算法定义了 initialize_fnnext_fnnext_fn 将利用我们使用纯 TensorFlow 代码定义的 client_updateserver_update

不过,为了使我们的算法成为联合计算,我们需要 next_fninitialize_fn 均为 tff.federated_computations

TensorFlow 联合块

创建初始化计算

初始化函数将非常简单:我们将使用 model_fn 创建一个模型。不过,请记住,我们必须使用 tff.tf_computation 分离出我们的 TensorFlow 代码。

@tff.tf_computation
def server_init():
  model = model_fn()
  return model.weights.trainable

随后,我们可以使用 tff.federated_value 将其直接传递到联合计算中。

@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

创建 next_fn

我们现在使用我们的客户端和服务器更新代码以编写实际的算法。我们首先将我们的 client_update 转换为 tff.tf_computation,它接受客户端数据集和服务器权重,并输出更新的客户端权重张量。

我们将需要相应的类型来正确装饰我们的函数。幸运的是,服务器权重的类型可以从我们的模型中直接提取。

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

我们看一下数据集类型签名。请记住,我们获取了 28 x 28 张图像(带整数标签)并将它们展平。

str(tf_dataset_type)

我们还可以使用上面的 server_init 函数提取模型权重类型。

model_weights_type = server_init.type_signature.result

检查类型签名,我们将能够看到我们模型的架构!

str(model_weights_type)

我们现在可以为客户端更新创建我们的 tff.tf_computation

@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

服务器更新的 tff.tf_computation 版本可以使用我们已经提取的类型以类似方式定义。

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

最后,我们需要创建 tff.federated_computation 将所有这些结合在一起。此函数将接受两个联合值,一个对应于服务器权重(采用 tff.SERVER 布局),另一个对应于客户端数据集(采用 tff.CLIENTS 布局)。

请注意,这两种类型均已在上面定义!我们只需要使用 tff.type_at_{server/clients}` 为它们提供适当的布局。

federated_server_type = tff.type_at_server(model_weights_type)
federated_dataset_type = tff.type_at_clients(tf_dataset_type)

还记得 FL 算法的 4 个元素吗?

  1. 服务器到客户端的广播步骤。
  2. 本地客户端更新步骤。
  3. 客户端到服务器的上传步骤。
  4. 服务器更新步骤。

现在,我们已经构建了上面的元素,每个部分都可以用一行 TFF 代码简洁地表示。这种简洁性就是为什么我们必须格外小心地指定诸如联合类型之类的内容!

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

我们现在有一个 tff.federated_computation 用于算法初始化和运行算法的一个步骤。为了完成我们的算法,我们将这些传递给 tff.templates.IterativeProcess

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

我们看看迭代过程的 initializenext 函数的类型签名

str(federated_algorithm.initialize.type_signature)

这反映了 federated_algorithm.initialize 是一个返回单层模型(具有 784×10 权重矩阵和 10 个偏置单元)的无参数函数的事实。

str(federated_algorithm.next.type_signature)

在这里,我们看到 federated_algorithm.next 接受服务器模型和客户端数据,并返回更新的服务器模型。

评估算法

我们来运行几个轮次,看看损失如何变化。首先,我们将使用第二个教程中讨论的集中式方式定义评估函数。

我们首先创建一个集中式评估数据集,然后应用我们用于训练数据的相同预处理。

请注意,出于计算效率的原因,我们仅 take 前 1000 个元素,但通常我们会使用整个测试数据集。

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients().take(1000)
central_emnist_test = preprocess(central_emnist_test)

接下来,我们编写一个接受服务器状态的函数,并使用 Keras 对测试数据集进行评估。如果您熟悉 tf.Keras,那么这一切都轻车熟路,但要注意 set_weights 的使用!

def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

现在,我们初始化我们的算法并对测试集进行评估。

server_state = federated_algorithm.initialize()
evaluate(server_state)

我们训练几个轮次,看看有什么变化。

for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)

我们看到损失函数略有下降。虽然跳跃很小,但请注意,我们只对一小部分客户端进行了 10 轮训练。为了获得更好的结果,我们可能需要进行数百轮甚至数千轮训练。

修改我们的算法

此时此刻,我们停下来思考一下我们已经完成的工作。我们通过将纯 TensorFlow 代码(用于客户端和服务器更新)与来自 TFF Federated Core 的联合计算相结合,直接实现了联合平均。

为了执行更复杂的学习,我们可以简单地改变上面的内容。特别是,通过编辑上面的纯 TF 代码,我们可以更改客户端执行训练的方式,或者服务器更新其模型的方式。

挑战:梯度裁剪添加到 client_update 函数。