การเรียนรู้แบบรวมกลุ่มแบบจำลองขนาดใหญ่ที่มีประสิทธิภาพของไคลเอ็นต์ผ่านการรวมแบบ federated_select และแบบกระจัดกระจาย

กวดวิชานี้แสดงให้เห็นว่าฉิบหายสามารถนำมาใช้ในการฝึกอบรมรุ่นที่มีขนาดใหญ่มากอุปกรณ์ลูกค้าแต่ละเพียงดาวน์โหลดและการปรับปรุงส่วนเล็ก ๆ ของรูปแบบการใช้ tff.federated_select และการรวมเบาบาง ในขณะที่การกวดวิชานี้เป็นธรรมที่ตนเองมีที่ tff.federated_select กวดวิชา และ กำหนดเอง FL ขั้นตอนวิธีการกวดวิชา ให้การแนะนำที่ดีในบางส่วนของเทคนิคที่ใช้ที่นี่

อย่างเป็นรูปธรรม ในบทช่วยสอนนี้ เราจะพิจารณาการถดถอยโลจิสติกสำหรับการจัดประเภทหลายป้ายกำกับ โดยคาดการณ์ว่า "แท็ก" ใดที่เกี่ยวข้องกับสตริงข้อความโดยพิจารณาจากการแสดงคุณลักษณะแบบทีละคำ ที่สำคัญการสื่อสารและการคำนวณฝั่งไคลเอ็นต์ค่าใช้จ่ายจะถูกควบคุมโดยคงที่คงที่ ( MAX_TOKENS_SELECTED_PER_CLIENT ) และไม่ได้ขนาดที่มีขนาดคำศัพท์โดยรวมซึ่งอาจจะมีขนาดใหญ่มากในการตั้งค่าการปฏิบัติ

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio
.apply()
import collections
import itertools
import numpy as np

from typing import Callable, List, Tuple

import tensorflow as tf
import tensorflow_federated as tff
tff
.backends.native.set_local_python_execution_context()

ลูกค้าแต่ละคนจะ federated_select แถวของน้ำหนักรุ่นสำหรับราชสกุลที่ไม่ซ้ำกันมากที่สุดนี้มาก นี้บนขอบเขตขนาดของรูปแบบท้องถิ่นของลูกค้าและปริมาณของเซิร์ฟเวอร์ -> ลูกค้า ( federated_select ) และลูกค้า -> เซิร์ฟเวอร์ (federated_aggregate ) การสื่อสารดำเนินการ

บทช่วยสอนนี้ควรยังคงทำงานอย่างถูกต้องแม้ว่าคุณจะตั้งค่านี้ให้มีขนาดเล็กเพียง 1 (ตรวจสอบให้แน่ใจว่าไม่ได้เลือกโทเค็นทั้งหมดจากไคลเอนต์แต่ละราย) หรือเป็นค่าที่มาก แม้ว่าการบรรจบกันของโมเดลอาจได้รับผลกระทบ

MAX_TOKENS_SELECTED_PER_CLIENT = 6

นอกจากนี้เรายังกำหนดค่าคงที่สองสามชนิดสำหรับประเภทต่างๆ สำหรับ Colab นี้โทเค็นเป็นตัวระบุจำนวนเต็มสำหรับคำเฉพาะหลังจากแยกชุดข้อมูล

# There are some constraints on types
# here that will require some explicit type conversions:
#    - `tff.federated_select` requires int32
#    - `tf.SparseTensor` requires int64 indices.
TOKEN_DTYPE
= tf.int64
SELECT_KEY_DTYPE
= tf.int32

# Type for counts of token occurences.
TOKEN_COUNT_DTYPE
= tf.int32

# A sparse feature vector can be thought of as a map
# from TOKEN_DTYPE to FEATURE_DTYPE.
# Our features are {0, 1} indicators, so we could potentially
# use tf.int8 as an optimization.
FEATURE_DTYPE
= tf.int32

กำลังตั้งค่าปัญหา: ชุดข้อมูลและรุ่น

เราสร้างชุดข้อมูลของเล่นขนาดเล็กสำหรับการทดลองอย่างง่ายในบทช่วยสอนนี้ อย่างไรก็ตามรูปแบบของชุดข้อมูลที่เข้ากันได้กับ สหพันธ์ StackOverflow และ ก่อนการประมวลผล และ สถาปัตยกรรมรูปแบบ ที่เป็นที่ยอมรับจากปัญหาแท็กคำทำนายของ StackOverflow Adaptive สหพันธ์การเพิ่มประสิทธิภาพ

การแยกวิเคราะห์ชุดข้อมูลและการประมวลผลล่วงหน้า

NUM_OOV_BUCKETS = 1

BatchType = collections.namedtuple('BatchType', ['tokens', 'tags'])

def build_to_ids_fn(word_vocab: List[str],
                    tag_vocab
: List[str]) -> Callable[[tf.Tensor], tf.Tensor]:
 
"""Constructs a function mapping examples to sequences of token indices."""
  word_table_values
