تعلم موحد ذو نموذج كبير يتسم بكفاءة العميل من خلال التجميع الموحد المختار والمتناثر

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

وهذا يدل على تعليمي كيف TFF يمكن استخدامها لتدريب نموذج كبير جدا حيث كل جهاز العميل تنزيل فقط وتحديث جزء صغير من هذا النموذج، باستخدام tff.federated_select وتجميع متفرق. في حين أن هذا البرنامج التعليمي هو إلى حد ما قائمة بذاتها، و tff.federated_select تعليمي و العرف FL خوارزميات تعليمي يوفر مقدمات جيدة لبعض التقنيات المستخدمة هنا.

بشكل ملموس ، في هذا البرنامج التعليمي ، نأخذ في الاعتبار الانحدار اللوجستي للتصنيف متعدد العلامات ، والتنبؤ بأي "العلامات" المرتبطة بسلسلة نصية بناءً على تمثيل ميزة كيس من الكلمات. الأهم من ذلك، يتم التحكم في تكاليف الاتصالات وحساب العميل بواسطة ثابت ثابت ( MAX_TOKENS_SELECTED_PER_CLIENTوليس مقياس لحجم المفردات العام، والتي يمكن أن تكون كبيرة للغاية في ضبط العملية.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import itertools
import numpy as np

from typing import Callable, List, Tuple

import tensorflow as tf
import tensorflow_federated as tff
tff.backends.native.set_local_python_execution_context()

وسيكون لكل عميل federated_select الصفوف من الأوزان نموذجا للفي معظم هذه العديد من الرموز فريدة من نوعها. هذا الحدود العليا حجم النموذج المحلي العميل ومقدار الخادم -> العميل ( federated_select ) والعميل -> الخادم (federated_aggregate إجراء) الاتصالات.

يجب أن يستمر تشغيل هذا البرنامج التعليمي بشكل صحيح حتى إذا قمت بتعيين هذا على أنه صغير مثل 1 (ضمان عدم تحديد جميع الرموز المميزة من كل عميل) أو إلى قيمة كبيرة ، على الرغم من إمكانية تقارب النموذج.

MAX_TOKENS_SELECTED_PER_CLIENT = 6

نحدد أيضًا بعض الثوابت لأنواع مختلفة. لهذا colab، عربون هو معرف صحيح عن كلمة معينة بعد تحليل البيانات.

# There are some constraints on types
# here that will require some explicit type conversions:
#    - `tff.federated_select` requires int32
#    - `tf.SparseTensor` requires int64 indices.
TOKEN_DTYPE = tf.int64
SELECT_KEY_DTYPE = tf.int32

# Type for counts of token occurences.
TOKEN_COUNT_DTYPE = tf.int32

# A sparse feature vector can be thought of as a map
# from TOKEN_DTYPE to FEATURE_DTYPE. 
# Our features are {0, 1} indicators, so we could potentially
# use tf.int8 as an optimization.
FEATURE_DTYPE = tf.int32

إعداد المشكلة: مجموعة البيانات والنموذج

نقوم ببناء مجموعة بيانات لعبة صغيرة للتجريب السهل في هذا البرنامج التعليمي. ومع ذلك، فإن شكل بيانات متوافق مع اتحاد ستاكوفيرفلوو ، و قبل المعالجة و الهندسة المعمارية نموذج يتم اعتماد من مشكلة العلامة التنبؤ ستاكوفيرفلوو من التكيف الاتحادية الأمثل .

تحليل مجموعة البيانات والمعالجة المسبقة

NUM_OOV_BUCKETS = 1

BatchType = collections.namedtuple('BatchType', ['tokens', 'tags'])

def build_to_ids_fn(word_vocab: List[str],
                    tag_vocab: List[str]) -> Callable[[tf.Tensor], tf.Tensor]:
  """Constructs a function mapping examples to sequences of token indices."""
  word_table_values = np.arange(len(word_vocab), dtype=np.int64)
  word_table = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(word_vocab, word_table_values),
      num_oov_buckets=NUM_OOV_BUCKETS)

  tag_table_values = np.arange(len(tag_vocab), dtype=np.int64)
  tag_table = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(tag_vocab, tag_table_values),
      num_oov_buckets=NUM_OOV_BUCKETS)

  def to_ids(example):
    """Converts a Stack Overflow example to a bag-of-words/tags format."""
    sentence = tf.strings.join([example['tokens'], example['title']],
                               separator=' ')

    # We represent that label (output tags) densely.
    raw_tags = example['tags']
    tags = tf.strings.split(raw_tags, sep='|')
    tags = tag_table.lookup(tags)
    tags, _ = tf.unique(tags)
    tags = tf.one_hot(tags, len(tag_vocab) + NUM_OOV_BUCKETS)
    tags = tf.reduce_max(tags, axis=0)

    # We represent the features as a SparseTensor of {0, 1}s.
    words = tf.strings.split(sentence)
    tokens = word_table.lookup(words)
    tokens, _ = tf.unique(tokens)
    # Note:  We could choose to use the word counts as the feature vector
    # instead of just {0, 1} values (see tf.unique_with_counts).
    tokens = tf.reshape(tokens, shape=(tf.size(tokens), 1))
    tokens_st = tf.SparseTensor(
        tokens,
        tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape=(len(word_vocab) + NUM_OOV_BUCKETS,))
    tokens_st = tf.sparse.reorder(tokens_st)

    return BatchType(tokens_st, tags)

  return to_ids
