カスタム反復プロセスでTFFオプティマイザーを使用する

これは、の代替であるビルド自分の連合学習アルゴリズムのチュートリアルとsimple_fedavg用のカスタム反復プロセス構築する例連合平均化アルゴリズムを。このチュートリアルでは、使用するTFFオプティマイザの代わりKerasオプティマイザを。 TFFオプティマイザーの抽象化は、TFF反復プロセスに簡単に組み込むことができるように、state-in-state-outになるように設計されています。 tff.learning APIも入力引数としてTFFオプティマイザを受け入れます。

始める前に

開始する前に、以下を実行して、環境が正しくセットアップされていることを確認してください。あなたが挨拶が表示されない場合は、を参照してください。インストールの手順についてのガイド。

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio
.apply()
import functools
import attr
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

データとモデルの準備

EMNISTデータ処理及びモデルは非常に類似しているsimple_fedavg例。

only_digits=True

# Load dataset.
emnist_train
, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits)

# Define preprocessing functions.
def preprocess_fn(dataset, batch_size=16):

 
def batch_format_fn(element):
   
return (tf.expand_dims(element['pixels'], -1), element['label'])

 
return dataset.batch(batch_size).map(batch_format_fn)

# Preprocess and sample clients for prototyping.
train_client_ids
= sorted(emnist_train.client_ids)
train_data
= emnist_train.preprocess(preprocess_fn)
central_test_data
= preprocess_fn(
    emnist_train
.create_tf_dataset_for_client(train_client_ids[0]))

# Define model.
def create_keras_model():
 
"""The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format
= 'channels_last'
  input_shape
= [28, 28, 1]

  max_pool
= functools.partial(
      tf
.keras.layers.MaxPooling2D,
      pool_size
=(2, 2),
      padding
='same',
      data_format
=data_format)
  conv2d
= functools.partial(
      tf
.keras.layers.Conv2D,
      kernel_size
=5,
      padding
='same',
      data_format
=data_format,
      activation
=tf.nn.relu)

  model
= tf.keras.models.Sequential([
      conv2d
(filters=32, input_shape=input_shape),
      max_pool
(),
      conv2d
(filters=64),
      max_pool
(),
      tf
.keras.layers.Flatten(),
      tf
.keras.layers.Dense(512, activation=tf.nn.relu),
      tf
.keras.layers.Dense(10 if only_digits else 62),
 
])

 
return model

# Wrap as `tff.learning.Model`.
def model_fn():
  keras_model
= create_keras_model()
 
return tff.learning.from_keras_model(
      keras_model
,
      input_spec
=central_test_data.element_spec,
      loss
=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

カスタム反復プロセス

多くの場合、フェデレーションアルゴリズムには4つの主要なコンポーネントがあります。

  1. サーバーからクライアントへのブロードキャストステップ。
  2. ローカルクライアントの更新手順。
  3. クライアントからサーバーへのアップロード手順。
  4. サーバーの更新手順。

TFFにおいて、我々は、一般的に、フェデレーテッド・アルゴリズム表すtff.templates.IterativeProcess (先ほどと呼ぶIterativeProcess全体を)。これが含まれているクラスでinitializeし、 next機能を。ここでは、 initializeサーバーを初期化するために使用され、そしてnext連合アルゴリズムの一つの通信ラウンドを実行します。

さまざまなコンポーネントを導入して、フェデレーション平均(FedAvg)アルゴリズムを構築します。このアルゴリズムは、クライアントの更新ステップでオプティマイザーを使用し、サーバーの更新ステップで別のオプティマイザーを使用します。クライアントとサーバーの更新のコアロジックは、純粋なTFブロックとして表現できます。

TFブロック:クライアントとサーバーの更新

各クライアントでは、ローカルclient_optimizer初期化され、クライアントモデルの重みを更新するために使用されます。サーバーでは、 server_optimizer前のラウンドからの状態を使用し、次のラウンドのために状態を更新します。

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
 
"""Performs local training on the client's dataset."""
 
# Initialize the client model with the current server weights.
  client_weights
= model.trainable_variables
 
# Assign the server weights to the client model.
  tf
.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights
, server_weights)
 
# Initialize the client optimizer.
  trainable_tensor_specs
