TensorFlow.orgで表示 | GoogleColabで実行 | GitHubでソースを表示 | ノートブックをダウンロード |
これは、の代替であるビルド自分の連合学習アルゴリズムのチュートリアルとsimple_fedavg用のカスタム反復プロセス構築する例連合平均化アルゴリズムを。このチュートリアルでは、使用するTFFオプティマイザの代わりKerasオプティマイザを。 TFFオプティマイザーの抽象化は、TFF反復プロセスに簡単に組み込むことができるように、state-in-state-outになるように設計されています。 tff.learning
APIも入力引数としてTFFオプティマイザを受け入れます。
始める前に
開始する前に、以下を実行して、環境が正しくセットアップされていることを確認してください。あなたが挨拶が表示されない場合は、を参照してください。インストールの手順についてのガイド。
!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
import nest_asyncio
nest_asyncio.apply()
import functools
import attr
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
データとモデルの準備
EMNISTデータ処理及びモデルは非常に類似しているsimple_fedavg例。
only_digits=True
# Load dataset.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits)
# Define preprocessing functions.
def preprocess_fn(dataset, batch_size=16):
def batch_format_fn(element):
return (tf.expand_dims(element['pixels'], -1), element['label'])
return dataset.batch(batch_size).map(batch_format_fn)
# Preprocess and sample clients for prototyping.
train_client_ids = sorted(emnist_train.client_ids)
train_data = emnist_train.preprocess(preprocess_fn)
central_test_data = preprocess_fn(
emnist_train.create_tf_dataset_for_client(train_client_ids[0]))
# Define model.
def create_keras_model():
"""The CNN model used in https://arxiv.org/abs/1602.05629."""
data_format = 'channels_last'
input_shape = [28, 28, 1]
max_pool = functools.partial(
tf.keras.layers.MaxPooling2D,
pool_size=(2, 2),
padding='same',
data_format=data_format)
conv2d = functools.partial(
tf.keras.layers.Conv2D,
kernel_size=5,
padding='same',
data_format=data_format,
activation=tf.nn.relu)
model = tf.keras.models.Sequential([
conv2d(filters=32, input_shape=input_shape),
max_pool(),
conv2d(filters=64),
max_pool(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10 if only_digits else 62),
])
return model
# Wrap as `tff.learning.Model`.
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=central_test_data.element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
カスタム反復プロセス
多くの場合、フェデレーションアルゴリズムには4つの主要なコンポーネントがあります。
- サーバーからクライアントへのブロードキャストステップ。
- ローカルクライアントの更新手順。
- クライアントからサーバーへのアップロード手順。
- サーバーの更新手順。
TFFにおいて、我々は、一般的に、フェデレーテッド・アルゴリズム表すtff.templates.IterativeProcess
(先ほどと呼ぶIterativeProcess
全体を)。これが含まれているクラスでinitialize
し、 next
機能を。ここでは、 initialize
サーバーを初期化するために使用され、そしてnext
連合アルゴリズムの一つの通信ラウンドを実行します。
さまざまなコンポーネントを導入して、フェデレーション平均(FedAvg)アルゴリズムを構築します。このアルゴリズムは、クライアントの更新ステップでオプティマイザーを使用し、サーバーの更新ステップで別のオプティマイザーを使用します。クライアントとサーバーの更新のコアロジックは、純粋なTFブロックとして表現できます。
TFブロック:クライアントとサーバーの更新
各クライアントでは、ローカルclient_optimizer
初期化され、クライアントモデルの重みを更新するために使用されます。サーバーでは、 server_optimizer
前のラウンドからの状態を使用し、次のラウンドのために状態を更新します。
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
"""Performs local training on the client's dataset."""
# Initialize the client model with the current server weights.
client_weights = model.trainable_variables
# Assign the server weights to the client model.
tf.nest.map_structure(lambda x, y: x.assign(y),
client_weights, server_weights)
# Initialize the client optimizer.
trainable_tensor_specs = tf.nest.map_structure(
lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
optimizer_state = client_optimizer.initialize(trainable_tensor_specs)
# Use the client_optimizer to update the local model.
for batch in iter(dataset):
with tf.GradientTape() as tape:
# Compute a forward pass on the batch of data.
outputs = model.forward_pass(batch)
# Compute the corresponding gradient.
grads = tape.gradient(outputs.loss, client_weights)
# Apply the gradient using a client optimizer.
optimizer_state, updated_weights = client_optimizer.next(
optimizer_state, client_weights, grads)
tf.nest.map_structure(lambda a, b: a.assign(b),
client_weights, updated_weights)
# Return model deltas.
return tf.nest.map_structure(tf.subtract, client_weights, server_weights)
@attr.s(eq=False, frozen=True, slots=True)
class ServerState(object):
trainable_weights = attr.ib()
optimizer_state = attr.ib()
@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
"""Updates the server model weights."""
# Use aggregated negative model delta as pseudo gradient.
negative_weights_delta = tf.nest.map_structure(
lambda w: -1.0 * w, mean_model_delta)
new_optimizer_state, updated_weights = server_optimizer.next(
server_state.optimizer_state, server_state.trainable_weights,
negative_weights_delta)
return tff.structure.update_struct(
server_state,
trainable_weights=updated_weights,
optimizer_state=new_optimizer_state)
TFFブロック: tff.tf_computation
とtff.federated_computation
現在、オーケストレーションにTFFを使用し、FedAvgの反復プロセスを構築しています。私たちは、と上で定義されTFブロックラップする必要がありtff.tf_computation
、および使用のTFF法tff.federated_broadcast
、 tff.federated_map
、 tff.federated_mean
してtff.federated_computation
機能。使いやすいですtff.learning.optimizers.Optimizer
でAPIをinitialize
し、 next
カスタム反復プロセスを定義する際に機能しています。
# 1. Server and client optimizer to be used.
server_optimizer = tff.learning.optimizers.build_sgdm(
learning_rate=0.05, momentum=0.9)
client_optimizer = tff.learning.optimizers.build_sgdm(
learning_rate=0.01)
# 2. Functions return initial state on server.
@tff.tf_computation
def server_init():
model = model_fn()
trainable_tensor_specs = tf.nest.map_structure(
lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
optimizer_state = server_optimizer.initialize(trainable_tensor_specs)
return ServerState(
trainable_weights=model.trainable_variables,
optimizer_state=optimizer_state)
@tff.federated_computation
def server_init_tff():
return tff.federated_value(server_init(), tff.SERVER)
# 3. One round of computation and communication.
server_state_type = server_init.type_signature.result
print('server_state_type:\n',
server_state_type.formatted_representation())
trainable_weights_type = server_state_type.trainable_weights
print('trainable_weights_type:\n',
trainable_weights_type.formatted_representation())
# 3-1. Wrap server and client TF blocks with `tff.tf_computation`.
@tff.tf_computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
return server_update(server_state, model_delta, server_optimizer)
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
print('tf_dataset_type:\n',
tf_dataset_type.formatted_representation())
@tff.tf_computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
model = model_fn()
return client_update(model, dataset, server_weights, client_optimizer)
# 3-2. Orchestration with `tff.federated_computation`.
federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
# Server-to-client broadcast.
server_weights_at_client = tff.federated_broadcast(
server_state.trainable_weights)
# Local client update.
model_deltas = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
# Client-to-server upload and aggregation.
mean_model_delta = tff.federated_mean(model_deltas)
# Server update.
server_state = tff.federated_map(
server_update_fn, (server_state, mean_model_delta))
return server_state
# 4. Build the iterative process for FedAvg.
fedavg_process = tff.templates.IterativeProcess(
initialize_fn=server_init_tff, next_fn=run_one_round)
print('type signature of `initialize`:\n',
fedavg_process.initialize.type_signature.formatted_representation())
print('type signature of `next`:\n',
fedavg_process.next.type_signature.formatted_representation())
server_state_type: < trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > > trainable_weights_type: < float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > tf_dataset_type: < float32[?,28,28,1], int32[?] >* type signature of `initialize`: ( -> < trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > >@SERVER) type signature of `next`: (< server_state=< trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > >@SERVER, federated_dataset={< float32[?,28,28,1], int32[?] >*}@CLIENTS > -> < trainable_weights=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] >, optimizer_state=< float32[5,5,1,32], float32[32], float32[5,5,32,64], float32[64], float32[3136,512], float32[512], float32[512,10], float32[10] > >@SERVER)
アルゴリズムの評価
一元化された評価データセットでパフォーマンスを評価します。
def evaluate(server_state):
keras_model = create_keras_model()
tf.nest.map_structure(
lambda var, t: var.assign(t),
keras_model.trainable_weights, server_state.trainable_weights)
metric = tf.keras.metrics.SparseCategoricalAccuracy()
for batch in iter(central_test_data):
preds = keras_model(batch[0], training=False)
metric.update_state(y_true=batch[1], y_pred=preds)
return metric.result().numpy()
server_state = fedavg_process.initialize()
acc = evaluate(server_state)
print('Initial test accuracy', acc)
# Evaluate after a few rounds
CLIENTS_PER_ROUND=2
sampled_clients = train_client_ids[:CLIENTS_PER_ROUND]
sampled_train_data = [
train_data.create_tf_dataset_for_client(client)
for client in sampled_clients]
for round in range(20):
server_state = fedavg_process.next(server_state, sampled_train_data)
acc = evaluate(server_state)
print('Test accuracy', acc)
Initial test accuracy 0.09677419 Test accuracy 0.13978495