= np.arange(len(word_vocab), dtype=np.int64)
  word_table
= tf.lookup.StaticVocabularyTable(
      tf
.lookup.KeyValueTensorInitializer(word_vocab, word_table_values),
      num_oov_buckets
=NUM_OOV_BUCKETS)

  tag_table_values
= np.arange(len(tag_vocab), dtype=np.int64)
  tag_table
= tf.lookup.StaticVocabularyTable(
      tf
.lookup.KeyValueTensorInitializer(tag_vocab, tag_table_values),
      num_oov_buckets
=NUM_OOV_BUCKETS)

 
def to_ids(example):
   
"""Converts a Stack Overflow example to a bag-of-words/tags format."""
    sentence
= tf.strings.join([example['tokens'], example['title']],
                               separator
=' ')

   
# We represent that label (output tags) densely.
    raw_tags
= example['tags']
    tags
= tf.strings.split(raw_tags, sep='|')
    tags
= tag_table.lookup(tags)
    tags
, _ = tf.unique(tags)
    tags
= tf.one_hot(tags, len(tag_vocab) + NUM_OOV_BUCKETS)
    tags
= tf.reduce_max(tags, axis=0)

   
# We represent the features as a SparseTensor of {0, 1}s.
    words
= tf.strings.split(sentence)
    tokens
= word_table.lookup(words)
    tokens
, _ = tf.unique(tokens)
   
# Note:  We could choose to use the word counts as the feature vector
   
# instead of just {0, 1} values (see tf.unique_with_counts).
    tokens
= tf.reshape(tokens, shape=(tf.size(tokens), 1))
    tokens_st
= tf.SparseTensor(
        tokens
,
        tf
.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape
=(len(word_vocab) + NUM_OOV_BUCKETS,))
    tokens_st
= tf.sparse.reorder(tokens_st)

   
return BatchType(tokens_st, tags)

 
return to_ids
def build_preprocess_fn(word_vocab, tag_vocab):

 
@tf.function
 
def preprocess_fn(dataset):
    to_ids
= build_to_ids_fn(word_vocab, tag_vocab)
   
# We *don't* shuffle in order to make this colab deterministic for
   
# easier testing and reproducibility.
   
# But real-world training should use `.shuffle()`.
   
return dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE)

 
return preprocess_fn

ชุดข้อมูลของเล่นจิ๋ว

เราสร้างชุดข้อมูลของเล่นขนาดเล็กที่มีคำศัพท์สากล 12 คำและลูกค้า 3 ราย ตัวอย่างเล็ก ๆ นี้จะเป็นประโยชน์สำหรับการทดสอบกรณีขอบ (ตัวอย่างเช่นเรามีสองลูกค้าที่มีน้อยกว่า MAX_TOKENS_SELECTED_PER_CLIENT = 6 ราชสกุลที่แตกต่างกันและหนึ่งที่มีมากขึ้น) และการพัฒนารหัส

อย่างไรก็ตาม กรณีการใช้งานจริงของแนวทางนี้อาจเป็นคำศัพท์ทั่วโลกตั้งแต่ 10 ล้านคำขึ้นไป โดยอาจมีโทเค็นที่แตกต่างกันถึง 1,000 รายการปรากฏบนไคลเอนต์แต่ละราย เพราะรูปแบบของข้อมูลที่จะเหมือนกันขยายการแก้ปัญหา testbed สมจริงมากขึ้นเช่น tff.simulation.datasets.stackoverflow.load_data() ชุดข้อมูลที่ควรจะตรงไปตรงมา

อันดับแรก เรากำหนดคำศัพท์และแท็กคำศัพท์ของเรา

# Features
FRUIT_WORDS
= ['apple', 'orange', 'pear', 'kiwi']
VEGETABLE_WORDS
= ['carrot', 'broccoli', 'arugula', 'peas']
FISH_WORDS
= ['trout', 'tuna', 'cod', 'salmon']
WORD_VOCAB
= FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS

# Labels
TAG_VOCAB
= ['FRUIT', 'VEGETABLE', 'FISH']

ตอนนี้ เราสร้างลูกค้า 3 รายด้วยชุดข้อมูลในเครื่องขนาดเล็ก หากคุณกำลังใช้งานบทช่วยสอนนี้ใน colab การใช้คุณลักษณะ "มิเรอร์เซลล์ในแท็บ" เพื่อตรึงเซลล์นี้และเอาต์พุตเพื่อตีความ/ตรวจสอบเอาต์พุตของฟังก์ชันที่พัฒนาขึ้นด้านล่างอาจเป็นประโยชน์

preprocess_fn = build_preprocess_fn(WORD_VOCAB, TAG_VOCAB)


def make_dataset(raw):
  d
= tf.data.Dataset.from_tensor_slices(
     
# Matches the StackOverflow formatting
      collections
.OrderedDict(
          tokens
=tf.constant([t[0] for t in raw]),
          tags
=tf.constant([t[1] for t in raw]),
          title
=['' for _ in raw]))
  d
