استخدم محسنات TFF في عملية تكرارية مخصصة

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

هذا هو بديل ل بناء الخاصة بك الاتحادية التعلم خوارزمية البرنامج التعليمي و simple_fedavg سبيل المثال لبناء عملية تكرارية مخصصة لل المتوسط الاتحادية الخوارزمية. هذا البرنامج التعليمي سوف تستخدم أبتيميزر TFF بدلا من أبتيميزر Keras. تم تصميم تجريد مُحسِّن TFF ليكون في حالة في الحالة الخارجية ليكون من الأسهل دمجه في عملية تكرارية لـ TFF. و tff.learning واجهات برمجة التطبيقات أيضا قبول أبتيميزر TFF كوسيطة الإدخال.

قبل أن نبدأ

قبل أن نبدأ ، يرجى تشغيل ما يلي للتأكد من إعداد بيئتك بشكل صحيح. إذا كنت لا ترى تحية، يرجى الرجوع إلى تركيب دليل للتعليمات.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import functools
import attr
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

تجهيز البيانات والنموذج

معالجة البيانات EMNIST والنموذج هي مشابهة جدا ل simple_fedavg سبيل المثال.

only_digits=True

# Load dataset.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits)

# Define preprocessing functions.
def preprocess_fn(dataset, batch_size=16):

  def batch_format_fn(element):
    return (tf.expand_dims(element['pixels'], -1), element['label'])

  return dataset.batch(batch_size).map(batch_format_fn)

# Preprocess and sample clients for prototyping.
train_client_ids = sorted(emnist_train.client_ids)
train_data = emnist_train.preprocess(preprocess_fn)
central_test_data = preprocess_fn(
    emnist_train.create_tf_dataset_for_client(train_client_ids[0]))