def build_preprocess_fn(word_vocab, tag_vocab):

  @tf.function
  def preprocess_fn(dataset):
    to_ids = build_to_ids_fn(word_vocab, tag_vocab)
    # We *don't* shuffle in order to make this colab deterministic for
    # easier testing and reproducibility.
    # But real-world training should use `.shuffle()`.
    return dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  return preprocess_fn

مجموعة بيانات لعبة صغيرة

نقوم ببناء مجموعة بيانات لعبة صغيرة بمفردات عالمية من 12 كلمة و 3 عملاء. هذا المثال صغير مفيد لاختبار الحالات حافة (على سبيل المثال، لدينا اثنين من عملاء بأقل من MAX_TOKENS_SELECTED_PER_CLIENT = 6 الرموز متميزة، واحدة مع أكثر)، وتطوير التعليمات البرمجية.

ومع ذلك ، فإن حالات الاستخدام الواقعية لهذا النهج ستكون عبارة عن مفردات عالمية لعشرات الملايين أو أكثر ، وربما تظهر آلاف الرموز المميزة على كل عميل. لأن تنسيق البيانات هو نفسه، وتمديد لمشاكل اختبارات أكثر واقعية، على سبيل المثال tff.simulation.datasets.stackoverflow.load_data() مجموعة البيانات، يجب أن تكون واضحة.

أولاً ، نحدد كلمتنا ووسم المفردات.

# Features
FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi']
VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas']
FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon']
WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS

# Labels
TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']

الآن ، قمنا بإنشاء 3 عملاء بمجموعات بيانات محلية صغيرة. إذا كنت تقوم بتشغيل هذا البرنامج التعليمي في colab ، فقد يكون من المفيد استخدام ميزة "الخلية المتطابقة في علامة التبويب" لتثبيت هذه الخلية وإخراجها من أجل تفسير / التحقق من إخراج الوظائف المطورة أدناه.

preprocess_fn = build_preprocess_fn(WORD_VOCAB, TAG_VOCAB)


def make_dataset(raw):
  d = tf.data.Dataset.from_tensor_slices(
      # Matches the StackOverflow formatting
      collections.OrderedDict(
          tokens=tf.constant([t[0] for t in raw]),
          tags=tf.constant([t[1] for t in raw]),
          title=['' for _ in raw]))
  d = preprocess_fn(d)
  return d


# 4 distinct tokens
CLIENT1_DATASET = make_dataset([
    ('apple orange apple orange', 'FRUIT'),
    ('carrot trout', 'VEGETABLE|FISH'),
    ('orange apple', 'FRUIT'),
    ('orange', 'ORANGE|CITRUS')  # 2 OOV tag
])

