محاكاة TFF مع مسرعات

سيصف هذا البرنامج التعليمي كيفية إعداد محاكاة TFF باستخدام مسرعات. نحن نركز على وحدة معالجة الرسومات أحادية الجهاز (متعددة) في الوقت الحالي وسنقوم بتحديث هذا البرنامج التعليمي باستخدام إعدادات متعددة الأجهزة وإعدادات TPU.

قبل أن نبدأ

أولاً ، دعونا نتأكد من توصيل الكمبيوتر الدفتري بالواجهة الخلفية التي تم تجميع المكونات ذات الصلة بها.

!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio
!pip install -U tensorboard_plugin_profile

import nest_asyncio
%load_ext tensorboard
import collections
import time

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

تحقق مما إذا كان بإمكان TF اكتشاف وحدات معالجة الرسومات المادية وإنشاء بيئة افتراضية متعددة معالجات الرسوم لمحاكاة TFF GPU. سيكون لدى وحدتي معالجة الرسومات الظاهرتين ذاكرة محدودة لتوضيح كيفية تكوين وقت تشغيل TFF.

gpu_devices = tf.config.list_physical_devices('GPU')
if not gpu_devices:
  raise ValueError('Cannot detect physical GPU device in TF')
[LogicalDevice(name='/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/device:GPU:0', device_type='GPU'),
 LogicalDevice(name='/device:GPU:1', device_type='GPU')]

قم بتشغيل مثال "Hello World" التالي للتأكد من إعداد بيئة TFF بشكل صحيح. إذا كان لا يعمل، يرجى الرجوع إلى تركيب دليل للتعليمات.

def hello_world():
  return 'Hello, World!'

b'Hello, World!'

الإعداد التجريبي EMNIST

في هذا البرنامج التعليمي ، نقوم بتدريب مصنف الصور EMNIST باستخدام خوارزمية متوسط ​​المعدل. لنبدأ بتحميل مثال MNIST من موقع TFF.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits=True)

نحدد وظيفة تجهيزها المثال EMNIST بعد simple_fedavg سبيل المثال. علما بأن حجة client_epochs_per_round تسيطر على عدد من العهود المحلية على العملاء في التعلم الاتحادية.

