TFF لبحوث التعلم الموحد: ضغط النموذج والتحديث

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

في هذا البرنامج التعليمي، ونحن نستخدم EMNIST بيانات لشرح كيفية تمكين خوارزميات الضغط الضياع للحد من تكلفة الاتصالات في خوارزمية المتوسط الاتحادية باستخدام tff.learning.build_federated_averaging_process API و tensor_encoding API. لمزيد من التفاصيل حول الخوارزمية المتوسط الاتحادية، راجع ورقة التعلم من الاتصالات كفاءة شبكات ديب من البيانات اللامركزية .

قبل أن نبدأ

قبل أن نبدأ ، يرجى تشغيل ما يلي للتأكد من إعداد بيئتك بشكل صحيح. إذا كنت لا ترى تحية، يرجى الرجوع إلى تركيب دليل للتعليمات.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade tensorflow-model-optimization
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard

import functools

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from tensorflow_model_optimization.python.core.internal import tensor_encoding as te

تحقق مما إذا كان TFF يعمل.

@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

تجهيز بيانات الإدخال

في هذا القسم نقوم بتحميل مجموعة بيانات EMNIST المضمنة في TFF ومعالجتها مسبقًا. يرجى مراجعة الاتحادية التعلم للتصنيف صور البرنامج التعليمي لمزيد من التفاصيل حول EMNIST البيانات.

# This value only applies to EMNIST dataset, consider choosing appropriate
# values if switching to other datasets.
MAX_CLIENT_DATASET_SIZE = 418

CLIENT_EPOCHS_PER_ROUND = 1
CLIENT_BATCH_SIZE = 20
TEST_BATCH_SIZE = 500

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
    only_digits=True)

def reshape_emnist_element(element):
  return (tf.expand_dims(element['pixels'], axis=-1), element['label'])

def preprocess_train_dataset(dataset):
  """Preprocessing function for the EMNIST training dataset."""
  return (dataset
          # Shuffle according to the largest client dataset
          .shuffle(buffer_size=MAX_CLIENT_DATASET_SIZE)
          # Repeat to do multiple local epochs
          .repeat(CLIENT_EPOCHS_PER_ROUND)
          # Batch to a fixed client batch size
          .batch(CLIENT_BATCH_SIZE, drop_remainder=False)
          # Preprocessing step
          .map(reshape_emnist_element))

emnist_train = emnist_train.preprocess(preprocess_train_dataset)

تحديد النموذج

نحن هنا تحديد نموذج keras على أساس orginial FedAvg CNN، ثم لف نموذج keras في مثيل tff.learning.Model بحيث يمكن استهلاكها من قبل TFF.

لاحظ أننا سوف نحتاج إلى وظيفة التي تنتج نموذجا بدلا من مجرد نموذج مباشرة. وبالإضافة إلى ذلك، فإن وظيفة لا يمكن فقط التقاط نموذج شيدت قبل، فإنه يجب خلق نموذج في السياق الذي يطلق عليه. والسبب هو أن TFF مصمم للانتقال إلى الأجهزة ، ويحتاج إلى التحكم في وقت إنشاء الموارد بحيث يمكن التقاطها وتعبئتها.

