このチュートリアルでは、アクセラレータを使用してTFFシミュレーションを設定する方法について説明します。ここでは、シングルマシン(マルチ)GPUに焦点を当て、このチュートリアルをマルチマシンとTPUの設定で更新します。
TensorFlow.orgで表示 | GoogleColabで実行 | GitHubでソースを表示 | ノートブックをダウンロード |
始める前に
まず、関連するコンポーネントがコンパイルされたバックエンドにノートブックが接続されていることを確認しましょう。
!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio
!pip install -U tensorboard_plugin_profile
import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
import collections
import time
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
TFが物理GPUを検出し、TFFGPUシミュレーション用の仮想マルチGPU環境を作成できるかどうかを確認します。 2つの仮想GPUのメモリは限られており、TFFランタイムの構成方法を示しています。
gpu_devices = tf.config.list_physical_devices('GPU')
if not gpu_devices:
raise ValueError('Cannot detect physical GPU device in TF')
tf.config.set_logical_device_configuration(
gpu_devices[0],
[tf.config.LogicalDeviceConfiguration(memory_limit=1024),
tf.config.LogicalDeviceConfiguration(memory_limit=1024)])
tf.config.list_logical_devices()
[LogicalDevice(name='/device:CPU:0', device_type='CPU'), LogicalDevice(name='/device:GPU:0', device_type='GPU'), LogicalDevice(name='/device:GPU:1', device_type='GPU')]
次の「HelloWorld」の例を実行して、TFF環境が正しくセットアップされていることを確認します。それが動作しない場合は、を参照してください。インストールの手順についてのガイド。
@tff.federated_computation
def hello_world():
return 'Hello, World!'
hello_world()
b'Hello, World!'
EMNIST実験セットアップ
このチュートリアルでは、FederatedAveragingアルゴリズムを使用してEMNIST画像分類器をトレーニングします。 TFFWebサイトからMNISTの例をロードすることから始めましょう。
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits=True)
我々は、次のEMNIST例を前処理機能定義simple_fedavg例を。引数のことを注意client_epochs_per_round
連合学習におけるクライアントのローカルエポックの数を制御します。
def preprocess_emnist_dataset(client_epochs_per_round, batch_size, test_batch_size):
def element_fn(element):
return collections.OrderedDict(
x=tf.expand_dims(element['pixels'], -1), y=element['label'])
def preprocess_train_dataset(dataset):
# Use buffer_size same as the maximum client dataset size,
# 418 for Federated EMNIST
return dataset.map(element_fn).shuffle(buffer_size=418).repeat(
count=client_epochs_per_round).batch(batch_size, drop_remainder=False)
def preprocess_test_dataset(dataset):
return dataset.map(element_fn).batch(test_batch_size, drop_remainder=False)
train_set = emnist_train.preprocess(preprocess_train_dataset)
test_set = preprocess_test_dataset(
emnist_test.create_tf_dataset_from_all_clients())
return train_set, test_set
VGGのようなモデルを使用します。つまり、各ブロックには2つの3x3畳み込みがあり、特徴マップがサブサンプリングされると、フィルターの数が2倍になります。
def _conv_3x3(input_tensor, filters, strides):
"""2D Convolutional layer with kernel size 3x3."""
x = tf.keras.layers.Conv2D(
filters=filters,
strides=strides,
kernel_size=3,
padding='same',
kernel_initializer='he_normal',
use_bias=False,
)(input_tensor)
return x
def _basic_block(input_tensor, filters, strides):
"""A block of two 3x3 conv layers."""
x = input_tensor
x = _conv_3x3(x, filters, strides)
x = tf.keras.layers.Activation('relu')(x)
x = _conv_3x3(x, filters, 1)
x = tf.keras.layers.Activation('relu')(x)
return x
def _vgg_block(input_tensor, size, filters, strides):
"""A stack of basic blocks."""
x = _basic_block(input_tensor, filters, strides=strides)
for _ in range(size - 1):
x = _basic_block(x, filters, strides=1)
return x
def create_cnn(num_blocks, conv_width_multiplier=1, num_classes=10):
"""Create a VGG-like CNN model.
The CNN has (6*num_blocks + 2) layers.
"""
input_shape = (28, 28, 1) # channels_last
img_input = tf.keras.layers.Input(shape=input_shape)
x = img_input
x = tf.image.per_image_standardization(x)
x = _conv_3x3(x, 16 * conv_width_multiplier, 1)
x = _vgg_block(x, size=num_blocks, filters=16 * conv_width_multiplier, strides=1)
x = _vgg_block(x, size=num_blocks, filters=32 * conv_width_multiplier, strides=2)
x = _vgg_block(x, size=num_blocks, filters=64 * conv_width_multiplier, strides=2)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(num_classes)(x)
model = tf.keras.models.Model(
img_input,
x,
name='cnn-{}-{}'.format(6 * num_blocks + 2, conv_width_multiplier))
return model
次に、EMNISTのトレーニングループを定義しましょう。なお、 use_experimental_simulation_loop=True
でtff.learning.build_federated_averaging_process
パフォーマンスTFFシミュレーションのために提案し、単一マシン上で複数のGPUを活用するために必要とされます。参照してくださいsimple_fedavg重要な機能の一つが明示的に使用することです、のGPU上で高い性能を有するカスタマイズされた連合学習アルゴリズムを定義する方法については、例をfor ... iter(dataset)
トレーニングforループ。
def keras_evaluate(model, test_data, metric):
metric.reset_states()
for batch in test_data:
preds = model(batch['x'], training=False)
metric.update_state(y_true=batch['y'], y_pred=preds)
return metric.result()
def run_federated_training(client_epochs_per_round,
train_batch_size,
test_batch_size,
cnn_num_blocks,
conv_width_multiplier,
server_learning_rate,
client_learning_rate,
total_rounds,
clients_per_round,
rounds_per_eval,
logdir='logdir'):
train_data, test_data = preprocess_emnist_dataset(
client_epochs_per_round, train_batch_size, test_batch_size)
data_spec = test_data.element_spec
def _model_fn():
keras_model = create_cnn(cnn_num_blocks, conv_width_multiplier)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
return tff.learning.from_keras_model(
keras_model, input_spec=data_spec, loss=loss)
def _server_optimizer_fn():
return tf.keras.optimizers.SGD(learning_rate=server_learning_rate)
def _client_optimizer_fn():
return tf.keras.optimizers.SGD(learning_rate=client_learning_rate)
iterative_process = tff.learning.build_federated_averaging_process(
model_fn=_model_fn,
server_optimizer_fn=_server_optimizer_fn,
client_optimizer_fn=_client_optimizer_fn,
use_experimental_simulation_loop=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
eval_model = create_cnn(cnn_num_blocks, conv_width_multiplier)
logging.info(eval_model.summary())
server_state = iterative_process.initialize()
start_time = time.time()
for round_num in range(total_rounds):
sampled_clients = np.random.choice(
train_data.client_ids,
size=clients_per_round,
replace=False)
sampled_train_data = [
train_data.create_tf_dataset_for_client(client)
for client in sampled_clients
]
if round_num == total_rounds-1:
with tf.profiler.experimental.Profile(logdir):
server_state, train_metrics = iterative_process.next(
server_state, sampled_train_data)
else:
server_state, train_metrics = iterative_process.next(
server_state, sampled_train_data)
print(f'Round {round_num} training loss: {train_metrics["train"]["loss"]}, '
f'time: {(time.time()-start_time)/(round_num+1.)} secs')
if round_num % rounds_per_eval == 0 or round_num == total_rounds-1:
server_state.model.assign_weights_to(eval_model)
accuracy = keras_evaluate(eval_model, test_data, metric)
print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')
シングルGPU実行
TFFのデフォルトのランタイムはTFと同じです。GPUが提供されている場合、最初のGPUが実行用に選択されます。以前に定義したフェデレーショントレーニングを、比較的小さなモデルで数ラウンド実行します。実行の最後のラウンドをしてプロファイリングされtf.profiler
とにより可視化tensorboard
。プロファイリングにより、最初のGPUが使用されていることが確認されました。
run_federated_training(
client_epochs_per_round=1,
train_batch_size=16,
test_batch_size=128,
cnn_num_blocks=2,
conv_width_multiplier=4,
server_learning_rate=1.0,
client_learning_rate=0.01,
total_rounds=10,
clients_per_round=16,
rounds_per_eval=2,
)
Model: "cnn-14-4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ tf.image.per_image_standardi (None, 28, 28, 1) 0 _________________________________________________________________ conv2d (Conv2D) (None, 28, 28, 64) 576 _________________________________________________________________ conv2d_1 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_1 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_2 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_3 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 14, 14, 128) 73728 _________________________________________________________________ activation_4 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_5 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_6 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_8 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_7 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_9 (Conv2D) (None, 7, 7, 256) 294912 _________________________________________________________________ activation_8 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_10 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_9 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_11 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_10 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_12 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_11 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ global_average_pooling2d (Gl (None, 256) 0 _________________________________________________________________ dense (Dense) (None, 10) 2570 ================================================================= Total params: 2,731,082 Trainable params: 2,731,082 Non-trainable params: 0 _________________________________________________________________ Round 0 training loss: 2.4688243865966797, time: 13.382015466690063 secs Round 0 validation accuracy: 15.240497589111328 Round 1 training loss: 2.3217368125915527, time: 9.311999917030334 secs Round 2 training loss: 2.3100595474243164, time: 6.972411632537842 secs Round 2 validation accuracy: 11.226489067077637 Round 3 training loss: 2.303222417831421, time: 6.467299699783325 secs Round 4 training loss: 2.2976326942443848, time: 5.526083135604859 secs Round 4 validation accuracy: 11.224040031433105 Round 5 training loss: 2.2919719219207764, time: 5.468692660331726 secs Round 6 training loss: 2.2911534309387207, time: 4.935825347900391 secs Round 6 validation accuracy: 11.833855628967285 Round 7 training loss: 2.2871201038360596, time: 4.918408691883087 secs Round 8 training loss: 2.2818832397460938, time: 4.602836343977186 secs Round 8 validation accuracy: 11.385677337646484 Round 9 training loss: 2.2790346145629883, time: 4.99558527469635 secs Round 9 validation accuracy: 11.226489067077637
%tensorboard --logdir=logdir --port=0
CPU実行
比較として、CPU実行用にTFFランタイムを構成してみましょう。この比較的小さなモデルでは、CPUの実行がわずかに遅くなります。
cpu_device = tf.config.list_logical_devices('CPU')[0]
tff.backends.native.set_local_python_execution_context(
server_tf_device=cpu_device, client_tf_devices=[cpu_device])
run_federated_training(
client_epochs_per_round=1,
train_batch_size=16,
test_batch_size=128,
cnn_num_blocks=2,
conv_width_multiplier=4,
server_learning_rate=1.0,
client_learning_rate=0.01,
total_rounds=10,
clients_per_round=16,
rounds_per_eval=2,
)
Model: "cnn-14-4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_2 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ tf.image.per_image_standardi (None, 28, 28, 1) 0 _________________________________________________________________ conv2d_13 (Conv2D) (None, 28, 28, 64) 576 _________________________________________________________________ conv2d_14 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_12 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_15 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_13 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_16 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_14 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_17 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_15 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_18 (Conv2D) (None, 14, 14, 128) 73728 _________________________________________________________________ activation_16 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_19 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_17 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_20 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_18 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_21 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_19 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_22 (Conv2D) (None, 7, 7, 256) 294912 _________________________________________________________________ activation_20 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_23 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_21 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_24 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_22 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_25 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_23 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ global_average_pooling2d_1 ( (None, 256) 0 _________________________________________________________________ dense_1 (Dense) (None, 10) 2570 ================================================================= Total params: 2,731,082 Trainable params: 2,731,082 Non-trainable params: 0 _________________________________________________________________ Round 0 training loss: 2.4787657260894775, time: 15.264191627502441 secs Round 0 validation accuracy: 12.191418647766113 Round 1 training loss: 2.3097336292266846, time: 11.785032272338867 secs Round 2 training loss: 2.3062121868133545, time: 9.677561124165853 secs Round 2 validation accuracy: 11.415066719055176 Round 3 training loss: 2.2982261180877686, time: 9.301376760005951 secs Round 4 training loss: 2.2953946590423584, time: 8.377780866622924 secs Round 4 validation accuracy: 20.537813186645508 Round 5 training loss: 2.290337324142456, time: 8.385509928067526 secs Round 6 training loss: 2.2842795848846436, time: 7.809031554630825 secs Round 6 validation accuracy: 11.934267044067383 Round 7 training loss: 2.2752432823181152, time: 7.8433578312397 secs Round 8 training loss: 2.2698657512664795, time: 7.478067080179851 secs Round 8 validation accuracy: 26.16330337524414 Round 9 training loss: 2.2609798908233643, time: 7.632814192771912 secs Round 9 validation accuracy: 23.079936981201172
マルチGPUの実行
マルチGPU実行用にTFFを構成するのは簡単です。クライアントトレーニングはTFFで並列化されていることに注意してください。マルチGPU設定では、クライアントはラウンドロビン方式でマルチGPUに割り当てられます。次の2つのGPU実行は、クライアントトレーニングがシングルGPU設定とマルチGPU設定の両方で並列化され、multiGPU設定に単一の物理GPUから作成された2つの仮想GPUがあるため、シングルGPU実行よりも高速ではありません。
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_python_execution_context(client_tf_devices=gpu_devices)
run_federated_training(
client_epochs_per_round=1,
train_batch_size=16,
test_batch_size=128,
cnn_num_blocks=2,
conv_width_multiplier=4,
server_learning_rate=1.0,
client_learning_rate=0.01,
total_rounds=10,
clients_per_round=16,
rounds_per_eval=2,
logdir='multigpu'
)
Model: "cnn-14-4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_3 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ tf.image.per_image_standardi (None, 28, 28, 1) 0 _________________________________________________________________ conv2d_26 (Conv2D) (None, 28, 28, 64) 576 _________________________________________________________________ conv2d_27 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_24 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_28 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_25 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_29 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_26 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_30 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_27 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_31 (Conv2D) (None, 14, 14, 128) 73728 _________________________________________________________________ activation_28 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_32 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_29 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_33 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_30 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_34 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_31 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_35 (Conv2D) (None, 7, 7, 256) 294912 _________________________________________________________________ activation_32 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_36 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_33 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_37 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_34 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_38 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_35 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ global_average_pooling2d_2 ( (None, 256) 0 _________________________________________________________________ dense_2 (Dense) (None, 10) 2570 ================================================================= Total params: 2,731,082 Trainable params: 2,731,082 Non-trainable params: 0 _________________________________________________________________ Round 0 training loss: 2.911365270614624, time: 12.759389877319336 secs Round 0 validation accuracy: 9.541536331176758 Round 1 training loss: 2.3175694942474365, time: 9.202919721603394 secs Round 2 training loss: 2.311001777648926, time: 6.802880525588989 secs Round 2 validation accuracy: 9.911344528198242 Round 3 training loss: 2.3105244636535645, time: 6.611470937728882 secs Round 4 training loss: 2.3082072734832764, time: 5.678833389282227 secs Round 4 validation accuracy: 10.212578773498535 Round 5 training loss: 2.304673671722412, time: 5.5404335260391235 secs Round 6 training loss: 2.3035168647766113, time: 5.008027451378958 secs Round 6 validation accuracy: 9.935834884643555 Round 7 training loss: 2.3052737712860107, time: 5.1173741817474365 secs Round 8 training loss: 2.3007171154022217, time: 4.745321141348945 secs Round 8 validation accuracy: 10.768514633178711 Round 9 training loss: 2.302018404006958, time: 5.0809732437133786 secs Round 9 validation accuracy: 12.311422348022461
大型モデルとOOM
フェデレーションラウンドが少ないCPUでより大きなモデルを実行してみましょう。
cpu_device = tf.config.list_logical_devices('CPU')[0]
tff.backends.native.set_local_python_execution_context(
server_tf_device=cpu_device, client_tf_devices=[cpu_device])
run_federated_training(
client_epochs_per_round=1,
train_batch_size=16,
test_batch_size=128,
cnn_num_blocks=4,
conv_width_multiplier=4,
server_learning_rate=1.0,
client_learning_rate=0.01,
total_rounds=5,
clients_per_round=16,
rounds_per_eval=2,
)
Model: "cnn-26-4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_4 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ tf.image.per_image_standardi (None, 28, 28, 1) 0 _________________________________________________________________ conv2d_39 (Conv2D) (None, 28, 28, 64) 576 _________________________________________________________________ conv2d_40 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_36 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_41 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_37 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_42 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_38 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_43 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_39 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_44 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_40 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_45 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_41 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_46 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_42 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_47 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_43 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_48 (Conv2D) (None, 14, 14, 128) 73728 _________________________________________________________________ activation_44 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_49 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_45 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_50 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_46 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_51 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_47 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_52 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_48 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_53 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_49 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_54 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_50 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_55 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_51 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_56 (Conv2D) (None, 7, 7, 256) 294912 _________________________________________________________________ activation_52 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_57 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_53 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_58 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_54 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_59 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_55 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_60 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_56 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_61 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_57 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_62 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_58 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_63 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_59 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ global_average_pooling2d_3 ( (None, 256) 0 _________________________________________________________________ dense_3 (Dense) (None, 10) 2570 ================================================================= Total params: 5,827,658 Trainable params: 5,827,658 Non-trainable params: 0 _________________________________________________________________ Round 0 training loss: 2.437223434448242, time: 24.121686458587646 secs Round 0 validation accuracy: 9.024785041809082 Round 1 training loss: 2.3081459999084473, time: 19.48685622215271 secs Round 2 training loss: 2.305305242538452, time: 15.73950457572937 secs Round 2 validation accuracy: 9.791339874267578 Round 3 training loss: 2.303149700164795, time: 15.194068729877472 secs Round 4 training loss: 2.3026506900787354, time: 14.036769819259643 secs Round 4 validation accuracy: 12.193867683410645
このモデルは、単一のGPUでメモリ不足の問題が発生する可能性があります。大規模なCPU実験からGPUシミュレーションへの移行は、GPUのメモリが限られていることが多いため、メモリ使用量によって制約を受ける可能性があります。 OOMの問題を軽減するために、TFFランタイムで調整できるいくつかのパラメーターがあります。
- マルチGPU実行を試してください
- 大きな使っ
clients_per_thread
してtff.backends.native.set_local_python_execution_context
クライアントtraininngの同時実行を制御します。デフォルトのclients_per_thread
1である、とconcurrenctクライアントはおおよそですclients_per_round/clients_per_thread
- ピンは
client_tf_devices
にtff.backends.native.set_local_python_execution_context
CPUに。 - セット
max_fanout
でtff.backends.native.set_local_python_execution_context
よりも大きくなるようにclients_per_round
階層的集約を無効にします。
# Single GPU execution might hit OOM.
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_python_execution_context(client_tf_devices=[gpu_devices[0]])
try:
run_federated_training(
client_epochs_per_round=1,
train_batch_size=16,
test_batch_size=128,
cnn_num_blocks=4,
conv_width_multiplier=4,
server_learning_rate=1.0,
client_learning_rate=0.01,
total_rounds=5,
clients_per_round=16,
rounds_per_eval=2,
)
except ResourceExhaustedError as e:
print(e)
# Control concurrency by `clients_per_thread`.
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_python_execution_context(
client_tf_devices=[gpu_devices[0]], clients_per_thread=2)
run_federated_training(
client_epochs_per_round=1,
train_batch_size=16,
test_batch_size=128,
cnn_num_blocks=4,
conv_width_multiplier=4,
server_learning_rate=1.0,
client_learning_rate=0.01,
total_rounds=5,
clients_per_round=16,
rounds_per_eval=2,
)
Model: "cnn-26-4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ tf.image.per_image_standardi (None, 28, 28, 1) 0 _________________________________________________________________ conv2d (Conv2D) (None, 28, 28, 64) 576 _________________________________________________________________ conv2d_1 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_1 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_2 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_3 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_4 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_5 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_6 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_8 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_7 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_9 (Conv2D) (None, 14, 14, 128) 73728 _________________________________________________________________ activation_8 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_10 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_9 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_11 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_10 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_12 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_11 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_13 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_12 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_14 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_13 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_15 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_14 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_16 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_15 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_17 (Conv2D) (None, 7, 7, 256) 294912 _________________________________________________________________ activation_16 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_18 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_17 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_19 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_18 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_20 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_19 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_21 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_20 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_22 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_21 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_23 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_22 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_24 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_23 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ global_average_pooling2d (Gl (None, 256) 0 _________________________________________________________________ dense (Dense) (None, 10) 2570 ================================================================= Total params: 5,827,658 Trainable params: 5,827,658 Non-trainable params: 0 _________________________________________________________________ Round 0 training loss: 2.4990053176879883, time: 11.922378778457642 secs Round 0 validation accuracy: 11.224040031433105 Round 1 training loss: 2.307560920715332, time: 9.916815996170044 secs Round 2 training loss: 2.3032877445220947, time: 7.68927804629008 secs Round 2 validation accuracy: 11.224040031433105 Round 3 training loss: 2.302366256713867, time: 7.681552231311798 secs Round 4 training loss: 2.301671028137207, time: 7.613566827774048 secs Round 4 validation accuracy: 11.224040031433105
# Multi-GPU execution with configuration to mitigate OOM.
cpu_device = tf.config.list_logical_devices('CPU')[0]
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_python_execution_context(
server_tf_device=cpu_device,
client_tf_devices=gpu_devices,
clients_per_thread=1,
max_fanout=32)
run_federated_training(
client_epochs_per_round=1,
train_batch_size=16,
test_batch_size=128,
cnn_num_blocks=4,
conv_width_multiplier=4,
server_learning_rate=1.0,
client_learning_rate=0.01,
total_rounds=5,
clients_per_round=16,
rounds_per_eval=2,
)
Model: "cnn-26-4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ tf.image.per_image_standardi (None, 28, 28, 1) 0 _________________________________________________________________ conv2d (Conv2D) (None, 28, 28, 64) 576 _________________________________________________________________ conv2d_1 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_1 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_2 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_3 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_4 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_5 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_6 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_8 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_7 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_9 (Conv2D) (None, 14, 14, 128) 73728 _________________________________________________________________ activation_8 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_10 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_9 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_11 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_10 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_12 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_11 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_13 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_12 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_14 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_13 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_15 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_14 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_16 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_15 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_17 (Conv2D) (None, 7, 7, 256) 294912 _________________________________________________________________ activation_16 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_18 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_17 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_19 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_18 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_20 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_19 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_21 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_20 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_22 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_21 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_23 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_22 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_24 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_23 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ global_average_pooling2d (Gl (None, 256) 0 _________________________________________________________________ dense (Dense) (None, 10) 2570 ================================================================= Total params: 5,827,658 Trainable params: 5,827,658 Non-trainable params: 0 _________________________________________________________________ Round 0 training loss: 2.4691953659057617, time: 17.81941556930542 secs Round 0 validation accuracy: 10.817495346069336 Round 1 training loss: 2.3081436157226562, time: 12.986191034317017 secs Round 2 training loss: 2.3028159141540527, time: 9.518500963846842 secs Round 2 validation accuracy: 11.500783920288086 Round 3 training loss: 2.303886651992798, time: 8.989932537078857 secs Round 4 training loss: 2.3030669689178467, time: 8.733866214752197 secs Round 4 validation accuracy: 12.992260932922363
パフォーマンスを最適化する
より良い性能を達成できるTFでの技術は、一般的にTFF、例えば、で使用することができる混合精度のトレーニングとXLA 。混合精度(V100などのGPU上で)高速化とメモリの節約は、しばしばによって検査することができた、重要であり得るtf.profiler
。
# Mixed precision training.
cpu_device = tf.config.list_logical_devices('CPU')[0]
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_python_execution_context(
server_tf_device=cpu_device,
client_tf_devices=gpu_devices,
clients_per_thread=1,
max_fanout=32)
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
run_federated_training(
client_epochs_per_round=1,
train_batch_size=16,
test_batch_size=128,
cnn_num_blocks=4,
conv_width_multiplier=4,
server_learning_rate=1.0,
client_learning_rate=0.01,
total_rounds=5,
clients_per_round=16,
rounds_per_eval=2,
logdir='mixed'
)
Model: "cnn-26-4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ tf.image.per_image_standardi (None, 28, 28, 1) 0 _________________________________________________________________ conv2d (Conv2D) (None, 28, 28, 64) 576 _________________________________________________________________ conv2d_1 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_1 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_2 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_3 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_4 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_5 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_6 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_8 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_7 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_9 (Conv2D) (None, 14, 14, 128) 73728 _________________________________________________________________ activation_8 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_10 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_9 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_11 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_10 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_12 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_11 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_13 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_12 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_14 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_13 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_15 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_14 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_16 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_15 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_17 (Conv2D) (None, 7, 7, 256) 294912 _________________________________________________________________ activation_16 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_18 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_17 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_19 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_18 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_20 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_19 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_21 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_20 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_22 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_21 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_23 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_22 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_24 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_23 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ global_average_pooling2d (Gl (None, 256) 0 _________________________________________________________________ dense (Dense) (None, 10) 2570 ================================================================= Total params: 5,827,658 Trainable params: 5,827,658 Non-trainable params: 0 _________________________________________________________________ Round 0 training loss: 2.4187185764312744, time: 18.763780117034912 secs Round 0 validation accuracy: 9.977468490600586 Round 1 training loss: 2.305102825164795, time: 13.712820529937744 secs Round 2 training loss: 2.304737091064453, time: 9.993690172831217 secs Round 2 validation accuracy: 11.779976844787598 Round 3 training loss: 2.2996833324432373, time: 9.29404467344284 secs Round 4 training loss: 2.299349308013916, time: 9.195427560806275 secs Round 4 validation accuracy: 11.224040031433105
%tensorboard --logdir=mixed --port=0