# Define model.
def create_keras_model():
  """The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format = 'channels_last'
  input_shape = [28, 28, 1]

  max_pool = functools.partial(
      tf.keras.layers.MaxPooling2D,
      pool_size=(2, 2),
      padding='same',
      data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      conv2d(filters=32, input_shape=input_shape),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10 if only_digits else 62),
  ])

  return model

# Wrap as `tff.learning.Model`.
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=central_test_data.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

عملية تكرارية مخصصة

في كثير من الحالات ، تحتوي الخوارزميات الموحدة على 4 مكونات رئيسية:

  1. خطوة بث من خادم إلى عميل.
  2. خطوة تحديث العميل المحلي.
  3. خطوة تحميل من عميل إلى خادم.
  4. خطوة تحديث الخادم.

في TFF، التي نمثلها عموما خوارزميات الاتحادية باعتبارها tff.templates.IterativeProcess (والذي نطلق عليه مجرد IterativeProcess طوال). هذا هو الفئة التي تحتوي initialize و next الوظائف. هنا، initialize يستخدم لتهيئة الخادم، و next سوف تؤدي جولة اتصال واحد من الخوارزمية الاتحادية.

سنقدم مكونات مختلفة لبناء خوارزمية المتوسط ​​الموحد (FedAvg) ، والتي ستستخدم مُحسِّنًا في خطوة تحديث العميل ، ومحسنًا آخر في خطوة تحديث الخادم. يمكن التعبير عن المنطق الأساسي لتحديثات العميل والخادم ككتل TF خالصة.

كتل TF: تحديث العميل والخادم

على كل عميل، محلية client_optimizer تتم تهيئة واستخدامها في تحديث الأوزان نموذج العميل. على الخادم، server_optimizer سوف تستخدم الدولة من الجولة السابقة، وتحديث الدولة للجولة القادمة.

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs local training on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)
  # Initialize the client optimizer.
  trainable_tensor_specs = tf.nest.map_structure(
          lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
  optimizer_state = client_optimizer.initialize(trainable_tensor_specs)
  # Use the client_optimizer to update the local model.
  for batch in iter(dataset):
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data.
      outputs = model.forward_pass(batch)
    # Compute the corresponding gradient.
    grads = tape.gradient(outputs.loss, client_weights)
    # Apply the gradient using a client optimizer.
    optimizer_state, updated_weights = client_optimizer.next(
        optimizer_state, client_weights, grads)
    tf.nest.map_structure(lambda a, b: a.assign(b), 
                          client_weights, updated_weights)
  # Return model deltas.
  return tf.nest.map_structure(tf.subtract, client_weights, server_weights)
@attr.s(eq=False, frozen=True, slots=True)
class ServerState(object):
  trainable_weights = attr.ib()
  optimizer_state = attr.ib()

@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
  """Updates the server model weights."""
  # Use aggregated negative model delta as pseudo gradient. 
  negative_weights_delta = tf.nest.map_structure(
      lambda w: -1.0 * w, mean_model_delta)
  new_optimizer_state, updated_weights = server_optimizer.next(
      server_state.optimizer_state, server_state.trainable_weights, 
      negative_weights_delta)
  return tff.structure.update_struct(
      server_state,
      trainable_weights=updated_weights,
      optimizer_state=new_optimizer_state)

كتل TFF: tff.tf_computation و tff.federated_computation

نستخدم الآن TFF للتنسيق وبناء العملية التكرارية لـ FedAvg. علينا أن التفاف الكتل TF المحددة أعلاه مع tff.tf_computation ، واستخدام أساليب TFF tff.federated_broadcast ، tff.federated_map ، tff.federated_mean في tff.federated_computation وظيفة. فمن السهل أن استخدام tff.learning.optimizers.Optimizer واجهات برمجة التطبيقات مع initialize و next وظائف عند تعريف عملية تكرارية المخصصة.

# 1. Server and client optimizer to be used.
server_optimizer = tff.learning.optimizers.build_sgdm(
    learning_rate=0.05, momentum=0.9)
client_optimizer = tff.learning.optimizers.build_sgdm(
    learning_rate=0.01)

# 2. Functions return initial state on server. 
@tff.tf_computation
def server_init():
  model = model_fn()
  trainable_tensor_specs = tf.nest.map_structure(
        lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
  optimizer_state = server_optimizer.initialize(trainable_tensor_specs)
  return ServerState(
      trainable_weights=model.trainable_variables,
      optimizer_state=optimizer_state)

@tff.federated_computation
def server_init_tff():
  return tff.federated_value(server_init(), tff.SERVER)

# 3. One round of computation and communication.
server_state_type = server_init.type_signature.result
print('server_state_type:\n', 
      server_state_type.formatted_representation())
trainable_weights_type = server_state_type.trainable_weights
print('trainable_weights_type:\n', 
      trainable_weights_type.formatted_representation())

# 3-1. Wrap server and client TF blocks with `tff.tf_computation`.
@tff.tf_computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
  return server_update(server_state, model_delta, server_optimizer)

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
print('tf_dataset_type:\n', 
      tf_dataset_type.formatted_representation())
@tff.tf_computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
  model = model_fn()
  return client_update(model, dataset, server_weights, client_optimizer)

# 3-2. Orchestration with `tff.federated_computation`.
federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
  # Server-to-client broadcast.
  server_weights_at_client = tff.federated_broadcast(
      server_state.trainable_weights)
  # Local client update.
  model_deltas = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))
  # Client-to-server upload and aggregation.
  mean_model_delta = tff.federated_mean(model_deltas)
  # Server update.
  server_state = tff.federated_map(
      server_update_fn, (server_state, mean_model_delta))
  return server_state

# 4. Build the iterative process for FedAvg.
fedavg_process = tff.templates.IterativeProcess(
    initialize_fn=server_init_tff, next_fn=run_one_round)
print('type signature of `initialize`:\n', 
      fedavg_process.initialize.type_signature.formatted_representation())
print('type signature of `next`:\n', 
      fedavg_process.next.type_signature.formatted_representation())
server_state_type:
 <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>
trainable_weights_type:
 <
  float32[5,5,1,32],
  float32[32],
  float32[5,5,32,64],
  float32[64],
  float32[3136,512],
  float32[512],
  float32[512,10],
  float32[10]
>
tf_dataset_type:
 <
  float32[?,28,28,1],
  int32[?]
>*
type signature of `initialize`:
 ( -> <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>@SERVER)
type signature of `next`:
 (<
  server_state=<
    trainable_weights=<
      float32[5,5,1,32],
      float32[32],
      float32[5,5,32,64],
      float32[64],
      float32[3136,512],
      float32[512],
      float32[512,10],
      float32[10]
    >,
    optimizer_state=<
      float32[5,5,1,32],
      float32[32],
      float32[5,5,32,64],
      float32[64],
      float32[3136,512],
      float32[512],
      float32[512,10],
      float32[10]
    >
  >@SERVER,
  federated_dataset={<
    float32[?,28,28,1],
    int32[?]
  >*}@CLIENTS
> -> <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>@SERVER)

تقييم الخوارزمية

نقوم بتقييم الأداء على مجموعة بيانات تقييم مركزية.

def evaluate(server_state):
  keras_model = create_keras_model()
  tf.nest.map_structure(
      lambda var, t: var.assign(t),
      keras_model.trainable_weights, server_state.trainable_weights)
  metric = tf.keras.metrics.SparseCategoricalAccuracy()
  for batch in iter(central_test_data):
    preds = keras_model(batch[0], training=False)
    metric.update_state(y_true=batch[1], y_pred=preds)
  return metric.result().numpy()
server_state = fedavg_process.initialize()
acc = evaluate(server_state)
print('Initial test accuracy', acc)

# Evaluate after a few rounds
CLIENTS_PER_ROUND=2
sampled_clients = train_client_ids[:CLIENTS_PER_ROUND]
sampled_train_data = [
    train_data.create_tf_dataset_for_client(client)
    for client in sampled_clients]
for round in range(20):
  server_state = fedavg_process.next(server_state, sampled_train_data)
acc = evaluate(server_state)
print('Test accuracy', acc)
Initial test accuracy 0.09677419
Test accuracy 0.13978495