= tf.nest.map_structure(
         
lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
  optimizer_state
= client_optimizer.initialize(trainable_tensor_specs)
 
# Use the client_optimizer to update the local model.
 
for batch in iter(dataset):
   
with tf.GradientTape() as tape:
     
# Compute a forward pass on the batch of data.
      outputs
= model.forward_pass(batch)
   
# Compute the corresponding gradient.
    grads
= tape.gradient(outputs.loss, client_weights)
   
# Apply the gradient using a client optimizer.
    optimizer_state
, updated_weights = client_optimizer.next(
        optimizer_state
, client_weights, grads)
    tf
.nest.map_structure(lambda a, b: a.assign(b),
                          client_weights
, updated_weights)
 
# Return model deltas.
 
return tf.nest.map_structure(tf.subtract, client_weights, server_weights)
@attr.s(eq=False, frozen=True, slots=True)
class ServerState(object):
  trainable_weights
= attr.ib()
  optimizer_state
= attr.ib()

@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
 
"""Updates the server model weights."""
 
# Use aggregated negative model delta as pseudo gradient.
  negative_weights_delta
= tf.nest.map_structure(
     
lambda w: -1.0 * w, mean_model_delta)
  new_optimizer_state
, updated_weights = server_optimizer.next(
      server_state
.optimizer_state, server_state.trainable_weights,
      negative_weights_delta
)
 
return tff.structure.update_struct(
      server_state
,
      trainable_weights
=updated_weights,
      optimizer_state
=new_optimizer_state)

TFFブロック: tff.tf_computationtff.federated_computation

現在、オーケストレーションにTFFを使用し、FedAvgの反復プロセスを構築しています。私たちは、と上で定義されTFブロックラップする必要がありtff.tf_computation 、および使用のTFF法tff.federated_broadcasttff.federated_maptff.federated_meanしてtff.federated_computation機能。使いやすいですtff.learning.optimizers.OptimizerでAPIをinitializeし、 nextカスタム反復プロセスを定義する際に機能しています。

# 1. Server and client optimizer to be used.
server_optimizer
= tff.learning.optimizers.build_sgdm(
    learning_rate
=0.05, momentum=0.9)
client_optimizer
= tff.learning.optimizers.build_sgdm(
    learning_rate
=0.01)

# 2. Functions return initial state on server.
@tff.tf_computation
def server_init():
  model
= model_fn()
  trainable_tensor_specs
= tf.nest.map_structure(
       
lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
  optimizer_state
= server_optimizer.initialize(trainable_tensor_specs)
 
return ServerState(
      trainable_weights
=model.trainable_variables,
      optimizer_state
=optimizer_state)

@tff.federated_computation
def server_init_tff():
 
return tff.federated_value(server_init(), tff.SERVER)

# 3. One round of computation and communication.
server_state_type
= server_init.type_signature.result
print('server_state_type:\n',
      server_state_type
.formatted_representation())
trainable_weights_type
= server_state_type.trainable_weights
print('trainable_weights_type:\n',
      trainable_weights_type
.formatted_representation())

# 3-1. Wrap server and client TF blocks with `tff.tf_computation`.
@tff.tf_computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
 
return server_update(server_state, model_delta, server_optimizer)

whimsy_model
= model_fn()
tf_dataset_type
= tff.SequenceType(whimsy_model.input_spec)
print('tf_dataset_type:\n',
      tf_dataset_type
.formatted_representation())
@tff.tf_computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
  model
= model_fn()
 
return client_update(model, dataset, server_weights, client_optimizer)

# 3-2. Orchestration with `tff.federated_computation`.
federated_server_type
= tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type
= tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
 
# Server-to-client broadcast.
  server_weights_at_client
= tff.federated_broadcast(
      server_state
.trainable_weights)
 
# Local client update.
  model_deltas
= tff.federated_map(
      client_update_fn
, (federated_dataset, server_weights_at_client))
 
# Client-to-server upload and aggregation.
  mean_model_delta
= tff.federated_mean(model_deltas)
 
# Server update.
  server_state
= tff.federated_map(
      server_update_fn
, (server_state, mean_model_delta))
 
return server_state

# 4. Build the iterative process for FedAvg.
fedavg_process
= tff.templates.IterativeProcess(
    initialize_fn
=server_init_tff, next_fn=run_one_round)
print('type signature of `initialize`:\n',
      fedavg_process
.initialize.type_signature.formatted_representation())
print('type signature of `next`:\n',
      fedavg_process
.next.type_signature.formatted_representation())
server_state_type:
 <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>
trainable_weights_type:
 <
  float32[5,5,1,32],
  float32[32],
  float32[5,5,32,64],
  float32[64],
  float32[3136,512],
  float32[512],
  float32[512,10],
  float32[10]
>
tf_dataset_type:
 <
  float32[?,28,28,1],
  int32[?]
>*
type signature of `initialize`:
 ( -> <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>@SERVER)
type signature of `next`:
 (<
  server_state=<
    trainable_weights=<
      float32[5,5,1,32],
      float32[32],
      float32[5,5,32,64],
      float32[64],
      float32[3136,512],
      float32[512],
      float32[512,10],
      float32[10]
    >,
    optimizer_state=<
      float32[5,5,1,32],
      float32[32],
      float32[5,5,32,64],
      float32[64],
      float32[3136,512],
      float32[512],
      float32[512,10],
      float32[10]
    >
  >@SERVER,
  federated_dataset={<
    float32[?,28,28,1],
    int32[?]
  >*}@CLIENTS
> -> <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>@SERVER)

アルゴリズムの評価

一元化された評価データセットでパフォーマンスを評価します。

def evaluate(server_state):
  keras_model
= create_keras_model()
  tf
.nest.map_structure(
     
lambda var, t: var.assign(t),
      keras_model
.trainable_weights, server_state.trainable_weights)
  metric
= tf.keras.metrics.SparseCategoricalAccuracy()
 
for batch in iter(central_test_data):
    preds
= keras_model(batch[0], training=False)
    metric
.update_state(y_true=batch[1], y_pred=preds)
 
return metric.result().numpy()
server_state = fedavg_process.initialize()
acc
= evaluate(server_state)
print('Initial test accuracy', acc)

# Evaluate after a few rounds
CLIENTS_PER_ROUND
=2
sampled_clients
= train_client_ids[:CLIENTS_PER_ROUND]
sampled_train_data
= [
    train_data
.create_tf_dataset_for_client(client)
   
for client in sampled_clients]
for round in range(20):
  server_state
= fedavg_process.next(server_state, sampled_train_data)
acc
= evaluate(server_state)
print('Test accuracy', acc)
Initial test accuracy 0.09677419
Test accuracy 0.13978495