= preprocess_fn(d)
 
return d


# 4 distinct tokens
CLIENT1_DATASET
= make_dataset([
   
('apple orange apple orange', 'FRUIT'),
   
('carrot trout', 'VEGETABLE|FISH'),
   
('orange apple', 'FRUIT'),
   
('orange', 'ORANGE|CITRUS')  # 2 OOV tag
])

# 6 distinct tokens
CLIENT2_DATASET
= make_dataset([
   
('pear cod', 'FRUIT|FISH'),
   
('arugula peas', 'VEGETABLE'),
   
('kiwi pear', 'FRUIT'),
   
('sturgeon', 'FISH'),  # OOV word
   
('sturgeon bass', 'FISH')  # 2 OOV words
])

# A client with all possible words & tags (13 distinct tokens).
# With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model
# slices for all tokens that occur on this client.
CLIENT3_DATASET
= make_dataset([
   
(' '.join(WORD_VOCAB + ['oovword']), '|'.join(TAG_VOCAB)),
   
# Mathe the OOV token and 'salmon' occur in the largest number
   
# of examples on this client:
   
('salmon oovword', 'FISH|OOVTAG')
])

print('Word vocab')
for i, word in enumerate(WORD_VOCAB):
 
print(f'{i:2d} {word}')

print('\nTag vocab')
for i, tag in enumerate(TAG_VOCAB):
 
print(f'{i:2d} {tag}')
Word vocab
 0 apple
 1 orange
 2 pear
 3 kiwi
 4 carrot
 5 broccoli
 6 arugula
 7 peas
 8 trout
 9 tuna
10 cod
11 salmon

Tag vocab
 0 FRUIT
 1 VEGETABLE
 2 FISH

กำหนดค่าคงที่สำหรับจำนวนข้อมูลดิบของคุณสมบัติอินพุต (โทเค็น/คำ) และป้ายกำกับ (แท็กโพสต์) พื้นที่อินพุต / เอาต์พุตที่เกิดขึ้นจริงของเรามี NUM_OOV_BUCKETS = 1 ขนาดใหญ่เพราะเราเพิ่ม OOV โทเค็น / แท็ก

NUM_WORDS = len(WORD_VOCAB) 
NUM_TAGS
= len(TAG_VOCAB)

WORD_VOCAB_SIZE
= NUM_WORDS + NUM_OOV_BUCKETS
TAG_VOCAB_SIZE
= NUM_TAGS + NUM_OOV_BUCKETS

สร้างเวอร์ชันแบตช์ของชุดข้อมูล และแต่ละแบตช์ ซึ่งจะเป็นประโยชน์ในการทดสอบโค้ดเมื่อเราดำเนินการ

batched_dataset1 = CLIENT1_DATASET.batch(2)
batched_dataset2
= CLIENT2_DATASET.batch(3)
batched_dataset3
= CLIENT3_DATASET.batch(2)

batch1
= next(iter(batched_dataset1))
batch2
= next(iter(batched_dataset2))
batch3
= next(iter(batched_dataset3))

กำหนดแบบจำลองที่มีอินพุตเบาบาง

เราใช้แบบจำลองการถดถอยโลจิสติกอิสระอย่างง่ายสำหรับแต่ละแท็ก

def create_logistic_model(word_vocab_size: int, vocab_tags_size: int):

  model
= tf.keras.models.Sequential([
      tf
.keras.layers.InputLayer(input_shape=(word_vocab_size,), sparse=True),
      tf
.keras.layers.Dense(
          vocab_tags_size
,
          activation
='sigmoid',
          kernel_initializer
=tf.keras.initializers.zeros,
         
# For simplicity, don't use a bias vector; this means the model
         
# is a single tensor, and we only need sparse aggregation of
         
# the per-token slices of the model. Generalizing to also handle
         
# other model weights that are fully updated
         
# (non-dense broadcast and aggregate) would be a good exercise.
          use_bias
=False),
 
])

 
return model

ตรวจสอบให้แน่ใจก่อนว่าได้ผล ขั้นแรก โดยการทำนาย:

model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
p
= model.predict(batch1.tokens)
print(p)
[[0.5 0.5 0.5 0.5]
 [0.5 0.5 0.5 0.5]]

และการฝึกอบรมแบบรวมศูนย์ง่ายๆ บางอย่าง:

model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001),
              loss
=tf.keras.losses.BinaryCrossentropy())
model
.train_on_batch(batch1.tokens, batch1.tags)

การสร้างบล็อคสำหรับการคำนวณแบบรวมศูนย์

เราจะดำเนินการรุ่นที่เรียบง่ายของ Averaging สหพันธ์ อัลกอริทึมที่มีความแตกต่างที่สำคัญที่อุปกรณ์แต่ละเพียงดาวน์โหลดชุดย่อยที่เกี่ยวข้องของรูปแบบและมีเพียงส่วนช่วยในการปรับปรุงส่วนย่อยที่

