This tutorial will describe how to setup TFF simulations with accelerators. We focus on single-machine (multi-)GPU for now and will update this tutorial with multi-machine and TPU settings.
View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Before we begin
First, let us make sure the notebook is connected to a backend that has the relevant components compiled.
pip install --quiet --upgrade tensorflow-federated
pip install -U tensorboard_plugin_profile
%load_ext tensorboard
import collections
import time
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
Check if TF can detect physical GPUs and create a virtual multi-GPU environment for TFF GPU simulations. The two virtual GPUs will have limited memory to demonstrate how to configure TFF runtime.
gpu_devices = tf.config.list_physical_devices('GPU')
if not gpu_devices:
raise ValueError('Cannot detect physical GPU device in TF')
# TODO: b/277213652 - Remove this call, as it doesn't work with C++ executor
tf.config.set_logical_device_configuration(
gpu_devices[0],
[tf.config.LogicalDeviceConfiguration(memory_limit=1024),
tf.config.LogicalDeviceConfiguration(memory_limit=1024)])
tf.config.list_logical_devices()
[LogicalDevice(name='/device:CPU:0', device_type='CPU'), LogicalDevice(name='/device:GPU:0', device_type='GPU'), LogicalDevice(name='/device:GPU:1', device_type='GPU')]
Run the following "Hello World" example to make sure the TFF environment is correctly setup. If it doesn't work, please refer to the Installation guide for instructions.
@tff.federated_computation
def hello_world():
return 'Hello, World!'
hello_world()
b'Hello, World!'
EMNIST experimental setup
In this tutorial, we train an EMNIST image classifier with Federated Averaging algorithm. Let us start by loading the MNIST example from the TFF website.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits=True)
We define a function preprocessing the EMNIST example following the simple_fedavg example. Note that the argument client_epochs_per_round
controls the number of local epochs on clients in federated learning.
def preprocess_emnist_dataset(client_epochs_per_round, batch_size, test_batch_size):
def element_fn(element):
return collections.OrderedDict(
x=tf.expand_dims(element['pixels'], -1), y=element['label'])
def preprocess_train_dataset(dataset):
# Use buffer_size same as the maximum client dataset size,
# 418 for Federated EMNIST
return dataset.map(element_fn).shuffle(buffer_size=418).repeat(
count=client_epochs_per_round).batch(batch_size, drop_remainder=False)
def preprocess_test_dataset(dataset):
return dataset.map(element_fn).batch(test_batch_size, drop_remainder=False)
train_set = emnist_train.preprocess(preprocess_train_dataset)
test_set = preprocess_test_dataset(
emnist_test.create_tf_dataset_from_all_clients())
return train_set, test_set
We use a VGG-like model, i.e., each block has two 3x3 convolutions and number of filters are doubled when the feature maps are subsampled.
def _conv_3x3(input_tensor, filters, strides):
"""2D Convolutional layer with kernel size 3x3."""
x = tf.keras.layers.Conv2D(
filters=filters,
strides=strides,
kernel_size=3,
padding='same',
kernel_initializer='he_normal',
use_bias=False,
)(input_tensor)
return x
def _basic_block(input_tensor, filters, strides):
"""A block of two 3x3 conv layers."""
x = input_tensor
x = _conv_3x3(x, filters, strides)
x = tf.keras.layers.Activation('relu')(x)
x = _conv_3x3(x, filters, 1)
x = tf.keras.layers.Activation('relu')(x)
return x
def _vgg_block(input_tensor, size, filters, strides):
"""A stack of basic blocks."""
x = _basic_block(input_tensor, filters, strides=strides)
for _ in range(size - 1):
x = _basic_block(x, filters, strides=1)
return x
def create_cnn(num_blocks, conv_width_multiplier=1, num_classes=10):
"""Create a VGG-like CNN model.
The CNN has (6*num_blocks + 2) layers.
"""
input_shape = (28, 28, 1) # channels_last
img_input = tf.keras.layers.Input(shape=input_shape)
x = img_input
x = tf.image.per_image_standardization(x)
x = _conv_3x3(x, 16 * conv_width_multiplier, 1)
x = _vgg_block(x, size=num_blocks, filters=16 * conv_width_multiplier, strides=1)
x = _vgg_block(x, size=num_blocks, filters=32 * conv_width_multiplier, strides=2)
x = _vgg_block(x, size=num_blocks, filters=64 * conv_width_multiplier, strides=2)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(num_classes)(x)
model = tf.keras.models.Model(
img_input,
x,
name='cnn-{}-{}'.format(6 * num_blocks + 2, conv_width_multiplier))
return model
Now let us define the training loop for EMNIST. Note that use_experimental_simulation_loop=True
in tff.learning.algorithms.build_weighted_fed_avg
is suggested for performant TFF simulation, and required to take advantage of multi-GPUs on a single machine. See simple_fedavg example for how to define customized federated learning algorithm that has high performance on GPUs, one of the key features is to explicitly use for ... iter(dataset)
for training loops.
def keras_evaluate(model, test_data, metric):
metric.reset_states()
for batch in test_data:
preds = model(batch['x'], training=False)
metric.update_state(y_true=batch['y'], y_pred=preds)
return metric.result()
def run_federated_training(client_epochs_per_round,
train_batch_size,
test_batch_size,
cnn_num_blocks,
conv_width_multiplier,
server_learning_rate,
client_learning_rate,
total_rounds,
clients_per_round,
rounds_per_eval,
logdir='logdir'):
train_data, test_data = preprocess_emnist_dataset(
client_epochs_per_round, train_batch_size, test_batch_size)
data_spec = test_data.element_spec
def _model_fn():
keras_model = create_cnn(cnn_num_blocks, conv_width_multiplier)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
return tff.learning.models.from_keras_model(
keras_model, input_spec=data_spec, loss=loss)
server_optimizer = tff.learning.optimizers.build_sgdm(server_learning_rate)
client_optimizer = tff.learning.optimizers.build_sgdm(client_learning_rate)
learning_process = tff.learning.algorithms.build_weighted_fed_avg(
model_fn=_model_fn,
server_optimizer_fn=server_optimizer,
client_optimizer_fn=client_optimizer,
use_experimental_simulation_loop=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
eval_model = create_cnn(cnn_num_blocks, conv_width_multiplier)
logging.info(eval_model.summary())
server_state = learning_process.initialize()
start_time = time.time()
for round_num in range(total_rounds):
sampled_clients = np.random.choice(
train_data.client_ids,
size=clients_per_round,
replace=False)
sampled_train_data = [
train_data.create_tf_dataset_for_client(client)
for client in sampled_clients
]
if round_num == total_rounds-1:
with tf.profiler.experimental.Profile(logdir):
result = learning_process.next(
server_state, sampled_train_data)
else:
result = learning_process.next(
server_state, sampled_train_data)
server_state = result.state
train_metrics = result.metrics['client_work']['train']
print(f'Round {round_num} training loss: {train_metrics["loss"]}, '
f'time: {(time.time()-start_time)/(round_num+1.)} secs')
if round_num % rounds_per_eval == 0 or round_num == total_rounds-1:
model_weights = learning_process.get_model_weights(server_state)
model_weights.assign_weights_to(eval_model)
accuracy = keras_evaluate(eval_model, test_data, metric)
print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')
Single GPU execution
The default runtime of TFF is the same as TF: when GPUs are provided, the first GPU will be chosen for execution. We run the previously defined federated training for several rounds with a relatively small model. The last round of execution is profiled with tf.profiler
and visualized by tensorboard
. The profiling verified the first GPU is used.
run_federated_training(
client_epochs_per_round=1,
train_batch_size=16,
test_batch_size=128,
cnn_num_blocks=2,
conv_width_multiplier=4,
server_learning_rate=1.0,
client_learning_rate=0.01,
total_rounds=10,
clients_per_round=16,
rounds_per_eval=2,
)
Model: "cnn-14-4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ tf.image.per_image_standardi (None, 28, 28, 1) 0 _________________________________________________________________ conv2d (Conv2D) (None, 28, 28, 64) 576 _________________________________________________________________ conv2d_1 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_1 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_2 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_3 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 14, 14, 128) 73728 _________________________________________________________________ activation_4 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_5 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_6 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_8 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_7 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_9 (Conv2D) (None, 7, 7, 256) 294912 _________________________________________________________________ activation_8 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_10 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_9 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_11 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_10 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_12 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_11 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ global_average_pooling2d (Gl (None, 256) 0 _________________________________________________________________ dense (Dense) (None, 10) 2570 ================================================================= Total params: 2,731,082 Trainable params: 2,731,082 Non-trainable params: 0 _________________________________________________________________ Round 0 training loss: 2.4688243865966797, time: 13.382015466690063 secs Round 0 validation accuracy: 15.240497589111328 Round 1 training loss: 2.3217368125915527, time: 9.311999917030334 secs Round 2 training loss: 2.3100595474243164, time: 6.972411632537842 secs Round 2 validation accuracy: 11.226489067077637 Round 3 training loss: 2.303222417831421, time: 6.467299699783325 secs Round 4 training loss: 2.2976326942443848, time: 5.526083135604859 secs Round 4 validation accuracy: 11.224040031433105 Round 5 training loss: 2.2919719219207764, time: 5.468692660331726 secs Round 6 training loss: 2.2911534309387207, time: 4.935825347900391 secs Round 6 validation accuracy: 11.833855628967285 Round 7 training loss: 2.2871201038360596, time: 4.918408691883087 secs Round 8 training loss: 2.2818832397460938, time: 4.602836343977186 secs Round 8 validation accuracy: 11.385677337646484 Round 9 training loss: 2.2790346145629883, time: 4.99558527469635 secs Round 9 validation accuracy: 11.226489067077637
%tensorboard --logdir=logdir --port=0
Larger model and OOM
Let us run a larger model on CPU with less federated rounds.
run_federated_training(
client_epochs_per_round=1,
train_batch_size=16,
test_batch_size=128,
cnn_num_blocks=4,
conv_width_multiplier=4,
server_learning_rate=1.0,
client_learning_rate=0.01,
total_rounds=5,
clients_per_round=16,
rounds_per_eval=2,
)
Model: "cnn-26-4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_4 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ tf.image.per_image_standardi (None, 28, 28, 1) 0 _________________________________________________________________ conv2d_39 (Conv2D) (None, 28, 28, 64) 576 _________________________________________________________________ conv2d_40 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_36 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_41 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_37 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_42 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_38 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_43 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_39 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_44 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_40 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_45 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_41 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_46 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_42 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_47 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_43 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_48 (Conv2D) (None, 14, 14, 128) 73728 _________________________________________________________________ activation_44 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_49 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_45 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_50 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_46 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_51 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_47 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_52 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_48 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_53 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_49 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_54 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_50 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_55 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_51 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_56 (Conv2D) (None, 7, 7, 256) 294912 _________________________________________________________________ activation_52 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_57 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_53 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_58 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_54 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_59 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_55 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_60 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_56 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_61 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_57 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_62 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_58 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_63 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_59 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ global_average_pooling2d_3 ( (None, 256) 0 _________________________________________________________________ dense_3 (Dense) (None, 10) 2570 ================================================================= Total params: 5,827,658 Trainable params: 5,827,658 Non-trainable params: 0 _________________________________________________________________ Round 0 training loss: 2.437223434448242, time: 24.121686458587646 secs Round 0 validation accuracy: 9.024785041809082 Round 1 training loss: 2.3081459999084473, time: 19.48685622215271 secs Round 2 training loss: 2.305305242538452, time: 15.73950457572937 secs Round 2 validation accuracy: 9.791339874267578 Round 3 training loss: 2.303149700164795, time: 15.194068729877472 secs Round 4 training loss: 2.3026506900787354, time: 14.036769819259643 secs Round 4 validation accuracy: 12.193867683410645
This model might hit an out of memory issue on a single GPU. The migration from large scale CPU experiments to GPU simulation can be constrained by memory usage as GPUs often have limited memeories. There are several parameters can be tuned in TFF runtime to mitigate OOM issue
- Adjust
max_concurrent_computation_calls
intff.backends.native.set_sync_local_cpp_execution_context
to control the concurrency of client traininng.
# Control concurrency by `max_concurrent_computation_calls`.
tff.backends.native.set_sync_local_cpp_execution_context(
max_concurrent_computation_calls=16/2)
run_federated_training(
client_epochs_per_round=1,
train_batch_size=16,
test_batch_size=128,
cnn_num_blocks=4,
conv_width_multiplier=4,
server_learning_rate=1.0,
client_learning_rate=0.01,
total_rounds=5,
clients_per_round=16,
rounds_per_eval=2,
)
Model: "cnn-26-4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ tf.image.per_image_standardi (None, 28, 28, 1) 0 _________________________________________________________________ conv2d (Conv2D) (None, 28, 28, 64) 576 _________________________________________________________________ conv2d_1 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_1 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_2 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_3 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_4 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_5 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_6 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_8 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_7 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_9 (Conv2D) (None, 14, 14, 128) 73728 _________________________________________________________________ activation_8 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_10 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_9 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_11 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_10 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_12 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_11 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_13 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_12 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_14 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_13 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_15 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_14 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_16 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_15 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_17 (Conv2D) (None, 7, 7, 256) 294912 _________________________________________________________________ activation_16 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_18 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_17 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_19 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_18 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_20 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_19 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_21 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_20 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_22 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_21 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_23 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_22 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_24 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_23 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ global_average_pooling2d (Gl (None, 256) 0 _________________________________________________________________ dense (Dense) (None, 10) 2570 ================================================================= Total params: 5,827,658 Trainable params: 5,827,658 Non-trainable params: 0 _________________________________________________________________ Round 0 training loss: 2.4990053176879883, time: 11.922378778457642 secs Round 0 validation accuracy: 11.224040031433105 Round 1 training loss: 2.307560920715332, time: 9.916815996170044 secs Round 2 training loss: 2.3032877445220947, time: 7.68927804629008 secs Round 2 validation accuracy: 11.224040031433105 Round 3 training loss: 2.302366256713867, time: 7.681552231311798 secs Round 4 training loss: 2.301671028137207, time: 7.613566827774048 secs Round 4 validation accuracy: 11.224040031433105
Optimize performance
Techniques in TF that could achieve better performance can generally be used in TFF, e.g., mixed precision training and XLA. The speedup (on GPUs like V100) and memory saving of mixed precision can often be significant, which could be examined by tf.profiler
.
# Mixed precision training.
tff.backends.native.set_sync_local_cpp_execution_context()
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
run_federated_training(
client_epochs_per_round=1,
train_batch_size=16,
test_batch_size=128,
cnn_num_blocks=4,
conv_width_multiplier=4,
server_learning_rate=1.0,
client_learning_rate=0.01,
total_rounds=5,
clients_per_round=16,
rounds_per_eval=2,
logdir='mixed'
)
Model: "cnn-26-4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 28, 28, 1)] 0 _________________________________________________________________ tf.image.per_image_standardi (None, 28, 28, 1) 0 _________________________________________________________________ conv2d (Conv2D) (None, 28, 28, 64) 576 _________________________________________________________________ conv2d_1 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_1 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_2 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_3 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_4 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_5 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_6 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_8 (Conv2D) (None, 28, 28, 64) 36864 _________________________________________________________________ activation_7 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_9 (Conv2D) (None, 14, 14, 128) 73728 _________________________________________________________________ activation_8 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_10 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_9 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_11 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_10 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_12 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_11 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_13 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_12 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_14 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_13 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_15 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_14 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_16 (Conv2D) (None, 14, 14, 128) 147456 _________________________________________________________________ activation_15 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_17 (Conv2D) (None, 7, 7, 256) 294912 _________________________________________________________________ activation_16 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_18 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_17 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_19 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_18 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_20 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_19 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_21 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_20 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_22 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_21 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_23 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_22 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_24 (Conv2D) (None, 7, 7, 256) 589824 _________________________________________________________________ activation_23 (Activation) (None, 7, 7, 256) 0 _________________________________________________________________ global_average_pooling2d (Gl (None, 256) 0 _________________________________________________________________ dense (Dense) (None, 10) 2570 ================================================================= Total params: 5,827,658 Trainable params: 5,827,658 Non-trainable params: 0 _________________________________________________________________ Round 0 training loss: 2.4187185764312744, time: 18.763780117034912 secs Round 0 validation accuracy: 9.977468490600586 Round 1 training loss: 2.305102825164795, time: 13.712820529937744 secs Round 2 training loss: 2.304737091064453, time: 9.993690172831217 secs Round 2 validation accuracy: 11.779976844787598 Round 3 training loss: 2.2996833324432373, time: 9.29404467344284 secs Round 4 training loss: 2.299349308013916, time: 9.195427560806275 secs Round 4 validation accuracy: 11.224040031433105
%tensorboard --logdir=mixed --port=0