Sử dụng trình tối ưu hóa TFF trong quy trình lặp lại tùy chỉnh

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Đây là một thay thế cho xây dựng riêng Federated Learning Algorithm của bạn hướng dẫn và simple_fedavg dụ để xây dựng một quy trình lặp đi lặp lại tùy chỉnh cho các trung bình liên thuật toán. Hướng dẫn này sẽ sử dụng tối ưu TFF thay vì tối ưu Keras. Tính trừu tượng của trình tối ưu hóa TFF được mô tả là trạng thái trong trạng thái-ngoài để dễ dàng kết hợp hơn trong một quy trình lặp lại TFF. Các tff.learning API cũng chấp nhận tối ưu TFF như là đối số đầu vào.

Trước khi chúng ta bắt đầu

Trước khi chúng tôi bắt đầu, vui lòng chạy phần sau để đảm bảo rằng môi trường của bạn được thiết lập chính xác. Nếu bạn không thấy một lời chào, xin vui lòng tham khảo các cài đặt hướng dẫn để được hướng dẫn.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import functools
import attr
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

Chuẩn bị dữ liệu và mô hình

Việc xử lý dữ liệu EMNIST và mô hình rất giống với các simple_fedavg ví dụ.

only_digits=True

# Load dataset.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits)

# Define preprocessing functions.
def preprocess_fn(dataset, batch_size=16):

  def batch_format_fn(element):
    return (tf.expand_dims(element['pixels'], -1), element['label'])

  return dataset.batch(batch_size).map(batch_format_fn)

# Preprocess and sample clients for prototyping.
train_client_ids = sorted(emnist_train.client_ids)
train_data = emnist_train.preprocess(preprocess_fn)
central_test_data = preprocess_fn(
    emnist_train.create_tf_dataset_for_client(train_client_ids[0]))