# 6 distinct tokens
CLIENT2_DATASET = make_dataset([
    ('pear cod', 'FRUIT|FISH'),
    ('arugula peas', 'VEGETABLE'),
    ('kiwi pear', 'FRUIT'),
    ('sturgeon', 'FISH'),  # OOV word
    ('sturgeon bass', 'FISH')  # 2 OOV words
])

# A client with all possible words & tags (13 distinct tokens).
# With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model
# slices for all tokens that occur on this client.
CLIENT3_DATASET = make_dataset([
    (' '.join(WORD_VOCAB + ['oovword']), '|'.join(TAG_VOCAB)),
    # Mathe the OOV token and 'salmon' occur in the largest number
    # of examples on this client:
    ('salmon oovword', 'FISH|OOVTAG')
])

print('Word vocab')
for i, word in enumerate(WORD_VOCAB):
  print(f'{i:2d} {word}')

print('\nTag vocab')
for i, tag in enumerate(TAG_VOCAB):
  print(f'{i:2d} {tag}')
Word vocab
 0 apple
 1 orange
 2 pear
 3 kiwi
 4 carrot
 5 broccoli
 6 arugula
 7 peas
 8 trout
 9 tuna
10 cod
11 salmon

Tag vocab
 0 FRUIT
 1 VEGETABLE
 2 FISH

حدد ثوابت للأعداد الأولية لمعالم الإدخال (الرموز المميزة / الكلمات) والتسميات (علامات النشر). لدينا مساحات الإدخال / الإخراج الفعلية NUM_OOV_BUCKETS = 1 أكبر لأننا إضافة OOV رمز / علامة.

NUM_WORDS = len(WORD_VOCAB) 
NUM_TAGS = len(TAG_VOCAB)

WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS
TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS

أنشئ إصدارات مجمعة من مجموعات البيانات ، ودُفعات فردية ، والتي ستكون مفيدة في اختبار الكود أثناء ذهابنا.

batched_dataset1 = CLIENT1_DATASET.batch(2)
batched_dataset2 = CLIENT2_DATASET.batch(3)
batched_dataset3 = CLIENT3_DATASET.batch(2)

batch1 = next(iter(batched_dataset1))
batch2 = next(iter(batched_dataset2))
batch3 = next(iter(batched_dataset3))

حدد نموذجًا بمدخلات متفرقة

نستخدم نموذج انحدار لوجستي مستقل بسيط لكل علامة.

def create_logistic_model(word_vocab_size: int, vocab_tags_size: int):

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(word_vocab_size,), sparse=True),
      tf.keras.layers.Dense(
          vocab_tags_size,
          activation='sigmoid',
          kernel_initializer=tf.keras.initializers.zeros,
          # For simplicity, don't use a bias vector; this means the model
          # is a single tensor, and we only need sparse aggregation of
          # the per-token slices of the model. Generalizing to also handle
          # other model weights that are fully updated 
          # (non-dense broadcast and aggregate) would be a good exercise.
          use_bias=False),
  ])

  return model

دعنا نتأكد من أنها تعمل ، أولاً عن طريق عمل تنبؤات:

model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
p = model.predict(batch1.tokens)
print(p)
[[0.5 0.5 0.5 0.5]
 [0.5 0.5 0.5 0.5]]

وبعض التدريبات المركزية البسيطة:

model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001),
              loss=tf.keras.losses.BinaryCrossentropy())
model.train_on_batch(batch1.tokens, batch1.tags)

اللبنات الأساسية للحساب الموحد

سننفذ نسخة بسيطة من المتوسط الاتحادية خوارزمية مع فارق أساسي أن كل جهاز بتحميل مجموعة فرعية فقط ذات الصلة من هذا النموذج، ويساهم التحديثات فقط إلى أن فرعية.