def preprocess_emnist_dataset(client_epochs_per_round, batch_size, test_batch_size):

  def element_fn(element):
    return collections.OrderedDict(
        x=tf.expand_dims(element['pixels'], -1), y=element['label'])

  def preprocess_train_dataset(dataset):
    # Use buffer_size same as the maximum client dataset size,
    # 418 for Federated EMNIST
    return dataset.map(element_fn).shuffle(buffer_size=418).repeat(
        count=client_epochs_per_round).batch(batch_size, drop_remainder=False)

  def preprocess_test_dataset(dataset):
    return dataset.map(element_fn).batch(test_batch_size, drop_remainder=False)

  train_set = emnist_train.preprocess(preprocess_train_dataset)
  test_set = preprocess_test_dataset(
  return train_set, test_set

نستخدم نموذجًا مشابهًا لـ VGG ، أي أن كل كتلة بها تلافيفان 3 × 3 ويتضاعف عدد المرشحات عندما يتم تجميع عينات فرعية لخرائط الميزات.

def _conv_3x3(input_tensor, filters, strides):
  """2D Convolutional layer with kernel size 3x3."""

  x = tf.keras.layers.Conv2D(
  return x

def _basic_block(input_tensor, filters, strides):
  """A block of two 3x3 conv layers."""

  x = input_tensor
  x = _conv_3x3(x, filters, strides)
  x = tf.keras.layers.Activation('relu')(x)

  x = _conv_3x3(x, filters, 1)
  x = tf.keras.layers.Activation('relu')(x)
  return x

def _vgg_block(input_tensor, size, filters, strides):
  """A stack of basic blocks."""
  x = _basic_block(input_tensor, filters, strides=strides)
  for _ in range(size - 1):
      x = _basic_block(x, filters, strides=1)
  return x

def create_cnn(num_blocks, conv_width_multiplier=1, num_classes=10):
  """Create a VGG-like CNN model. 

  The CNN has (6*num_blocks + 2) layers.
  input_shape = (28, 28, 1)  # channels_last
  img_input = tf.keras.layers.Input(shape=input_shape)
  x = img_input
  x = tf.image.per_image_standardization(x)

  x = _conv_3x3(x, 16 * conv_width_multiplier, 1)
  x = _vgg_block(x, size=num_blocks, filters=16 * conv_width_multiplier, strides=1)
  x = _vgg_block(x, size=num_blocks, filters=32 * conv_width_multiplier, strides=2)
  x = _vgg_block(x, size=num_blocks, filters=64 * conv_width_multiplier, strides=2)

  x = tf.keras.layers.GlobalAveragePooling2D()(x)
  x = tf.keras.layers.Dense(num_classes)(x)

  model = tf.keras.models.Model(
      name='cnn-{}-{}'.format(6 * num_blocks + 2, conv_width_multiplier))
  return model

الآن دعونا نحدد حلقة التدريب لـ EMNIST. علما بأن use_experimental_simulation_loop=True في tff.learning.build_federated_averaging_process ويقترح لمحاكاة TFF performant لل، ويشترط للاستفادة من وحدات معالجة الرسومات المتعددة على جهاز واحد. انظر simple_fedavg مثال لكيفية تحديد مخصصة خوارزمية التعلم الاتحادية التي لديها عالية الأداء في وحدات معالجة الرسومات، واحدة من السمات الرئيسية لاستخدام صراحة for ... iter(dataset) لحلقات التدريب.

def keras_evaluate(model, test_data, metric):
  for batch in test_data:
    preds = model(batch['x'], training=False)
    metric.update_state(y_true=batch['y'], y_pred=preds)
  return metric.result()

def run_federated_training(client_epochs_per_round, 

  train_data, test_data = preprocess_emnist_dataset(
      client_epochs_per_round, train_batch_size, test_batch_size)
  data_spec = test_data.element_spec

  def _model_fn():
    keras_model = create_cnn(cnn_num_blocks, conv_width_multiplier)
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    return tff.learning.from_keras_model(
        keras_model, input_spec=data_spec, loss=loss)

  def _server_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=server_learning_rate)

  def _client_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=client_learning_rate)

  iterative_process = tff.learning.build_federated_averaging_process(

  metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
  eval_model = create_cnn(cnn_num_blocks, conv_width_multiplier)

  server_state = iterative_process.initialize()
  start_time = time.time()
  for round_num in range(total_rounds):
    sampled_clients = np.random.choice(
    sampled_train_data = [
        for client in sampled_clients
    if round_num == total_rounds-1:
      with tf.profiler.experimental.Profile(logdir):
        server_state, train_metrics = iterative_process.next(
            server_state, sampled_train_data)
      server_state, train_metrics = iterative_process.next(
            server_state, sampled_train_data)
    print(f'Round {round_num} training loss: {train_metrics["train"]["loss"]}, '
     f'time: {(time.time()-start_time)/(round_num+1.)} secs')
    if round_num % rounds_per_eval == 0 or round_num == total_rounds-1:
      accuracy = keras_evaluate(eval_model, test_data, metric)
      print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')

تنفيذ GPU واحد

وقت التشغيل الافتراضي لـ TFF هو نفسه TF: عند توفير وحدات معالجة الرسومات ، سيتم اختيار أول GPU للتنفيذ. نقوم بتشغيل التدريب الفيدرالي المحدد مسبقًا لعدة جولات بنموذج صغير نسبيًا. ولمحة الجولة الأخيرة من التنفيذ مع tf.profiler وتصور من قبل tensorboard . تم التحقق من استخدام أول GPU.

Model: "cnn-14-4"
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
tf.image.per_image_standardi (None, 28, 28, 1)         0         
conv2d (Conv2D)              (None, 28, 28, 64)        576       
conv2d_1 (Conv2D)            (None, 28, 28, 64)        36864     
activation (Activation)      (None, 28, 28, 64)        0         
conv2d_2 (Conv2D)            (None, 28, 28, 64)        36864     
activation_1 (Activation)    (None, 28, 28, 64)        0         
conv2d_3 (Conv2D)            (None, 28, 28, 64)        36864     
activation_2 (Activation)    (None, 28, 28, 64)        0         
conv2d_4 (Conv2D)            (None, 28, 28, 64)        36864     
activation_3 (Activation)    (None, 28, 28, 64)        0         
conv2d_5 (Conv2D)            (None, 14, 14, 128)       73728     
activation_4 (Activation)    (None, 14, 14, 128)       0         
conv2d_6 (Conv2D)            (None, 14, 14, 128)       147456    
activation_5 (Activation)    (None, 14, 14, 128)       0         
conv2d_7 (Conv2D)            (None, 14, 14, 128)       147456    
activation_6 (Activation)    (None, 14, 14, 128)       0         
conv2d_8 (Conv2D)            (None, 14, 14, 128)       147456    
activation_7 (Activation)    (None, 14, 14, 128)       0         
conv2d_9 (Conv2D)            (None, 7, 7, 256)         294912    
activation_8 (Activation)    (None, 7, 7, 256)         0         
conv2d_10 (Conv2D)           (None, 7, 7, 256)         589824    
activation_9 (Activation)    (None, 7, 7, 256)         0         
conv2d_11 (Conv2D)           (None, 7, 7, 256)         589824    
activation_10 (Activation)   (None, 7, 7, 256)         0         
conv2d_12 (Conv2D)           (None, 7, 7, 256)         589824    
activation_11 (Activation)   (None, 7, 7, 256)         0         
global_average_pooling2d (Gl (None, 256)               0         
dense (Dense)                (None, 10)                2570      
Total params: 2,731,082
Trainable params: 2,731,082
Non-trainable params: 0
Round 0 training loss: 2.4688243865966797, time: 13.382015466690063 secs
Round 0 validation accuracy: 15.240497589111328
Round 1 training loss: 2.3217368125915527, time: 9.311999917030334 secs
Round 2 training loss: 2.3100595474243164, time: 6.972411632537842 secs
Round 2 validation accuracy: 11.226489067077637
Round 3 training loss: 2.303222417831421, time: 6.467299699783325 secs
Round 4 training loss: 2.2976326942443848, time: 5.526083135604859 secs
Round 4 validation accuracy: 11.224040031433105
Round 5 training loss: 2.2919719219207764, time: 5.468692660331726 secs
Round 6 training loss: 2.2911534309387207, time: 4.935825347900391 secs
Round 6 validation accuracy: 11.833855628967285
Round 7 training loss: 2.2871201038360596, time: 4.918408691883087 secs
Round 8 training loss: 2.2818832397460938, time: 4.602836343977186 secs
Round 8 validation accuracy: 11.385677337646484
Round 9 training loss: 2.2790346145629883, time: 4.99558527469635 secs
Round 9 validation accuracy: 11.226489067077637
تنفيذ وحدة المعالجة المركزية

على سبيل المقارنة ، دعنا نهيئ وقت تشغيل TFF لتنفيذ وحدة المعالجة المركزية. يعد تنفيذ وحدة المعالجة المركزية أبطأ قليلاً فقط لهذا الطراز الصغير نسبيًا.

cpu_device = tf.config.list_logical_devices('CPU')[0]
    server_tf_device=cpu_device, client_tf_devices=[cpu_device])

Model: "cnn-14-4"
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 28, 28, 1)]       0         
tf.image.per_image_standardi (None, 28, 28, 1)         0         
conv2d_13 (Conv2D)           (None, 28, 28, 64)        576       
conv2d_14 (Conv2D)           (None, 28, 28, 64)        36864     
activation_12 (Activation)   (None, 28, 28, 64)        0         
conv2d_15 (Conv2D)           (None, 28, 28, 64)        36864     
activation_13 (Activation)   (None, 28, 28, 64)        0         
conv2d_16 (Conv2D)           (None, 28, 28, 64)        36864     
activation_14 (Activation)   (None, 28, 28, 64)        0         
conv2d_17 (Conv2D)           (None, 28, 28, 64)        36864     
activation_15 (Activation)   (None, 28, 28, 64)        0         
conv2d_18 (Conv2D)           (None, 14, 14, 128)       73728     
activation_16 (Activation)   (None, 14, 14, 128)       0         
conv2d_19 (Conv2D)           (None, 14, 14, 128)       147456    
activation_17 (Activation)   (None, 14, 14, 128)       0         
conv2d_20 (Conv2D)           (None, 14, 14, 128)       147456    
activation_18 (Activation)   (None, 14, 14, 128)       0         
conv2d_21 (Conv2D)           (None, 14, 14, 128)       147456    
activation_19 (Activation)   (None, 14, 14, 128)       0         
conv2d_22 (Conv2D)           (None, 7, 7, 256)         294912    
activation_20 (Activation)   (None, 7, 7, 256)         0         
conv2d_23 (Conv2D)           (None, 7, 7, 256)         589824    
activation_21 (Activation)   (None, 7, 7, 256)         0         
conv2d_24 (Conv2D)           (None, 7, 7, 256)         589824    
activation_22 (Activation)   (None, 7, 7, 256)         0         
conv2d_25 (Conv2D)           (None, 7, 7, 256)         589824    
activation_23 (Activation)   (None, 7, 7, 256)         0         
global_average_pooling2d_1 ( (None, 256)               0         
dense_1 (Dense)              (None, 10)                2570      
Total params: 2,731,082
Trainable params: 2,731,082
Non-trainable params: 0
Round 0 training loss: 2.4787657260894775, time: 15.264191627502441 secs
Round 0 validation accuracy: 12.191418647766113
Round 1 training loss: 2.3097336292266846, time: 11.785032272338867 secs
Round 2 training loss: 2.3062121868133545, time: 9.677561124165853 secs
Round 2 validation accuracy: 11.415066719055176
Round 3 training loss: 2.2982261180877686, time: 9.301376760005951 secs
Round 4 training loss: 2.2953946590423584, time: 8.377780866622924 secs
Round 4 validation accuracy: 20.537813186645508
Round 5 training loss: 2.290337324142456, time: 8.385509928067526 secs
Round 6 training loss: 2.2842795848846436, time: 7.809031554630825 secs
Round 6 validation accuracy: 11.934267044067383
Round 7 training loss: 2.2752432823181152, time: 7.8433578312397 secs
Round 8 training loss: 2.2698657512664795, time: 7.478067080179851 secs
Round 8 validation accuracy: 26.16330337524414
Round 9 training loss: 2.2609798908233643, time: 7.632814192771912 secs
Round 9 validation accuracy: 23.079936981201172

تنفيذ متعدد معالجات الجرافيكس

من السهل تكوين TFF لتنفيذ وحدات معالجة رسومات متعددة. لاحظ أن تدريب العميل متوازي في TFF. في الإعداد متعدد وحدات معالجة الرسومات ، سيتم تعيين العملاء لوحدات معالجة رسومات متعددة بطريقة روبن مستديرة. تنفيذ وحدتي معالجة الرسومات التاليين ليس أسرع من تنفيذ GPU الفردي لأن تدريب العميل متوازٍ في كل من إعدادات GPU الفردية والمتعددة ، ويحتوي إعداد MultiGPU على وحدتي GPU ظاهريتين تم إنشاؤهما من وحدة GPU فعلية واحدة.

gpu_devices = tf.config.list_logical_devices('GPU')

Model: "cnn-14-4"
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 28, 28, 1)]       0         
tf.image.per_image_standardi (None, 28, 28, 1)         0         
conv2d_26 (Conv2D)           (None, 28, 28, 64)        576       
conv2d_27 (Conv2D)           (None, 28, 28, 64)        36864     
activation_24 (Activation)   (None, 28, 28, 64)        0         
conv2d_28 (Conv2D)           (None, 28, 28, 64)        36864     
activation_25 (Activation)   (None, 28, 28, 64)        0         
conv2d_29 (Conv2D)           (None, 28, 28, 64)        36864     
activation_26 (Activation)   (None, 28, 28, 64)        0         
conv2d_30 (Conv2D)           (None, 28, 28, 64)        36864     
activation_27 (Activation)   (None, 28, 28, 64)        0         
conv2d_31 (Conv2D)           (None, 14, 14, 128)       73728     
activation_28 (Activation)   (None, 14, 14, 128)       0         
conv2d_32 (Conv2D)           (None, 14, 14, 128)       147456    
activation_29 (Activation)   (None, 14, 14, 128)       0         
conv2d_33 (Conv2D)           (None, 14, 14, 128)       147456    
activation_30 (Activation)   (None, 14, 14, 128)       0         
conv2d_34 (Conv2D)           (None, 14, 14, 128)       147456    
activation_31 (Activation)   (None, 14, 14, 128)       0         
conv2d_35 (Conv2D)           (None, 7, 7, 256)         294912    
activation_32 (Activation)   (None, 7, 7, 256)         0         
conv2d_36 (Conv2D)           (None, 7, 7, 256)         589824    
activation_33 (Activation)   (None, 7, 7, 256)         0         
conv2d_37 (Conv2D)           (None, 7, 7, 256)         589824    
activation_34 (Activation)   (None, 7, 7, 256)         0         
conv2d_38 (Conv2D)           (None, 7, 7, 256)         589824    
activation_35 (Activation)   (None, 7, 7, 256)         0         
global_average_pooling2d_2 ( (None, 256)               0         
dense_2 (Dense)              (None, 10)                2570      
Total params: 2,731,082
Trainable params: 2,731,082
Non-trainable params: 0
Round 0 training loss: 2.911365270614624, time: 12.759389877319336 secs
Round 0 validation accuracy: 9.541536331176758
Round 1 training loss: 2.3175694942474365, time: 9.202919721603394 secs
Round 2 training loss: 2.311001777648926, time: 6.802880525588989 secs
Round 2 validation accuracy: 9.911344528198242
Round 3 training loss: 2.3105244636535645, time: 6.611470937728882 secs
Round 4 training loss: 2.3082072734832764, time: 5.678833389282227 secs
Round 4 validation accuracy: 10.212578773498535
Round 5 training loss: 2.304673671722412, time: 5.5404335260391235 secs
Round 6 training loss: 2.3035168647766113, time: 5.008027451378958 secs
Round 6 validation accuracy: 9.935834884643555
Round 7 training loss: 2.3052737712860107, time: 5.1173741817474365 secs
Round 8 training loss: 2.3007171154022217, time: 4.745321141348945 secs
Round 8 validation accuracy: 10.768514633178711
Round 9 training loss: 2.302018404006958, time: 5.0809732437133786 secs
Round 9 validation accuracy: 12.311422348022461

نموذج أكبر و OOM

لنقم بتشغيل نموذج أكبر على وحدة المعالجة المركزية بجولات أقل اتحادًا.

cpu_device = tf.config.list_logical_devices('CPU')[0]
    server_tf_device=cpu_device, client_tf_devices=[cpu_device])

Model: "cnn-26-4"
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         [(None, 28, 28, 1)]       0         
tf.image.per_image_standardi (None, 28, 28, 1)         0         
conv2d_39 (Conv2D)           (None, 28, 28, 64)        576       
conv2d_40 (Conv2D)           (None, 28, 28, 64)        36864     
activation_36 (Activation)   (None, 28, 28, 64)        0         
conv2d_41 (Conv2D)           (None, 28, 28, 64)        36864     
activation_37 (Activation)   (None, 28, 28, 64)        0         
conv2d_42 (Conv2D)           (None, 28, 28, 64)        36864     
activation_38 (Activation)   (None, 28, 28, 64)        0         
conv2d_43 (Conv2D)           (None, 28, 28, 64)        36864     
activation_39 (Activation)   (None, 28, 28, 64)        0         
conv2d_44 (Conv2D)           (None, 28, 28, 64)        36864     
activation_40 (Activation)   (None, 28, 28, 64)        0         
conv2d_45 (Conv2D)           (None, 28, 28, 64)        36864     
activation_41 (Activation)   (None, 28, 28, 64)        0         
conv2d_46 (Conv2D)           (None, 28, 28, 64)        36864     
activation_42 (Activation)   (None, 28, 28, 64)        0         
conv2d_47 (Conv2D)           (None, 28, 28, 64)        36864     
activation_43 (Activation)   (None, 28, 28, 64)        0         
conv2d_48 (Conv2D)           (None, 14, 14, 128)       73728     
activation_44 (Activation)   (None, 14, 14, 128)       0         
conv2d_49 (Conv2D)           (None, 14, 14, 128)       147456    
activation_45 (Activation)   (None, 14, 14, 128)       0         
conv2d_50 (Conv2D)           (None, 14, 14, 128)       147456    
activation_46 (Activation)   (None, 14, 14, 128)       0         
conv2d_51 (Conv2D)           (None, 14, 14, 128)       147456    
activation_47 (Activation)   (None, 14, 14, 128)       0         
conv2d_52 (Conv2D)           (None, 14, 14, 128)       147456    
activation_48 (Activation)   (None, 14, 14, 128)       0         
conv2d_53 (Conv2D)           (None, 14, 14, 128)       147456    
activation_49 (Activation)   (None, 14, 14, 128)       0         
conv2d_54 (Conv2D)           (None, 14, 14, 128)       147456    
activation_50 (Activation)   (None, 14, 14, 128)       0         
conv2d_55 (Conv2D)           (None, 14, 14, 128)       147456    
activation_51 (Activation)   (None, 14, 14, 128)       0         
conv2d_56 (Conv2D)           (None, 7, 7, 256)         294912    
activation_52 (Activation)   (None, 7, 7, 256)         0         
conv2d_57 (Conv2D)           (None, 7, 7, 256)         589824    
activation_53 (Activation)   (None, 7, 7, 256)         0         
conv2d_58 (Conv2D)           (None, 7, 7, 256)         589824    
activation_54 (Activation)   (None, 7, 7, 256)         0         
conv2d_59 (Conv2D)           (None, 7, 7, 256)         589824    
activation_55 (Activation)   (None, 7, 7, 256)         0         
conv2d_60 (Conv2D)           (None, 7, 7, 256)         589824    
activation_56 (Activation)   (None, 7, 7, 256)         0         
conv2d_61 (Conv2D)           (None, 7, 7, 256)         589824    
activation_57 (Activation)   (None, 7, 7, 256)         0         
conv2d_62 (Conv2D)           (None, 7, 7, 256)         589824    
activation_58 (Activation)   (None, 7, 7, 256)         0         
conv2d_63 (Conv2D)           (None, 7, 7, 256)         589824    
activation_59 (Activation)   (None, 7, 7, 256)         0         
global_average_pooling2d_3 ( (None, 256)               0         
dense_3 (Dense)              (None, 10)                2570      
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
Round 0 training loss: 2.437223434448242, time: 24.121686458587646 secs
Round 0 validation accuracy: 9.024785041809082
Round 1 training loss: 2.3081459999084473, time: 19.48685622215271 secs
Round 2 training loss: 2.305305242538452, time: 15.73950457572937 secs
Round 2 validation accuracy: 9.791339874267578
Round 3 training loss: 2.303149700164795, time: 15.194068729877472 secs
Round 4 training loss: 2.3026506900787354, time: 14.036769819259643 secs
Round 4 validation accuracy: 12.193867683410645

قد يواجه هذا النموذج مشكلة نفاد الذاكرة على وحدة معالجة رسومات واحدة. يمكن تقييد الترحيل من تجارب وحدة المعالجة المركزية واسعة النطاق إلى محاكاة وحدة معالجة الرسومات من خلال استخدام الذاكرة نظرًا لأن وحدات معالجة الرسومات غالبًا ما تحتوي على ذاكرة محدودة. هناك العديد من المعلمات يمكن ضبطها في وقت تشغيل TFF للتخفيف من مشكلة OOM

# Single GPU execution might hit OOM. 
gpu_devices = tf.config.list_logical_devices('GPU')

except ResourceExhaustedError as e:
# Control concurrency by `clients_per_thread`.
gpu_devices = tf.config.list_logical_devices('GPU')
    client_tf_devices=[gpu_devices[0]], clients_per_thread=2)

Model: "cnn-26-4"
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
tf.image.per_image_standardi (None, 28, 28, 1)         0         
conv2d (Conv2D)              (None, 28, 28, 64)        576       
conv2d_1 (Conv2D)            (None, 28, 28, 64)        36864     
activation (Activation)      (None, 28, 28, 64)        0         
conv2d_2 (Conv2D)            (None, 28, 28, 64)        36864     
activation_1 (Activation)    (None, 28, 28, 64)        0         
conv2d_3 (Conv2D)            (None, 28, 28, 64)        36864     
activation_2 (Activation)    (None, 28, 28, 64)        0         
conv2d_4 (Conv2D)            (None, 28, 28, 64)        36864     
activation_3 (Activation)    (None, 28, 28, 64)        0         
conv2d_5 (Conv2D)            (None, 28, 28, 64)        36864     
activation_4 (Activation)    (None, 28, 28, 64)        0         
conv2d_6 (Conv2D)            (None, 28, 28, 64)        36864     
activation_5 (Activation)    (None, 28, 28, 64)        0         
conv2d_7 (Conv2D)            (None, 28, 28, 64)        36864     
activation_6 (Activation)    (None, 28, 28, 64)        0         
conv2d_8 (Conv2D)            (None, 28, 28, 64)        36864     
activation_7 (Activation)    (None, 28, 28, 64)        0         
conv2d_9 (Conv2D)            (None, 14, 14, 128)       73728     
activation_8 (Activation)    (None, 14, 14, 128)       0         
conv2d_10 (Conv2D)           (None, 14, 14, 128)       147456    
activation_9 (Activation)    (None, 14, 14, 128)       0         
conv2d_11 (Conv2D)           (None, 14, 14, 128)       147456    
activation_10 (Activation)   (None, 14, 14, 128)       0         
conv2d_12 (Conv2D)           (None, 14, 14, 128)       147456    
activation_11 (Activation)   (None, 14, 14, 128)       0         
conv2d_13 (Conv2D)           (None, 14, 14, 128)       147456    
activation_12 (Activation)   (None, 14, 14, 128)       0         
conv2d_14 (Conv2D)           (None, 14, 14, 128)       147456    
activation_13 (Activation)   (None, 14, 14, 128)       0         
conv2d_15 (Conv2D)           (None, 14, 14, 128)       147456    
activation_14 (Activation)   (None, 14, 14, 128)       0         
conv2d_16 (Conv2D)           (None, 14, 14, 128)       147456    
activation_15 (Activation)   (None, 14, 14, 128)       0         
conv2d_17 (Conv2D)           (None, 7, 7, 256)         294912    
activation_16 (Activation)   (None, 7, 7, 256)         0         
conv2d_18 (Conv2D)           (None, 7, 7, 256)         589824    
activation_17 (Activation)   (None, 7, 7, 256)         0         
conv2d_19 (Conv2D)           (None, 7, 7, 256)         589824    
activation_18 (Activation)   (None, 7, 7, 256)         0         
conv2d_20 (Conv2D)           (None, 7, 7, 256)         589824    
activation_19 (Activation)   (None, 7, 7, 256)         0         
conv2d_21 (Conv2D)           (None, 7, 7, 256)         589824    
activation_20 (Activation)   (None, 7, 7, 256)         0         
conv2d_22 (Conv2D)           (None, 7, 7, 256)         589824    
activation_21 (Activation)   (None, 7, 7, 256)         0         
conv2d_23 (Conv2D)           (None, 7, 7, 256)         589824    
activation_22 (Activation)   (None, 7, 7, 256)         0         
conv2d_24 (Conv2D)           (None, 7, 7, 256)         589824    
activation_23 (Activation)   (None, 7, 7, 256)         0         
global_average_pooling2d (Gl (None, 256)               0         
dense (Dense)                (None, 10)                2570      
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
Round 0 training loss: 2.4990053176879883, time: 11.922378778457642 secs
Round 0 validation accuracy: 11.224040031433105
Round 1 training loss: 2.307560920715332, time: 9.916815996170044 secs
Round 2 training loss: 2.3032877445220947, time: 7.68927804629008 secs
Round 2 validation accuracy: 11.224040031433105
Round 3 training loss: 2.302366256713867, time: 7.681552231311798 secs
Round 4 training loss: 2.301671028137207, time: 7.613566827774048 secs
Round 4 validation accuracy: 11.224040031433105
# Multi-GPU execution with configuration to mitigate OOM. 
cpu_device = tf.config.list_logical_devices('CPU')[0]
gpu_devices = tf.config.list_logical_devices('GPU')

Model: "cnn-26-4"
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
tf.image.per_image_standardi (None, 28, 28, 1)         0         
conv2d (Conv2D)              (None, 28, 28, 64)        576       
conv2d_1 (Conv2D)            (None, 28, 28, 64)        36864     
activation (Activation)      (None, 28, 28, 64)        0         
conv2d_2 (Conv2D)            (None, 28, 28, 64)        36864     
activation_1 (Activation)    (None, 28, 28, 64)        0         
conv2d_3 (Conv2D)            (None, 28, 28, 64)        36864     
activation_2 (Activation)    (None, 28, 28, 64)        0         
conv2d_4 (Conv2D)            (None, 28, 28, 64)        36864     
activation_3 (Activation)    (None, 28, 28, 64)        0         
conv2d_5 (Conv2D)            (None, 28, 28, 64)        36864     
activation_4 (Activation)    (None, 28, 28, 64)        0         
conv2d_6 (Conv2D)            (None, 28, 28, 64)        36864     
activation_5 (Activation)    (None, 28, 28, 64)        0         
conv2d_7 (Conv2D)            (None, 28, 28, 64)        36864     
activation_6 (Activation)    (None, 28, 28, 64)        0         
conv2d_8 (Conv2D)            (None, 28, 28, 64)        36864     
activation_7 (Activation)    (None, 28, 28, 64)        0         
conv2d_9 (Conv2D)            (None, 14, 14, 128)       73728     
activation_8 (Activation)    (None, 14, 14, 128)       0         
conv2d_10 (Conv2D)           (None, 14, 14, 128)       147456    
activation_9 (Activation)    (None, 14, 14, 128)       0         
conv2d_11 (Conv2D)           (None, 14, 14, 128)       147456    
activation_10 (Activation)   (None, 14, 14, 128)       0         
conv2d_12 (Conv2D)           (None, 14, 14, 128)       147456    
activation_11 (Activation)   (None, 14, 14, 128)       0         
conv2d_13 (Conv2D)           (None, 14, 14, 128)       147456    
activation_12 (Activation)   (None, 14, 14, 128)       0         
conv2d_14 (Conv2D)           (None, 14, 14, 128)       147456    
activation_13 (Activation)   (None, 14, 14, 128)       0         
conv2d_15 (Conv2D)           (None, 14, 14, 128)       147456    
activation_14 (Activation)   (None, 14, 14, 128)       0         
conv2d_16 (Conv2D)           (None, 14, 14, 128)       147456    
activation_15 (Activation)   (None, 14, 14, 128)       0         
conv2d_17 (Conv2D)           (None, 7, 7, 256)         294912    
activation_16 (Activation)   (None, 7, 7, 256)         0         
conv2d_18 (Conv2D)           (None, 7, 7, 256)         589824    
activation_17 (Activation)   (None, 7, 7, 256)         0         
conv2d_19 (Conv2D)           (None, 7, 7, 256)         589824    
activation_18 (Activation)   (None, 7, 7, 256)         0         
conv2d_20 (Conv2D)           (None, 7, 7, 256)         589824    
activation_19 (Activation)   (None, 7, 7, 256)         0         
conv2d_21 (Conv2D)           (None, 7, 7, 256)         589824    
activation_20 (Activation)   (None, 7, 7, 256)         0         
conv2d_22 (Conv2D)           (None, 7, 7, 256)         589824    
activation_21 (Activation)   (None, 7, 7, 256)         0         
conv2d_23 (Conv2D)           (None, 7, 7, 256)         589824    
activation_22 (Activation)   (None, 7, 7, 256)         0         
conv2d_24 (Conv2D)           (None, 7, 7, 256)         589824    
activation_23 (Activation)   (None, 7, 7, 256)         0         
global_average_pooling2d (Gl (None, 256)               0         
dense (Dense)                (None, 10)                2570      
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
Round 0 training loss: 2.4691953659057617, time: 17.81941556930542 secs
Round 0 validation accuracy: 10.817495346069336
Round 1 training loss: 2.3081436157226562, time: 12.986191034317017 secs
Round 2 training loss: 2.3028159141540527, time: 9.518500963846842 secs
Round 2 validation accuracy: 11.500783920288086
Round 3 training loss: 2.303886651992798, time: 8.989932537078857 secs
Round 4 training loss: 2.3030669689178467, time: 8.733866214752197 secs
Round 4 validation accuracy: 12.992260932922363

تحسين الأداء

عموما يمكن التقنيات في TF التي يمكن أن تحقق أداء أفضل أن تستخدم في TFF، على سبيل المثال، التدريب الدقة مختلطة و XLA . لتسريع (على وحدات معالجة الرسومات مثل V100) وتوفير الذاكرة من الدقة مختلطة في كثير من الأحيان يمكن أن تكون كبيرة، والتي يمكن دراستها من قبل tf.profiler .

# Mixed precision training. 
cpu_device = tf.config.list_logical_devices('CPU')[0]
gpu_devices = tf.config.list_logical_devices('GPU')
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')

Model: "cnn-26-4"
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
tf.image.per_image_standardi (None, 28, 28, 1)         0         
conv2d (Conv2D)              (None, 28, 28, 64)        576       
conv2d_1 (Conv2D)            (None, 28, 28, 64)        36864     
activation (Activation)      (None, 28, 28, 64)        0         
conv2d_2 (Conv2D)            (None, 28, 28, 64)        36864     
activation_1 (Activation)    (None, 28, 28, 64)        0         
conv2d_3 (Conv2D)            (None, 28, 28, 64)        36864     
activation_2 (Activation)    (None, 28, 28, 64)        0         
conv2d_4 (Conv2D)            (None, 28, 28, 64)        36864     
activation_3 (Activation)    (None, 28, 28, 64)        0         
conv2d_5 (Conv2D)            (None, 28, 28, 64)        36864     
activation_4 (Activation)    (None, 28, 28, 64)        0         
conv2d_6 (Conv2D)            (None, 28, 28, 64)        36864     
activation_5 (Activation)    (None, 28, 28, 64)        0         
conv2d_7 (Conv2D)            (None, 28, 28, 64)        36864     
activation_6 (Activation)    (None, 28, 28, 64)        0         
conv2d_8 (Conv2D)            (None, 28, 28, 64)        36864     
activation_7 (Activation)    (None, 28, 28, 64)        0         
conv2d_9 (Conv2D)            (None, 14, 14, 128)       73728     
activation_8 (Activation)    (None, 14, 14, 128)       0         
conv2d_10 (Conv2D)           (None, 14, 14, 128)       147456    
activation_9 (Activation)    (None, 14, 14, 128)       0         
conv2d_11 (Conv2D)           (None, 14, 14, 128)       147456    
activation_10 (Activation)   (None, 14, 14, 128)       0         
conv2d_12 (Conv2D)           (None, 14, 14, 128)       147456    
activation_11 (Activation)   (None, 14, 14, 128)       0         
conv2d_13 (Conv2D)           (None, 14, 14, 128)       147456    
activation_12 (Activation)   (None, 14, 14, 128)       0         
conv2d_14 (Conv2D)           (None, 14, 14, 128)       147456    
activation_13 (Activation)   (None, 14, 14, 128)       0         
conv2d_15 (Conv2D)           (None, 14, 14, 128)       147456    
activation_14 (Activation)   (None, 14, 14, 128)       0         
conv2d_16 (Conv2D)           (None, 14, 14, 128)       147456    
activation_15 (Activation)   (None, 14, 14, 128)       0         
conv2d_17 (Conv2D)           (None, 7, 7, 256)         294912    
activation_16 (Activation)   (None, 7, 7, 256)         0         
conv2d_18 (Conv2D)           (None, 7, 7, 256)         589824    
activation_17 (Activation)   (None, 7, 7, 256)         0         
conv2d_19 (Conv2D)           (None, 7, 7, 256)         589824    
activation_18 (Activation)   (None, 7, 7, 256)         0         
conv2d_20 (Conv2D)           (None, 7, 7, 256)         589824    
activation_19 (Activation)   (None, 7, 7, 256)         0         
conv2d_21 (Conv2D)           (None, 7, 7, 256)         589824    
activation_20 (Activation)   (None, 7, 7, 256)         0         
conv2d_22 (Conv2D)           (None, 7, 7, 256)         589824    
activation_21 (Activation)   (None, 7, 7, 256)         0         
conv2d_23 (Conv2D)           (None, 7, 7, 256)         589824    
activation_22 (Activation)   (None, 7, 7, 256)         0         
conv2d_24 (Conv2D)           (None, 7, 7, 256)         589824    
activation_23 (Activation)   (None, 7, 7, 256)         0         
global_average_pooling2d (Gl (None, 256)               0         
dense (Dense)                (None, 10)                2570      
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
Round 0 training loss: 2.4187185764312744, time: 18.763780117034912 secs
Round 0 validation accuracy: 9.977468490600586
Round 1 training loss: 2.305102825164795, time: 13.712820529937744 secs
Round 2 training loss: 2.304737091064453, time: 9.993690172831217 secs
Round 2 validation accuracy: 11.779976844787598
Round 3 training loss: 2.2996833324432373, time: 9.29404467344284 secs
Round 4 training loss: 2.299349308013916, time: 9.195427560806275 secs
Round 4 validation accuracy: 11.224040031433105