# Define model.
def create_keras_model():
  """The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format = 'channels_last'
  input_shape = [28, 28, 1]

  max_pool = functools.partial(
      tf.keras.layers.MaxPooling2D,
      pool_size=(2, 2),
      padding='same',
      data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      conv2d(filters=32, input_shape=input_shape),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10 if only_digits else 62),
  ])

  return model

# Wrap as `tff.learning.Model`.
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=central_test_data.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

Quy trình lặp lại tùy chỉnh

Trong nhiều trường hợp, các thuật toán liên hợp có 4 thành phần chính:

  1. Bước truyền phát từ máy chủ đến máy khách.
  2. Một bước cập nhật ứng dụng khách cục bộ.
  3. Một bước tải lên từ máy khách đến máy chủ.
  4. Một bước cập nhật máy chủ.

Trong TFF, chúng tôi thường đại diện cho các thuật toán liên như một tff.templates.IterativeProcess (mà chúng tôi gọi là chỉ là một IterativeProcess suốt). Đây là một lớp học có chứa initializenext chức năng. Ở đây, initialize được sử dụng để khởi tạo máy chủ, và next sẽ thực hiện một vòng thông tin liên lạc của thuật toán liên.

Chúng tôi sẽ giới thiệu các thành phần khác nhau để xây dựng thuật toán tính trung bình liên kết (FedAvg), thuật toán này sẽ sử dụng một trình tối ưu hóa trong bước cập nhật máy khách và một trình tối ưu hóa khác trong bước cập nhật máy chủ. Logic cốt lõi của các cập nhật máy khách và máy chủ có thể được biểu thị dưới dạng khối TF thuần túy.

Khối TF: cập nhật máy khách và máy chủ

Trên mỗi khách hàng, một địa phương client_optimizer được khởi tạo và sử dụng để cập nhật các trọng số mô hình client. Trên máy chủ, server_optimizer sẽ sử dụng nhà nước từ vòng trước, và cập nhật trạng thái cho các vòng tiếp theo.

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs local training on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)
  # Initialize the client optimizer.
  trainable_tensor_specs = tf.nest.map_structure(
          lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
  optimizer_state = client_optimizer.initialize(trainable_tensor_specs)
  # Use the client_optimizer to update the local model.
  for batch in iter(dataset):
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data.
      outputs = model.forward_pass(batch)
    # Compute the corresponding gradient.
    grads = tape.gradient(outputs.loss, client_weights)
    # Apply the gradient using a client optimizer.
    optimizer_state, updated_weights = client_optimizer.next(
        optimizer_state, client_weights, grads)
    tf.nest.map_structure(lambda a, b: a.assign(b), 
                          client_weights, updated_weights)
  # Return model deltas.
  return tf.nest.map_structure(tf.subtract, client_weights, server_weights)
@attr.s(eq=False, frozen=True, slots=True)
class ServerState(object):
  trainable_weights = attr.ib()
  optimizer_state = attr.ib()

@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
  """Updates the server model weights."""
  # Use aggregated negative model delta as pseudo gradient. 
  negative_weights_delta = tf.nest.map_structure(
      lambda w: -1.0 * w, mean_model_delta)
  new_optimizer_state, updated_weights = server_optimizer.next(
      server_state.optimizer_state, server_state.trainable_weights, 
      negative_weights_delta)
  return tff.structure.update_struct(
      server_state,
      trainable_weights=updated_weights,
      optimizer_state=new_optimizer_state)

Khối TFF: tff.tf_computationtff.federated_computation

Giờ đây, chúng tôi sử dụng TFF để điều phối và xây dựng quy trình lặp lại cho FedAvg. Chúng ta phải quấn các khối TF định nghĩa ở trên với tff.tf_computation , và các phương pháp sử dụng TFF tff.federated_broadcast , tff.federated_map , tff.federated_mean trong một tff.federated_computation chức năng. Nó rất dễ dàng để sử dụng tff.learning.optimizers.Optimizer API với initializenext chức năng khi xác định một quá trình lặp đi lặp lại tùy chỉnh.

# 1. Server and client optimizer to be used.
server_optimizer = tff.learning.optimizers.build_sgdm(
    learning_rate=0.05, momentum=0.9)
client_optimizer = tff.learning.optimizers.build_sgdm(
    learning_rate=0.01)

# 2. Functions return initial state on server. 
@tff.tf_computation
def server_init():
  model = model_fn()
  trainable_tensor_specs = tf.nest.map_structure(
        lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
  optimizer_state = server_optimizer.initialize(trainable_tensor_specs)
  return ServerState(
      trainable_weights=model.trainable_variables,
      optimizer_state=optimizer_state)

@tff.federated_computation
def server_init_tff():
  return tff.federated_value(server_init(), tff.SERVER)

# 3. One round of computation and communication.
server_state_type = server_init.type_signature.result
print('server_state_type:\n', 
      server_state_type.formatted_representation())
trainable_weights_type = server_state_type.trainable_weights
print('trainable_weights_type:\n', 
      trainable_weights_type.formatted_representation())

# 3-1. Wrap server and client TF blocks with `tff.tf_computation`.
@tff.tf_computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
  return server_update(server_state, model_delta, server_optimizer)

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
print('tf_dataset_type:\n', 
      tf_dataset_type.formatted_representation())
@tff.tf_computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
  model = model_fn()
  return client_update(model, dataset, server_weights, client_optimizer)

# 3-2. Orchestration with `tff.federated_computation`.
federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
  # Server-to-client broadcast.
  server_weights_at_client = tff.federated_broadcast(
      server_state.trainable_weights)
  # Local client update.
  model_deltas = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))
  # Client-to-server upload and aggregation.
  mean_model_delta = tff.federated_mean(model_deltas)
  # Server update.
  server_state = tff.federated_map(
      server_update_fn, (server_state, mean_model_delta))
  return server_state

# 4. Build the iterative process for FedAvg.
fedavg_process = tff.templates.IterativeProcess(
    initialize_fn=server_init_tff, next_fn=run_one_round)
print('type signature of `initialize`:\n', 
      fedavg_process.initialize.type_signature.formatted_representation())
print('type signature of `next`:\n', 
      fedavg_process.next.type_signature.formatted_representation())
server_state_type:
 <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>
trainable_weights_type:
 <
  float32[5,5,1,32],
  float32[32],
  float32[5,5,32,64],
  float32[64],
  float32[3136,512],
  float32[512],
  float32[512,10],
  float32[10]
>
tf_dataset_type:
 <
  float32[?,28,28,1],
  int32[?]
>*
type signature of `initialize`:
 ( -> <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>@SERVER)
type signature of `next`:
 (<
  server_state=<
    trainable_weights=<
      float32[5,5,1,32],
      float32[32],
      float32[5,5,32,64],
      float32[64],
      float32[3136,512],
      float32[512],
      float32[512,10],
      float32[10]
    >,
    optimizer_state=<
      float32[5,5,1,32],
      float32[32],
      float32[5,5,32,64],
      float32[64],
      float32[3136,512],
      float32[512],
      float32[512,10],
      float32[10]
    >
  >@SERVER,
  federated_dataset={<
    float32[?,28,28,1],
    int32[?]
  >*}@CLIENTS
> -> <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>@SERVER)

Đánh giá thuật toán

Chúng tôi đánh giá hiệu suất trên tập dữ liệu đánh giá tập trung.

def evaluate(server_state):
  keras_model = create_keras_model()
  tf.nest.map_structure(
      lambda var, t: var.assign(t),
      keras_model.trainable_weights, server_state.trainable_weights)
  metric = tf.keras.metrics.SparseCategoricalAccuracy()
  for batch in iter(central_test_data):
    preds = keras_model(batch[0], training=False)
    metric.update_state(y_true=batch[1], y_pred=preds)
  return metric.result().numpy()
server_state = fedavg_process.initialize()
acc = evaluate(server_state)
print('Initial test accuracy', acc)

# Evaluate after a few rounds
CLIENTS_PER_ROUND=2
sampled_clients = train_client_ids[:CLIENTS_PER_ROUND]
sampled_train_data = [
    train_data.create_tf_dataset_for_client(client)
    for client in sampled_clients]
for round in range(20):
  server_state = fedavg_process.next(server_state, sampled_train_data)
acc = evaluate(server_state)
print('Test accuracy', acc)
Initial test accuracy 0.09677419
Test accuracy 0.13978495