เราใช้ M เป็นชวเลข MAX_TOKENS_SELECTED_PER_CLIENT ในระดับสูง การฝึกอบรมหนึ่งรอบเกี่ยวข้องกับขั้นตอนเหล่านี้:

  1. ลูกค้าที่เข้าร่วมแต่ละรายจะสแกนชุดข้อมูลในเครื่อง แยกวิเคราะห์สตริงอินพุตและจับคู่กับโทเค็นที่ถูกต้อง (ดัชนี int) เรื่องนี้ต้องมีการเข้าถึงทั่วโลก (ขนาดใหญ่) พจนานุกรม (นี่อาจจะหลีกเลี่ยงการใช้ คุณลักษณะคร่ำเครียด เทคนิค) จากนั้นเราจะนับเบา ๆ ว่าแต่ละโทเค็นเกิดขึ้นกี่ครั้ง ถ้า U ราชสกุลที่ไม่ซ้ำกันเกิดขึ้นบนอุปกรณ์เราเลือก num_actual_tokens = min(U, M) ราชสกุลที่พบบ่อยที่สุดในการฝึกอบรม

  2. ลูกค้าใช้ federated_select เพื่อดึงค่าสัมประสิทธิ์รุ่นสำหรับ num_actual_tokens เลือกราชสกุลจากเซิร์ฟเวอร์ แต่ละรุ่นฝานเป็นเมตริกซ์ของรูปร่าง (TAG_VOCAB_SIZE, ) เพื่อให้ข้อมูลทั้งหมดส่งไปยังลูกค้าที่ส่วนใหญ่ของขนาด TAG_VOCAB_SIZE * M (ดูหมายเหตุด้านล่าง)

  3. ลูกค้าสร้างการทำแผนที่ global_token -> local_token ที่โทเค็นท้องถิ่น (ดัชนี int) เป็นดัชนีของโทเค็นระดับโลกในรายการของสัญญาณที่เลือก

  4. ลูกค้าใช้ "เล็ก" รุ่นของรูปแบบระดับโลกที่มีเพียงค่าสัมประสิทธิ์สำหรับที่มากที่สุด M ราชสกุลจากช่วง [0, num_actual_tokens) global -> local การทำแผนที่จะใช้ในการเริ่มต้นพารามิเตอร์หนาแน่นของรุ่นนี้จากชิ้นรูปแบบที่เลือก

  5. ลูกค้าในการฝึกอบรมรุ่นท้องถิ่นของตนโดยใช้ SGD กับข้อมูล preprocessed กับ global -> local การทำแผนที่

  6. ลูกค้าหันพารามิเตอร์ของแบบจำลองในท้องถิ่นของตนลงใน IndexedSlices การปรับปรุงโดยใช้ local -> global ทำแผนที่เพื่อดัชนีแถว เซิร์ฟเวอร์รวมการอัปเดตเหล่านี้โดยใช้การรวมแบบกระจาย

  7. เซิร์ฟเวอร์ใช้ผลลัพธ์ (หนาแน่น) ของการรวมข้างต้น หารด้วยจำนวนไคลเอนต์ที่เข้าร่วม และใช้การอัปเดตโดยเฉลี่ยที่เป็นผลลัพธ์กับโมเดลส่วนกลาง

ในส่วนนี้เราสร้างอาคารตึกสำหรับขั้นตอนเหล่านี้ซึ่งจะนำมารวมกันในรอบสุดท้าย federated_computation ที่จับตรรกะเต็มรูปแบบของการฝึกอบรมรอบหนึ่ง

นับราชสกุลลูกค้าและตัดสินใจที่ชิ้นรูปแบบการ federated_select

อุปกรณ์แต่ละเครื่องจำเป็นต้องตัดสินใจว่า "ชิ้นส่วน" ของรุ่นใดที่เกี่ยวข้องกับชุดข้อมูลการฝึกอบรมในพื้นที่ สำหรับปัญหาของเรา เราทำสิ่งนี้โดย (อย่างถี่ถ้วน!) นับจำนวนตัวอย่างที่มีโทเค็นแต่ละตัวในชุดข้อมูลการฝึกลูกค้า

@tf.function
def token_count_fn(token_counts, batch):
 
"""Adds counts from `batch` to the running `token_counts` sum."""
 
# Sum across the batch dimension.
  flat_tokens
= tf.sparse.reduce_sum(
      batch
.tokens, axis=0, output_is_sparse=True)
  flat_tokens
= tf.cast(flat_tokens, dtype=TOKEN_COUNT_DTYPE)
 
return tf.sparse.add(token_counts, flat_tokens)
# Simple tests
# Create the initial zero token counts using empty tensors.
initial_token_counts
= tf.SparseTensor(
    indices
=tf.zeros(shape=(0, 1), dtype=TOKEN_DTYPE),
    values
=tf.zeros(shape=(0,), dtype=TOKEN_COUNT_DTYPE),
    dense_shape
=(WORD_VOCAB_SIZE,))

client_token_counts
= batched_dataset1.reduce(initial_token_counts,
                                              token_count_fn
)
tokens
= tf.reshape(client_token_counts.indices, (-1,)).numpy()
print('tokens:', tokens)
np
.testing.assert_array_equal(tokens, [0, 1, 4, 8])
# The count is the number of *examples* in which the token/word
# occurs, not the total number of occurences, since we still featurize
# multiple occurences in the same example as a "1".
counts
= client_token_counts.values.numpy()
print('counts:', counts)
np
.testing.assert_array_equal(counts, [2, 3, 1, 1])
tokens: [0 1 4 8]
counts: [2 3 1 1]

เราจะเลือกพารามิเตอร์แบบที่สอดคล้องกับ MAX_TOKENS_SELECTED_PER_CLIENT บ่อยที่สุดเกิดขึ้นราชสกุลบนอุปกรณ์ ถ้าน้อยกว่าโทเค็นมากนี้เกิดขึ้นบนอุปกรณ์ของเราแผ่นรายการที่จะช่วยให้การใช้งานของ federated_select

โปรดทราบว่ากลยุทธ์อื่นๆ อาจดีกว่า เช่น สุ่มเลือกโทเค็น (อาจขึ้นอยู่กับความน่าจะเป็นที่จะเกิดขึ้น) ซึ่งจะทำให้มั่นใจได้ว่าชิ้นส่วนทั้งหมดของโมเดล (ซึ่งไคลเอ็นต์มีข้อมูล) มีโอกาสได้รับการอัปเดต

@tf.function
def keys_for_client(client_dataset, max_tokens_per_client):
 
"""Computes a set of max_tokens_per_client keys."""
  initial_token_counts
= tf.SparseTensor(
      indices
=tf.zeros((0, 1), dtype=TOKEN_DTYPE),
      values
=tf.zeros((0,), dtype=TOKEN_COUNT_DTYPE),
      dense_shape
=(WORD_VOCAB_SIZE,))
  client_token_counts
= client_dataset.reduce(initial_token_counts,
                                              token_count_fn
)
 
# Find the most-frequently occuring tokens
  tokens
= tf.reshape(client_token_counts.indices, shape=(-1,))
  counts
= client_token_counts.values
  perm
= tf.argsort(counts, direction='DESCENDING')
  tokens
= tf.gather(tokens, perm)
  counts
= tf.gather(counts, perm)
  num_raw_tokens
= tf.shape(tokens)[0]
  actual_num_tokens
= tf.minimum(max_tokens_per_client, num_raw_tokens)
  selected_tokens
= tokens[:actual_num_tokens]
  paddings
= [[0, max_tokens_per_client - tf.shape(selected_tokens)[0]]]
  padded_tokens
= tf.pad(selected_tokens, paddings=paddings)
 
# Make sure the type is statically determined
  padded_tokens
= tf.reshape(padded_tokens, shape=(max_tokens_per_client,))

 
# We will pass these tokens as keys into `federated_select`, which
 
# requires SELECT_KEY_DTYPE=tf.int32 keys.
  padded_tokens
= tf.cast(padded_tokens, dtype=SELECT_KEY_DTYPE)
 
return padded_tokens, actual_num_tokens
# Simple test

# Case 1: actual_num_tokens > max_tokens_per_client
selected_tokens
, actual_num_tokens = keys_for_client(batched_dataset1, 3)
assert tf.size(selected_tokens) == 3
assert actual_num_tokens == 3

# Case 2: actual_num_tokens < max_tokens_per_client
selected_tokens
, actual_num_tokens = keys_for_client(batched_dataset1, 10)
assert tf.size(selected_tokens) == 10
assert actual_num_tokens == 4

แมปโทเค็นทั่วโลกกับโทเค็นในพื้นที่

การเลือกข้างต้นจะช่วยให้เราเป็นชุดที่มีความหนาแน่นของสัญญาณในช่วง [0, actual_num_tokens) ซึ่งเราจะใช้สำหรับรูปแบบบนอุปกรณ์ แต่ชุดข้อมูลที่เราอ่านมีราชสกุลจากช่วงคำศัพท์ระดับโลกมีขนาดใหญ่มาก [0, WORD_VOCAB_SIZE)

ดังนั้น เราจำเป็นต้องแมปโทเค็นทั่วโลกกับโทเค็นท้องถิ่นที่เกี่ยวข้อง รหัสโทเค็ท้องถิ่นจะได้รับเพียงโดยดัชนีลงใน selected_tokens เมตริกซ์คำนวณได้ในขั้นตอนก่อนหน้า

@tf.function
def map_to_local_token_ids(client_data, client_keys):
  global_to_local
= tf.lookup.StaticHashTable(
     
# Note int32 -> int64 maps are not supported
      tf
.lookup.KeyValueTensorInitializer(
          keys
=tf.cast(client_keys, dtype=TOKEN_DTYPE),
         
# Note we need to use tf.shape, not the static
         
# shape client_keys.shape[0]
          values
=tf.range(0, limit=tf.shape(client_keys)[0],
                          dtype
=TOKEN_DTYPE)),
     
# We use -1 for tokens that were not selected, which can occur for clients
     
# with more than MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens.
     
# We will simply remove these invalid indices from the batch below.
      default_value
=-1)

 
def to_local_ids(sparse_tokens):
    indices_t
= tf.transpose(sparse_tokens.indices)
    batch_indices
= indices_t[0]  # First column
    tokens
= indices_t[1]  # Second column
    tokens
= tf.map_fn(
       
lambda global_token_id: global_to_local.lookup(global_token_id), tokens)
   
# Remove tokens that aren't actually available (looked up as -1):
    available_tokens
= tokens >= 0
    tokens
= tokens[available_tokens]
    batch_indices
= batch_indices[available_tokens]

    updated_indices
= tf.transpose(
        tf
.concat([[batch_indices], [tokens]], axis=0))
    st
= tf.sparse.SparseTensor(
        updated_indices
,
        tf
.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape
=sparse_tokens.dense_shape)
    st
= tf.sparse.reorder(st)
   
return st

 
return client_data.map(lambda b: BatchType(to_local_ids(b.tokens), b.tags))
# Simple test
client_keys
, actual_num_tokens = keys_for_client(
    batched_dataset3
, MAX_TOKENS_SELECTED_PER_CLIENT)
client_keys
= client_keys[:actual_num_tokens]

d
= map_to_local_token_ids(batched_dataset3, client_keys)
batch  
= next(iter(d))
all_tokens
= tf.gather(batch.tokens.indices, indices=1, axis=1)
# Confirm we have local indices in the range [0, MAX):
assert tf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT
assert tf.math.reduce_max(all_tokens) >= 0

ฝึกโมเดล (ย่อย) ในพื้นที่ของลูกค้าแต่ละราย

หมายเหตุ federated_select จะกลับชิ้นที่เลือกเป็น tf.data.Dataset ในลำดับเดียวกับปุ่มเลือก ดังนั้นเราจึงกำหนดฟังก์ชันยูทิลิตี้ก่อนเพื่อนำชุดข้อมูลดังกล่าวมาแปลงเป็นเทนเซอร์หนาแน่นตัวเดียว ซึ่งสามารถใช้เป็นน้ำหนักแบบจำลองของแบบจำลองไคลเอ็นต์ได้

@tf.function
def slices_dataset_to_tensor(slices_dataset):
 
"""Convert a dataset of slices to a tensor."""
 
# Use batching to gather all of the slices into a single tensor.
  d
= slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT,
                           drop_remainder
=False)
  iter_d
= iter(d)
  tensor
= next(iter_d)
 
# Make sure we have consumed everything
  opt
= iter_d.get_next_as_optional()
  tf
.Assert(tf.logical_not(opt.has_value()), data=[''], name='CHECK_EMPTY')
 
return tensor
# Simple test
weights
= np.random.random(
    size
=(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE)).astype(np.float32)
model_slices_as_dataset
= tf.data.Dataset.from_tensor_slices(weights)
weights2
= slices_dataset_to_tensor(model_slices_as_dataset)
np
.testing.assert_array_equal(weights, weights2)

ตอนนี้เรามีส่วนประกอบทั้งหมดที่เราต้องการเพื่อกำหนดลูปการฝึกอบรมในพื้นที่อย่างง่าย ซึ่งจะทำงานบนไคลเอนต์แต่ละราย

@tf.function
def client_train_fn(model, client_optimizer,
                    model_slices_as_dataset
, client_data,
                    client_keys
, actual_num_tokens):

  initial_model_weights
= slices_dataset_to_tensor(model_slices_as_dataset)
 
assert len(model.trainable_variables) == 1
  model
.trainable_variables[0].assign(initial_model_weights)

 
# Only keep the "real" (unpadded) keys.
  client_keys
= client_keys[:actual_num_tokens]

  client_data
= map_to_local_token_ids(client_data, client_keys)

  loss_fn
= tf.keras.losses.BinaryCrossentropy()
 
for features, labels in client_data:
   
with tf.GradientTape() as tape:
      predictions
= model(features)
      loss
= loss_fn(labels, predictions)
    grads
= tape.gradient(loss, model.trainable_variables)
    client_optimizer
.apply_gradients(zip(grads, model.trainable_variables))

  model_weights_delta
= model.trainable_weights[0] - initial_model_weights
  model_weights_delta
= tf.slice(model_weights_delta, begin=[0, 0],
                           size
=[actual_num_tokens, -1])
 
return client_keys, model_weights_delta
# Simple test
# Note if you execute this cell a second time, you need to also re-execute
# the preceeding cell to avoid "tf.function-decorated function tried to
# create variables on non-first call" errors.
on_device_model
= create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                        TAG_VOCAB_SIZE
)
client_optimizer
= tf.keras.optimizers.SGD(learning_rate=0.001)
client_keys
, actual_num_tokens = keys_for_client(
    batched_dataset2
, MAX_TOKENS_SELECTED_PER_CLIENT)

model_slices_as_dataset
= tf.data.Dataset.from_tensor_slices(
    np
.zeros((MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE),
             dtype
=np.float32))

keys
, delta = client_train_fn(
    on_device_model
,
    client_optimizer
,
    model_slices_as_dataset
,
    client_data
=batched_dataset3,
    client_keys
=client_keys,
    actual_num_tokens
=actual_num_tokens)

print(delta)

รวม IndexedSlices

เราใช้ tff.federated_aggregate ที่จะสร้างผลรวมเบาบาง federated สำหรับ IndexedSlices การดำเนินงานที่เรียบง่ายนี้มีข้อ จำกัด ที่ว่า dense_shape เป็นที่รู้จักกันแบบคงที่ล่วงหน้า ยังทราบว่าผลรวมนี้เป็นเพียงกึ่งเบาบางในแง่ที่ว่าไคลเอนต์ -> การสื่อสารเซิร์ฟเวอร์จะเบาบาง แต่เซิร์ฟเวอร์ยังคงเป็นตัวแทนความหนาแน่นของจำนวนเงินใน accumulate และ merge และเอาท์พุทนี้เป็นตัวแทนความหนาแน่นสูง

def federated_indexed_slices_sum(slice_indices, slice_values, dense_shape):
 
"""
  Sumes IndexedSlices@CLIENTS to a dense @SERVER Tensor.

  Intermediate aggregation is performed by converting to a dense representation,
  which may not be suitable for all applications.

  Args:
    slice_indices: An IndexedSlices.indices tensor @CLIENTS.
    slice_values: An IndexedSlices.values tensor @CLIENTS.
    dense_shape: A statically known dense shape.

  Returns:
    A dense tensor placed @SERVER representing the sum of the client's
    IndexedSclies.
  """

  slices_dtype
= slice_values.type_signature.member.dtype
  zero
= tff.tf_computation(
     
lambda: tf.zeros(dense_shape, dtype=slices_dtype))()

 
@tf.function
 
def accumulate_slices(dense, client_value):
    indices
, slices = client_value
   
# There is no built-in way to add `IndexedSlices`, but
   
# tf.convert_to_tensor is a quick way to convert to a dense representation
   
# so we can add them.
   
return dense + tf.convert_to_tensor(
        tf
.IndexedSlices(slices, indices, dense_shape))


 
return tff.federated_aggregate(
     
(slice_indices, slice_values),
      zero
=zero,
      accumulate
=tff.tf_computation(accumulate_slices),
      merge
=tff.tf_computation(lambda d1, d2: tf.add(d1, d2, name='merge')),
      report
=tff.tf_computation(lambda d: d))

สร้างน้อยที่สุด federated_computation เป็นแบบทดสอบ

dense_shape = (6, 2)
indices_type
= tff.TensorType(tf.int64, (None,))
values_type
= tff.TensorType(tf.float32, (None, 2))
client_slice_type
= tff.type_at_clients(
   
(indices_type, values_type))

@tff.federated_computation(client_slice_type)
def test_sum_indexed_slices(indices_values_at_client):
  indices
, values = indices_values_at_client
 
return federated_indexed_slices_sum(indices, values, dense_shape)

print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x = tf.IndexedSlices(
    values
=np.array([[2., 2.1], [0., 0.1], [1., 1.1], [5., 5.1]],
                    dtype
=np.float32),
    indices
=[2, 0, 1, 5],
    dense_shape
=dense_shape)
y
= tf.IndexedSlices(
    values
=np.array([[0., 0.3], [3.1, 3.2]], dtype=np.float32),
    indices
=[1, 3],
    dense_shape
=dense_shape)

# Sum one.
result
= test_sum_indexed_slices([(x.indices, x.values)])
np
.testing.assert_array_equal(tf.convert_to_tensor(x), result)

# Sum two.
expected
= [[0., 0.1], [1., 1.4], [2., 2.1], [3.1, 3.2], [0., 0.], [5., 5.1]]
result
= test_sum_indexed_slices([(x.indices, x.values), (y.indices, y.values)])
np
.testing.assert_array_almost_equal(expected, result)

วางไว้ทั้งหมดเข้าด้วยกันใน federated_computation

ตอนนี้เราใช้ฉิบหายที่จะผูกกันส่วนประกอบเป็น tff.federated_computation

DENSE_MODEL_SHAPE = (WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
client_data_type
= tff.SequenceType(batched_dataset1.element_spec)
model_type
= tff.TensorType(tf.float32, shape=DENSE_MODEL_SHAPE)

เราใช้ฟังก์ชันการฝึกเซิร์ฟเวอร์พื้นฐานโดยยึดตาม Federated Averaging โดยใช้การอัปเดตด้วยอัตราการเรียนรู้เซิร์ฟเวอร์ 1.0 เป็นสิ่งสำคัญที่เราต้องใช้การอัปเดต (เดลต้า) กับโมเดล แทนที่จะเพียงแค่หาค่าเฉลี่ยโมเดลที่ไคลเอ็นต์จัดหา มิฉะนั้น หากไคลเอ็นต์ไม่ได้ฝึกฝนชิ้นส่วนโมเดลที่ระบุในรอบที่กำหนด สัมประสิทธิ์ของโมเดลนั้นอาจเป็นศูนย์ ออก.

@tff.tf_computation
def server_update(current_model_weights, update_sum, num_clients):
  average_update
= update_sum / num_clients
 
return current_model_weights + average_update

เราจำเป็นต้องมีคู่มากขึ้น tff.tf_computation ส่วนประกอบ:

# Function to select slices from the model weights in federated_select:
select_fn
= tff.tf_computation(
   
lambda model_weights, index: tf.gather(model_weights, index))


# We need to wrap `client_train_fn` as a `tff.tf_computation`, making
# sure we do any operations that might construct `tf.Variable`s outside
# of the `tf.function` we are wrapping.
@tff.tf_computation
def client_train_fn_tff(model_slices_as_dataset, client_data, client_keys,
                        actual_num_tokens
):
 
# Note this is amaller than the global model, using
 
# MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE.
 
# W7e would like a model of size `actual_num_tokens`, but we
 
# can't build the model dynamically, so we will slice off the padded
 
# weights at the end.
  client_model
= create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                       TAG_VOCAB_SIZE
)
  client_optimizer
= tf.keras.optimizers.SGD(learning_rate=0.1)
 
return client_train_fn(client_model, client_optimizer,
                         model_slices_as_dataset
, client_data, client_keys,
                         actual_num_tokens
)

@tff.tf_computation
def keys_for_client_tff(client_data):
 
return keys_for_client(client_data, MAX_TOKENS_SELECTED_PER_CLIENT)

ตอนนี้เราพร้อมที่จะรวบรวมชิ้นส่วนทั้งหมดแล้ว!

@tff.federated_computation(
    tff
.type_at_server(model_type), tff.type_at_clients(client_data_type))
def sparse_model_update(server_model, client_data):
  max_tokens
= tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT, tff.SERVER)
  keys_at_clients
, actual_num_tokens = tff.federated_map(
      keys_for_client_tff
, client_data)

  model_slices
= tff.federated_select(keys_at_clients, max_tokens, server_model,
                                      select_fn
)

  update_keys
, update_slices = tff.federated_map(
      client_train_fn_tff
,
     
(model_slices, client_data, keys_at_clients, actual_num_tokens))

  dense_update_sum
= federated_indexed_slices_sum(update_keys, update_slices,
                                                  DENSE_MODEL_SHAPE
)
  num_clients
= tff.federated_sum(tff.federated_value(1.0, tff.CLIENTS))

  updated_server_model
= tff.federated_map(
      server_update
, (server_model, dense_update_sum, num_clients))

 
return updated_server_model


print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)

มาฝึกโมเดลกันเถอะ!

ตอนนี้เรามีฟังก์ชั่นการฝึกแล้ว มาลองดูกัน

server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
server_model
.compile(  # Compile to make evaluation easy.
    optimizer
=tf.keras.optimizers.Adagrad(learning_rate=0.0),  # Unused
    loss
=tf.keras.losses.BinaryCrossentropy(),
    metrics
=[
      tf
.keras.metrics.Precision(name='precision'),
      tf
.keras.metrics.AUC(name='auc'),
      tf
.keras.metrics.Recall(top_k=2, name='recall_at_2'),
 
])

def evaluate(model, dataset, name):
  metrics
= model.evaluate(dataset, verbose=0)
  metrics_str
= ', '.join([f'{k}={v:.2f}' for k, v in
                         
(zip(server_model.metrics_names, metrics))])
 
print(f'{name}: {metrics_str}')
print('Before training')
evaluate
(server_model, batched_dataset1, 'Client 1')
evaluate
(server_model, batched_dataset2, 'Client 2')
evaluate
(server_model, batched_dataset3, 'Client 3')

model_weights
= server_model.trainable_weights[0]

client_datasets
= [batched_dataset1, batched_dataset2, batched_dataset3]
for _ in range(10):  # Run 10 rounds of FedAvg
 
# We train on 1, 2, or 3 clients per round, selecting
 
# randomly.
  cohort_size
= np.random.randint(1, 4)
  clients
= np.random.choice([0, 1, 2], cohort_size, replace=False)
 
print('Training on clients', clients)
  model_weights
= sparse_model_update(
      model_weights
, [client_datasets[i] for i in clients])
server_model
.set_weights([model_weights])

print('After training')
evaluate
(server_model, batched_dataset1, 'Client 1')
evaluate
(server_model, batched_dataset2, 'Client 2')
evaluate
(server_model, batched_dataset3, 'Client 3')
Before training
Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60
Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50
Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40
Training on clients [0 1]
Training on clients [0 2 1]
Training on clients [2 0]
Training on clients [1 0 2]
Training on clients [2]
Training on clients [2 0]
Training on clients [1 2 0]
Training on clients [0]
Training on clients [2]
Training on clients [1 2]
After training
Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80
Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00
Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80