نحن نستخدم M كما اختصار ل MAX_TOKENS_SELECTED_PER_CLIENT . على مستوى عالٍ ، تتضمن إحدى جولات التدريب الخطوات التالية:

  1. يقوم كل عميل مشارك بمسح مجموعة البيانات المحلية الخاصة به ، وتحليل سلاسل الإدخال وتعيينها إلى الرموز المميزة الصحيحة (فهارس int). هذا يتطلب الوصول إلى (كبير) القاموس العالمي (وهذا يحتمل أن تجنبها باستخدام ميزة التجزئة التقنيات). ثم نحسب عدد المرات التي تحدث فيها كل علامة. إذا U تحدث الرموز الفريدة على الجهاز، علينا أن نختار num_actual_tokens = min(U, M) أكثر الرموز المتكررة للقطار.

  2. عملاء استخدام federated_select لاسترداد معاملات نموذجا لل num_actual_tokens اختيار الرموز من الخادم. كل شريحة النموذج هو ممتد من شكل (TAG_VOCAB_SIZE, ) ، وبالتالي فإن مجموع البيانات المرسلة إلى العميل على الأكثر من حجم TAG_VOCAB_SIZE * M (انظر الملاحظة أدناه).

  3. العملاء بناء تعيين global_token -> local_token حيث الرمز المميز المحلي (مؤشر كثافة العمليات) هو مؤشر الرمز المميز عالميا في قائمة الرموز المختارة.

  4. عملاء استخدام نسخة "صغيرة" من النموذج العالمي الذي لديه معاملات فقط في معظم M الرموز، من مجموعة [0, num_actual_tokens) . و global -> local يستخدم لرسم الخرائط لتهيئة المعلمات كثيفة من هذا النموذج من شرائح نموذج المحدد.

  5. عملاء تدريب نموذج المحلي باستخدام SGD على البيانات preprocessed مع global -> local رسم الخرائط.

  6. عملاء تحويل المعلمات من نموذج المحلي إلى IndexedSlices التحديثات باستخدام local -> global لرسم الخرائط لمؤشر الصفوف. يقوم الخادم بتجميع هذه التحديثات باستخدام تجميع إجمالي ضئيل.

  7. يأخذ الخادم النتيجة (الكثيفة) للتجميع أعلاه ، ويقسمها على عدد العملاء المشاركين ، ويطبق متوسط ​​التحديث الناتج على النموذج العام.

في هذا القسم ونحن بناء اللبنات الأساسية لهذه الخطوات، والتي سوف تكون مجتمعة في نهائي federated_computation أن يلتقط منطق الكامل للجولة تدريبية واحدة.

عدد الرموز العميل والتي تقرر شرائح نموذج ل federated_select

يحتاج كل جهاز إلى تحديد "شرائح" النموذج ذات الصلة بمجموعة بيانات التدريب المحلية الخاصة به. بالنسبة لمشكلتنا ، نقوم بذلك عن طريق حساب عدد الأمثلة التي تحتوي على كل رمز مميز في مجموعة بيانات تدريب العميل (بشكل متقطع!).

@tf.function
def token_count_fn(token_counts, batch):
  """Adds counts from `batch` to the running `token_counts` sum."""
  # Sum across the batch dimension.
  flat_tokens = tf.sparse.reduce_sum(
      batch.tokens, axis=0, output_is_sparse=True)
  flat_tokens = tf.cast(flat_tokens, dtype=TOKEN_COUNT_DTYPE)
  return tf.sparse.add(token_counts, flat_tokens)
# Simple tests
# Create the initial zero token counts using empty tensors.
initial_token_counts = tf.SparseTensor(
    indices=tf.zeros(shape=(0, 1), dtype=TOKEN_DTYPE),
    values=tf.zeros(shape=(0,), dtype=TOKEN_COUNT_DTYPE),
    dense_shape=(WORD_VOCAB_SIZE,))

client_token_counts = batched_dataset1.reduce(initial_token_counts,
                                              token_count_fn)
tokens = tf.reshape(client_token_counts.indices, (-1,)).numpy()
print('tokens:', tokens)
np.testing.assert_array_equal(tokens, [0, 1, 4, 8])
# The count is the number of *examples* in which the token/word
# occurs, not the total number of occurences, since we still featurize
# multiple occurences in the same example as a "1".
counts = client_token_counts.values.numpy()
print('counts:', counts)
np.testing.assert_array_equal(counts, [2, 3, 1, 1])
tokens: [0 1 4 8]
counts: [2 3 1 1]

سوف نقوم بتحديد معالم النموذج المقابلة ل MAX_TOKENS_SELECTED_PER_CLIENT في معظم الأحيان التي تحدث الرموز على الجهاز. إذا أقل من هذا العدد الكبير من الرموز تحدث على الجهاز، نحن وحة القائمة لتمكين استخدام federated_select .

لاحظ أن الاستراتيجيات الأخرى ربما تكون أفضل ، على سبيل المثال ، الاختيار العشوائي للرموز (ربما بناءً على احتمالية حدوثها). من شأن ذلك أن يضمن أن جميع شرائح النموذج (التي يمتلك العميل بيانات عنها) لديها بعض فرص التحديث.

@tf.function
def keys_for_client(client_dataset, max_tokens_per_client):
  """Computes a set of max_tokens_per_client keys."""
  initial_token_counts = tf.SparseTensor(
      indices=tf.zeros((0, 1), dtype=TOKEN_DTYPE),
      values=tf.zeros((0,), dtype=TOKEN_COUNT_DTYPE),
      dense_shape=(WORD_VOCAB_SIZE,))
  client_token_counts = client_dataset.reduce(initial_token_counts,
                                              token_count_fn)
  # Find the most-frequently occuring tokens
  tokens = tf.reshape(client_token_counts.indices, shape=(-1,))
  counts = client_token_counts.values
  perm = tf.argsort(counts, direction='DESCENDING')
  tokens = tf.gather(tokens, perm)
  counts = tf.gather(counts, perm)
  num_raw_tokens = tf.shape(tokens)[0]
  actual_num_tokens = tf.minimum(max_tokens_per_client, num_raw_tokens)
  selected_tokens = tokens[:actual_num_tokens]
  paddings = [[0, max_tokens_per_client - tf.shape(selected_tokens)[0]]]
  padded_tokens = tf.pad(selected_tokens, paddings=paddings)
  # Make sure the type is statically determined
  padded_tokens = tf.reshape(padded_tokens, shape=(max_tokens_per_client,))

  # We will pass these tokens as keys into `federated_select`, which
  # requires SELECT_KEY_DTYPE=tf.int32 keys.
  padded_tokens = tf.cast(padded_tokens, dtype=SELECT_KEY_DTYPE)
  return padded_tokens, actual_num_tokens
# Simple test

# Case 1: actual_num_tokens > max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 3)
assert tf.size(selected_tokens) == 3
assert actual_num_tokens == 3

# Case 2: actual_num_tokens < max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 10)
assert tf.size(selected_tokens) == 10
assert actual_num_tokens == 4

تعيين الرموز العالمية إلى الرموز المحلية

اختيار المذكورة أعلاه تعطينا مجموعة كثيفة من الرموز في نطاق [0, actual_num_tokens) التي سوف نستخدم لنموذج على الجهاز. ومع ذلك، فإن بيانات نقرأ له الرموز من مجموعة أكبر من ذلك بكثير العالمي المفردات [0, WORD_VOCAB_SIZE) .

وبالتالي ، نحتاج إلى تعيين الرموز المميزة العالمية إلى الرموز المميزة المحلية المقابلة لها. يتم إعطاء معرفات رمزية المحلية ببساطة عن طريق المؤشرات في selected_tokens موتر حسابها في الخطوة السابقة.

@tf.function
def map_to_local_token_ids(client_data, client_keys):
  global_to_local = tf.lookup.StaticHashTable(
      # Note int32 -> int64 maps are not supported
      tf.lookup.KeyValueTensorInitializer(
          keys=tf.cast(client_keys, dtype=TOKEN_DTYPE),
          # Note we need to use tf.shape, not the static 
          # shape client_keys.shape[0]
          values=tf.range(0, limit=tf.shape(client_keys)[0],
                          dtype=TOKEN_DTYPE)),
      # We use -1 for tokens that were not selected, which can occur for clients
      # with more than MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens.
      # We will simply remove these invalid indices from the batch below.
      default_value=-1)

  def to_local_ids(sparse_tokens):
    indices_t = tf.transpose(sparse_tokens.indices)
    batch_indices = indices_t[0]  # First column
    tokens = indices_t[1]  # Second column
    tokens = tf.map_fn(
        lambda global_token_id: global_to_local.lookup(global_token_id), tokens)
    # Remove tokens that aren't actually available (looked up as -1):
    available_tokens = tokens >= 0
    tokens = tokens[available_tokens]
    batch_indices = batch_indices[available_tokens]

    updated_indices = tf.transpose(
        tf.concat([[batch_indices], [tokens]], axis=0))
    st = tf.sparse.SparseTensor(
        updated_indices,
        tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape=sparse_tokens.dense_shape)
    st = tf.sparse.reorder(st)
    return st

  return client_data.map(lambda b: BatchType(to_local_ids(b.tokens), b.tags))
# Simple test
client_keys, actual_num_tokens = keys_for_client(
    batched_dataset3, MAX_TOKENS_SELECTED_PER_CLIENT)
client_keys = client_keys[:actual_num_tokens]

d = map_to_local_token_ids(batched_dataset3, client_keys)
batch  = next(iter(d))
all_tokens = tf.gather(batch.tokens.indices, indices=1, axis=1)
# Confirm we have local indices in the range [0, MAX):
assert tf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT
assert tf.math.reduce_max(all_tokens) >= 0

تدريب النموذج المحلي (الفرعي) على كل عميل

ملاحظة federated_select سيعود الشرائح المحددة باعتبارها tf.data.Dataset في نفس الترتيب كما مفاتيح الاختيار. لذلك ، نحدد أولاً وظيفة الأداة المساعدة لأخذ مجموعة البيانات هذه وتحويلها إلى موتر واحد كثيف يمكن استخدامه كأوزان نموذجية لنموذج العميل.

@tf.function
def slices_dataset_to_tensor(slices_dataset):
  """Convert a dataset of slices to a tensor."""
  # Use batching to gather all of the slices into a single tensor.
  d = slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT,
                           drop_remainder=False)
  iter_d = iter(d)
  tensor = next(iter_d)
  # Make sure we have consumed everything
  opt = iter_d.get_next_as_optional()
  tf.Assert(tf.logical_not(opt.has_value()), data=[''], name='CHECK_EMPTY')
  return tensor
# Simple test
weights = np.random.random(
    size=(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE)).astype(np.float32)
model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(weights)
weights2 = slices_dataset_to_tensor(model_slices_as_dataset)
np.testing.assert_array_equal(weights, weights2)

لدينا الآن جميع المكونات التي نحتاجها لتحديد حلقة تدريب محلية بسيطة تعمل على كل عميل.

@tf.function
def client_train_fn(model, client_optimizer,
                    model_slices_as_dataset, client_data,
                    client_keys, actual_num_tokens):

  initial_model_weights = slices_dataset_to_tensor(model_slices_as_dataset)
  assert len(model.trainable_variables) == 1
  model.trainable_variables[0].assign(initial_model_weights)

  # Only keep the "real" (unpadded) keys.
  client_keys = client_keys[:actual_num_tokens]

  client_data = map_to_local_token_ids(client_data, client_keys)

  loss_fn = tf.keras.losses.BinaryCrossentropy()
  for features, labels in client_data:
    with tf.GradientTape() as tape:
      predictions = model(features)
      loss = loss_fn(labels, predictions)
    grads = tape.gradient(loss, model.trainable_variables)
    client_optimizer.apply_gradients(zip(grads, model.trainable_variables))

  model_weights_delta = model.trainable_weights[0] - initial_model_weights
  model_weights_delta = tf.slice(model_weights_delta, begin=[0, 0], 
                           size=[actual_num_tokens, -1])
  return client_keys, model_weights_delta
# Simple test
# Note if you execute this cell a second time, you need to also re-execute
# the preceeding cell to avoid "tf.function-decorated function tried to 
# create variables on non-first call" errors.
on_device_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                        TAG_VOCAB_SIZE)
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
client_keys, actual_num_tokens = keys_for_client(
    batched_dataset2, MAX_TOKENS_SELECTED_PER_CLIENT)

model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(
    np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE),
             dtype=np.float32))

keys, delta = client_train_fn(
    on_device_model,
    client_optimizer,
    model_slices_as_dataset,
    client_data=batched_dataset3,
    client_keys=client_keys,
    actual_num_tokens=actual_num_tokens)

print(delta)

الشرائح المفهرسة المجمعة

نحن نستخدم tff.federated_aggregate لبناء مبلغ متفرق الاتحادية لل IndexedSlices . هذا التطبيق البسيط لديه القيد أن dense_shape هو معروف بشكل ثابت مقدما. لاحظ أيضا أن هذا المبلغ هو فقط شبه متناثر، بمعنى أن العميل -> اتصال الملقم هو متفرق، ولكن يحافظ على الخادم تمثيل كثيفة من المبلغ في accumulate و merge ، وإخراج هذا التمثيل كثيفة.

def federated_indexed_slices_sum(slice_indices, slice_values, dense_shape):
  """
  Sumes IndexedSlices@CLIENTS to a dense @SERVER Tensor.

  Intermediate aggregation is performed by converting to a dense representation,
  which may not be suitable for all applications.

  Args:
    slice_indices: An IndexedSlices.indices tensor @CLIENTS.
    slice_values: An IndexedSlices.values tensor @CLIENTS.
    dense_shape: A statically known dense shape.

  Returns:
    A dense tensor placed @SERVER representing the sum of the client's
    IndexedSclies.
  """
  slices_dtype = slice_values.type_signature.member.dtype
  zero = tff.tf_computation(
      lambda: tf.zeros(dense_shape, dtype=slices_dtype))()

  @tf.function
  def accumulate_slices(dense, client_value):
    indices, slices = client_value
    # There is no built-in way to add `IndexedSlices`, but 
    # tf.convert_to_tensor is a quick way to convert to a dense representation
    # so we can add them.
    return dense + tf.convert_to_tensor(
        tf.IndexedSlices(slices, indices, dense_shape))


  return tff.federated_aggregate(
      (slice_indices, slice_values),
      zero=zero,
      accumulate=tff.tf_computation(accumulate_slices),
      merge=tff.tf_computation(lambda d1, d2: tf.add(d1, d2, name='merge')),
      report=tff.tf_computation(lambda d: d))

بناء الحد الأدنى federated_computation كاختبار

dense_shape = (6, 2)
indices_type = tff.TensorType(tf.int64, (None,))
values_type = tff.TensorType(tf.float32, (None, 2))
client_slice_type = tff.type_at_clients(
    (indices_type, values_type))

@tff.federated_computation(client_slice_type)
def test_sum_indexed_slices(indices_values_at_client):
  indices, values = indices_values_at_client
  return federated_indexed_slices_sum(indices, values, dense_shape)

print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x = tf.IndexedSlices(
    values=np.array([[2., 2.1], [0., 0.1], [1., 1.1], [5., 5.1]],
                    dtype=np.float32),
    indices=[2, 0, 1, 5],
    dense_shape=dense_shape)
y = tf.IndexedSlices(
    values=np.array([[0., 0.3], [3.1, 3.2]], dtype=np.float32),
    indices=[1, 3],
    dense_shape=dense_shape)

# Sum one.
result = test_sum_indexed_slices([(x.indices, x.values)])
np.testing.assert_array_equal(tf.convert_to_tensor(x), result)

# Sum two.
expected = [[0., 0.1], [1., 1.4], [2., 2.1], [3.1, 3.2], [0., 0.], [5., 5.1]]
result = test_sum_indexed_slices([(x.indices, x.values), (y.indices, y.values)])
np.testing.assert_array_almost_equal(expected, result)

وضع كل ذلك معا في federated_computation

نحن الآن يستخدم TFF لربط معا المكونات في tff.federated_computation .

DENSE_MODEL_SHAPE = (WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
client_data_type = tff.SequenceType(batched_dataset1.element_spec)
model_type = tff.TensorType(tf.float32, shape=DENSE_MODEL_SHAPE)

نحن نستخدم وظيفة تدريب أساسية للخادم تعتمد على المتوسطات الموحدة ، ونطبق التحديث بمعدل تعلم خادم 1.0. من المهم أن نطبق تحديثًا (دلتا) على النموذج ، بدلاً من مجرد حساب متوسط ​​النماذج التي يوفرها العميل ، وإلا إذا لم يتم تدريب شريحة معينة من النموذج من قبل أي عميل في جولة معينة ، يمكن أن تكون معاملاتها صفرية خارج.

@tff.tf_computation
def server_update(current_model_weights, update_sum, num_clients):
  average_update = update_sum / num_clients
  return current_model_weights + average_update

نحن بحاجة إلى أكثر زوجين tff.tf_computation عناصر هي:

# Function to select slices from the model weights in federated_select:
select_fn = tff.tf_computation(
    lambda model_weights, index: tf.gather(model_weights, index))


# We need to wrap `client_train_fn` as a `tff.tf_computation`, making
# sure we do any operations that might construct `tf.Variable`s outside
# of the `tf.function` we are wrapping.
@tff.tf_computation
def client_train_fn_tff(model_slices_as_dataset, client_data, client_keys,
                        actual_num_tokens):
  # Note this is amaller than the global model, using
  # MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE.
  # W7e would like a model of size `actual_num_tokens`, but we
  # can't build the model dynamically, so we will slice off the padded
  # weights at the end.
  client_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                       TAG_VOCAB_SIZE)
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
  return client_train_fn(client_model, client_optimizer,
                         model_slices_as_dataset, client_data, client_keys,
                         actual_num_tokens)

@tff.tf_computation
def keys_for_client_tff(client_data):
  return keys_for_client(client_data, MAX_TOKENS_SELECTED_PER_CLIENT)

نحن الآن جاهزون لتجميع كل القطع معًا!

@tff.federated_computation(
    tff.type_at_server(model_type), tff.type_at_clients(client_data_type))
def sparse_model_update(server_model, client_data):
  max_tokens = tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT, tff.SERVER)
  keys_at_clients, actual_num_tokens = tff.federated_map(
      keys_for_client_tff, client_data)

  model_slices = tff.federated_select(keys_at_clients, max_tokens, server_model,
                                      select_fn)

  update_keys, update_slices = tff.federated_map(
      client_train_fn_tff,
      (model_slices, client_data, keys_at_clients, actual_num_tokens))

  dense_update_sum = federated_indexed_slices_sum(update_keys, update_slices,
                                                  DENSE_MODEL_SHAPE)
  num_clients = tff.federated_sum(tff.federated_value(1.0, tff.CLIENTS))

  updated_server_model = tff.federated_map(
      server_update, (server_model, dense_update_sum, num_clients))

  return updated_server_model


print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)

دعونا ندرب نموذجا!

الآن بعد أن أصبح لدينا وظيفة التدريب الخاصة بنا ، فلنجربها.

server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
server_model.compile(  # Compile to make evaluation easy.
    optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0),  # Unused
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[ 
      tf.keras.metrics.Precision(name='precision'),
      tf.keras.metrics.AUC(name='auc'),
      tf.keras.metrics.Recall(top_k=2, name='recall_at_2'),
  ])

def evaluate(model, dataset, name):
  metrics = model.evaluate(dataset, verbose=0)
  metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in 
                          (zip(server_model.metrics_names, metrics))])
  print(f'{name}: {metrics_str}')
print('Before training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')

model_weights = server_model.trainable_weights[0]

client_datasets = [batched_dataset1, batched_dataset2, batched_dataset3]
for _ in range(10):  # Run 10 rounds of FedAvg
  # We train on 1, 2, or 3 clients per round, selecting
  # randomly.
  cohort_size = np.random.randint(1, 4)
  clients = np.random.choice([0, 1, 2], cohort_size, replace=False)
  print('Training on clients', clients)
  model_weights = sparse_model_update(
      model_weights, [client_datasets[i] for i in clients])
server_model.set_weights([model_weights])

print('After training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')
Before training
Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60
Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50
Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40
Training on clients [0 1]
Training on clients [0 2 1]
Training on clients [2 0]
Training on clients [1 0 2]
Training on clients [2]
Training on clients [2 0]
Training on clients [1 2 0]
Training on clients [0]
Training on clients [2]
Training on clients [1 2]
After training
Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80
Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00
Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80