def create_original_fedavg_cnn_model(only_digits=True):
  """The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format = 'channels_last'

  max_pool = functools.partial(
      tf.keras.layers.MaxPooling2D,
      pool_size=(2, 2),
      padding='same',
      data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
      conv2d(filters=32),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10 if only_digits else 62),
      tf.keras.layers.Softmax(),
  ])

  return model

# Gets the type information of the input data. TFF is a strongly typed
# functional programming framework, and needs type information about inputs to 
# the model.
input_spec = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0]).element_spec

def tff_model_fn():
  keras_model = create_original_fedavg_cnn_model()
  return tff.learning.from_keras_model(
      keras_model=keras_model,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

تدريب النموذج وإخراج مقاييس التدريب

نحن الآن جاهزون لإنشاء خوارزمية متوسطات موحدة وتدريب النموذج المحدد على مجموعة بيانات EMNIST.

أولا نحن بحاجة لبناء المتوسط خوارزمية الاتحادية باستخدام tff.learning.build_federated_averaging_process API.

federated_averaging = tff.learning.build_federated_averaging_process(
    model_fn=tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

لنقم الآن بتشغيل خوارزمية المتوسطات الموحدة. يبدو تنفيذ خوارزمية التعلم الفيدرالي من منظور TFF كما يلي:

  1. ابدأ الخوارزمية واحصل على حالة الخادم الأولية. تحتوي حالة الخادم على المعلومات الضرورية لتنفيذ الخوارزمية. تذكر ، نظرًا لأن TFF وظيفي ، فإن هذه الحالة تتضمن أي حالة مُحسِّن تستخدمها الخوارزمية (مثل مصطلحات الزخم) بالإضافة إلى معلمات النموذج نفسها - سيتم تمريرها كوسائط وإعادتها كنتائج من حسابات TFF.
  2. تنفيذ الخوارزمية جولة بعد جولة. في كل جولة ، سيتم إرجاع حالة خادم جديدة كنتيجة لتدريب كل عميل على النموذج على بياناته. عادة في جولة واحدة:
    1. بث الخادم النموذج لجميع العملاء المشاركين.
    2. يقوم كل عميل بعمل يعتمد على النموذج والبيانات الخاصة به.
    3. يقوم الخادم بتجميع كل النماذج لإنتاج حالة قطع تحتوي على نموذج جديد.

لمزيد من المعلومات، يرجى الاطلاع مخصص الاتحادية الخوارزميات، الجزء 2: تنفيذ اتحاد المتوسط التعليمي.

تتم كتابة مقاييس التدريب في دليل Tensorboard لعرضها بعد التدريب.

تحميل وظائف المرافق

def train(federated_averaging_process, num_rounds, num_clients_per_round, summary_writer):
  """Trains the federated averaging process and output metrics."""
  # Create a environment to get communication cost.
  environment = set_sizing_environment()

  # Initialize the Federated Averaging algorithm to get the initial server state.
  state = federated_averaging_process.initialize()

  with summary_writer.as_default():
    for round_num in range(num_rounds):
      # Sample the clients parcitipated in this round.
      sampled_clients = np.random.choice(
          emnist_train.client_ids,
          size=num_clients_per_round,
          replace=False)
      # Create a list of `tf.Dataset` instances from the data of sampled clients.
      sampled_train_data = [
          emnist_train.create_tf_dataset_for_client(client)
          for client in sampled_clients
      ]
      # Round one round of the algorithm based on the server state and client data
      # and output the new state and metrics.
      state, metrics = federated_averaging_process.next(state, sampled_train_data)

      # For more about size_info, please see https://www.tensorflow.org/federated/api_docs/python/tff/framework/SizeInfo
      size_info = environment.get_size_info()
      broadcasted_bits = size_info.broadcast_bits[-1]
      aggregated_bits = size_info.aggregate_bits[-1]

      print('round {:2d}, metrics={}, broadcasted_bits={}, aggregated_bits={}'.format(round_num, metrics, format_size(broadcasted_bits), format_size(aggregated_bits)))

      # Add metrics to Tensorboard.
      for name, value in metrics['train'].items():
          tf.summary.scalar(name, value, step=round_num)

      # Add broadcasted and aggregated data size to Tensorboard.
      tf.summary.scalar('cumulative_broadcasted_bits', broadcasted_bits, step=round_num)
      tf.summary.scalar('cumulative_aggregated_bits', aggregated_bits, step=round_num)
      summary_writer.flush()
# Clean the log directory to avoid conflicts.
try:
  tf.io.gfile.rmtree('/tmp/logs/scalars')
except tf.errors.OpError as e:
  pass  # Path doesn't exist

# Set up the log directory and writer for Tensorboard.
logdir = "/tmp/logs/scalars/original/"
summary_writer = tf.summary.create_file_writer(logdir)

train(federated_averaging_process=federated_averaging, num_rounds=10,
      num_clients_per_round=10, summary_writer=summary_writer)
round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07383774), ('loss', 2.3276227)])), ('stat', OrderedDict([('num_examples', 1097)]))]), broadcasted_bits=507.62Mibit, aggregated_bits=507.62Mibit
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.099585064), ('loss', 2.3152695)])), ('stat', OrderedDict([('num_examples', 964)]))]), broadcasted_bits=1015.24Mibit, aggregated_bits=1015.24Mibit
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09760766), ('loss', 2.3077576)])), ('stat', OrderedDict([('num_examples', 1045)]))]), broadcasted_bits=1.49Gibit, aggregated_bits=1.49Gibit
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.0963035), ('loss', 2.3066626)])), ('stat', OrderedDict([('num_examples', 1028)]))]), broadcasted_bits=1.98Gibit, aggregated_bits=1.98Gibit
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10694184), ('loss', 2.3033001)])), ('stat', OrderedDict([('num_examples', 1066)]))]), broadcasted_bits=2.48Gibit, aggregated_bits=2.48Gibit
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.1185567), ('loss', 2.2999184)])), ('stat', OrderedDict([('num_examples', 970)]))]), broadcasted_bits=2.97Gibit, aggregated_bits=2.97Gibit
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11751663), ('loss', 2.296883)])), ('stat', OrderedDict([('num_examples', 902)]))]), broadcasted_bits=3.47Gibit, aggregated_bits=3.47Gibit
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13063477), ('loss', 2.2990246)])), ('stat', OrderedDict([('num_examples', 1087)]))]), broadcasted_bits=3.97Gibit, aggregated_bits=3.97Gibit
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12742382), ('loss', 2.2971866)])), ('stat', OrderedDict([('num_examples', 1083)]))]), broadcasted_bits=4.46Gibit, aggregated_bits=4.46Gibit
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13555992), ('loss', 2.2934425)])), ('stat', OrderedDict([('num_examples', 1018)]))]), broadcasted_bits=4.96Gibit, aggregated_bits=4.96Gibit

ابدأ TensorBoard بدليل السجل الجذر المحدد أعلاه لعرض مقاييس التدريب. قد يستغرق تحميل البيانات بضع ثوانٍ. باستثناء الخسارة والدقة ، نقوم أيضًا بإخراج كمية البيانات التي يتم بثها والمجمعة. تشير البيانات المذاعة إلى الموترات التي يدفعها الخادم لكل عميل بينما تشير البيانات المجمعة إلى الموترات التي يعود كل عميل إلى الخادم.

%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard...
Reusing TensorBoard on port 34445 (pid 579503), started 1:53:14 ago. (Use '!kill 579503' to kill it.)
<IPython.core.display.Javascript at 0x7f9135ef1630>

بناء البث المخصص ووظيفة التجميع

الآن دعونا تنفيذ وظيفة لاستخدام خوارزميات الضغط الضياع على البيانات بث البيانات وتجميعها باستخدام tensor_encoding API.

أولاً ، نحدد وظيفتين:

  • broadcast_encoder_fn مما يخلق مثيل te.core.SimpleEncoder إلى التنسورات ترميز أو المتغيرات في الخادم اتصال العميل (بيانات البث).
  • mean_encoder_fn مما يخلق مثيل te.core.GatherEncoder إلى التنسورات ترميز أو المتغيرات في العميل إلى خادم communicaiton (بيانات تجميع).

من المهم ملاحظة أننا لا نطبق طريقة ضغط على النموذج بأكمله مرة واحدة. بدلاً من ذلك ، نقرر كيف (وما إذا كنا) نضغط كل متغير في النموذج بشكل مستقل. والسبب هو أن المتغيرات الصغيرة بشكل عام مثل التحيزات تكون أكثر حساسية لعدم الدقة ، وكونها صغيرة نسبيًا ، فإن المدخرات المحتملة في الاتصال تكون أيضًا صغيرة نسبيًا. ومن ثم فإننا لا نضغط المتغيرات الصغيرة افتراضيًا. في هذا المثال ، نطبق تكميمًا موحدًا على 8 بتات (256 دلوًا) لكل متغير يحتوي على أكثر من 10000 عنصر ، ونطبق الهوية فقط على المتغيرات الأخرى.

def broadcast_encoder_fn(value):
  """Function for building encoded broadcast."""
  spec = tf.TensorSpec(value.shape, value.dtype)
  if value.shape.num_elements() > 10000:
    return te.encoders.as_simple_encoder(
        te.encoders.uniform_quantization(bits=8), spec)
  else:
    return te.encoders.as_simple_encoder(te.encoders.identity(), spec)


def mean_encoder_fn(tensor_spec):
  """Function for building a GatherEncoder."""
  spec = tf.TensorSpec(tensor_spec.shape, tensor_spec.dtype)
  if tensor_spec.shape.num_elements() > 10000:
    return te.encoders.as_gather_encoder(
        te.encoders.uniform_quantization(bits=8), spec)
  else:
    return te.encoders.as_gather_encoder(te.encoders.identity(), spec)

يوفر TFF واجهات برمجة التطبيقات لتحويل وظيفة التشفير إلى تنسيق tff.learning.build_federated_averaging_process API يمكن أن تستهلك. باستخدام tff.learning.framework.build_encoded_broadcast_from_model و tff.aggregators.MeanFactory ، ونحن يمكن أن تخلق اثنين من الكائنات التي يمكن أن تنتقل إلى broadcast_process و model_update_aggregation_factory agruments من tff.learning.build_federated_averaging_process إلى إنشاء اتحاد متوسط خوارزميات مع خوارزمية ضغط الضياع.

encoded_broadcast_process = (
    tff.learning.framework.build_encoded_broadcast_process_from_model(
        tff_model_fn, broadcast_encoder_fn))

mean_factory = tff.aggregators.MeanFactory(
    tff.aggregators.EncodedSumFactory(mean_encoder_fn), # numerator
    tff.aggregators.EncodedSumFactory(mean_encoder_fn), # denominator
)

federated_averaging_with_compression = tff.learning.build_federated_averaging_process(
    tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    broadcast_process=encoded_broadcast_process,
    model_update_aggregation_factory=mean_factory)

تدريب النموذج مرة أخرى

لنقم الآن بتشغيل خوارزمية المتوسطات الموحدة الجديدة.

logdir_for_compression = "/tmp/logs/scalars/compression/"
summary_writer_for_compression = tf.summary.create_file_writer(
    logdir_for_compression)

train(federated_averaging_process=federated_averaging_with_compression, 
      num_rounds=10,
      num_clients_per_round=10,
      summary_writer=summary_writer_for_compression)
round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.093), ('loss', 2.3194966)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=146.46Mibit, aggregated_bits=146.46Mibit
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10432034), ('loss', 2.3079953)])), ('stat', OrderedDict([('num_examples', 949)]))]), broadcasted_bits=292.92Mibit, aggregated_bits=292.93Mibit
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07886754), ('loss', 2.3101337)])), ('stat', OrderedDict([('num_examples', 989)]))]), broadcasted_bits=439.38Mibit, aggregated_bits=439.39Mibit
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09774436), ('loss', 2.305069)])), ('stat', OrderedDict([('num_examples', 1064)]))]), broadcasted_bits=585.84Mibit, aggregated_bits=585.85Mibit
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09404097), ('loss', 2.302943)])), ('stat', OrderedDict([('num_examples', 1074)]))]), broadcasted_bits=732.30Mibit, aggregated_bits=732.32Mibit
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09), ('loss', 2.304385)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=878.77Mibit, aggregated_bits=878.78Mibit
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14368932), ('loss', 2.2973824)])), ('stat', OrderedDict([('num_examples', 1030)]))]), broadcasted_bits=1.00Gibit, aggregated_bits=1.00Gibit
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12140871), ('loss', 2.2993405)])), ('stat', OrderedDict([('num_examples', 1079)]))]), broadcasted_bits=1.14Gibit, aggregated_bits=1.14Gibit
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13600783), ('loss', 2.2953267)])), ('stat', OrderedDict([('num_examples', 1022)]))]), broadcasted_bits=1.29Gibit, aggregated_bits=1.29Gibit
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13844621), ('loss', 2.295768)])), ('stat', OrderedDict([('num_examples', 1004)]))]), broadcasted_bits=1.43Gibit, aggregated_bits=1.43Gibit

ابدأ TensorBoard مرة أخرى لمقارنة مقاييس التدريب بين مرحلتين.

كما ترون في Tensorboard، هناك انخفاض كبير بين orginial و compression المنحنيات في broadcasted_bits و aggregated_bits المؤامرات بينما في loss و sparse_categorical_accuracy مؤامرة المنحنيين هي مماثلة إلى حد ما.

في الختام ، قمنا بتنفيذ خوارزمية ضغط يمكنها تحقيق أداء مشابه لخوارزمية المتوسطات الموحدة الأصلية بينما يتم تقليل تكلفة التعليم بشكل كبير.

%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard...
Reusing TensorBoard on port 34445 (pid 579503), started 1:54:12 ago. (Use '!kill 579503' to kill it.)
<IPython.core.display.Javascript at 0x7f9140eb5ef0>

تمارين

لتنفيذ خوارزمية ضغط مخصصة وتطبيقها على حلقة التدريب ، يمكنك:

  1. تنفيذ خوارزمية ضغط جديدة في فئة فرعية من EncodingStageInterface أو البديل أعم لها، AdaptiveEncodingStageInterface التالية هذا المثال .
  2. بناء الجديد Encoder ومتخصصون لمدة البث نموذج أو نموذج تحديث المتوسط .
  3. استخدام هذه الكائنات لبناء بأكمله حساب التدريب .

تشمل الأسئلة البحثية المفتوحة ذات القيمة المحتملة ما يلي: التكميم غير المنتظم ، والضغط بلا خسارة مثل تشفير هوفمان ، وآليات تكييف الضغط بناءً على المعلومات من جولات التدريب السابقة.

مواد القراءة الموصى بها: